SpTr: PyTorch空间稀疏Transformer库
SparseTransformer (SpTr) 为具有可变token数量的稀疏transformer(例如用于3D点云的窗口transformer)提供了快速、内存高效且易于使用的实现。
SpTr 已被以下工作采用:
安装
安装依赖
pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch_scatter==2.0.9
pip install torch_geometric==1.7.2
编译sptr
python3 setup.py install
使用
SpTr可以轻松应用于当前大多数基于transformer的3D点云网络中,只需进行几处小修改。首先,定义注意力模块sptr.VarLengthMultiheadSA
。然后,将输入特征和索引包装成sptr.SparseTrTensor
,并将其传入模块。就这么简单。下面是一个简单的示例。对于更复杂的用法,您可以参考上述工作的代码(例如SphereFormer、StratifiedFormer)。
示例
import sptr
# 定义模块
dim = 48
num_heads = 3
indice_key = 'sptr_0'
window_size = np.array([0.4, 0.4, 0.4]) # 对于基于体素的方法也可以是整数
shift_win = False # 是否采用移位窗口
self.attn = sptr.VarLengthMultiheadSA(
dim,
num_heads,
indice_key,
window_size,
shift_win
)
# 将输入特征和索引包装成SparseTrTensor。注意:索引可以是基于体素方法的整数,也可以是基于点的方法的浮点数(即xyz)
# feats: [N, C], indices: [N, 4],第0列为批次索引
input_tensor = sptr.SparseTrTensor(feats, indices, spatial_shape=None, batch_size=None)
output_tensor = self.attn(input_tensor)
# 从输出张量中提取特征
output_feats = output_tensor.query_feats
作者
Xin Lai(香港中文大学计算机科学与工程系博士生,xinlai@cse.cuhk.edu.hk) - 初始CUDA实现,维护。
Fanbin Lu(香港中文大学计算机科学与工程系博士生) - 改进CUDA实现,维护。
Yukang Chen(香港中文大学计算机科学与工程系博士生) - 维护。
引用
如果您发现本项目有用,请考虑引用
@inproceedings{lai2023spherical,
title={Spherical Transformer for LiDAR-based 3D Recognition},
author={Lai, Xin and Chen, Yukang and Lu, Fanbin and Liu, Jianhui and Jia, Jiaya},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2023}
}
@inproceedings{lai2022stratified,
title={Stratified transformer for 3d point cloud segmentation},
author={Lai, Xin and Liu, Jianhui and Jiang, Li and Wang, Liwei and Zhao, Hengshuang and Liu, Shu and Qi, Xiaojuan and Jia, Jiaya},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={8500--8509},
year={2022}
}
许可证
本项目采用Apache License 2.0许可证。