Enformer - Pytorch 项目简介
项目概述
Enformer-Pytorch 是一个用于基因表达预测的深度学习模型,由 Deepmind 的注意力网络 Enformer 实现。该项目旨在使用 Pytorch 框架重现原始的 Tensorflow Sonnet 代码,帮助用户更方便地进行下游任务所需的预训练模型微调。近期的更新还包括对伪块染色质可及性预测的精调。
安装指南
要使用 Enformer-Pytorch 项目,只需在命令行中执行以下命令安装:
$ pip install enformer-pytorch
使用方法
以下是在 pytorch 中加载和使用 Enformer 模型的基本示例:
import torch
from enformer_pytorch import Enformer
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = {'human': 5313, 'mouse': 1643},
target_length = 896,
)
seq = torch.randint(0, 5, (1, 196_608))
output = model(seq)
output['human']
output['mouse']
该模型支持序列数据作为输入,允许用户直接提供序列的独热编码形式,并能获取嵌入向量用于微调。
预训练模型
Deepmind 发布了 Enformer 模型的 Tensorflow Sonnet 权重,已被移植到 Pytorch 并上传至 Huggingface。用户可以通过使用预训练功能来加载模型。同时,也可以根据需求调整目标序列的长度。
from enformer_pytorch import from_pretrained
enformer = from_pretrained('EleutherAI/enformer-official-rough')
微调
Enformer-Pytorch 项目提供各种微调方法,可用于在新轨道或背景数据上进行模型精调。
新轨道微调
import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import HeadAdapterWrapper
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = HeadAdapterWrapper(
enformer = enformer,
num_tracks = 128,
).cuda()
seq = torch.randint(0, 5, (1, 196_608 // 2)).cuda()
target = torch.randn(1, 200, 128).cuda() # 128 tracks
loss = model(seq, target = target)
loss.backward()
背景数据微调
import torch
from enformer_pytorch import from_pretrained
from enformer_pytorch.finetune import ContextAdapterWrapper
enformer = from_pretrained('EleutherAI/enformer-official-rough')
model = ContextAdapterWrapper(
enformer = enformer,
context_dim = 1024
).cuda()
seq = torch.randint(0, 5, (1, 196_608 // 2)).cuda()
target = torch.randn(1, 200, 4).cuda()
context = torch.randn(4, 1024).cuda()
loss = model(
seq,
context = context,
target = target
)
loss.backward()
数据处理
用户可以利用 GenomicIntervalDataset
轻松地从 .bed
文件中获取任何长度的序列,为模型训练提供必要的数据。
import torch
import polars as pl
from enformer_pytorch import Enformer, GenomeIntervalDataset
ds = GenomeIntervalDataset(
bed_file = './sequences.bed',
fasta_file = './hg38.ml.fa',
return_seq_indices = True,
context_length = 196_608,
)
model = Enformer.from_hparams(
dim = 1536,
depth = 11,
heads = 8,
output_heads = dict(human = 5313, mouse = 1643),
target_length = 896,
)
seq = ds[0]
pred = model(seq, head = 'human')
未来计划
该项目未来计划继续优化模型权重加载及微调过程,提供更为便捷的训练工具,旨在进一步简化基因组数据分析的实现流程。
通过 Enformer-Pytorch 项目,研究者们能够更加方便地利用最新的深度学习技术为基因组学研究服务。无论是科学家或是开发人员,均可利用该工具在复杂的基因组数据中提取有价值的生物学信息。