Project Icon

Trainer

基于PyTorch的通用模型训练框架

Trainer是一个基于PyTorch的开源模型训练框架,具有简洁的代码结构和灵活的优化控制。该框架支持自动优化、高级优化循环、批量大小查找、分布式训练和Accelerate集成。此外,Trainer提供回调功能、性能分析和多种实验日志记录选项,包括Tensorboard和ClearML等。这个框架适用于各类深度学习任务,能够简化训练流程并提升效率。

👟 Trainer

一个基于PyTorch的通用模型训练器,具有简洁的代码基础和独特的见解。

安装

从GitHub安装:

git clone https://github.com/coqui-ai/Trainer
cd Trainer
make install

从PyPI安装:

pip install trainer

建议从GitHub安装,因为它更稳定。

实现模型

继承并重写TrainerModel()中的函数。

使用自动优化训练模型

参见MNIST示例

使用高级优化训练模型

通过👟,您可以根据需要定义整个优化周期,就像下面的GAN示例一样。它为更高级的训练循环提供了更多的底层控制和灵活性。

您只需使用scaled_backward()函数来处理混合精度训练。

...

def optimize(self, batch, trainer):
    imgs, _ = batch

    # 采样噪声
    z = torch.randn(imgs.shape[0], 100)
    z = z.type_as(imgs)

    # 训练判别器
    imgs_gen = self.generator(z)
    logits = self.discriminator(imgs_gen.detach())
    fake = torch.zeros(imgs.size(0), 1)
    fake = fake.type_as(imgs)
    loss_fake = trainer.criterion(logits, fake)

    valid = torch.ones(imgs.size(0), 1)
    valid = valid.type_as(imgs)
    logits = self.discriminator(imgs)
    loss_real = trainer.criterion(logits, valid)
    loss_disc = (loss_real + loss_fake) / 2

    # 更新判别器
    _, _ = self.scaled_backward(loss_disc, None, trainer, trainer.optimizer[0])

    if trainer.total_steps_done % trainer.grad_accum_steps == 0:
        trainer.optimizer[0].step()
        trainer.optimizer[0].zero_grad()

    # 训练生成器
    imgs_gen = self.generator(z)

    valid = torch.ones(imgs.size(0), 1)
    valid = valid.type_as(imgs)

    logits = self.discriminator(imgs_gen)
    loss_gen = trainer.criterion(logits, valid)

    # 更新生成器
    _, _ = self.scaled_backward(loss_gen, None, trainer, trainer.optimizer[1])
    if trainer.total_steps_done % trainer.grad_accum_steps == 0:
        trainer.optimizer[1].step()
        trainer.optimizer[1].zero_grad()
    return {"model_outputs": logits}, {"loss_gen": loss_gen, "loss_disc": loss_disc}

...

参见GAN训练示例,其中使用了梯度累积。

使用批量大小查找器进行训练

有关使用批量大小查找器进行训练的测试脚本,请参见此处

批量大小查找器从默认的BS(默认为2048,也可以由用户定义)开始,搜索可以适应您的硬件的最大批量大小。您应该预期它会运行多次训练,直到找到合适的批量大小。要使用它,您需要调用trainer.fit_with_largest_batch_size(starting_batch_size=2048),而不是调用trainer.fit(),其中starting_batch_size是您想要开始搜索的批量大小。如果您想尽可能多地使用GPU内存,这非常有用。

使用DDP进行训练

$ python -m trainer.distribute --script path/to/your/train.py --gpus "0,1"

我们不使用.spawn()来启动多GPU训练,因为它会导致某些限制。

  • 所有内容都必须是可序列化的。
  • .spawn()在子进程中训练模型,主进程中的模型不会更新。
  • 当N很大时,具有N个进程的DataLoader会变得非常慢。

使用Accelerate进行训练

TrainingArgs中将use_accelerate设置为True将启用使用Accelerate进行训练。

您也可以将其用于多GPU或分布式训练。

CUDA_VISIBLE_DEVICES="0,1,2" accelerate launch --multi_gpu --num_processes 3 train_recipe_autoregressive_prompt.py

参见Accelerate文档

添加回调

👟支持回调以自定义您的运行。您可以在模型实现中设置回调,也可以显式地将它们提供给Trainer。

请查看trainer.utils.callbacks以查看可用的回调。

以下是如何为权重重新初始化向👟Trainer对象提供显式回调的示例。

def my_callback(trainer):
    print(" > 我的回调被调用了。")

trainer = Trainer(..., callbacks={"on_init_end": my_callback})
trainer.fit()

性能分析示例

  • 根据需要创建torch性能分析器,并将其传递给trainer。
    import torch
    profiler = torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA,
        ],
        schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
        on_trace_ready=torch.profiler.tensorboard_trace_handler("./profiler/"),
        record_shapes=True,
        profile_memory=True,
        with_stack=True,
    )
    prof = trainer.profile_fit(profiler, epochs=1, small_run=64)
    然后运行Tensorboard
    
  • 运行Tensorboard。
    tensorboard --logdir="./profiler/"
    

支持的实验记录器

要添加新的记录器,您必须继承BaseDashboardLogger并重写其函数。

匿名遥测

我们不断寻求改进🐸以服务社区。为了更好地了解社区的需求并相应地解决这些需求,当您运行trainer时,我们会收集精简的匿名使用统计数据。

当然,如果您不希望这样做,您可以通过设置环境变量TRAINER_TELEMETRY=0来选择退出。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号