项目介绍:LibMTL
LibMTL 是一个开源的多任务学习(MTL)库,基于大家熟知的深度学习框架 PyTorch 构建。它的目标是为研究人员和开发人员提供一个统一、全面且可扩展的平台来实现和评估多任务学习算法。
项目特色
-
统一性:LibMTL 提供了一致的代码库和评估过程,涵盖从数据处理到度量指标和超参数调整的所有环节。这种统一性使得用户可以对不同的 MTL 算法进行量化、公平和一致的比较。
-
全面性:LibMTL 支持许多先进的 MTL 方法,包括8种架构和16种优化策略。同时,它还提供了对不同领域的多个基准数据集的公平比较。
-
可扩展性:LibMTL 遵循模块化设计原则,允许用户灵活添加自定义组件或进行个性化修改。用户可以快速开发新的优化策略和架构,或在新的应用场景中应用现有的 MTL 算法。
整体框架
LibMTL 的整体架构由多个模块组成,这些模块在项目的文档中有详细介绍。用户可以通过查阅文档,了解每个模块的具体功能和使用方法。
支持的算法
LibMTL 支持一系列的现代 MTL 算法,包括但不限于:
- 优化策略:如 Equal Weighting、GradNorm、Uncertainty Weights、MGDA、PCGrad 等。
- 架构:如 Hard Parameter Sharing、Cross-Stitch Networks、MTAN、Learning to Branch 等。
支持的基准数据集
LibMTL 涵盖多种应用领域的数据集,例如:
- NYUv2 和 Cityscapes:用于场景理解,任务包括语义分割、深度估计等。
- Office-31 和 Office-Home:图像识别任务。
- QM9 和 PAWS-X:分子属性预测和释义识别任务。
安装指南
用户只需执行以下步骤即可安装 LibMTL:
# 创建虚拟环境
conda create -n libmtl python=3.8
conda activate libmtl
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 -f https://download.pytorch.org/whl/torch_stable.html
# 克隆项目仓库
git clone https://github.com/median-research-group/LibMTL.git
# 安装 LibMTL
cd LibMTL
pip install -r requirements.txt
pip install -e .
快速开始
以下是使用 NYUv2 数据集进行模型训练的简单示例:
- 下载并预处理数据集。
- 使用命令行运行模型,指定各种参数:
python main.py --weighting EW --arch HPS --dataset_path /path/to/nyuv2 --gpu_id 0 --scheduler step --mode train --save_path PATH
引用与贡献者
如果在您的研究或开发中使用了 LibMTL,请按照以下格式进行引用:
@article{lin2023libmtl,
title={{LibMTL}: A {P}ython Library for Multi-Task Learning},
author={Baijiong Lin and Yu Zhang},
journal={Journal of Machine Learning Research},
volume={24},
number={209},
pages={1--7},
year={2023}
}
项目由 Baijiong Lin 开发并维护。用户可以通过GitHub问题或者邮箱与开发者联系,提供建议或反馈。
LibMTL 为开发者提供完善的功能和丰富的文档支持,既适用于学术研究,也可用于实际工程项目的多任务学习开发。