Project Icon

tensordict

PyTorch张量集合操作的高效字典类工具

TensorDict是一个继承张量属性的字典类,为PyTorch提供简洁的张量集合操作方法。它支持异步设备传输、节点间快速通信、张量形状操作和分布式计算,提高了代码的可读性、紧凑性和模块化。这个工具适用于模型训练、优化器实现等机器学习任务,能有效简化代码结构,提升开发效率。

文档 - GitHub.io Discord 频道 基准测试 Python 版本 GitHub 许可证 pypi 版本 pypi 每日构建版本 下载量 下载量 代码覆盖率 CircleCI Conda - 平台 Conda(仅限频道)

📖 TensorDict

TensorDict 是一个类似字典的类,它继承了张量的属性,使在 PyTorch 中处理张量集合变得简单。它提供了一种简单直观的方式来操作和处理张量,让您可以专注于构建和训练模型。

主要特性 | 示例 | 安装 | 引用 | 许可证

主要特性

TensorDict 使您的代码更加易读、简洁、模块化和高效。 它抽象了定制操作,使您的代码更不容易出错,因为它会为您处理对叶节点的操作分发。

主要特性包括:

  • 🧮 可组合性TensorDicttorch.Tensor 的操作推广到张量集合。
  • ⚡️ 速度:异步传输到设备,通过 consolidate 实现快速节点间通信,兼容 torch.compile
  • ✂️ 形状操作:对 TensorDict 实例执行类似张量的操作,如索引、切片或连接。
  • 🌐 分布式/多进程能力:轻松将 TensorDict 实例分布在多个工作进程、设备和机器上。
  • 💾 序列化和内存映射
  • λ 函数式编程及与 torch.vmap 的兼容性
  • 📦 嵌套:嵌套 TensorDict 实例以创建层次结构。
  • 延迟预分配:为 TensorDict 实例预分配内存,无需初始化张量。
  • 📝 专用数据类用于 torch.Tensor(@tensorclass

tensordict.png

示例

本节展示了该库的几个突出应用。 查看我们的入门指南,了解 TensorDict 的功能概览!

快速设备间复制

TensorDict 优化了设备间的数据传输,使其安全且快速。 默认情况下,数据传输将异步进行,并在需要时调用同步。

# 快速安全的异步复制到 'cuda'
td_cuda = TensorDict(**dict_of_tensor, device="cuda")
# 快速安全的异步复制到 'cpu'
td_cpu = td_cuda.to("cpu")
# 强制同步复制
td_cpu = td_cuda.to("cpu", non_blocking=False)

编写优化器

例如,使用 TensorDict 你可以像为单个 torch.Tensor 编写 Adam 优化器一样,并将其应用于 TensorDict 输入。在 cuda 上,这些操作将依赖于融合内核,使其执行速度非常快:

class Adam:
    def __init__(self, weights: TensorDict, alpha: float=1e-3,
                 beta1: float=0.9, beta2: float=0.999,
                 eps: float = 1e-6):
        # 锁定以提高效率
        weights = weights.lock_()
        self.weights = weights
        self.t = 0

        self._mu = weights.data.clone()
        self._sigma = weights.data.mul(0.0)
        self.beta1 = beta1
        self.beta2 = beta2
        self.alpha = alpha
        self.eps = eps

    def step(self):
        self._mu.mul_(self.beta1).add_(self.weights.grad, 1 - self.beta1)
        self._sigma.mul_(self.beta2).add_(self.weights.grad.pow(2), 1 - self.beta2)
        self.t += 1
        mu = self._mu.div_(1-self.beta1**self.t)
        sigma = self._sigma.div_(1 - self.beta2 ** self.t)
        self.weights.data.add_(mu.div_(sigma.sqrt_().add_(self.eps)).mul_(-self.alpha))

训练模型

使用 tensordict 原语,大多数监督训练循环可以以通用方式重写:

for i, data in enumerate(dataset):
    # 模型读取和写入 tensordicts
    data = model(data)
    loss = loss_module(data)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

通过这种抽象级别,可以为高度异构的任务重复使用训练循环。 训练循环的每个单独步骤(数据收集和转换、模型预测、损失计算等)都可以针对特定用例进行定制,而不会影响其他步骤。 例如,上述示例可以轻松用于分类和分割任务等多种任务。

安装

使用 Pip

要安装最新稳定版本的 tensordict,只需运行

pip install tensordict

这适用于 Python 3.7 及以上版本以及 PyTorch 1.12 及以上版本。

要享受最新功能,可以使用

pip install tensordict-nightly

使用 Conda

conda-forge 频道安装 tensordict

conda install -c conda-forge tensordict

引用

如果您正在使用TensorDict,请使用以下BibTeX条目来引用这项工作:

@misc{bou2023torchrl,
      title={TorchRL: A data-driven decision-making library for PyTorch},
      author={Albert Bou and Matteo Bettini and Sebastian Dittert and Vikash Kumar and Shagun Sodhani and Xiaomeng Yang and Gianni De Fabritiis and Vincent Moens},
      year={2023},
      eprint={2306.00577},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}

免责声明

TensorDict目前处于测试阶段,这意味着可能会引入破坏兼容性的更改,但会有相应的保证。 希望这些更改不会太频繁,因为当前的路线图主要涉及添加新功能和构建与更广泛的PyTorch生态系统的兼容性。

许可证

TensorDict采用MIT许可证。有关详细信息,请参阅LICENSE文件。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号