Project Icon

morph-net

在训练过程中优化深度网络结构的方法

MorphNet是一种在训练过程中优化深度网络结构的方法。通过使用正则化器优化FLOPs或模型大小等资源的消耗,MorphNet实现了受约束的网络结构优化。新方法FiGS采用概率性通道正则化,适用于剪枝和可微架构搜索。MorphNet可以在不改变网络拓扑的情况下调整卷积层的输出通道数,以简化模型并满足内存和延迟需求。项目由Elad Eban和Andrew Poon等人维护。

MorphNet: 快速简便的资源约束深度网络结构学习

[目录]

新产品:FiGS:细粒度随机架构搜索

FiGS 是我们在细粒度随机架构搜索中介绍的一种概率通道正则化方法。它的性能超越了我们之前的正则化器,可以用作剪枝算法或全功能的可微分架构搜索方法。这是推荐使用 MorphNet 的方式。在下面的文档中,它被称为 LogisticSigmoid 正则化器。

什么是 MorphNet?

MorphNet 是一种在训练过程中学习深度网络结构的方法。其关键原理是对网络结构学习问题进行连续松弛。简而言之,MorphNet 正则化器推动滤波器的影响降低,当它们足够小时,相应的输出通道将被标记为从网络中移除。

具体而言,通过添加针对特定资源(如 FLOPs 或模型大小)消耗的正则化器来引发激活稀疏性。当正则化损失添加到训练损失,并通过随机梯度下降或类似优化器最小化它们的和时,学习问题即成为网络结构的约束优化问题,约束由正则化器表示。我们在 CVPR 2018 会议上首次介绍了这一方法,论文为"MorphNet: Fast & Simple Resource-Constrained Learning of Deep Network Structure"。该方法的概述以及设备特定的延迟正则化器在 GTC 2019 上进行了展示。[幻灯片, 录音: YouTube, GTC on-demand]。我们新的概率剪枝方法称为 FiGS,详见细粒度随机架构搜索

使用方法

假设你有一个用于图像分类的工作中的卷积神经网络,但想缩小模型以满足某些约束(如内存、延迟)。给定现有模型(“种子网络”)和目标标准,MorphNet 将通过调整每个卷积层中的输出通道数来提出新模型。

请注意,MorphNet 不会更改网络的拓扑结构——建议的模型将拥有与种子网络相同的层数和连接模式。

使用 MorphNet 你必须:

  1. morphnet.network_regularizers 中选择一个正则化器。选择基于

    • 你的目标成本(如 FLOPs、延迟)
    • 添加新层到模型的能力:
      • 在任何希望剪枝的层之后添加我们的概率门控操作,并使用 LogisticSigmoid 正则化器。[推荐]
      • 如果无法添加新层,根据网络架构选择正则化器类型:种子网络有 BatchNorm 时使用 Gamma 正则化器;否则使用 GroupLasso [已弃用]

    注意:如果你使用 BatchNorm,必须启用比例参数(“伽马变量”),即如果使用 tf.keras.layers.BatchNormalization,请设置 scale=True

    注意:如果使用 LogisticSigmoid,别忘了添加概率门控操作!参见下文示例。

  2. 用阈值和模型的输出边界操作(和可选的输入边界操作)初始化正则化器。

    MorphNet 正则化器从输出边界开始爬取图,并对它遇到的一些操作施加正则化。当遇到任何输入边界操作时,它不会爬取到它们之后(输入边界中的操作不被正则化)。阈值决定哪些输出通道可以被消除。

  3. 将正则化项添加到你的损失中。

    正则化损失必须缩放。我们建议沿对数刻度搜索缩放超参数(正则化强度),跨越 1/(初始成本) 附近几个数量级。例如,如果种子网络始于 1e9 FLOPs,探索 1e-9 附近的正则化强度。

    注意:MorphNet 当前不会将正则化损失添加到 tf.GraphKeys.REGULARIZATION_LOSSES 集合;此选择可能会调整。

    注意:不要混淆 get_regularization_term()(应添加到训练中的损失)与 get_cost()(如果应用建议结构的网络估计成本)。

  4. 训练模型。

    注意:我们建议在此步骤中使用固定学习率(无衰减),但这不是绝对必要的。

  5. 使用 StructureExporter 保存建议的模型结构。

    导出的文件是 JSON 格式的。请注意,随着训练进程,建议的模型结构会发生变化。没有关于停训时间的具体指南,尽管你可能希望等待正则化损失(通过汇总报告)稳定下来。

  6. (可选)创建汇总操作以通过 TensorBoard 监控训练进度。

  7. 使用 StructureExporter 输出修改你的模型。

  8. 从头开始重训练模型,不使用 MorphNet 正则化器。

    注意:对于所有超参数(如学习率调度),使用标准值。

  9. (可选)均匀扩展网络以根据需要调整准确性与成本的权衡。或者,可以在结构学习步骤之前执行此步骤。

我们将第一轮训练称为 结构学习,第二轮称为 重训练

总结来说,MorphNet 的关键超参数是:

  • 正则化强度
  • 存活阈值

请注意,正则化器类型不是超参数,因为它由关注的指标(FLOPs、延迟)和 BatchNorm 的存在唯一决定。

正则化器类型

正则化器类可以在 network_regularizers/ 目录下找到。它们以使用的算法和试图最小化的目标成本命名。例如,LogisticSigmoidFlopsRegularizer 使用 Logistic-Sigmoid 概率方法正则化 FLOP 成本,而 GammaModelSizeRegularizer 通过使用批标准化的伽马变量正则化模型大小成本。

正则化器算法

  • [新] LogisticSigmoid 设计用于控制任何模型类型,但需要向模型添加简单的 门控层
  • GroupLasso 专为没有批标准化的模型设计。
  • Gamma 专为具有批标准化的模型设计;需要启用批标准化比例。

正则化器目标成本

  • Flops 目标是推理网络的 FLOP 计数。
  • 模型大小 目标是网络的权重数量。
  • 延迟 优化推理网络的估计推理延迟,基于特定硬件特性。

示例

添加 FLOPs 正则化器

下面的例子演示了如何使用 MorphNet 减少模型中的 FLOPs。在这例中,正则化器将从 logits 开始遍历图,并且不会越过 inputslabels 之前的任何操作;这允许指定 MorphNet 要优化的子图。

from morph_net.network_regularizers import flop_regularizer
from morph_net.tools import structure_exporter

def build_model(inputs, labels, is_training, ...):
  gated_relu = activation_gating.gated_relu_activation()

  net = tf.layers.conv2d(inputs, kernel=[5, 5], num_outputs=256)
  net = gated_relu(net, is_training=is_training)

  ...
  ...

  net = tf.layers.conv2d(net, kernel=[3, 3], num_outputs=1024)
  net = gated_relu(net, is_training=is_training)

  logits = tf.reduce_mean(net, [1, 2])
  logits = tf.layers.dense(logits, units=1024)
  return logits

inputs, labels = preprocessor()
logits = build_model(inputs, labels, is_training=True, ...)

network_regularizer = flop_regularizer.LogisticSigmoidFlopsRegularizer(
    output_boundary=[logits.op],
    input_boundary=[inputs.op, labels.op],
    alive_threshold=0.1  # 值为[0, 1]。此默认值适用于大多数情况。
)
regularization_strength = 1e-10
regularizer_loss = (network_regularizer.get_regularization_term() * regularization_strength)

model_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)

optimizer = tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9)

train_op = optimizer.minimize(model_loss + regularizer_loss)

你应该通过 Tensorboard 监控结构学习训练的进展。特别是,应该考虑添加一个汇总来计算如果采用当前建议的结构,则当前 MorphNet 正则化损失和网络成本。

tf.summary.scalar('RegularizationLoss', regularizer_loss)
tf.summary.scalar(network_regularizer.cost_name, network_regularizer.get_cost())

TensorBoardDisplayOfFlops

较大的 regularization_strength 值将使有效 FLOP 计数收敛到较小值。如果 regularization_strength 足够大,FLOP 计数将降至零。相反,如果它足够小,FLOP 计数将保持初始值,网络结构将不会变化。regularization_strength 参数是你控制价格性能曲线位置的调节器。alive_threshold 参数用于确定何时激活存活。

提取由 MorphNet 学习到的架构

在训练过程中,应保存一个包含所学网络结构的 JSON 文件,即由 MorphNet 保持活跃(未移除)激活的层数。

exporter = structure_exporter.StructureExporter(
    network_regularizer.op_regularizer_manager)

with tf.Session() as sess:
  tf.global_variables_initializer().run()
  for step in range(max_steps):
    _, structure_exporter_tensors = sess.run([train_op, exporter.tensors])
    if (step % 1000 == 0):
      exporter.populate_tensor_values(structure_exporter_tensors)
      exporter.create_file_and_save_alive_counts(train_dir, step)

其他信息

联系人: morphnet@google.com

维护者

贡献者

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号