📖 TensorDict
TensorDict 是一个类似字典的类,它继承了张量的属性,使在 PyTorch 中处理张量集合变得简单。它提供了一种简单直观的方式来操作和处理张量,让您可以专注于构建和训练模型。
主要特性
TensorDict 使您的代码更加易读、简洁、模块化和高效。 它抽象了定制操作,使您的代码更不容易出错,因为它会为您处理对叶节点的操作分发。
主要特性包括:
- 🧮 可组合性:
TensorDict
将torch.Tensor
的操作推广到张量集合。 - ⚡️ 速度:异步传输到设备,通过
consolidate
实现快速节点间通信,兼容torch.compile
。 - ✂️ 形状操作:对 TensorDict 实例执行类似张量的操作,如索引、切片或连接。
- 🌐 分布式/多进程能力:轻松将 TensorDict 实例分布在多个工作进程、设备和机器上。
- 💾 序列化和内存映射
- λ 函数式编程及与
torch.vmap
的兼容性 - 📦 嵌套:嵌套 TensorDict 实例以创建层次结构。
- ⏰ 延迟预分配:为 TensorDict 实例预分配内存,无需初始化张量。
- 📝 专用数据类用于 torch.Tensor(
@tensorclass
)
示例
本节展示了该库的几个突出应用。 查看我们的入门指南,了解 TensorDict 的功能概览!
快速设备间复制
TensorDict
优化了设备间的数据传输,使其安全且快速。
默认情况下,数据传输将异步进行,并在需要时调用同步。
# 快速安全的异步复制到 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# 快速安全的异步复制到 'cpu'
td_cpu = td_cuda.to("cpu")
# 强制同步复制
td_cpu = td_cuda.to("cpu", non_blocking=False)
编写优化器
例如,使用 TensorDict
你可以像为单个 torch.Tensor
编写 Adam 优化器一样,并将其应用于 TensorDict
输入。在 cuda
上,这些操作将依赖于融合内核,使其执行速度非常快:
class Adam:
def __init__(self, weights: TensorDict, alpha: float=1e-3,
beta1: float=0.9, beta2: float=0.999,
eps: float = 1e-6):
# 锁定以提高效率
weights = weights.lock_()
self.weights = weights
self.t = 0
self._mu = weights.data.clone()
self._sigma = weights.data.mul(0.0)
self.beta1 = beta1
self.beta2 = beta2
self.alpha = alpha
self.eps = eps
def step(self):
self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
self.t += 1
mu = self._mu.div_(1-self.beta1**self.t)
sigma = self._sigma.div_(1 - self.beta2 ** self.t)
self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))
训练模型
使用 tensordict 原语,大多数监督训练循环可以以通用方式重写:
for i, data in enumerate(dataset):
# 模型读取和写入 tensordicts
data = model(data)
loss = loss_module(data)
loss.backward()
optimizer.step()
optimizer.zero_grad()
通过这种抽象级别,可以为高度异构的任务重复使用训练循环。 训练循环的每个单独步骤(数据收集和转换、模型预测、损失计算等)都可以针对特定用例进行定制,而不会影响其他步骤。 例如,上述示例可以轻松用于分类和分割任务等多种任务。
安装
使用 Pip:
要安装最新稳定版本的 tensordict,只需运行
pip install tensordict
这适用于 Python 3.7 及以上版本以及 PyTorch 1.12 及以上版本。
要享受最新功能,可以使用
pip install tensordict-nightly
使用 Conda:
从 conda-forge
频道安装 tensordict
。
conda install -c conda-forge tensordict
引用
如果您正在使用TensorDict,请使用以下BibTeX条目来引用这项工作:
@misc{bou2023torchrl,
title={TorchRL: A data-driven decision-making library for PyTorch},
author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
year={2023},
eprint={2306.00577},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
免责声明
TensorDict目前处于测试阶段,这意味着可能会引入破坏兼容性的更改,但会有相应的保证。 希望这些更改不会太频繁,因为当前的路线图主要涉及添加新功能和构建与更广泛的PyTorch生态系统的兼容性。
许可证
TensorDict采用MIT许可证。有关详细信息,请参阅LICENSE文件。