视频帧插值的统一金字塔循环网络
本项目是我们 CVPR 2023 论文《视频帧插值的统一金字塔循环网络》的官方实现。
简介
我们提出了 UPR-Net,一种用于帧插值的新型统一金字塔循环网络。UPR-Net 采用灵活的金字塔框架,利用轻量级循环模块进行双向光流估计和中间帧合成。在每个金字塔层级,它利用估计的双向光流生成前向变形表示用于帧合成;在不同金字塔层级之间,它实现了光流和中间帧的迭代细化。特别是,我们表明我们的迭代合成策略可以显著提高大幅度运动情况下帧插值的鲁棒性。尽管极其轻量级(170万参数),我们的 UPR-Net 基础版本在大量基准测试中都取得了出色的性能。
Python 和 CUDA 环境
本代码已在 PyTorch 1.6 和 CUDA 10.2 环境下测试。它应该也兼容更高版本的 PyTorch 和 CUDA。运行以下命令初始化环境:
conda create --name uprnet python=3.7
conda activate uprnet
conda install pytorch==1.6.0 torchvision==0.7.0 cudatoolkit=10.2 -c pytorch
pip3 install cupy_cuda102==9.4.0
pip3 install -r requirements.txt
特别是,运行前向变形操作需要 CuPy 包(详情请参考 softmax-splatting)。如果您的 CUDA 版本低于 10.2(但不低于 9.2),我们建议将上述命令中的 cudatoolkit=10.2
替换为 cudatoolkit=9.2
,并将 cupy_cuda102==9.4.0
替换为 cupy_cuda92==9.6.0
。
演示
我们将训练好的模型权重放在 checkpoints
中,并提供了一个脚本来测试我们的帧插值模型。给定两个连续的输入帧和所需的时间步长,运行以下命令,您将在 ./demo/output
目录中获得估计的双向光流和插值帧。
python3 -m demo.interp_imgs \
--frame0 demo/images/beanbags0.png \
--frame1 demo/images/beanbags1.png \
--time_period 0.5
这里的 time_period
(0~1 之间的浮点数)表示您想要插值的中间帧的时间步长。
在 Vimeo90K 上训练
默认情况下,我们的模型在 Vimeo90K 上训练。如果您想训练我们的模型,请下载 Vimeo90K。
默认训练配置
您可以运行以下命令来训练我们 UPR-Net 的基础版本:
CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m torch.distributed.launch \
--nproc_per_node=4 --master_port=10000 -m tools.train \
--world_size=4 \
--data_root /path/to/vimeo_triplet \
--train_log_root /path/to/train_log \
--exp_name upr-base \
--batch_size 8 \
--nr_data_worker 2
请(必须)将 data_root
指定为用于训练的 vimeo_triplet 路径,并(可选)将 train_log_root
指定为保存日志的路径(训练权重和 tensorboard 日志)。我们不建议在此代码库目录下保存日志文件。如果未明确指定 train_log_root
,默认情况下所有日志将保存在 ../upr-train-log
中。为方便起见,我们还特意将 train_log_root
软链接到 ./train-log
。
一些训练技巧
-
如果您想训练我们 UPR-Net 的 large 或 LARGE 版本,请将参数
model_size
指定为large
或LARGE
。 -
如果您暂停了训练,想从先前的检查点重新开始训练,请在训练命令中将参数
resume
设置为True
。 -
默认情况下,我们将总批量大小设置为 32,并使用 4 个 GPU 进行分布式训练,每个 GPU 在一个批次中处理 8 个样本(在我们的训练命令中将
batch_size
设置为 8)。因此,如果您使用 2 个 GPU 进行训练,请在训练命令中将batch_size
设置为 16。 -
您可以通过运行类似
tensorboard --logdir=./train-log/upr-base/tensorboard
的命令,使用 TensorBoard 查看训练曲线、插值和光流。
基准测试
训练好的模型权重
我们已将训练好的模型权重放在 ./checkpoints
中。我们 UPR-Net 的 base/large/LARGE 版本的权重分别命名为 upr.pkl
、upr_large.pkl
、upr_llarge.pkl
。
基准数据集
我们在 Vimeo90K、UCF101、SNU-FILM 和 4K1000FPS 上评估了我们的 UPR-Net 系列。
如果您想训练和测试我们的模型,请下载 Vimeo90K、UCF101、SNU-FILM 和 4K1000FPS。
基准测试脚本
我们提供了脚本来测试 Vimeo90K、UCF101、SNU-FILM 和 4K1000FPS 上的帧插值精度。运行这些脚本时,您应该配置基准数据集的路径。
python3 -m tools.benchmark_vimeo90k --data_root /path/to/vimeo_triplet/
python3 -m tools.benchmark_ucf101 --data_root /path/to/ucf101/
python3 -m tools.benchmark_snufilm --data_root /path/to/SNU-FILM/
python3 -m tools.benchmark_8x_4k1000fps --test_data_path /path/to/4k1000fps/test
默认情况下,我们测试 UPR-Net 的基础版本。要测试 large/LARGE 版本,请在基准测试脚本中更改相应的参数(model_size
和 model_file
)。
此外,运行以下命令可以测试我们的运行时间。
python -m tools.runtime
我们的基准测试结果
我们在 UCF101、Vimeo90K、SNU-FILM 上的基准测试结果如下表所示。您可以通过运行我们的基准测试脚本来验证我们的结果。运行时间是在单个 2080TI GPU 上测量的,用于插值两个 640x480 的帧。
我们在 4K1000FPS 上的基准测试结果如下表所示。
致谢
我们借鉴了 RIFE、softmax-splatting 和 EBME 的一些代码。我们感谢这些作者的出色工作。在使用我们的代码时,请同时注意 RIFE、softmax-splatting 和 EBME 的许可证。
引用
@inproceedings{jin2023unified,
title={A Unified Pyramid Recurrent Network for Video Frame Interpolation},
author={Jin, Xin and Wu, Longhai and Chen, Jie and Chen, Youxin and Koo,
Jayoon and Hahm, Cheul-hee},
booktitle={Proceedings of the IEEE conference on computer vision and pattern
recognition},
year={2023}
}