几何变换注意力机制
Takeru Miyato · Bernhard Jaeger · Max Welling · Andreas Geiger
OpenReview | arXiv | 项目主页
我们ICLR2024工作的官方复现代码:"GTA: 多视角Transformer的几何感知注意力机制",一种简单的方法来让您的多视角Transformer更具表现力!
(2024年3月15日):GTA机制对于图像生成这一纯2D任务也很有效。您可以在我们的准备就绪论文中找到实验细节,并在此分支中找到实现。
内容
本仓库包含以下不同的代码库,可以通过切换到相应的分支来访问:
您可以在这里找到多视角ViT的GTA代码,在这里找到图像ViT的GTA代码。
如果您有任何问题,请随时与我们联系!
设置
1. 创建环境并安装Python库
conda create -n gta python=3.9
conda activate gta
pip3 install -r requirements.txt
2. 下载数据集
export DATADIR=<path_to_datadir>
mkdir -p $DATADIR
CLEVR-TR
从此链接下载数据集并将其放在$DATADIR
下
MultiShapeNet Hard (MSN-Hard)
gsutil -m cp -r gs://kubric-public/tfds/kubric_frames/multi_shapenet_conditional/2.8.0/ ${DATADIR}/multi_shapenet_frames/
*预训练模型(MSN-Hard预训练模型即将上传)
训练
CLEVR-TR
torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/clevrtr/GTA/gta/config.yaml ${DATADIR}/clevrtr --seed=0
MSN-Hard
torchrun --standalone --nnodes 1 --nproc_per_node 4 train.py runs/msn/GTA/gta_so3/config.yaml ${DATADIR} --seed=0
PSNR、SSIM和LPIPS的评估
python evaluate.py runs/clevrtr/GTA/gta/config.yaml ${DATADIR}/clevrtr $PATH_TO_CHECKPOINT # CLEVR-TR
python evaluate.py runs/msn/GTA/gta_so3/config.yaml ${DATADIR} $PATH_TO_CHECKPOINT # MSN-Hard
致谢
本仓库建立在@stelzner创建的SRT和OSRT之上。我们要感谢他为SRT模型的开源贡献。 我们还要感谢@lucidrains提供J矩阵的值,这些值对于高效计算SO(3)的不可约表示是必需的。
引用
@inproceedings{Miyato2024GTA,
title={GTA: A Geometry-Aware Attention Mechanism for Multi-View Transformers},
author={Miyato,Takeru and Jaeger, Bernhard and Welling, Max and Geiger, Andreas},
booktitle={International Conference on Learning Representations (ICLR)},
year={2024}
}