基于Jax的从零开始设计和训练变压器模型的库。
作者:Henry Ndubuaku(点击 Discord 和 Docs 徽章可以跳转)
N/B:代码以教学为主,重复性较高。 每个模型都独立放置在一个文件中,没有文件间的依赖关系。
概述
开发和训练基于变压器的模型通常需要大量资源和时间,而AI/ML专家们常常需要为特定问题构建规模较小的版本。Jax是一个资源低但功能强大的框架,加速了神经网络的开发并简化了分布式训练,但现有的用Jax开发变压器的资源有限。NanoDL解决了这个问题,具有以下特性:
- 提供丰富的模块和层,方便从零开始创建定制的变压器模型。
- 提供丰富的模型选择,如Gemma、LlaMa3、Mistral、GPT3、GPT4(推测)、T5、Whisper、ViT、Mixers、CLIP等。
- 数据并行分布式训练器模型可在多个GPU或TPU上运行,无需手动训练循环。
- 数据加载器,简化了Jax/Flax的数据处理过程,使其更高效。
- 提供Flax/Jax中未提供的层,如RoPE、GQA、MQA和SWin注意力机制,使模型开发更灵活。
- GPU/TPU加速的经典机器学习模型,如PCA、KMeans、回归、高斯过程等。
- Jax中的真正随机数生成器,不需要冗长的代码。
- 提供一系列高级算法用于NLP和计算机视觉任务,如高斯模糊、BLEU、分词器等。
- 每个模型都独立放置在一个文件中,没有外部依赖,因此源代码也易于使用。
- Jax中的真正随机数生成器,不需要冗长的代码(在后续部分中展示示例)。
有些实验性和/或未完成的功能(如MAMBA、KAN、BitNet、GAT和RLHF)在仓库中,但尚未通过软件包提供,可以从该仓库复制。 欢迎在我们的讨论、问题和 pull request 线程中提供反馈!请在 Discord 上报告任何功能请求、问题、疑问或关注的内容,或只是告诉我们你正在进行的工作!
快速安装
您需要 Python 3.9 或更高版本,以及能够正常工作的JAX安装,FLAX安装,OPTAX安装(运行训练需要 GPU 支持,否则仅支持创建)。 模型可以在 CPU 上设计和测试,但训练器都是数据并行分布式的,需要一个或多个 GPU/TPUS。如果仅需要JAX的CPU版本:
pip install --upgrade pip # 支持 manylinux2010 轮子。
pip install jax flax optax
然后,从 PyPi 安装 nanodl:
pip install nanodl
nanodl 的样子是什么样子的?
我们提供了 nanodl API 的各种示例用法。
import jax
import nanodl
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import GPT4, GPTDataParallelTrainer
# 准备数据集
batch_size = 8
max_length = 50
vocab_size = 1000
# 创建随机数据
data = nanodl.uniform(
shape=(batch_size, max_length),
minval=0, maxval=vocab_size-1
).astype(jnp.int32)
# 移动以创建下一令牌预测数据集
dummy_inputs, dummy_targets = data[:, :-1], data[:, 1:]
# 创建数据集和数据加载器
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(
dataset, batch_size=batch_size, shuffle=True, drop_last=False
)
# 模型参数
hyperparams = {
'num_layers': 1,
'hidden_dim': 256,
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': vocab_size,
'embed_dim': 256,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
}
# 推测的 GPT4 模型
model = GPT4(**hyperparams)
trainer = GPTDataParallelTrainer(
model, dummy_inputs.shape, 'params.pkl'
)
trainer.train(
train_loader=dataloader, num_epochs=100, val_loader=dataloader
) # 使用实际验证数据
# 从起始令牌生成
start_tokens = jnp.array([[123, 456]])
# 记住加载训练参数
params = trainer.load_params('params.pkl')
outputs = model.apply(
{'params': params},
start_tokens,
rngs={'dropout': nanodl.time_rng_key()},
method=model.generate
)
视觉示例
import nanodl
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import DiffusionModel, DiffusionDataParallelTrainer
image_size = 32
block_depth = 2
batch_size = 8
widths = [32, 64, 128]
input_shape = (101, image_size, image_size, 3)
images = nanodl.normal(shape=input_shape)
# 使用你自己的图像
dataset = ArrayDataset(images)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# 创建扩散模型
diffusion_model = DiffusionModel(image_size, widths, block_depth)
# 在你的数据上训练
trainer = DiffusionDataParallelTrainer(diffusion_model,
input_shape=images.shape,
weights_filename='params.pkl',
learning_rate=1e-4)
trainer.train(dataloader, 10)
# 生成一些样本:每个模型都是 Flax.linen 模块
# 通常使用
params = trainer.load_params('params.pkl')
generated_images = diffusion_model.apply({'params': params},
num_images=5,
diffusion_steps=5,
method=diffusion_model.generate)
音频示例
import jax
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Whisper, WhisperDataParallelTrainer
# 虚拟数据参数
batch_size = 8
max_length = 50
embed_dim = 256
vocab_size = 1000
# 生成数据:替换为实际的分词/量化数据
dummy_targets = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_inputs = jnp.ones((101, max_length, embed_dim))
dataset = ArrayDataset(dummy_inputs, dummy_targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# 模型参数
hyperparams = {
'num_layers': 1,
'hidden_dim': 256,
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': 1000,
'embed_dim': embed_dim,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
}
# 初始化模型
model = Whisper(**hyperparams)
# 在你的数据上训练
trainer = WhisperDataParallelTrainer(model,
dummy_inputs.shape,
dummy_targets.shape,
'params.pkl')
trainer.train(dataloader, 2, dataloader)
# 采样推理
params = trainer.load_params('params.pkl')
# 对于多个样本,通常使用 model.generate_batch
transcripts = model.apply({'params': params},
dummy_inputs[:1],
method=model.generate)
RLHF 的奖励模型示例
import nanodl
import jax.numpy as jnp
from nanodl import ArrayDataset, DataLoader
from nanodl import Mistral, RewardModel, RewardDataParallelTrainer
# 生成虚拟数据
batch_size = 8
max_length = 10
# 替换为实际分词数据
dummy_chosen = jnp.ones((101, max_length), dtype=jnp.int32)
dummy_rejected = jnp.zeros((101, max_length), dtype=jnp.int32)
# 创建数据集和数据加载器
dataset = ArrayDataset(dummy_chosen, dummy_rejected)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=False)
# 模型参数
hyperparams = {
'num_layers': 1,
'hidden_dim': 256,
'num_heads': 2,
'feedforward_dim': 256,
'dropout': 0.1,
'vocab_size': 1000,
'embed_dim': 256,
'max_length': max_length,
'start_token': 0,
'end_token': 50,
'num_groups': 2,
'window_size': 5,
'shift_size': 2
}
# 从 Mistral 初始化奖励模型
model = Mistral(**hyperparams)
reward_model = RewardModel(model, dim=hyperparams['hidden_dim'], dropout=0.1)
# 训练奖励模型
trainer = RewardDataParallelTrainer(reward_model, dummy_chosen.shape, 'reward_model_weights.pkl')
trainer.train(dataloader, 5, dataloader)
params = trainer.load_params('reward_model_weights.pkl')
# 使用与常规 Flax 模型相同的方式调用
rewards = reward_model.apply({'params': params},
dummy_chosen,
rngs={'dropout': nanodl.time_rng_key()})
PCA 示例
import nanodl
from nanodl import PCA
# 使用实际数据
data = nanodl.normal(shape=(1000, 10))
# 初始化并训练 PCA 模型
pca = PCA(n_components=2)
pca.fit(data)
# 获取 PCA 变换
transformed_data = pca.transform(data)
# 获取逆变换
original_data = pca.inverse_transform(transformed_data)
# 从分布中采样
X_sampled = pca.sample(n_samples=1000, key=None)
这仍在开发中,效果很好,但预计会有不完善的地方,因此非常鼓励贡献!
- 在不改变设计模式的情况下进行更改。
- 如有必要,为更改编写测试。
- 使用
pip3 install -e .
在本地安装。 - 使用
python3 -m unittest discover -s tests
运行测试。 - 然后提交一个 pull request。
贡献可以有多种形式:
- 撰写文档。
- 修复错误。
- 实现论文。
- 编写覆盖率高的测试。
- 优化现有代码。
- 进行实验并提交实际示例到示例部分。
- 报告错误。
- 回应报告的问题。
加入 Discord 服务器了解更多。
赞助
名称“ NanoDL”代表 Nano Deep Learning。模型正迅速扩大规模,因此限制了资源有限的专家和公司无法在没有高昂成本的情况下构建灵活的模型。 继Phi模型成功后,长期目标是在NanoDL的基础上构建和训练所有可用模型的nano版本,同时确保其性能与原始模型竞争,总参数数量不超过1B。训练好的权重将通过此库提供。 任何形式的赞助、资金将有助于训练资源。 您可以通过 GitHub 或发送邮件到 ndubuakuhenry@gmail.com 进行赞助。
引用 nanodl
引用这个仓库:
@software{nanodl2024github,
author = {Henry Ndubuaku},
title = {NanoDL: A Jax-based library for designing and training transformer models from scratch.},
url = {http://github.com/hmunachi/nanodl},
year = {2024},
}