PyTorch Scatter
此软件包包含用于 PyTorch 的一小部分高度优化的稀疏更新(scatter 和 segment)操作的扩展库,主包中没有这些操作。 根据给定的“组索引”张量,scatter 和 segment 操作大致可以描述为归约操作。 Segment 操作要求“组索引”张量排序,而 scatter 操作不受此限制。
该包包含以下带有归约类型 "sum"|"mean"|"min"|"max"
的操作:
- 基于任意索引的 scatter
- 基于排序索引的 segment_coo
- 基于指针压缩索引的 segment_csr
此外,我们还提供以下复合函数,它们在底层使用 scatter_*
操作:scatter_std
、scatter_logsumexp
、scatter_softmax
和 scatter_log_softmax
。
所有包含的操作都是可广播的,适用于不同的数据类型,同时在 CPU 和 GPU 上实现,并且具有相应的反向实现,且完全可追溯。
安装
Anaconda
更新: 你现在可以通过 Anaconda 安装 pytorch-scatter
,适用于所有主要操作系统/PyTorch/CUDA 组合 🤗
假设你已经安装了 pytorch >= 1.8.0
,只需运行
conda install pytorch-scatter -c pyg
二进制文件
我们还提供了适用于所有主要操作系统/PyTorch/CUDA 组合的 pip 轮子,详见 此处。
PyTorch 2.3
要安装适用于 PyTorch 2.3.0 的二进制文件,只需运行
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.3.0+${CUDA}.html
其中 ${CUDA}
应替换为 cpu
、cu118
或 cu121
,具体取决于你的 PyTorch 安装。
cpu | cu118 | cu121 | |
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
Windows | ✅ | ✅ | ✅ |
macOS | ✅ |
PyTorch 2.2
要安装适用于 PyTorch 2.2.0 的二进制文件,只需运行
pip install torch-scatter -f https://data.pyg.org/whl/torch-2.2.0+${CUDA}.html
其中 ${CUDA}
应替换为 cpu
、cu118
或 cu121
,具体取决于你的 PyTorch 安装。
cpu | cu118 | cu121 | |
---|---|---|---|
Linux | ✅ | ✅ | ✅ |
Windows | ✅ | ✅ | ✅ |
macOS | ✅ |
注意: 适用于 PyTorch 1.4.0、PyTorch 1.5.0、PyTorch 1.6.0、PyTorch 1.7.0/1.7.1、PyTorch 1.8.0/1.8.1、PyTorch 1.9.0、PyTorch 1.10.0/1.10.1/1.10.2、PyTorch 1.11.0、PyTorch 1.12.0/1.12.1、PyTorch 1.13.0/1.13.1、PyTorch 2.0.0/2.0.1 和 PyTorch 2.1.0/2.1.1/2.1.2 的旧版本二进制文件也可用(使用相同的步骤)。
对于旧版本,您需要明确指定最新支持的版本号,或通过 pip install --no-index
安装以防止手动从源代码安装。
您可以在 这里 查看最新支持的版本号。
从源代码安装
确保安装了至少 PyTorch 1.4.0,并验证 cuda/bin
和 cuda/include
是否在你的 $PATH
和 $CPATH
中,例如:
$ python -c "import torch; print(torch.__version__)"
>>> 1.4.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
$ echo $CPATH
>>> /usr/local/cuda/include:...
然后运行:
pip install torch-scatter
在没有 NVIDIA 驱动程序的 docker 容器中运行时,PyTorch 需要评估计算能力并可能会失败。
在这种情况下,确保通过环境变量 TORCH_CUDA_ARCH_LIST
设置计算能力,例如:
export TORCH_CUDA_ARCH_LIST="6.0 6.1 7.2+PTX 7.5+PTX"
示例
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
print(out)
tensor([[0, 0, 4, 3, 2, 0],
[2, 4, 3, 0, 0, 0]])
print(argmax)
tensor([[5, 5, 3, 4, 0, 1],
[1, 4, 3, 5, 5, 5]])
运行测试
pytest
C++ 接口
torch-scatter
还提供了 C++ 接口,其中包含 Python 模型的 C++ 等效版本。
为此,我们需要将 TorchLib
添加到 -DCMAKE_PREFIX_PATH
(例如,如果通过 conda
安装,它可能位于 {CONDA}/lib/python{X.X}/site-packages/torch
中):
mkdir build
cd build
# 添加 -DWITH_CUDA=on 以支持 CUDA 支持
cmake -DCMAKE_PREFIX_PATH="..." ..
make
make install