Merlin数据加载器
merlin-dataloader让您能够快速训练适用于TensorFlow、PyTorch和JAX的推荐模型。它通过提供GPU优化的数据加载器来消除训练推荐模型的最大瓶颈,这些加载器可以直接将数据读取到GPU中,然后使用dlpack进行零拷贝传输到TensorFlow和PyTorch。
Merlin数据加载器的优势包括:
- 比原生框架数据加载器快10倍以上
- 可处理大于内存的数据集
- 每个epoch进行洗牌
- 支持分布式训练
安装
Merlin-dataloader需要Python 3.7+版本。此外,GPU支持需要CUDA 11.0+。
使用Conda安装:
conda install -c nvidia -c rapidsai -c numba -c conda-forge merlin-dataloader python=3.7 cudatoolkit=11.2
从PyPi安装:
pip install merlin-dataloader
在NGC上也有Docker容器,其中包含merlin-dataloader及其依赖项
基本用法
# 从一组parquet文件获取merlin数据集
import merlin.io
dataset = merlin.io.Dataset(PARQUET_FILE_PATHS, engine="parquet")
# 从数据集创建Tensorflow数据加载器,每批加载65K个项目
from merlin.dataloader.tensorflow import Loader
loader = Loader(dataset, batch_size=65536)
# 获取单批数据。输入将是一个字典,键为列名,值为TensorFlow张量
inputs, target = next(loader)
# 使用数据加载器训练Keras模型
model = tf.keras.Model( ... )
model.fit(loader, epochs=5)