Project Icon

sae

高效训练语言模型k稀疏自编码器的开源库

这是一个用于训练语言模型k稀疏自编码器(SAE)的开源库。它使用TopK激活函数实现激活稀疏,可扩展至大型模型和数据集,无需额外存储。该库支持加载HuggingFace Hub预训练SAE,提供命令行和编程接口,允许自定义hookpoint训练任意子模块。支持分布式训练,适用于大规模语言模型。

简介

这个库按照Scaling and evaluating sparse autoencoders(Gao等人,2024年)中详述的方法,在HuggingFace语言模型的残差流激活上训练_k_稀疏自编码器(SAEs)。

这是一个精简、简单的库,具有少量配置选项。与大多数其他SAE库(例如SAELens)不同,它不会在磁盘上缓存激活,而是即时计算。这使我们能够扩展到非常大的模型和数据集,且无存储开销,但缺点是对同一模型和数据集尝试不同的超参数会比缓存激活时更慢(因为激活会被重新计算)。我们可能会在未来添加缓存作为一个选项。

遵循Gao等人的方法,我们使用TopK激活函数,直接强制激活达到所需的稀疏程度。这与其他使用L1惩罚项在损失函数中实现的库不同。我们认为TopK是对L1方法的帕累托改进,因此不打算支持它。

加载预训练SAEs

要从HuggingFace Hub加载预训练的SAE,你可以使用Sae.load_from_hub方法,如下所示:

from sae import Sae

sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x", hookpoint="layers.10")

这将加载Llama 3 8B的残差流层10的SAE,该SAE使用32倍扩展因子进行训练。你也可以使用Sae.load_many一次性加载所有层的SAEs:

saes = Sae.load_many("EleutherAI/sae-llama-3-8b-32x")
saes["layers.10"]

load_many返回的字典保证按钩点名称自然排序。对于钩点命名为embed_tokenslayers.0、...、layers.n的常见情况,这意味着SAEs将按层数排序。然后我们可以如下收集模型前向传播的SAE激活:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
inputs = tokenizer("Hello, world!", return_tensors="pt")

with torch.inference_mode():
    model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
    outputs = model(**inputs, output_hidden_states=True)

    latent_acts = []
    for sae, hidden_state in zip(saes.values(), outputs.hidden_states):
        latent_acts.append(sae.encode(hidden_state))

# 对潜在激活进行操作

训练SAEs

要从命令行训练SAEs,你可以使用以下命令:

python -m sae EleutherAI/pythia-160m togethercomputer/RedPajama-Data-1T-Sample

CLI支持TrainConfig类提供的所有配置选项。你可以通过运行python -m sae --help查看它们。

程序化使用很简单。这里有一个例子:

import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer

from sae import SaeConfig, SaeTrainer, TrainConfig
from sae.data import chunk_and_tokenize

MODEL = "EleutherAI/pythia-160m"
dataset = load_dataset(
    "togethercomputer/RedPajama-Data-1T-Sample",
    split="train",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized = chunk_and_tokenize(dataset, tokenizer)


gpt = AutoModelForCausalLM.from_pretrained(
    MODEL,
    device_map={"": "cuda"},
    torch_dtype=torch.bfloat16,
)

cfg = TrainConfig(
    SaeConfig(gpt.config.hidden_size), batch_size=16
)
trainer = SaeTrainer(cfg, tokenized, gpt)

trainer.fit()

自定义钩点

默认情况下,SAEs在模型的残差流激活上进行训练。但是,你也可以通过指定自定义钩点模式在任何其他子模块的激活上训练SAEs。这些模式类似于标准PyTorch模块名称(如h.0.ln_1),但也允许使用Unix模式匹配语法,包括通配符和字符集。例如,要在GPT-2的每个注意力模块的输出和每个MLP的内部激活上训练SAEs,你可以使用以下代码:

python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints "h.*.attn" "h.*.mlp.act"

要限制在前三层:

python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints "h.[012].attn" "h.[012].mlp.act"

我们目前不支持对每个钩点的学习率、潜在变量数量或其他超参数进行精细的手动控制。默认情况下,expansion_ratio选项用于根据每个钩点输出的宽度选择适当数量的潜在变量。然后根据潜在变量的数量使用反平方根缩放法则为每个钩点设置默认学习率。如果你手动设置潜在变量数量或学习率,它将应用于所有钩点。

分布式训练

我们通过PyTorch的torchrun命令支持分布式训练。默认情况下,我们使用分布式数据并行方法,这意味着每个SAE的权重在每个GPU上都有副本。

torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048

这很简单,但内存效率非常低。如果你想为模型的多个层训练SAEs,我们建议使用--distribute_modules标志,它将不同层的SAEs分配到不同的GPU上。目前,我们要求GPU数量能够整除你要训练SAEs的层数。

torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2

上述命令为Llama 3 8B的每个_偶数_层训练一个SAE,使用所有可用的GPU。它在8个小批量上累积梯度,并在将每个小批量输入SAE编码器之前将其分成2个微批量,从而节省大量内存。它还使用bitsandbytes以8位精度加载模型。此命令在8个GPU节点上每个GPU需要不超过48GB的内存。

待办事项

我们希望在不久的将来添加以下几个功能:

  • 微调预训练SAEs
  • 支持缓存激活
  • 评估移植到模型中的SAEs的KL散度

如果你想帮助实现这些功能,请随时提交PR!你可以在EleutherAI Discord的sparse-autoencoders频道与我们合作。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号