Mixture of Experts 项目介绍
项目概况
Mixture of Experts(MoE)是一个基于Pytorch框架的项目,旨在通过稀疏门控技术(Sparsely Gated)大幅提高语言模型的参数容量,同时保持计算量不变。该项目主要是对Tensorflow实现版本的逐行转换,并加以若干增强。项目推介使用ST Mixture of Experts作为其延续和更新。
安装指南
用户可以通过如下简单的命令安装Mixture of Experts:
$ pip install mixture_of_experts
使用方法
用户可以通过以下简单的Python代码来实现MoE模型:
import torch
from torch import nn
from mixture_of_experts import MoE
moe = MoE(
dim = 512,
num_experts = 16,
hidden_dim = 512 * 4,
activation = nn.LeakyReLU,
second_policy_train = 'random',
second_policy_eval = 'random',
second_threshold_train = 0.2,
second_threshold_eval = 0.2,
capacity_factor_train = 1.25,
capacity_factor_eval = 2.,
loss_coef = 1e-2
)
inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs)
这些设置足以在单台机器上运行,但如果需要实现一个两级分层专家模型,可以参考以下代码:
import torch
from mixture_of_experts import HeirarchicalMoE
moe = HeirarchicalMoE(
dim = 512,
num_experts = (4, 4)
)
inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs)
实现更大规模参数模型
用户还可以通过调整专家数量来实现更复杂的模型,例如一个具有10亿参数的网络:
import torch
from mixture_of_experts import HeirarchicalMoE
moe = HeirarchicalMoE(
dim = 512,
num_experts = (22, 22)
).cuda()
inputs = torch.randn(1, 1024, 512).cuda()
out, aux_loss = moe(inputs)
total_params = sum(p.numel() for p in moe.parameters())
print(f'number of parameters - {total_params}')
自定义专家网络
如果用户希望为模型定义更复杂的专家网络,那么可以创建自己的网络,并将其传递给MoE
类:
import torch
from torch import nn
from mixture_of_experts import MoE
class Experts(nn.Module):
def __init__(self, dim, num_experts = 16):
super().__init__()
self.w1 = nn.Parameter(torch.randn(num_experts, dim, dim * 4))
self.w2 = nn.Parameter(torch.randn(num_experts, dim * 4, dim * 4))
self.w3 = nn.Parameter(torch.randn(num_experts, dim * 4, dim))
self.act = nn.LeakyReLU(inplace = True)
def forward(self, x):
hidden1 = self.act(torch.einsum('end,edh->enh', x, self.w1))
hidden2 = self.act(torch.einsum('end,edh->enh', hidden1, self.w2))
out = torch.einsum('end,edh->enh', hidden2, self.w3)
return out
experts = Experts(512, num_experts = 16)
moe = MoE(
dim = 512,
num_experts = 16,
experts = experts
)
inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs)
通过这种方式,用户可以根据具体需求调整和优化模型,以适应不同的应用场景和研究方向。