项目介绍
torchinfo 是一个专为 PyTorch 模型用户开发的工具,用于提供丰富的模型总结信息。其功能类似于 TensorFlow 的 model.summary()
API,可以帮助用户在调试神经网络时查看模型的可视化信息。torchinfo 提供的功能比简单的 print(your_model)
更为全面,因而成为用户了解和调试其网络架构的得力助手。
项目背景
torchinfo 是对原始项目 torchsummary 和 torchsummaryX 的重写,解决了这两个项目遗留的问题,提供了全新的 API。此项目与 PyTorch 版本 1.4.0 及以上兼容,支持 Python 3.8 及更新版本。
核心功能
torchinfo 通过一个简单的接口,帮助用户轻松获取 PyTorch 模型的详尽信息。用户可以通过以下任意一种安装方式来获取此工具:
pip install torchinfo
或使用 conda 安装:
conda install -c conda-forge torchinfo
使用示例
使用 torchinfo 非常简单,只需如下几行代码即可获得模型的详细信息:
from torchinfo import summary
model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
这段代码将会输出模型的层次结构、每层的输入和输出形状、参数数量等信息。
支持的功能
- 支持 RNNs, LSTMs 及其他递归层的总结。
- 可视化分支输出,帮助探索指定深度的模型层次。
- 提供 ModelStatistics 对象,其中包含所有总结数据字段。
- 支持在 Jupyter Notebook 和 Google Colab 环境中的使用。
新增特性
- 详细模式:显示权重和偏置层的详细信息。
- 灵活接受输入数据或者仅输入形状。
- 可自定义行列宽度以及批次维度。
示例代码
用户可以通过不同配置选项探索模型的总结信息。例如,下面展示了一种使用 LSTM 网络进行总结的方式:
class LSTMNet(nn.Module):
def __init__(self, vocab_size=20, embed_dim=300, hidden_dim=512, num_layers=2):
super().__init__()
self.hidden_dim = hidden_dim
self.embedding = nn.Embedding(vocab_size, embed_dim)
self.encoder = nn.LSTM(embed_dim, hidden_dim, num_layers=num_layers, batch_first=True)
self.decoder = nn.Linear(hidden_dim, vocab_size)
def forward(self, x):
embed = self.embedding(x)
out, hidden = self.encoder(embed)
out = self.decoder(out)
out = out.view(-1, out.size(2))
return out, hidden
summary(
LSTMNet(),
(1, 100),
dtypes=[torch.long],
verbose=2,
col_width=16,
col_names=["kernel_size", "output_size", "num_params", "mult_adds"],
row_settings=["var_names"],
)
贡献与支持
torchinfo 的开发者欢迎社区参与到项目的开发中来,无论是通过提交问题或者拉取请求。项目的开发基于最新版的 Python,并确保向下兼容到 Python 3.8。贡献者可以使用下面的几行命令来准备开发环境:
- 安装开发依赖包:
pip install -r requirements-dev.txt
- 使用
pre-commit
进行自动格式化:pre-commit install
- 运行单元测试:
pytest
感谢任何能够帮助改进、测试和增强这个项目功能的社区贡献者。