Keras / TensorFlow 的 visualkeras
简介
Visualkeras 是一个 Python 包,用于帮助可视化 Keras(独立或包含在 tensorflow 中)神经网络架构。它允许轻松定制样式以满足大多数需求。该模块支持分层样式架构生成,非常适合 CNN(卷积神经网络),以及图形样式架构,适用于大多数模型,包括普通前馈网络。 有关引用本项目的帮助,请参阅此处。
模型支持
模式 | Sequential | Functional | 子类模型 |
---|---|---|---|
visualkeras.layered_view() | 是(1) | 部分支持(1,2) | 未测试 |
visualkeras.graph_view() | 是 | 是 | 未测试 |
1:任何超过 3 维的张量将被渲染为具有延长 z 轴的 3D 张量。
2:仅适用于每层不超过一个输入或输出的线性模型。非线性模型将按顺序显示。
版本支持
我们目前仅支持 Keras 2.0 及以上版本。我们计划在未来的更新中添加对 Keras 1.0 版本的支持。
安装
要从 PyPi 安装已发布的版本(最后更新:2024 年 7 月 19 日),请执行:
pip install visualkeras
要将 visualkeras 更新到最新版本,请在上述命令中添加 --upgrade
标志。
如果您想要最新(可能不稳定)的功能,也可以直接从 GitHub master 分支安装:
pip install git+https://github.com/paulgavrikov/visualkeras
使用方法
生成神经网络架构很简单:
import visualkeras
model = ...
visualkeras.layered_view(model).show() # 使用系统查看器显示
visualkeras.layered_view(model, to_file='output.png') # 写入磁盘
visualkeras.layered_view(model, to_file='output.png').show() # 写入并显示
为了帮助理解一些最重要的参数,我们将使用 VGG16 CNN 架构(参见 example.py)。
默认
visualkeras.layered_view(model)
图例
您可以设置图例参数来描述颜色和层类型之间的关系。也可以传递自定义的 PIL.ImageFont
(或者不设置,visualkeras 将使用默认的 PIL 字体)。请注意,根据您的操作系统,您可能需要提供所需字体的完整路径。
from PIL import ImageFont
font = ImageFont.truetype("arial.ttf", 32) # 严禁使用 comic sans!
visualkeras.layered_view(model, legend=True, font=font) # font 是可选的!
扁平样式
visualkeras.layered_view(model, draw_volume=False)
间距和逻辑分组
两层之间的全局距离可以通过 spacing
控制。为了生成逻辑分组,可以添加一个特殊的虚拟 keras 层 visualkeras.SpacingDummyLayer()
。
model = ...
...
model.add(visualkeras.SpacingDummyLayer(spacing=100))
...
visualkeras.layered_view(model, spacing=0)
自定义颜色映射
可以为每种层类型提供自定义的填充和轮廓颜色映射。
from tensorflow.python.keras.layers import Dense, Conv2D, Flatten, Dropout, MaxPooling2D, ZeroPadding2D
from collections import defaultdict
color_map = defaultdict(dict)
color_map[Conv2D]['fill'] = 'orange'
color_map[ZeroPadding2D]['fill'] = 'gray'
color_map[Dropout]['fill'] = 'pink'
color_map[MaxPooling2D]['fill'] = 'red'
color_map[Dense]['fill'] = 'green'
color_map[Flatten]['fill'] = 'teal'
visualkeras.layered_view(model, color_map=color_map)
隐藏层
某些模型可能包含太多层,难以可视化或理解。在这种情况下,在不修改 keras 模型的情况下隐藏(忽略)某些层会很有帮助。Visualkeras 允许通过层类型(type_ignore
)或在 keras 层序列中的索引(index_ignore
)来忽略层。
visualkeras.layered_view(model, type_ignore=[ZeroPadding2D, Dropout, Flatten])
缩放尺寸
Visualkeras通过输出形状计算每一层的大小。值被转换为像素。然后应用缩放。默认情况下,visualkeras会放大x和y维度并缩小z维度的大小,因为这被认为在视觉上最具吸引力。但是,可以使用scale_xy
和scale_z
来控制缩放。此外,为了防止过小或过大的选项,可以设置最小值和最大值(min_xy
、min_z
、max_xy
、max_z
)。
visualkeras.layered_view(model, scale_xy=1, scale_z=1, max_z=1000)
注意:缩放后的模型可能会隐藏层的真实复杂性,但在视觉上更具吸引力。
绘制信息文本
通过text_callable
参数,可以将一个函数传递给layered_view
函数,用于在特定层的下方或上方绘制文本。该函数应具有以下属性:
-
接受两个参数:第一个是模型中层的索引。这个索引忽略了
type_ignore
、index_ignore
中列出的层,也忽略了SpacingDummyLayer
类的层。第二个参数是模型中给定索引处使用的层对象。 -
返回两个参数:第一个返回值是包含要绘制的文本的字符串。第二个返回值是一个布尔值,指示文本是否要绘制在表示该层的方框上方。
以下函数旨在描述层的名称及其维度。它将产生如下图所示的输出:
def text_callable(layer_index, layer):
# 每隔一个文本绘制在层的上方,第一个在下方
above = bool(layer_index%2)
# 获取层的输出形状
output_shape = [x for x in list(layer.output_shape) if x is not None]
# 如果输出形状是元组列表,我们只取第一个
if isinstance(output_shape[0], tuple):
output_shape = list(output_shape[0])
output_shape = [x for x in output_shape if x is not None]
# 存储将要绘制的文本的变量
output_shape_txt = ""
# 创建输出形状的字符串表示
for ii in range(len(output_shape)):
output_shape_txt += str(output_shape[ii])
if ii < len(output_shape) - 2: # 在维度之间添加x,例如3x3
output_shape_txt += "x"
if ii == len(output_shape) - 2: # 在最后两个维度之间添加换行符,例如3x3 \n 64
output_shape_txt += "\n"
# 将层的名称作为新行添加到文本中
output_shape_txt += f"\n{layer.name}"
# 返回文本值和是否应该在层上方绘制
return output_shape_txt, above
注意:使用padding
参数避免长文本在图像的左边或右边被截断。同时使用SpacingDummyLayers
来避免不同层的文本交错。
反转视图
在某些用例中,反转架构视图以查看每一层的背面可能很有用。例如,在可视化解码器类架构时。在这种情况下,我们可以将draw_reversed设置为True。以下两张图分别显示了同一模型在draw_reversed设置为False和True时的情况。
visualkeras.layered_view(model, draw_reversed=False) # 默认行为
visualkeras.layered_view(model, draw_reversed=True)
显示层尺寸(在图例中)
可以在图例中显示层的尺寸。为此,在layered_view
中设置legend=True
和show_dimension=True
。这是一种比创建text_callable
参数的可调用对象更简单的替代方法,用于在每层上方或下方显示尺寸。
visualkeras.layered_view(model, legend=True, show_dimension=True)
常见问题
此处记录的功能X不起作用
主分支可能领先于pypi。考虑升级到最新(可能不稳定)的版本,如"安装"中所讨论的。
安装aggdraw失败
这很可能是由于缺少gcc / g++组件(例如在Elementary OS上)。尝试通过您的包管理器安装它们,例如:
sudo apt-get install gcc
sudo apt-get install g++
.show()没有打开窗口
您可能尚未配置默认图像查看器。您可以通过大多数包管理器安装imagemagick:
sudo apt-get install imagemagick
引用
如果您发现这个项目对您的研究有帮助,请考虑在您的出版物中按以下方式引用它。
@misc{Gavrikov2020VisualKeras,
author = {Gavrikov, Paul},
title = {visualkeras},
year = {2020},
publisher = {GitHub},
journal = {GitHub repository},
howpublished = {\url{https://github.com/paulgavrikov/visualkeras}},
}