TFRecord 读写器
该库允许在 Python 中高效地读写 tfrecord 文件。该库还为 PyTorch 提供了 tfrecord 文件的可迭代数据集读取器。目前支持未压缩和 gzip 压缩的 TFRecords。
安装
pip3 install 'tfrecord[torch]'
使用方法
建议为每个 TFRecord 文件创建一个索引文件。使用多个工作进程时必须提供索引文件,否则加载器可能会返回重复的记录。您可以使用以下实用程序为单个 tfrecord 文件创建索引文件:
python3 -m tfrecord.tools.tfrecord2idx <tfrecord 路径> <索引路径>
要为目录中所有的 ".tfrecord" 文件创建 ".tfidnex" 文件,请运行:
tfrecord2idx <数据目录>
读写 tf.train.Example
在 PyTorch 中读取 tf.Example 记录
使用 TFRecordDataset 在 PyTorch 中读取 TFRecord 文件。
import torch
from tfrecord.torch.dataset import TFRecordDataset
tfrecord_path = "/tmp/data.tfrecord"
index_path = None
description = {"image": "byte", "label": "float"}
dataset = TFRecordDataset(tfrecord_path, index_path, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
使用 MultiTFRecordDataset 读取多个 TFRecord 文件。此类根据给定的概率从给定的 tfrecord 文件中采样。
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset
tfrecord_pattern = "/tmp/{}.tfrecord"
index_pattern = "/tmp/{}.index"
splits = {
"dataset1": 0.8,
"dataset2": 0.2,
}
description = {"image": "byte", "label": "int"}
dataset = MultiTFRecordDataset(tfrecord_pattern, index_pattern, splits, description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
无限和有限 PyTorch 数据集
默认情况下,MultiTFRecordDataset
是无限的,这意味着它会永远采样数据。您可以通过提供适当的标志使其变为有限:
dataset = MultiTFRecordDataset(..., infinite=False)
数据洗牌
当您提供队列大小时,TFRecordDataset 和 MultiTFRecordDataset 都会自动对数据进行洗牌。
dataset = TFRecordDataset(..., shuffle_queue_size=1024)
转换输入数据
您可以选择将函数作为 transform
参数传递,以在返回之前对特征进行后处理。
这可以用于解码图像、将颜色归一化到特定范围或填充可变长度序列。
import tfrecord
import cv2
def decode_image(features):
# 从字节获取 BGR 图像
features["image"] = cv2.imdecode(features["image"], -1)
return features
description = {
"image": "bytes",
}
dataset = tfrecord.torch.TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=description,
transform=decode_image)
data = next(iter(dataset))
print(data)
在 Python 中写入 tf.Example 记录
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({
"image": (image_bytes, "byte"),
"label": (label, "float"),
"index": (index, "int")
})
writer.close()
在 Python 中读取 tf.Example 记录
import tfrecord
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None, {
"image": "byte",
"label": "float",
"index": "int"
})
for record in loader:
print(record["label"])
读写 tf.train.SequenceExample
SequenceExample 可以使用上面显示的相同方法进行读写,只需额外添加一个参数
(读取时为 sequence_description
,写入时为 sequence_datum
),这会导致相应的
读/写函数将数据视为 SequenceExample。
将 SequenceExample 写入文件
import tfrecord
writer = tfrecord.TFRecordWriter("/tmp/data.tfrecord")
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [0, 1, 0], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1, 1], 'int')})
writer.write({'length': (3, 'int'), 'label': (1, 'int')},
{'tokens': ([[0, 0, 1], [1, 0, 0]], 'int'), 'seq_labels': ([0, 1], 'int')})
writer.close()
在 Python 中读取 SequenceExample
从 SequenceExample 读取会产生一个包含两个元素的元组。
import tfrecord
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int", "seq_labels": "int"}
loader = tfrecord.tfrecord_loader("/tmp/data.tfrecord", None,
context_description,
sequence_description=sequence_description)
for context, sequence_feats in loader:
print(context["label"])
print(sequence_feats["seq_labels"])
在 PyTorch 中读取 SequenceExample
如"转换输入"部分所述,可以将函数作为 transform
参数传递,以对特征进行后处理。这对于序列特征尤其有用,因为它们是可变长度序列,需要在批处理之前进行填充。
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
PAD_WIDTH = 5
def pad_sequence_feats(data):
context, features = data
for k, v in features.items():
features[k] = np.pad(v, ((0, PAD_WIDTH - len(v)), (0, 0)), 'constant')
return (context, features)
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32)
data = next(iter(loader))
print(data)
或者,您可以选择实现自定义的 collate_fn
来组装批次,例如,执行动态填充。
import torch
import numpy as np
from tfrecord.torch.dataset import TFRecordDataset
def collate_fn(batch):
from torch.utils.data._utils import collate
from torch.nn.utils import rnn
context, feats = zip(*batch)
feats_ = {k: [torch.Tensor(d[k]) for d in feats] for k in feats[0]}
return (collate.default_collate(context),
{k: rnn.pad_sequence(f, True) for (k, f) in feats_.items()})
context_description = {"length": "int", "label": "int"}
sequence_description = {"tokens": "int ", "seq_labels": "int"}
dataset = TFRecordDataset("/tmp/data.tfrecord",
index_path=None,
description=context_description,
transform=pad_sequence_feats,
sequence_description=sequence_description)
loader = torch.utils.data.DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
data = next(iter(loader))
print(data)