项目介绍:Lightning Flash
Lightning Flash是一个基于PyTorch的高效人工智能框架,它的目标是让各种复杂的AI任务变得简单易用。它提供了一套完整的工具,可以处理超过15种任务,涵盖7个不同的数据领域。这项工具是您梦寐以求的生产级研究框架,适合那些希望快速申请AI技术但没有时间从头构建的开发者。
快速入门
使用PyPI,安装Lightning Flash只需简单的一行命令:
pip install lightning-flash
三步快速使用Flash
-
加载数据:
数据加载通过DataModule
的from_*
类方法完成。以图像分割任务为例,当数据存储在文件夹中时,可以使用SemanticSegmentationData
类的from_folders
方法:from flash.image import SemanticSegmentationData dm = SemanticSegmentationData.from_folders( train_folder="data/CameraRGB", train_target_folder="data/CameraSeg", val_split=0.1, image_size=(256, 256), num_classes=21, )
-
配置模型:
Flash的任务模块中预装了预训练的骨干网络和头部结构。选择一个合适的模型并创建:from flash.image import SemanticSegmentation model = SemanticSegmentation( head="fpn", backbone='efficientnet-b0', pretrained="advprop", num_classes=dm.num_classes)
-
微调模型:
使用Trainer
对模型进行微调并保存模型:from flash import Trainer trainer = Trainer(max_epochs=3) trainer.finetune(model, datamodule=datamodule, strategy="freeze") trainer.save_checkpoint("semantic_segmentation_model.pt")
PyTorch配方
使用Flash进行预测
使用Flash可以在两行代码内实现服务部署:
from flash.image import SemanticSegmentation
model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt")
model.serve()
或从原始数据直接进行预测:
from flash import Trainer
trainer = Trainer(strategy='ddp', accelerator="gpu", gpus=2)
dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB")
predictions = trainer.predict(model, dm)
Flash训练策略
Flash支持多种前沿训练策略,比如原型网络(Prototypical Networks)、模型无关元学习(MAML)等。这些策略特别有助于在生产环境中快速适应新的环境和标签数据。
Flash优化器和调度器
用户可以简单切换超过40种优化器和15种调度器。例如:
from flash.image import ImageClassifier
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)
Flash Zero: 从命令行实现PyTorch配方
Flash Zero是一个无代码机器学习平台,通过命令行可即时实现训练、测试等功能,极大降低了使用门槛。
开源社区
Lightning Flash由一个核心团队维护,并获得广泛的开源社区支持。如果您对加入贡献者行列感兴趣,可以阅读我们的贡献指南。
项目采用Apache 2.0许可证,并受到Caffe、Theano、Keras等项目的启发,继续将开源软件的优秀传统发扬光大。