项目概述
TFRecord 项目是一个用于在 Python 环境中高效读取和写入 TFRecord 文件的库。该库还为 PyTorch 提供了一个 IterableDataset 类型的 TFRecord 文件阅读器。目前,项目支持未压缩和 gzip 压缩的 TFRecord 文件。
安装指南
可以通过以下命令安装 TFRecord 库:
pip3 install 'tfrecord[torch]'
使用说明
推荐做法
为了避免在多进程加载时数据重复,推荐为每个 TFRecord 文件创建一个索引文件。可以使用以下工具程序为单个 TFRecord 文件创建索引文件:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord 路径> <索引路径>
若需为目录中的所有 ".tfrecord" 文件创建 ".tfidx" 文件,可以运行:
tfrecord2idx <数据目录>
读写 tf.train.Example
在 PyTorch 中读取 tf.Example 记录
可以使用 TFRecordDataset
类读取 TFRecord 文件:
import torch
from tfrecord.torch.dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
使用 MultiTFRecordDataset
类可以读取多个 TFRecord 文件。这个类可以根据设置的概率从给定的 TFRecord 文件中进行采样:
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1": 0.8,
"dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
无限和有限的 PyTorch 数据集
默认情况下,MultiTFRecordDataset
是无限的,这意味着数据采样会一直进行。若需设置为有限数据集,可以使用以下参数:
dataset = MultiTFRecordDataset(..., infinite=False)
数据随机打乱
提供队列大小参数时,TFRecordDataset
和 MultiTFRecordDataset
会自动随机打乱数据:
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
变换输入数据
可以通过传递一个函数作为 transform
参数,以实现特征的后处理。这通常用于图像解码、颜色规范化或填充可变长度的序列等操作:
import tfrecord
import cv2
def decode_image(features):
# 从字节中获取 BGR 图像
features["image"] = cv2.imdecode(features["image"], -1)
return features
description = {
"image": "bytes",
}
dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=description,
transform=decode_image)
data = next(iter(dataset))
print(data)
在 Python 中写入 tf.Example 记录
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({
"image": (image_bytes, "byte"),
"label": (label, "float"),
"index": (index, "int")
})
writer.close()
在 Python 中读取 tf.Example 记录
import tfrecord
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {
"image": "byte",
"label": "float",
"index": "int"
})
for record in loader:
print(record["label"])
读写 tf.train.SequenceExample
向文件写入 SequenceExamples
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()
在 Python 中读取 SequenceExamples
读取 SequenceExample 会返回包含两个元素的元组:
import tfrecord
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None,
context_description,
sequence_description=sequence_description)
for context, sequence_feats in loader:
print(context["label"])
print(sequence_feats["seq_labels"])
在 PyTorch 中读取 SequenceExamples
可以通过 transform
参数传递一个函数来对特征进行后处理,这种方式尤其适用于需要填充的序列特征:
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats(data):
context, features = data
for k, v in features.items():
features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
return (context, features)
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
也可以选择实现自定义的 collate_fn
来组装批次数据,例如进行动态填充:
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
def collate_fn(batch):
from torch.utils.data._utils import collate
from torch.nn.utils import rnn
context, feats = zip(*batch)
feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
return (collate.default_collate(context),
{k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)
这个项目旨在简化使用 TFRecord 文件的流程,并为用户提供灵活的接口来处理多种数据类型和场景,是处理庞大数据集的理想工具。