简而言之,Flash 是您一直梦想但没有时间构建的生产级研究框架。
开始使用
从 PyPI 安装:
pip install lightning-flash
查看 我们的安装指南 以获取更多选项。
三步快速入门
第一步:加载您的数据
在 Flash 中,所有的数据加载都是通过 DataModule
的 from_*
类方法进行的。
决定使用哪个 DataModule
以及哪些 from_*
方法可用,取决于您要执行的任务。
例如,对于数据存储在文件夹中的图像分割,您可以使用 SemanticSegmentationData
类的 from_folders
方法:
from flash.image import SemanticSegmentationData
dm = SemanticSegmentationData.from_folders(
train_folder="data/CameraRGB",
train_target_folder="data/CameraSeg",
val_split=0.1,
image_size=(256, 256),
num_classes=21,
)
第二步:配置您的模型
我们的任务自带预训练的主干网络和(如果适用)头部网络。
您可以使用 available_backbones
查看可用于任务的主干网络。
一旦选择了一个,就创建模型:
from flash.image import SemanticSegmentation
print(SemanticSegmentation.available_heads())
# ['deeplabv3', 'deeplabv3plus', 'fpn', ..., 'unetplusplus']
print(SemanticSegmentation.available_backbones('fpn'))
# ['densenet121', ..., 'xception'] # + 113 种模型
print(SemanticSegmentation.available_pretrained_weights('efficientnet-b0'))
# ['imagenet', 'advprop']
model = SemanticSegmentation(
head="fpn", backbone='efficientnet-b0', pretrained="advprop", num_classes=dm.num_classes)
第三步:微调!
from flash import Trainer
trainer = Trainer(max_epochs=3)
trainer.finetune(model, datamodule=datamodule, strategy="freeze")
trainer.save_checkpoint("semantic_segmentation_model.pt")
PyTorch 配方
使用 Flash 进行预测!
只需两行代码即可部署:
from flash.image import SemanticSegmentation
model = SemanticSegmentation.load_from_checkpoint("semantic_segmentation_model.pt")
model.serve()
或直接从原始数据进行预测。
from flash import Trainer
trainer = Trainer(strategy='ddp', accelerator="gpu", gpus=2)
dm = SemanticSegmentationData.from_folders(predict_folder="data/CameraRGB")
predictions = trainer.predict(model, dm)
Flash 训练策略
训练策略是可以与特定任务配合使用的 PyTorch 最先进训练配方。
查看这个 示例,ImageClassifier
支持来自 Learn2Learn 的 4 种 元学习算法。
如果您在生产中使用此模型,并且希望确保模型能够在最少标注数据的情况下快速适应新环境,这将非常有用。
from flash.image import ImageClassifier
model = ImageClassifier(
backbone="resnet18",
optimizer=torch.optim.Adam,
optimizer_kwargs={"lr": 0.001},
training_strategy="prototypicalnetworks",
training_strategy_kwargs={
"epoch_length": 10 * 16,
"meta_batch_size": 4,
"num_tasks": 200,
"test_num_tasks": 2000,
"ways": datamodule.num_classes,
"shots": 1,
"test_ways": 5,
"test_shots": 1,
"test_queries": 15,
},
)
具体来说,目前实现了以下方法:
- prototypicalnetworks : 来自 Snell et al. 2017, Prototypical Networks for Few-shot Learning
- maml : 来自 Finn et al. 2017, Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks
- metaoptnet : 来自 Lee et al. 2019, Meta-Learning with Differentiable Convex Optimization
- anil : 来自 Raghu et al. 2020, Rapid Learning or Feature Reuse? Towards Understanding the Effectiveness of MAML
Flash 优化器 / 调度器
使用 Flash,切换 40 多种优化器和 15 多种调度器配方非常简单。可以找到可用优化器 一旦你做出了选择,就创建模型:
#### 可以选择如下优化器
from flash.image import ImageClassifier
# - 字符串值
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=None)
# - 可调用对象
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=functools.partial(torch.optim.Adadelta, eps=0.5), lr_scheduler=None)
# - 元组[字符串, 字典]: (字典中包含优化器的参数)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer=("Adadelta", {"eps": 0.5}), lr_scheduler=None)
#### 可以选择如下学习率调度器
# - 字符串值
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="constant_schedule")
# - 可调用对象
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=functools.partial(CyclicLR, step_size_up=1500, mode='exp_range', gamma=0.5))
# - 元组[字符串, 字典]: (字典中包含学习率调度器的参数)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler=("StepLR", {"step_size": 10}))
你也可以提前注册自定义的学习率调度器方案,并像上面那样使用它们:
from flash.image import ImageClassifier
@ImageClassifier.lr_schedulers_registry
def my_steplr_recipe(optimizer):
return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10)
model = ImageClassifier(backbone="resnet18", num_classes=2, optimizer="Adam", lr_scheduler="my_steplr_recipe")
Flash 变换
Flash 默认包含每个任务的一些简单增强处理,但你通常想要覆盖这些处理并控制自己的增强方案。为此,Flash 支持使用 InputTransform
自定义变换。
InputTransform
有点类似于变换的回调函数,带有一些钩子用于对样本或批次进行变换,既可以在设备/加速器上进行,也可以在设备/加速器外进行。此外,钩子可以特别用来只对输入或目标进行变换。有了这些钩子,复杂的变换(如 MixUp)也可以轻松实现。下面是一个示例(还包含了一个 albumentations 变换!):
import torch
import numpy as np
import albumentations
from flash import InputTransform
from flash.image import ImageClassificationData
from flash.image.classification.input_transform import AlbumentationsAdapter
def mixup(batch, alpha=1.0):
images = batch["input"]
targets = batch["target"].float().unsqueeze(1)
lam = np.random.beta(alpha, alpha)
perm = torch.randperm(images.size(0))
batch["input"] = images * lam + images[perm] * (1 - lam)
batch["target"] = targets * lam + targets[perm] * (1 - lam)
return batch
class MixUpInputTransform(InputTransform):
def train_input_per_sample_transform(self):
return AlbumentationsAdapter(albumentations.HorizontalFlip(p=0.5))
# 这将在批次传输到设备后应用!
def train_per_batch_transform_on_device(self):
return mixup
datamodule = ImageClassificationData.from_folders(
train_folder="data/train",
transform=MixUpInputTransform,
batch_size=2,
)
Flash Zero - 命令行中的 PyTorch 方案!
Flash Zero 是一个零代码的机器学习平台,直接内置在 lightning-flash 中,使用 Lightning CLI
。
要开始并查看可用任务,运行:
flash --help
例如,要用 resnet50
骨架和 2 个 GPU 训练一个图像分类器 10 个 epoch,使用你自己的数据,可以这样做:
flash image_classification --trainer.max_epochs 10 --trainer.gpus 2 --model.backbone resnet50 from_folders --train_folder {PATH_TO_DATA}
Kaggle Notebook 示例
- 🚢 使用 Lightning⚡️Flash 处理泰坦尼克号事故
- 🏠 使用 Lightning⚡Flash 进行房价预测
- 使用 Lightning⚡Flash 处理📋 表格数据
- 🙊 使用 Lightning⚡Flash 处理有毒评论
- 🫁 使用 Lightning⚡️Flash 进行 COVID 检测
贡献!
lightning + Flash 团队正努力为常见的深度学习用例构建更多的任务。但我们正在寻找像你这样出色的贡献者来提交新任务!
加入我们的 Slack 或阅读我们的 CONTRIBUTING 指南,以获得帮助,成为贡献者!
注意: Flash 目前正在对真实案例进行测试,并且正在积极开发中。如果你发现任何不符合预期功能的地方,请 提交问题。
社区
Flash 由我们的 核心贡献者 维护。
如需帮助或有问题,加入我们庞大的 Slack 社区!
引用
我们很高兴继续开源软件的强大传承,多年来我们受到了 Caffe、Theano、Keras、PyTorch、torchbearer 和 fast.ai 的启发。如果有其它关于此内容的论文要写,我们将乐意引用这些框架和相应的作者。
Flash 借鉴了许多不同框架中的模型,以涵盖如此广泛的领域和任务。完整的供应商列表可在 我们的文档 中找到。
许可证
请遵守此仓库中列出的 Apache 2.0 许可证。