Project Icon

Ensemble-Pytorch

PyTorch集成学习框架助力模型优化

Ensemble-Pytorch是一个为PyTorch设计的集成学习框架,旨在提高深度学习模型的性能和鲁棒性。该框架支持多种集成策略,如Fusion、Voting、Bagging和Gradient Boosting,适用于分类和回归任务。作为PyTorch生态系统的一部分,Ensemble-Pytorch提供简洁的API和详细文档,便于研究人员和开发者实现和优化集成模型。

.. image:: ./docs/_images/badge_small.png

|github|_ |readthedocs|_ |codecov|_ |license|_

.. |github| image:: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/workflows/torchensemble-CI/badge.svg .. _github: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/actions

.. |readthedocs| image:: https://readthedocs.org/projects/ensemble-pytorch/badge/?version=latest .. _readthedocs: https://ensemble-pytorch.readthedocs.io/en/latest/index.html

.. |codecov| image:: https://codecov.io/gh/TorchEnsemble-Community/Ensemble-Pytorch/branch/master/graph/badge.svg?token=2FXCFRIDTV .. _codecov: https://codecov.io/gh/TorchEnsemble-Community/Ensemble-Pytorch

.. |license| image:: https://img.shields.io/github/license/TorchEnsemble-Community/Ensemble-Pytorch .. _license: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/blob/master/LICENSE

Ensemble PyTorch

Ensemble PyTorch是一个统一的集成框架,用于pytorch_,可以轻松提高深度学习模型的性能和鲁棒性。Ensemble-PyTorch是pytorch生态系统<https://pytorch.org/ecosystem/>__的一部分,这要求项目得到良好的维护。

  • 文档 <https://ensemble-pytorch.readthedocs.io/>__
  • 实验 <https://ensemble-pytorch.readthedocs.io/en/stable/experiment.html>__

安装

.. code:: bash

pip install torchensemble

示例

.. code:: python

from torchensemble import VotingClassifier  # 投票是一种经典的集成策略

# 加载数据
train_loader = DataLoader(...)
test_loader = DataLoader(...)

# 定义集成
ensemble = VotingClassifier(
    estimator=base_estimator,               # estimator是你的pytorch模型
    n_estimators=10,                        # 基础估计器的数量
)

# 设置优化器
ensemble.set_optimizer(
    "Adam",                                 # 参数优化器的类型
    lr=learning_rate,                       # 参数优化器的学习率
    weight_decay=weight_decay,              # 参数优化器的权重衰减
)

# 设置学习率调度器
ensemble.set_scheduler(
    "CosineAnnealingLR",                    # 学习率调度器的类型
    T_max=epochs,                           # 调度器的其他参数
)

# 训练集成
ensemble.fit(
    train_loader,
    epochs=epochs,                          # 训练轮数
)

# 评估集成
acc = ensemble.evaluate(test_loader)         # 测试准确率

支持的集成方法

+------------------------------+------------+---------------------------+-----------------------------+ | 集成名称 | 类型 | 源代码 | 问题 | +==============================+============+===========================+=============================+ | 融合 | 混合 | fusion.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 投票 [1]_ | 并行 | voting.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 神经森林 | 并行 | voting.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 装袋 [2]_ | 并行 | bagging.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 梯度提升 [3]_ | 顺序 | gradient_boosting.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 快照集成 [4]_ | 顺序 | snapshot_ensemble.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 对抗训练 [5]_ | 并行 | adversarial_training.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 快速几何集成 [6]_ | 顺序 | fast_geometric.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+ | 软梯度提升 [7]_ | 并行 | soft_gradient_boosting.py | 分类 / 回归 | +------------------------------+------------+---------------------------+-----------------------------+

依赖

  • scikit-learn>=0.23.0
  • torch>=1.4.0
  • torchvision>=0.2.2

参考文献

.. [1] Zhou, Zhi-Hua. Ensemble Methods: Foundations and Algorithms. CRC press, 2012.

.. [2] Breiman, Leo. Bagging Predictors. Machine Learning (1996): 123-140.

.. [3] Friedman, Jerome H. Greedy Function Approximation: A Gradient Boosting Machine. Annals of Statistics (2001): 1189-1232.

.. [4] Huang, Gao, et al. Snapshot Ensembles: Train 1, Get M For Free. ICLR, 2017.

.. [5] Lakshminarayanan, Balaji, et al. Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles. NIPS, 2017.

.. [6] Garipov, Timur, et al. Loss Surfaces, Mode Connectivity, and Fast Ensembling of DNNs. NeurIPS, 2018.

.. [7] Feng, Ji, et al. Soft Gradient Boosting Machine. ArXiv, 2020.

.. _pytorch: https://pytorch.org/

.. _pypi: https://pypi.org/project/torchensemble/

感谢所有贡献者

|contributors|

.. |contributors| image:: https://contributors-img.web.app/image?repo=TorchEnsemble-Community/Ensemble-Pytorch .. _contributors: https://github.com/TorchEnsemble-Community/Ensemble-Pytorch/graphs/contributors

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号