Project Icon

GradCache

突破GPU/TPU内存限制,实现对比学习无限扩展

Gradient Cache技术突破了GPU/TPU内存限制,可以无限扩展对比学习的批处理大小。仅需一个GPU即可完成原本需要8个V100 GPU的训练,并能够用更具成本效益的高FLOP低内存系统替换大内存GPU/TPU。该项目支持Pytorch和JAX框架,并已整合至密集段落检索工具DPR。

GradCache 项目介绍

GradCache 简介

GradCache 是一种简单的方法,可以将对比学习批次的扩展提升到超过 GPU/TPU 内存限制。这意味着过去需要重型硬件(例如 8 个 V100 GPU)进行的训练,现在可以在单个 GPU 上完成。此外,GradCache 还允许用户用更加经济高效的高 FLOP 低 RAM 系统替代高 RAM 的 GPU/TPU。

该项目提供了一个通用的 GradCache 实现,支持 Pytorch 和 JAX 框架。这一技术已经在论文《在内存限制条件下扩展深度对比学习的批次大小》中进行过描述,并被集成到 Dense Passage Retrieval (DPR) 系统中。

安装指南

要安装 GradCache,首先您需要安装所需的深度学习后端(Pytorch 或 JAX)。然后克隆项目库并运行 pip 进行安装:

git clone https://github.com/luyug/GradCache
cd GradCache
pip install .

对于开发用途,可使用以下安装方式:

pip install --editable .

如何使用

GradCache 的功能通过 GradCache 类实现。如果您正在开发新项目,而不是修补旧项目,还可以查看我们提供的具备简化工作量的功能方法

初始化

要使用 GradCache 首先需要初始化 GradCache 类,其 __init__ 方法需要对缓存进行定义,并且包含多个功能参数以方便调整模型行为。您也可以通过继承方式来使用。

grad_cache.GradCache(  
  models: List[nn.Module],  
  chunk_sizes: Union[int, List[int]],  
  loss_fn: Callable[..., Tensor],  
  split_input_fn: Callable[[Any, int], Any] = None,  
  get_rep_fn: Callable[..., Tensor] = None,  
  fp16: bool = False,  
  scaler: GradScaler = None,  
)
  • models:要通过 GradCache 更新的编码器模型列表。
  • chunk_sizes:块大小的整数或每个模型的块大小整数列表。该值基于可用的 GPU 内存设定,不宜过小以避免 GPU 未充分利用。
  • loss_fn:计算模型表示的损失函数。
  • split_input_fn:可选函数,用于根据定义的 chunk_sizes 将输入切分为小块。
  • get_rep_fn:可选函数,用于模型输出获取表示张量。
  • fp16scaler:是否使用混合精度训练。

缓存梯度步骤

可调用 cache_step 来运行缓存梯度计算步骤:

cache_step(  
  *model_inputs,  
  no_sync_except_last: bool = False,  
  **loss_kwargs  
)

在大多数情况下,通过该方法可以让模型在一个虚拟较大批次下运行,就如同在足够大的硬件上运行。执行此函数之后,模型的参数将更新。

使用 Huggingface Transformers 的示例

以下是一个简单示例:使用 BERT 模型创建双编码器来学习标签和文本的嵌入空间。

from transformers import AutoTokenizer, AutoModel
from grad_cache import GradCache
from grad_cache.loss import SimpleContrastiveLoss

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
encoder1 = AutoModel.from_pretrained("bert-base-uncased").cuda()
encoder2 = AutoModel.from_pretrained("bert-base-uncased").cuda()

loss_fn = SimpleContrastiveLoss()
gc = GradCache(
  models=[encoder1, encoder2],
  chunk_sizes=2,
  loss_fn=loss_fn,
  get_rep_fn=lambda v: v.pooler_output
)

创建模型输入并运行缓存步骤:

xx = tokenizer(["this is an apple"], return_tensors='pt', padding=True)
yy = tokenizer(["apple sells laptop"], return_tensors='pt', padding=True)

gc(xx, yy, reduction='mean')

分布式训练和多 GPU 支持

GradCache 可以与分布式数据并行 (Distributed Data Parallel) 模型结合使用,实现跨设备的梯度计算与通信。

from torch.nn.parallel import DistributedDataParallel

encoder1_ddp = DistributedDataParallel(encoder1, device_ids=[local_rank], output_device=local_rank)
encoder2_ddp = DistributedDataParallel(encoder2, device_ids=[local_rank], output_device=local_rank)

loss_fn_dist = DistributedContrastiveLoss()
gc = GradCache(
  models=[encoder1_ddp, encoder2_ddp],
  chunk_sizes=2,
  loss_fn=loss_fn_dist,
  get_rep_fn=lambda v: v.pooler_output
)

同样,您可以运行缓存步骤:

gc(xx, yy, no_sync_except_last=True, reduction='mean')

功能方法

项目还提供了便捷的功能装饰器,例如 cachedcat_input_tensor,简化缓存的模型和损失函数调用过程。这些装饰器特别适用于处理小批量数据构建大批量进行训练。


总之,GradCache 为那些内存受限的深度对比学习任务提供了一种高效且易于实施的方法,极大地降低了硬件门槛,具有广泛的应用前景。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号