项目介绍:nanodl
概述
nanodl 是一个基于 Jax 的库,专为从头开始设计和训练 transformer 模型而开发。在变换器模型的开发过程中,人们通常会遇到资源消耗巨大的挑战。为了应对这一问题,nanodl 提供了一系列精简但功能强大的工具和功能。这些工具通过 Jax 框架实现,使得神经网络的开发和分布式训练更加高效。
特性
nanodl 提供了许多强大的特性,包括:
- 兼具各类模块和层级结构,使用户可以从零开始构建定制的 transformer 模型。
- 支持多种模型,如 Gemma、LlaMa3、Mistral、GPT3、GPT4(推测版)、T5、Whisper、ViT、Mixers、CLIP 等。
- 提供数据并行分布式训练器,可以在多 GPU 或 TPU 上进行训练,无需手动编写训练循环。
- 提供数据加载器,简化了 Jax/Flax 的数据处理过程。
- 包含 Flax/Jax 中未涵盖的层,例如 RoPE、GQA、MQA 和 SWin 注意力等,支持更灵活的模型开发。
- 提供 GPU/TPU 加速的经典机器学习模型,如 PCA、KMeans、回归、高斯过程等。
- 支持不需要冗长代码的真随机数生成器。
- 提供一系列用于自然语言处理和计算机视觉任务的高级算法,如高斯模糊、BLEU、分词器等。
- 每个模型都被封装在一个独立的文件中,无需外部依赖,因此源码也可以轻松使用。
快速安装
要使用 nanodl,您需要 Python 3.9 或更高版本,并安装 JAX、FLAX 和 OPTAX(若需要运行训练则需 GPU 支持)。在设计和测试模型时可以使用 CPU,但训练器需要 GPU 或 TPU 支持。
使用以下命令来安装 nanodl:
pip install --upgrade pip
pip install jax flax optax
pip install nanodl
使用示例
nanodl 提供了丰富的 API 示例,帮助用户快速上手。以下简要介绍几个使用场景:
文本生成
用户可以使用 GPT4 模型进行文本生成,通过一个简单的 API 来训练和生成文本。
图像生成
借助扩散模型,用户能够训练图像数据并生成新图像。
音频处理
通过 Whisper 模型,用户可以处理音频数据并进行转录。
强化学习中的奖励模型
使用 Mistral 模型,可以在强化学习的上下文中训练奖励模型。
主成分分析(PCA)
用户可以使用 PCA 模型来进行数据降维和可视化。
社区与贡献
nanodl 是一个仍在开发中的项目,欢迎来自社区的贡献。用户可以通过撰写文档、修复错误、实现新功能或者改进代码等多种方式进行贡献。加入我们的 Discord 社区来交流想法或获取帮助。
赞助与引用
为了让有限资源的专家和企业也能以低成本构建灵活的模型,nanodl 的长远目标是构建和训练参数不超过 10 亿的精简版本模型,并优能与原始版本媲美。任何形式的赞助对我们来说都是一份支持。如果您愿意,请通过 GitHub 或邮件联系我们。
如果您希望引用 nanodl,请使用以下格式:
@software{nanodl2024github,
author = {Henry Ndubuaku},
title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},
url = {http://github.com/hmunachi/nanodl},
year = {2024},
}