KAN-GPT: 基于Kolmogorov-Arnold网络的GPT实现
KAN-GPT是一个使用PyTorch实现的基于Kolmogorov-Arnold网络(KANs)的生成式预训练Transformer(GPT)模型,用于语言建模任务。本文将介绍这个项目的主要特点以及相关学习资源,帮助感兴趣的读者快速上手。
项目简介
KAN-GPT项目的主要特点包括:
- 使用PyTorch实现GPT模型架构
- 引入Kolmogorov-Arnold网络作为核心组件
- 专注于语言建模任务
- 开源且易于使用和扩展
项目地址: https://github.com/AdityaNG/kan-gpt
安装使用
可以通过pip直接安装KAN-GPT:
pip install kan_gpt
快速上手
以下是一个简单的使用示例:
from kan_gpt.model import GPT
from transformers import GPT2Tokenizer
model_config = GPT.get_default_config()
model_config.model_type = "gpt2"
model_config.vocab_size = 50257
model_config.block_size = 1024
model = GPT(model_config)
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
prompt = "Bangalore is often described as the "
prompt_encoded = tokenizer.encode(
text=prompt, add_special_tokens=False
)
x = torch.tensor(prompt_encoded).unsqueeze(0)
model.eval()
y = model.generate(x, 50) # sample 50 tokens
result = tokenizer.decode(y[0])
print(result)
更详细的使用说明可以参考项目中的KAN_GPT.ipynb和kan_gpt/prompt.py。
开发指南
如果你想参与KAN-GPT的开发,可以按照以下步骤搭建开发环境:
- 克隆项目代码:
git clone https://github.com/AdityaNG/kan-gpt
cd kan-gpt
git pull
- 下载数据集:
python3 -m kan_gpt.download_dataset --dataset tinyshakespeare
python3 -m kan_gpt.download_dataset --dataset mnist
python3 -m kan_gpt.download_dataset --dataset webtext
- 安装开发依赖:
pip install -r requirements.txt
pip install -e .
模型训练
可以使用以下命令来训练模型:
python -m kan_gpt.train
模型推理
训练好模型后,可以使用以下命令进行文本生成:
python -m kan_gpt.prompt --prompt "Bangalore is often described as the " --model_path <checkpoint>
更多资源
KAN-GPT是一个有趣的将Kolmogorov-Arnold网络应用于GPT的尝试。无论你是对语言模型感兴趣,还是想了解KAN的应用,都可以从这个项目中获得启发。欢迎大家尝试使用并为项目做出贡献!