Project Icon

tfrecord

允许在 python 中有效地读取和写入 tfrecord 文件

该库在Python中提供了高效读取和写入TFRecord文件的方法,并为PyTorch提供了可迭代的数据集读取器。支持无压缩和gzip压缩的TFRecord文件,通过创建索引文件可以避免多线程重复记录。用户还能使用transform函数进行特征后处理,如解码图像和归一化颜色范围。该库简化了多文件读取和顺序数据处理流程。

TFRecord 读写器

该库允许在 Python 中高效地读写 tfrecord 文件。该库还为 PyTorch 提供了 tfrecord 文件的可迭代数据集读取器。目前支持未压缩和 gzip 压缩的 TFRecords。

安装

pip3 install 'tfrecord[torch]'

使用方法

建议为每个 TFRecord 文件创建一个索引文件。使用多个工作进程时必须提供索引文件,否则加载器可能会返回重复的记录。您可以使用以下实用程序为单个 tfrecord 文件创建索引文件:

python3 -m tfrecord.tools.tfrecord2idx <tfrecord 路径> <索引路径>

要为目录中所有的 ".tfrecord" 文件创建 ".tfidnex" 文件,请运行:

tfrecord2idx <数据目录>

读写 tf.train.Example

在 PyTorch 中读取 tf.Example 记录

使用 TFRecordDataset 在 PyTorch 中读取 TFRecord 文件。

import torch
from tfrecord.torch.dataset import TFRecordDataset

tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

使用 MultiTFRecordDataset 读取多个 TFRecord 文件。此类根据给定的概率从给定的 tfrecord 文件中采样。

import torch
from tfrecord.torch.dataset import MultiTFRecordDataset

tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
    "dataset1": 0.8,
    "dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)

data = next(iter(loader))
print(data)

无限和有限 PyTorch 数据集

默认情况下,MultiTFRecordDataset 是无限的,这意味着它会永远采样数据。您可以通过提供适当的标志使其变为有限:

dataset = MultiTFRecordDataset(..., infinite=False)

数据洗牌

当您提供队列大小时,TFRecordDataset 和 MultiTFRecordDataset 都会自动对数据进行洗牌。

dataset = TFRecordDataset(..., shuffle_queue_size=1024)

转换输入数据

您可以选择将函数作为 transform 参数传递,以在返回之前对特征进行后处理。 这可以用于解码图像、将颜色归一化到特定范围或填充可变长度序列。

import tfrecord
import cv2

def decode_image(features):
    # 从字节获取 BGR 图像
    features["image"] = cv2.imdecode(features["image"], -1)
    return features


description = {
    "image": "bytes",
}

dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
                                         index_path=None,
                                         description=description,
                                         transform=decode_image)

data = next(iter(dataset))
print(data)

在 Python 中写入 tf.Example 记录

import tfrecord

writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({
    "image": (image_bytes, "byte"),
    "label": (label, "float"),
    "index": (index, "int")
})
writer.close()

在 Python 中读取 tf.Example 记录

import tfrecord

loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {
    "image": "byte",
    "label": "float",
    "index": "int"
})
for record in loader:
    print(record["label"])

读写 tf.train.SequenceExample

SequenceExample 可以使用上面显示的相同方法进行读写,只需额外添加一个参数 (读取时为 sequence_description,写入时为 sequence_datum),这会导致相应的 读/写函数将数据视为 SequenceExample。

将 SequenceExample 写入文件

import tfrecord

writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
             {'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
             {'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()

在 Python 中读取 SequenceExample

从 SequenceExample 读取会产生一个包含两个元素的元组。

import tfrecord

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None,
                                  context_description,
                                  sequence_description=sequence_description)

for context, sequence_feats in loader:
    print(context["label"])
    print(sequence_feats["seq_labels"])

在 PyTorch 中读取 SequenceExample

如"转换输入"部分所述,可以将函数作为 transform 参数传递,以对特征进行后处理。这对于序列特征尤其有用,因为它们是可变长度序列,需要在批处理之前进行填充。

import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

PAD_WIDTH = 5
def pad_sequence_feats(data):
    context, features = data
    for k, v in features.items():
        features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
    return (context, features)

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
                          index_path=None,
                          description=context_description,
                          transform=pad_sequence_feats,
                          sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)

或者,您可以选择实现自定义的 collate_fn 来组装批次,例如,执行动态填充。

import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset

def collate_fn(batch):
    from torch.utils.data._utils import collate
    from torch.nn.utils import rnn
    context, feats = zip(*batch)
    feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
    return (collate.default_collate(context),
            {k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})

context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
                          index_path=None,
                          description=context_description,
                          transform=pad_sequence_feats,
                          sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

稿定AI

稿定设计 是一个多功能的在线设计和创意平台,提供广泛的设计工具和资源,以满足不同用户的需求。从专业的图形设计师到普通用户,无论是进行图片处理、智能抠图、H5页面制作还是视频剪辑,稿定设计都能提供简单、高效的解决方案。该平台以其用户友好的界面和强大的功能集合,帮助用户轻松实现创意设计。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号