E(n)-等变变换器
E(n)-等变变换器的实现,它扩展了Welling的E(n)-等变图神经网络的思想,结合了注意力机制和变换器架构的理念。
更新:用于设计抗体中的CDR环!
安装
$ pip install En-transformer
使用方法
import torch
from en_transformer import EnTransformer
model = EnTransformer(
dim = 512,
depth = 4, # 深度
dim_head = 64, # 每个头的维度
heads = 8, # 头的数量
edge_dim = 4, # 边特征的维度
neighbors = 64, # 只在N个最近邻坐标之间进行注意力计算 - 设为0可关闭此功能
talking_heads = True, # 使用Shazeer的talking heads https://arxiv.org/abs/2003.02436
checkpoint = True, # 使用检查点以便以较小的内存代价增加深度(并增加关注的邻居数)
use_cross_product = True, # 使用叉积向量(由@MattMcPartlon提出的想法)
num_global_linear_attn_heads = 4 # 如果上面的邻居数较少,可以分配一定数量的注意力头通过线性注意力弱关注所有其他节点(https://arxiv.org/abs/1812.01243)
)
feats = torch.randn(1, 1024, 512)
coors = torch.randn(1, 1024, 3)
edges = torch.randn(1, 1024, 1024, 4)
mask = torch.ones(1, 1024).bool()
feats, coors = model(feats, coors, edges, mask = mask) # (1, 1024, 512), (1, 1024, 3)
让网络同时处理原子和键类型的嵌入
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10, # 唯一节点的数量,比如原子
rel_pos_emb = True, # 如果你的序列不是无序集合,将此设为true。它将加速收敛
num_edge_tokens = 5, # 唯一边的数量,比如键类型
dim = 128,
edge_dim = 16,
depth = 3,
heads = 4,
dim_head = 32,
neighbors = 8
)
atoms = torch.randint(0, 10, (1, 16)) # 10种不同类型的原子
bonds = torch.randint(0, 5, (1, 16, 16)) # 5种不同类型的键(n x n)
coors = torch.randn(1, 16, 3) # 原子空间坐标
feats_out, coors_out = model(atoms, coors, edges = bonds) # (1, 16, 512), (1, 16, 3)
如果你只想关注由邻接矩阵定义的稀疏邻居(例如原子),你需要设置一个额外的标志,然后传入N x N
邻接矩阵。
import torch
from en_transformer import EnTransformer
model = EnTransformer(
num_tokens = 10,
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
neighbors = 0,
only_sparse_neighbors = True, # 必须设置为true
num_adj_degrees = 2, # 从传入的一级邻居派生的度数
adj_dim = 8 # 是否将邻接度信息作为边嵌入传递
)
atoms = torch.randint(0, 10, (1, 16))
coors = torch.randn(1, 16, 3)
# 简单假设一个单链原子
i = torch.arange(atoms.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
# 必须传入邻接矩阵
feats_out, coors_out = model(atoms, coors, adj_mat = adj_mat) # (1, 16, 512), (1, 16, 3)
边
如果你需要传入连续的边
import torch
from en_transformer import EnTransformer
from en_transformer.utils import rot
model = EnTransformer(
dim = 512,
depth = 1,
heads = 4,
dim_head = 32,
edge_dim = 4,
num_nearest_neighbors = 0,
only_sparse_neighbors = True
)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
i = torch.arange(feats.shape[1])
adj_mat = (i[:, None] <= (i[None, :] + 1)) & (i[:, None] >= (i[None, :] - 1))
feats1, coors1 = model(feats, coors, adj_mat = adj_mat, edges = edges)
示例
要运行蛋白质主链坐标去噪的玩具任务,首先安装 sidechainnet
$ pip install sidechainnet
然后
$ python denoise.py
待办事项
引用
@misc{satorras2021en,
title = {E(n) 等变图神经网络},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{shazeer2020talkingheads,
title = {会说话的头部注意力},
author = {Noam Shazeer and Zhenzhong Lan and Youlong Cheng and Nan Ding and Le Hou},
year = {2020},
eprint = {2003.02436},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
@misc{liu2021swin,
title = {Swin Transformer V2:扩展容量和分辨率},
author = {Ze Liu and Han Hu and Yutong Lin and Zhuliang Yao and Zhenda Xie and Yixuan Wei and Jia Ning and Yue Cao and Zheng Zhang and Li Dong and Furu Wei and Baining Guo},
year = {2021},
eprint = {2111.09883},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
@inproceedings{Kim2020TheLC,
title = {自注意力的李普希茨常数},
author = {Hyunjik Kim and George Papamakarios and Andriy Mnih},
booktitle = {国际机器学习会议},
year = {2020},
url = {https://api.semanticscholar.org/CorpusID:219530837}
}
@article {Mahajan2023.07.15.549154,
author = {Sai Pooja Mahajan and Jeffrey A. Ruffolo and Jeffrey J. Gray},
title = {来自等变图转换器的上下文蛋白质和抗体编码},
elocation-id = {2023.07.15.549154},
year = {2023},
doi = {10.1101/2023.07.15.549154},
publisher = {Cold Spring Harbor Laboratory},
URL = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154},
eprint = {https://www.biorxiv.org/content/early/2023/07/29/2023.07.15.549154.full.pdf},
journal = {bioRxiv}
}
@article{Bondarenko2023QuantizableTR,
title = {可量化的转换器:通过帮助注意力头什么都不做来消除异常值},
author = {Yelysei Bondarenko and Markus Nagel and Tijmen Blankevoort},
journal = {ArXiv},
year = {2023},
volume = {abs/2306.12929},
url = {https://api.semanticscholar.org/CorpusID:259224568}
}
@inproceedings{Arora2023ZoologyMA,
title = {动物学:测量和改进高效语言模型的召回率},
author = {Simran Arora and Sabri Eyuboglu and Aman Timalsina and Isys Johnson and Michael Poli and James Zou and Atri Rudra and Christopher R'e},
year = {2023},
url = {https://api.semanticscholar.org/CorpusID:266149332}
}