Project Icon

petastorm

开源数据访问库,支持单机或分布式训练和评估深度学习模型,直接从Apache Parquet格式数据集中读取数据

Petastorm是一个开源数据访问库,支持单机或分布式训练和评估深度学习模型,直接从Apache Parquet格式数据集中读取数据。该库兼容Tensorflow、PyTorch和PySpark等主流Python机器学习框架,也可用于纯Python代码。Petastorm支持多种数据压缩格式,提供方便的API用于数据生成和读取,并支持列选择、并行读取、行过滤等功能。用户可以轻松在单机或Spark集群上生成数据集,是构建高效机器学习管道的理想工具。

项目介绍

Petastorm 是一个由 Uber ATG 开发的开源数据访问库。这个库的主要功能是在单机或者分布式环境中,直接从 Apache Parquet 格式的数据集中进行深度学习模型的训练和评估。它支持多种流行的基于 Python 的机器学习框架,如 TensorFlow、PyTorch 和 PySpark,也可以被纯 Python 代码使用。

安装方法

Petastorm 可以通过 Python 的 pip 工具进行安装:

pip install petastorm

用户可以根据需要选择安装其他附加依赖项,例如,为安装 GPU 版本的 TensorFlow 和 opencv:

pip install petastorm[opencv,tf_gpu]

生成数据集

使用 Petastorm 创建的数据集存储在 Apache Parquet 格式中。Petastorm 在 Parquet 架构之上还储存更高层次的架构信息,使多维数组成为其数据集的一部分。用户可以通过 PySpark 生成数据集,因为 PySpark 对 Parquet 格式的原生支持,让其无论是在单机还是 Spark 计算集群上运行都非常简单。

以下是一个使用 PySpark 生成随机数据表的简单示例:

import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.types import IntegerType

from petastorm.codecs import ScalarCodec, CompressedImageCodec, NdarrayCodec
from petastorm.etl.dataset_metadata import materialize_dataset
from petastorm.unischema import dict_to_spark_row, Unischema, UnischemaField

HelloWorldSchema = Unischema('HelloWorldSchema', [
    UnischemaField('id', np.int32, (), ScalarCodec(IntegerType()), False),
    UnischemaField('image1', np.uint8, (128, 256, 3), CompressedImageCodec('png'), False),
    UnischemaField('array_4d', np.uint8, (None, 128, 30, None), NdarrayCodec(), False),
])

def row_generator(x):
    return {'id': x,
            'image1': np.random.randint(0, 255, dtype=np.uint8, size=(128, 256, 3)),
            'array_4d': np.random.randint(0, 255, dtype=np.uint8, size=(4, 128, 30, 3))}

def generate_petastorm_dataset(output_url='file:///tmp/hello_world_dataset'):
    rowgroup_size_mb = 256
    spark = SparkSession.builder.config('spark.driver.memory', '2g').master('local[2]').getOrCreate()
    sc = spark.sparkContext
    rows_count = 10
    with materialize_dataset(spark, output_url, HelloWorldSchema, rowgroup_size_mb):
        rows_rdd = sc.parallelize(range(rows_count))\
            .map(row_generator)\
            .map(lambda x: dict_to_spark_row(HelloWorldSchema, x))
        spark.createDataFrame(rows_rdd, HelloWorldSchema.as_spark_schema()) \
            .coalesce(10) \
            .write \
            .mode('overwrite') \
            .parquet(output_url)

此过程通过定义数据模型架构为一个 Unischema 实例,将数据字段转化为适合其目标框架的格式。采用 PySpark 来写入 Parquet 文件,并通过上下文管理器 materialize_dataset 来完成数据集的生成和保存。

直接在 Python 中使用

petastorm.reader.Reader 类是用户从 TensorFlow 或 PyTorch 等机器学习框架访问数据的主要入口。它具备如下特性:

  • 选择性读取列
  • 多种并行策略:线程、进程和单线程(用于调试)
  • 支持 N-grams 读取
  • 行过滤
  • 随机打乱
  • 多 GPU 训练的分区
  • 本地缓存

使用 petastorm.make_reader 工厂方法创建 Reader 非常简单:

from petastorm import make_reader

with make_reader('hdfs://myhadoop/some_dataset') as reader:
   for row in reader:
       print(row)

Reader 一旦被实例化,可以作为迭代器使用。

TensorFlow 和 PyTorch API 接入

Petastorm 提供了与 TensorFlow 和 PyTorch 的接入方法。

TensorFlow

使用 tf_tensors 函数将数据读入 TensorFlow 图:

from petastorm.tf_utils import tf_tensors

with make_reader('file:///some/localpath/a_dataset') as reader:
    row_tensors = tf_tensors(reader)
    with tf.Session() as session:
        for _ in range(3):
            print(session.run(row_tensors))

也可以使用 tf.data.Dataset API:

from petastorm.tf_utils import make_petastorm_dataset

with make_reader('file:///some/localpath/a_dataset') as reader:
    dataset = make_petastorm_dataset(reader)
    iterator = dataset.make_one_shot_iterator()
    tensor = iterator.get_next()
    with tf.Session() as sess:
        sample = sess.run(tensor)
        print(sample.id)

PyTorch

Petastorm 使用 petastorm.pytorch.DataLoader 进行数据读取,它支持自定义的 Pytorch 聚合函数和转换:

import torch
from petastorm.pytorch import DataLoader

torch.manual_seed(1)
device = torch.device('cpu')
model = Net().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

def _transform_row(mnist_row):
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    return (transform(mnist_row['image']), mnist_row['digit'])

transform = TransformSpec(_transform_row, removed_fields=['idx'])

with DataLoader(make_reader('file:///localpath/mnist/train', num_epochs=10,
                            transform_spec=transform, seed=1, shuffle_rows=True), batch_size=64) as train_loader:
    train(model, device, train_loader, 10, optimizer, 1)
with DataLoader(make_reader('file:///localpath/mnist/test', num_epochs=10,
                            transform_spec=transform), batch_size=1000) as test_loader:
    test(model, device, test_loader)

Spark 数据集转换 API

Spark 数据集转换 API 简化了从 Spark 到 TensorFlow 或 PyTorch 的数据转换过程。输入的 Spark DataFrame 首先以 Parquet 格式存储,然后加载为 tf.data.Datasettorch.utils.data.DataLoader

from petastorm.spark import SparkDatasetConverter, make_spark_converter

spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, 'hdfs:/...')
df = ... # `df` 是 Spark 数据帧

# 创建转换器
converter = make_spark_converter(df)

# 从转换器生成 tensorflow 数据集
with converter.make_tf_dataset() as dataset:
    dataset = dataset.map(...)
    model.fit(dataset)
    
converter.delete()

使用 PySpark 和 SQL 分析 Petastorm 数据集

Petastorm 数据集可以被加载到 Spark DataFrame 中,使用 Spark 的工具进行数据集的分析和处理。通过 PySpark 你可以使用如下代码:

dataframe = spark.read.parquet(dataset_url)
dataframe.printSchema()
dataframe.count()
dataframe.select('id').show()

spark.sql('SELECT count(id) from parquet.`file:///tmp/hello_world_dataset`').collect()

贡献与支持

Petastorm 欢迎社区通过 GitHub 提交补丁和提供改进建议。开发者可以在项目的 GitHub 页面找到相关说明。

综上所述,Petastorm 是一个非常灵活且功能强大的数据访问库,适合在各种数据密集型的深度学习项目中使用,尤其是在需要在大数据集上进行高效训练和推理的场景中。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

稿定AI

稿定设计 是一个多功能的在线设计和创意平台,提供广泛的设计工具和资源,以满足不同用户的需求。从专业的图形设计师到普通用户,无论是进行图片处理、智能抠图、H5页面制作还是视频剪辑,稿定设计都能提供简单、高效的解决方案。该平台以其用户友好的界面和强大的功能集合,帮助用户轻松实现创意设计。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号