简介
这个库按照Scaling and evaluating sparse autoencoders(Gao等人,2024年)中详述的方法,在HuggingFace语言模型的残差流激活上训练_k_稀疏自编码器(SAEs)。
这是一个精简、简单的库,具有少量配置选项。与大多数其他SAE库(例如SAELens)不同,它不会在磁盘上缓存激活,而是即时计算。这使我们能够扩展到非常大的模型和数据集,且无存储开销,但缺点是对同一模型和数据集尝试不同的超参数会比缓存激活时更慢(因为激活会被重新计算)。我们可能会在未来添加缓存作为一个选项。
遵循Gao等人的方法,我们使用TopK激活函数,直接强制激活达到所需的稀疏程度。这与其他使用L1惩罚项在损失函数中实现的库不同。我们认为TopK是对L1方法的帕累托改进,因此不打算支持它。
加载预训练SAEs
要从HuggingFace Hub加载预训练的SAE,你可以使用Sae.load_from_hub
方法,如下所示:
from sae import Sae
sae = Sae.load_from_hub("EleutherAI/sae-llama-3-8b-32x", hookpoint="layers.10")
这将加载Llama 3 8B的残差流层10的SAE,该SAE使用32倍扩展因子进行训练。你也可以使用Sae.load_many
一次性加载所有层的SAEs:
saes = Sae.load_many("EleutherAI/sae-llama-3-8b-32x")
saes["layers.10"]
load_many
返回的字典保证按钩点名称自然排序。对于钩点命名为embed_tokens
、layers.0
、...、layers.n
的常见情况,这意味着SAEs将按层数排序。然后我们可以如下收集模型前向传播的SAE激活:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
inputs = tokenizer("Hello, world!", return_tensors="pt")
with torch.inference_mode():
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3-8B")
outputs = model(**inputs, output_hidden_states=True)
latent_acts = []
for sae, hidden_state in zip(saes.values(), outputs.hidden_states):
latent_acts.append(sae.encode(hidden_state))
# 对潜在激活进行操作
训练SAEs
要从命令行训练SAEs,你可以使用以下命令:
python -m sae EleutherAI/pythia-160m togethercomputer/RedPajama-Data-1T-Sample
CLI支持TrainConfig
类提供的所有配置选项。你可以通过运行python -m sae --help
查看它们。
程序化使用很简单。这里有一个例子:
import torch
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from sae import SaeConfig, SaeTrainer, TrainConfig
from sae.data import chunk_and_tokenize
MODEL = "EleutherAI/pythia-160m"
dataset = load_dataset(
"togethercomputer/RedPajama-Data-1T-Sample",
split="train",
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(MODEL)
tokenized = chunk_and_tokenize(dataset, tokenizer)
gpt = AutoModelForCausalLM.from_pretrained(
MODEL,
device_map={"": "cuda"},
torch_dtype=torch.bfloat16,
)
cfg = TrainConfig(
SaeConfig(gpt.config.hidden_size), batch_size=16
)
trainer = SaeTrainer(cfg, tokenized, gpt)
trainer.fit()
自定义钩点
默认情况下,SAEs在模型的残差流激活上进行训练。但是,你也可以通过指定自定义钩点模式在任何其他子模块的激活上训练SAEs。这些模式类似于标准PyTorch模块名称(如h.0.ln_1
),但也允许使用Unix模式匹配语法,包括通配符和字符集。例如,要在GPT-2的每个注意力模块的输出和每个MLP的内部激活上训练SAEs,你可以使用以下代码:
python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints "h.*.attn" "h.*.mlp.act"
要限制在前三层:
python -m sae gpt2 togethercomputer/RedPajama-Data-1T-Sample --hookpoints "h.[012].attn" "h.[012].mlp.act"
我们目前不支持对每个钩点的学习率、潜在变量数量或其他超参数进行精细的手动控制。默认情况下,expansion_ratio
选项用于根据每个钩点输出的宽度选择适当数量的潜在变量。然后根据潜在变量的数量使用反平方根缩放法则为每个钩点设置默认学习率。如果你手动设置潜在变量数量或学习率,它将应用于所有钩点。
分布式训练
我们通过PyTorch的torchrun
命令支持分布式训练。默认情况下,我们使用分布式数据并行方法,这意味着每个SAE的权重在每个GPU上都有副本。
torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --batch_size 1 --layers 16 24 --k 192 --grad_acc_steps 8 --ctx_len 2048
这很简单,但内存效率非常低。如果你想为模型的多个层训练SAEs,我们建议使用--distribute_modules
标志,它将不同层的SAEs分配到不同的GPU上。目前,我们要求GPU数量能够整除你要训练SAEs的层数。
torchrun --nproc_per_node gpu -m sae meta-llama/Meta-Llama-3-8B --distribute_modules --batch_size 1 --layer_stride 2 --grad_acc_steps 8 --ctx_len 2048 --k 192 --load_in_8bit --micro_acc_steps 2
上述命令为Llama 3 8B的每个_偶数_层训练一个SAE,使用所有可用的GPU。它在8个小批量上累积梯度,并在将每个小批量输入SAE编码器之前将其分成2个微批量,从而节省大量内存。它还使用bitsandbytes
以8位精度加载模型。此命令在8个GPU节点上每个GPU需要不超过48GB的内存。
待办事项
我们希望在不久的将来添加以下几个功能:
- 微调预训练SAEs
- 支持缓存激活
- 评估移植到模型中的SAEs的KL散度
如果你想帮助实现这些功能,请随时提交PR!你可以在EleutherAI Discord的sparse-autoencoders频道与我们合作。