Make-A-Video - Pytorch (wip)
实现 Make-A-Video,Meta AI 最新的 SOTA 文本生成视频器,用 Pytorch 实现。他们结合了伪 3D 卷积(轴向卷积)和时间注意力,并展现了更好的时间融合效果。
伪 3D 卷积并不是一个新概念。在其他背景下,例如蛋白质接触预测中也曾被探索过,称为 “维度混合残差网络”。
论文的要点是,采用一个 SOTA 的文本到图像模型(这里使用的是 DALL-E2,但相同的学习点也能适用于 Imagen),在 时间注意和其他减少计算成本的方法上做一些小修改,正确进行帧插值,得到一个优秀的视频模型。
致谢
-
Stability.ai 慷慨赞助进行前沿人工智能研究
-
Jonathan Ho 通过 他的开创性论文,为生成型人工智能带来了革命
安装
$ pip install make-a-video-pytorch
使用
传递视频特征
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
conv_out = conv(video) # (1, 256, 8, 16, 16)
attn_out = attn(video) # (1, 256, 8, 16, 16)
传递图像(如果先在图像上进行预训练),时间卷积和注意力将自动跳过。换句话说,你可以直接在你的 2D Unet 中使用它,然后在训练的那个阶段结束后将其移植到 3D Unet。时间模块初始化为输出恒等操作,正如论文中所做的那样。
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
images = torch.randn(1, 256, 16, 16) # (batch, features, height, width)
conv_out = conv(images) # (1, 256, 16, 16)
attn_out = attn(images) # (1, 256, 16, 16)
你还可以控制这两个模块,使其在接受三维特征时,仅进行空间训练
import torch
from make_a_video_pytorch import PseudoConv3d, SpatioTemporalAttention
conv = PseudoConv3d(
dim = 256,
kernel_size = 3
)
attn = SpatioTemporalAttention(
dim = 256,
dim_head = 64,
heads = 8
)
video = torch.randn(1, 256, 8, 16, 16) # (batch, features, frames, height, width)
# 以下设置将不会在时间维度上进行训练
conv_out = conv(video, enable_time = False) # (1, 256, 8, 16, 16)
attn_out = attn(video, enable_time = False) # (1, 256, 8, 16, 16)
完整的 SpaceTimeUnet
,对图像或视频训练具有通用性,即便传入视频也可以忽略时间维度
import torch
from make_a_video_pytorch import SpaceTimeUnet
unet = SpaceTimeUnet(
dim = 64,
channels = 3,
dim_mult = (1, 2, 4, 8),
resnet_block_depths = (1, 1, 1, 2),
temporal_compression = (False, False, False, True),
self_attns = (False, False, False, True),
condition_on_timestep = False,
attn_pos_bias = False,
flash_attn = True
).cuda()
# 训练图像
images = torch.randn(1, 3, 128, 128).cuda()
images_out = unet(images)
assert images.shape == images_out.shape
# 然后训练视频
video = torch.randn(1, 3, 16, 128, 128).cuda()
video_out = unet(video)
assert video_out.shape == video.shape
# 或者甚至将视频按图像处理
video_as_images_out = unet(video, enable_time = False)
待办任务
-
赋予注意力最佳的定位嵌入研究
-
强化注意力
-
增加闪电注意力
-
确保 dalle2-pytorch 能接受
SpaceTimeUnet
进行训练
引用
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
@inproceedings{rogozhnikov2022einops,
title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
author = {Alex Rogozhnikov},
booktitle = {International Conference on Learning Representations},
year = {2022},
url = {https://openreview.net/forum?id=oapKSVM2bcj}
}
@article{Dong2021AttentionIN,
title = {Attention is Not All You Need: Pure Attention Loses Rank Doubly Exponentially with Depth},
author = {Yihe Dong and Jean-Baptiste Cordonnier and Andreas Loukas},
journal = {ArXiv},
year = {2021},
volume = {abs/2103.03404}
}
@article{Zhang2021TokenST,
title = {Token Shift Transformer for Video Classification},
author = {Hao Zhang and Y. Hao and Chong-Wah Ngo},
journal = {Proceedings of the 29th ACM International Conference on Multimedia},
year = {2021}
}
@inproceedings{shleifer2022normformer,
title = {NormFormer: Improved Transformer Pretraining with Extra Normalization},
author = {Sam Shleifer and Myle Ott},
booktitle = {Submitted to The Tenth International Conference on Learning Representations },
year = {2022},
url = {https://openreview.net/forum?id=GMYWzWztDx5},
}
@inproceedings{dao2022flashattention,
title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
booktitle = {Advances in Neural Information Processing Systems},
year = {2022}
}