在使用掩码时发现了邻居选择的一个bug。如果你在0.1.12版本之前运行了任何带有掩码的实验,请重新运行它们。🙏
EGNN - PyTorch
这是E(n)等变图神经网络在PyTorch中的实现。可能最终会用于AlphaFold2的复现。这种技术采用了简单的不变特征,最终在精确度和性能上超越了所有之前的方法(包括SE3 Transformer和Lie Conv)。在动力系统模型、分子活性预测任务等方面达到了最先进水平。
安装
$ pip install egnn-pytorch
使用方法
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512)
layer2 = EGNN(dim = 512)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
feats, coors = layer1(feats, coors)
feats, coors = layer2(feats, coors) # (1, 16, 512), (1, 16, 3)
带边的情况
import torch
from egnn_pytorch import EGNN
layer1 = EGNN(dim = 512, edge_dim = 4)
layer2 = EGNN(dim = 512, edge_dim = 4)
feats = torch.randn(1, 16, 512)
coors = torch.randn(1, 16, 3)
edges = torch.randn(1, 16, 16, 4)
feats, coors = layer1(feats, coors, edges)
feats, coors = layer2(feats, coors, edges) # (1, 16, 512), (1, 16, 3)
完整的EGNN网络
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
num_positions = 1024, # 除非你传入的是无序集合,否则将此设置为最大序列长度
dim = 32,
depth = 3,
num_nearest_neighbors = 8,
coor_weights_clamp_value = 2. # 坐标权重的绝对限制值,如果增加最近邻数量则需要此参数
)
feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
mask = torch.ones_like(feats).bool() # (1, 1024)
feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
仅关注稀疏邻居,通过邻接矩阵提供给网络。
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
only_sparse_neighbors = True
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
# 简单的邻接矩阵
# 假设序列作为一条链连接,最多有2个邻居 - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
你也可以让网络自动确定N阶邻居,并传入一个邻接嵌入(取决于阶数)作为边使用,只需添加两个额外的关键字参数
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_adj_degrees = 3, # 获取最多3阶邻居
adj_dim = 8, # 将邻接度嵌入传递给EGNN层,用于边MLP
only_sparse_neighbors = True
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
# 简单的邻接矩阵
# 假设序列作为一条链连接,最多有2个邻居 - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
## 边
如果你需要传入连续的边
```python
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
edge_dim = 4,
num_nearest_neighbors = 3
)
feats = torch.randint(0, 21, (1, 1024))
coors = torch.randn(1, 1024, 3)
mask = torch.ones_like(feats).bool()
continuous_edges = torch.randn(1, 1024, 1024, 4)
# 简单的邻接矩阵
# 假设序列作为一条链连接,最多有2个邻居 - (1024, 1024)
i = torch.arange(1024)
adj_mat = (i[:, None] >= (i[None, :] - 1)) & (i[:, None] <= (i[None, :] + 1))
feats_out, coors_out = net(feats, coors, edges = continuous_edges, mask = mask, adj_mat = adj_mat) # (1, 1024, 32), (1, 1024, 3)
稳定性
EGNN的初始架构在邻居数量较多时存在不稳定性问题。幸运的是,似乎有两个解决方案可以在很大程度上缓解这个问题。
import torch
from egnn_pytorch import EGNN_Network
net = EGNN_Network(
num_tokens = 21,
dim = 32,
depth = 3,
num_nearest_neighbors = 32,
norm_coors = True, # 对相对坐标进行归一化
coor_weights_clamp_value = 2. # 坐标权重的绝对钳制值,如果增加最近邻居数量则需要设置
)
feats = torch.randint(0, 21, (1, 1024)) # (1, 1024)
coors = torch.randn(1, 1024, 3) # (1, 1024, 3)
mask = torch.ones_like(feats).bool() # (1, 1024)
feats_out, coors_out = net(feats, coors, mask = mask) # (1, 1024, 32), (1, 1024, 3)
所有参数
import torch
from egnn_pytorch import EGNN
model = EGNN(
dim = dim, # 输入维度
edge_dim = 0, # 边的维度,如果存在,应该 > 0
m_dim = 16, # 隐藏模型维度
fourier_features = 0, # 相对距离编码的傅里叶特征数量 - 默认为论文中的无
num_nearest_neighbors = 0, # 通过相对距离限制进行消息传递的邻居数量上限
dropout = 0.0, # dropout
norm_feats = False, # 是否对特征进行层归一化
norm_coors = False, # 是否对坐标进行归一化,使用SE(3) Transformers论文中的策略
update_feats = True, # 是否更新特征 - 你可以构建一个只更新其中一个的层
update_coors = True, # 是否更新坐标
only_sparse_neighbors = False, # 使用此选项将只允许沿相邻邻居进行消息传递,使用传入的邻接矩阵
valid_radius = float('inf'), # 每个节点考虑进行消息传递的有效半径
m_pool_method = 'sum', # 是否对输出节点表示进行平均或求和池化
soft_edges = False, # 边上的额外GLU,据说有助于在论文更新版本中稳定网络
coor_weights_clamp_value = None # 坐标更新的钳制,同样是为了稳定性目的
)
示例
要运行蛋白质主链去噪示例,首先安装 sidechainnet
$ pip install sidechainnet
然后
$ python denoise_sparse.py
测试
确保你在本地安装了 pytorch geometric
$ python setup.py test
引用
@misc{satorras2021en,
title = {E(n) Equivariant Graph Neural Networks},
author = {Victor Garcia Satorras and Emiel Hoogeboom and Max Welling},
year = {2021},
eprint = {2102.09844},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}