PyTorch Scatter 项目介绍
项目概述
PyTorch Scatter 是一个专为 PyTorch 设计的小型扩展库,提供了高度优化的稀疏更新操作,如 scatter 和 segment 操作。这些操作通常用于根据给定的“组索引”张量进行的归约运算。尤其是在 PyTorch 主包缺少这些功能的情况下,PyTorch Scatter 能够填补这一空白。
scatter 和 segment 操作是这套库的核心。segment 操作需要“组索引”张量是排序的,而 scatter 操作则没有此要求。该库支持多种归约类型,包括"sum"
、"mean"
、"min"
和"max"
。
除了基本的 scatter 和 segment 功能外,库中还提供了一些更高级的复合函数,例如 scatter_std
、scatter_logsumexp
、scatter_softmax
和 scatter_log_softmax
。
功能特点
- 高效性:所有操作均经过高度优化。
- 多平台支持:支持 CPU 和 GPU,并提供相应的后向实现。
- 数据类型兼容性:可以处理多种数据类型。
- 可追溯性:实现了完整的可追溯功能。
- 可广播性:支持广播操作。
安装指南
使用 Anaconda 安装
用户可以通过 Anaconda 升级部署 pytorch-scatter
,这支持所有主流的操作系统、PyTorch 和 CUDA 组合。确保您已经安装了 pytorch >= 1.8.0
,然后运行:
conda install pytorch-scatter -c pyg
二进制安装
此外,pytorch-scatter 提供了适用于所有主流 OS、PyTorch 和 CUDA 组合的 pip 安装轮包。例如,要为 PyTorch 2.5.0 安装二进制文件,运行:
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.5.0+${CUDA}.html
${CUDA}
可以替换为 cpu
、cu118
、cu121
或者 cu124
,具体取决于您的 PyTorch 安装。
从源码安装
如果选择从源码安装,请确保至少安装了 PyTorch 1.4.0,并确保您的环境变量中 cuda/bin
和 cuda/include
路径正确设置。然后运行:
pip install torch-scatter
使用实例
以下是一个简单的使用示例:
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
print(out)
# 输出:
# tensor([[0, 0, 4, 3, 2, 0],
# [2, 4, 3, 0, 0, 0]])
print(argmax)
# 输出:
# tensor([[5, 5, 3, 4, 0, 1]
# [1, 4, 3, 5, 5, 5]])
运行测试
运行以下命令以执行项目的单元测试:
pytest
C++ API
PyTorch Scatter 也提供了 C++ API,包含与 Python 模型等效的 C++ 实现。若要使用 C++ API,需要将 TorchLib
添加到 -DCMAKE_PREFIX_PATH
中。例如,如果通过 conda
安装,它可能位于 {CONDA}/lib/python{X.X}/site-packages/torch
:
mkdir build
cd build
# 如需支持 CUDA,请添加 -DWITH_CUDA=on
cmake -DCMAKE_PREFIX_PATH="..." ..
make
make install
通过使用 PyTorch Scatter,开发者可以更轻松地实现高效的稀疏操作,使 PyTorch 项目更加强大和灵活。