Merlin Dataloader 项目介绍
Merlin Dataloader 是一个用于加速推荐模型训练的工具,支持 TensorFlow、PyTorch 和 JAX。该项目的主要目标是解决推荐模型在训练过程中遇到的数据加载瓶颈。通过提供 GPU 优化的数据加载器,Merlin Dataloader 可以直接将数据读取到 GPU,并使用 dlpack 无复制地传输到 TensorFlow 和 PyTorch。
Merlin Dataloader 的优点
Merlin Dataloader 的使用带来了显著的性能提升和功能扩展:
- 速度提升:相较于原生框架的数据加载器,速度提高超过十倍,这显著缩短了模型训练的时间。
- 处理超大数据集:能够有效处理那些内存无法容纳的大型数据集,使得大规模数据训练成为可能。
- 每轮洗牌:支持每个训练周期的数据洗牌,以增强模型的泛化能力。
- 分布式训练:能够在分布式环境下高效运行,适应现代大规模机器学习任务的需求。
安装指南
Merlin Dataloader 需要 Python 3.7 及以上版本。若需 GPU 支持,还需安装 CUDA 11.0 或更高版本。以下是安装方法:
-
使用 Conda 安装:
conda install -c nvidia -c rapidsai -c numba -c conda-forge merlin-dataloader python=3.7 cudatoolkit=11.2
-
使用 PyPi 安装:
pip install merlin-dataloader
此外,NGC 上提供了包含 merlin-dataloader 及其依赖的 Docker 容器,便于快速搭建环境。
基本用法
以下是一个简单的使用样例:
# 从一组 parquet 文件中获取 merlin 数据集
import merlin.io
dataset = merlin.io.Dataset(PARQUET_FILE_PATHS, engine="parquet")
# 从该数据集中创建 TensorFlow 数据加载器,每批加载 65K 条数据
from merlin.dataloader.tensorflow import Loader
loader = Loader(dataset, batch_size=65536)
# 获取一批数据。输入将是一个从列名到 TensorFlow 张量的字典
inputs, target = next(loader)
# 使用数据加载器训练 Keras 模型
model = tf.keras.Model( ... )
model.fit(loader, epochs=5)
Merlin Dataloader 提供的这些功能和特性,使得在处理大规模、复杂的推荐系统训练任务时变得更加轻松和高效。对于需要快速迭代和处理大数据集的推荐系统开发者而言,这无疑是一个强有力的工具。