Jamba
PyTorch 实现 Jamba: "Jamba: 混合Transformer-Mamba语言模型"
安装
$ pip install jamba
使用
# 导入提供机器学习工具的 torch 库
import torch
# 从 jamba.model 模块中导入 Jamba 模型
from jamba.model import Jamba
# 创建一个形状为(1, 100)的随机整数张量,取值范围在0-100之间
# 这模拟了一个我们将通过模型传递的 token 批次
x = torch.randint(0, 100, (1, 100))
# 使用指定参数初始化 Jamba 模型
# dim: 输入数据的维度
# depth: 模型的层数
# num_tokens: 输入数据中唯一 token 的数量
# d_state: 模型中隐藏状态的维度
# d_conv: 模型中卷积层的维度
# heads: 模型中的注意力头数量
# num_experts: 模型中的专家网络数量
# num_experts_per_token: 每个输入 token 使用的专家数量
model = Jamba(
dim=512,
depth=6,
num_tokens=100,
d_state=256,
d_conv=128,
heads=8,
num_experts=8,
num_experts_per_token=2,
)
# 使用输入数据执行模型的前向传播
# 这将返回模型对输入数据中每个 token 的预测
output = model(x)
# 打印模型的预测结果
print(output)
训练
python3 train.py
许可证
MIT