Project Icon

torch-imle

将离散优化算法融入深度学习的创新方法

torch-imle是一个PyTorch库,通过I-MLE梯度估计器将离散优化算法融入深度学习。它使用创新的采样和分布方法,实现了离散优化问题在深度学习中的应用,如最短路径学习。该库采用Perturb-and-MAP方法和新颖的噪声扰动来近似采样复杂分布,并提供替代经验分布。torch-imle通过梯度下降学习最优路径权重,为深度学习中的离散优化问题提供强大的解决方案。

torch-imle

这是一个简洁且独立的PyTorch库,实现了我们在NeurIPS 2021论文《Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions》中提出的I-MLE梯度估计器。

该仓库包含一个库,用于将任何组合黑盒求解器转换为可微分层。NeurIPS论文中所有实验的复现代码都可在NEC欧洲实验室的官方仓库中找到。

概述

隐式最大似然估计(I-MLE)使得在标准深度学习架构中包含离散组合优化算法(如Dijkstra算法或整数线性规划求解器)成为可能。I-MLE的核心思想是定义一个隐式的最大似然目标,其梯度用于更新模型的上游参数。每个I-MLE实例需要两个要素:

  1. 一种从复杂且难以处理的分布中近似采样的方法,该分布由组合求解器在解空间上诱导,其中最优解具有最高的概率质量。为此,我们使用扰动映射(又称Gumbel-max技巧),并提出了一种针对特定问题的新型噪声扰动家族。

  2. 一种计算替代经验分布的方法:普通MLE减少当前分布与经验分布之间的KL散度。由于在我们的设置中无法获得经验分布,我们必须设计替代经验分布。这里我们提出了两种广泛适用且实践效果良好的替代分布家族。

示例

例如,让我们考虑一个简单游戏的地图,任务是找到从左上角到右下角的最短路径。较暗的区域成本较高,较亮的区域成本较低。 在中间,你可以看到当我们使用提出的伽玛噪声分布之和来采样路径时会发生什么。 在右侧,你可以看到每个格子的边际概率结果(每个格子成为采样路径一部分的概率)。

[图片1] [图片2] [图片3]

梯度和学习

假设最优最短路径是左侧的路径。 从随机权重开始,模型可以通过梯度下降学习产生将导致最优最短路径的权重,方法是最小化生成路径与黄金路径之间的汉明损失。 这里我们展示了训练过程中产生的路径(中间)和相应的地图权重(右侧)。

输入噪声温度设置为0.0,目标噪声温度设置为0.0

[图片4] [图片5] [图片6]

输入噪声温度设置为1.0,目标噪声温度设置为1.0

[图片7] [图片8] [图片9]

输入噪声温度设置为2.0,目标噪声温度设置为2.0

[图片10] [图片11] [图片12]

输入噪声温度设置为5.0,目标噪声温度设置为5.0

[图片13] [图片14] [图片15]

输入噪声温度设置为5.0,目标噪声温度设置为0.0

[图片16] [图片17] [图片18]

所有动画都由这个脚本生成。

代码

使用这个库非常简单 -- 请参考这个示例。假设我们有一个实现黑盒组合求解器(如Dijkstra算法)的方法:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

我们可以通过以下方式获得相应的分布和梯度:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

或者,使用简单的函数注解:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

使用I-MLE的论文

  • Patrick Betz, Mathias Niepert, Pasquale Minervini, 和 Heiner Stuckenschmidt:《Backpropagating through Markov Logic Networks》,NeSy'20/21 @ IJCLR:第15届神经符号学习与推理国际研讨会

参考文献

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}
项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号