KotlinDL: Kotlin 中的高层深度学习 API
KotlinDL 是一个用 Kotlin 编写的高层深度学习 API,受 Keras 启发。它的底层使用 TensorFlow Java API 和 ONNX Runtime API for Java。KotlinDL 提供了简单的 API,用于从头开始训练深度学习模型,导入现有的 Keras 和 ONNX 模型进行推理,以及利用迁移学习为你的任务调整现有的预训练模型。
该项目旨在使 JVM 和 Android 开发人员更容易进行深度学习,并简化深度学习模型在生产环境中的部署。
以下是 KotlinDL 中经典卷积神经网络 LeNet 的示例:
private const val EPOCHS = 3
private const val TRAINING_BATCH_SIZE = 1000
private const val NUM_CHANNELS = 1L
private const val IMAGE_SIZE = 28L
private const val SEED = 12L
private const val TEST_BATCH_SIZE = 1000
private val lenet5Classic = Sequential.of(
Input(
IMAGE_SIZE,
IMAGE_SIZE,
NUM_CHANNELS
),
Conv2D(
filters = 6,
kernelSize = intArrayOf(5, 5),
strides = intArrayOf(1, 1, 1, 1),
activation = Activations.Tanh,
kernelInitializer = GlorotNormal(SEED),
biasInitializer = Zeros(),
padding = ConvPadding.SAME
),
AvgPool2D(
poolSize = intArrayOf(1, 2, 2, 1),
strides = intArrayOf(1, 2, 2, 1),
padding = ConvPadding.VALID
),
Conv2D(
filters = 16,
kernelSize = intArrayOf(5, 5),
strides = intArrayOf(1, 1, 1, 1),
activation = Activations.Tanh,
kernelInitializer = GlorotNormal(SEED),
biasInitializer = Zeros(),
padding = ConvPadding.SAME
),
AvgPool2D(
poolSize = intArrayOf(1, 2, 2, 1),
strides = intArrayOf(1, 2, 2, 1),
padding = ConvPadding.VALID
),
Flatten(), // 3136
Dense(
outputSize = 120,
activation = Activations.Tanh,
kernelInitializer = GlorotNormal(SEED),
biasInitializer = Constant(0.1f)
),
Dense(
outputSize = 84,
activation = Activations.Tanh,
kernelInitializer = GlorotNormal(SEED),
biasInitializer = Constant(0.1f)
),
Dense(
outputSize = 10,
activation = Activations.Linear,
kernelInitializer = GlorotNormal(SEED),
biasInitializer = Constant(0.1f)
)
)
fun main() {
val (train, test) = mnist()
lenet5Classic.use {
it.compile(
optimizer = Adam(clipGradient = ClipGradientByValue(0.1f)),
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
metric = Metrics.ACCURACY
)
it.logSummary()
it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE)
val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
println("Accuracy: $accuracy")
}
}
目录
- 库结构
- 如何在项目中配置 KotlinDL
- KotlinDL、ONNX Runtime、Android 和 JDK 版本
- 文档
- 示例和教程
- 在 GPU 上运行 KotlinDL
- 日志记录
- Fat Jar 问题
- 限制
- 贡献
- 报告问题/支持
- 行为准则
- 许可证
库结构
KotlinDL 由几个模块组成:
kotlin-deeplearning-api
API 接口和类kotlin-deeplearning-impl
实现类和工具kotlin-deeplearning-onnx
使用 ONNX Runtime 进行推理kotlin-deeplearning-tensorflow
使用 TensorFlow 进行学习和推理kotlin-deeplearning-visualization
可视化工具kotlin-deeplearning-dataset
数据集类
模块 kotlin-deeplearning-tensorflow
和 kotlin-deeplearning-dataset
仅适用于桌面 JVM,而其他制品也可以在 Android 上使用。
如何在项目中配置 KotlinDL
要在项目中使用 KotlinDL,请确保 mavenCentral
已添加到存储库列表中:
repositories {
mavenCentral()
}
然后将必要的依赖项添加到你的 build.gradle
文件中。
要开始创建简单的神经网络或下载预训练模型,只需添加以下依赖项:
// build.gradle
dependencies {
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")
}
使用 kotlin-deeplearning-onnx
模块进行 ONNX Runtime 推理:
// build.gradle
dependencies {
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")
}
要在你的 JVM 项目中充分发挥 KotlinDL 的作用,请将以下依赖项添加到你的 build.gradle
文件中:
// build.gradle
dependencies {
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]'
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-visualization:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-visualization:[KOTLIN-DL-VERSION]")
}
最新的稳定版 KotlinDL 版本是 0.5.2
,最新的不稳定版本是 0.6.0-alpha-1
。
更多详细信息以及 pom.xml
和 build.gradle.kts
示例,请参考 快速入门指南。
在 Jupyter Notebook 中使用 KotlinDL
你可以在 Jupyter Notebook 中使用 Kotlin 内核来交互地使用 KotlinDL。为此,请在你的笔记本中添加所需的依赖项:
@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")
有关安装 Jupyter Notebook 和添加 Kotlin 内核的更多详细信息,请查看 快速入门指南。
在 Android 项目中使用 KotlinDL
KotlinDL 支持在 Android 平台上进行 ONNX 模型的推理。 要在你的 Android 项目中使用 KotlinDL,请将以下依赖项添加到你的 build.gradle 文件中:
// build.gradle
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
// build.gradle.kts
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")
更多详细信息,请参阅 快速入门指南。
KotlinDL、ONNX Runtime、Android 和 JDK 版本
下表显示了 KotlinDL、TensorFlow、ONNX Runtime、Android 编译 SDK 和最低支持的 Java 版本之间的映射。
KotlinDL 版本 | 最低 Java 版本 | ONNX Runtime 版本 | TensorFlow 版本 | Android: 编译 SDK 版本 |
---|---|---|---|---|
0.1.* | 8 | 1.15 | ||
0.2.0 | 8 | 1.15 | ||
0.3.0 | 8 | 1.8.1 | 1.15 | |
0.4.0 | 8 | 1.11.0 | 1.15 | |
0.5.0-0.5.1 | 11 | 1.12.1 | 1.15 | 31 |
0.5.2 | 11 | 1.14.0 | 1.15 | 31 |
0.6.* | 11 | 1.16.0 | 1.15 | 31 |
文档
- 演讲和视频:
- 使用 KotlinDL 进行深度学习 (Zinoviev Alexey at Huawei Developer Group HDG UK 2021, 幻灯片)
- KotlinDL 深度学习简介 (Zinoviev Alexey at Kotlin Budapest User Group 2021, 幻灯片)
- KotlinDL 变更日志
- 完整的 KotlinDL API 参考
示例和教程
你不需要有深度学习的先验经验来使用 KotlinDL。
我们正在努力包含详尽的文档,以帮助你入门。 目前,请随意查看我们准备的以下教程:
- 快速入门指南
- 创建你的第一个神经网络
- 导入 Keras 模型
- 迁移学习
- 使用 Functional API 的迁移学习
- 在 JVM 上运行 ONNX 模型的推理
- 在 Android 上运行 ONNX 模型的推理
获取更多灵感,请查看本存储库中的 代码示例 和 Sample Android App。
在 GPU 上运行 KotlinDL
为了启用 GPU 上的训练和推理,请阅读这个 TensorFlow GPU 支持页面 并安装 CUDA 框架,以便在 GPU 设备上进行计算。
注意,只有 NVIDIA 设备得到支持。
如果你希望利用 GPU,还需要在项目中添加以下依赖项:
// build.gradle
implementation 'org.tensorflow:libtensorflow:1.15.0'
implementation 'org.tensorflow:libtensorflow_jni_gpu:1.15.0'
// build.gradle.kts
implementation ("org.tensorflow:libtensorflow:1.15.0")
implementation ("org.tensorflow:libtensorflow_jni_gpu:1.15.0")
在 Windows 上,需要以下分发版:
- CUDA cuda_10.0.130_411.31_win10
- cudnn-7.6.3
- C++ 可再发行包
为了在 CUDA 设备上进行 ONNX 模型的推理,你还需要在项目中添加以下依赖项:
// build.gradle
api 'com.microsoft.onnxruntime:onnxruntime_gpu:1.16.0'
// build.gradle.kts
api("com.microsoft.onnxruntime:onnxruntime_gpu:1.16.0")
欲了解更多有关 ONNXRuntime 和 CUDA 版本兼容性的详细信息,请参阅 ONNXRuntime CUDA 执行提供者页面。
日志记录
默认情况下,API 模块使用 kotlin-logging 库来组织与特定日志记录实现分离的日志记录过程。 你可以使用任何广泛知名的 JVM 日志库,比如 Simple Logging Facade for Java (SLF4J) 实现库,例如 Logback 或 Log4j/Log4j2。
如果你希望使用 log4j2,你还需要将以下依赖项和配置文件 log4j2.xml
添加到项目的 src/resource
文件夹中:
// build.gradle
implementation 'org.apache.logging.log4j:log4j-api:2.17.2'
implementation 'org.apache.logging.log4j:log4j-core:2.17.2'
implementation 'org.apache.logging.log4j:log4j-slf4j-impl:2.17.2'
// build.gradle.kts
implementation("org.apache.logging.log4j:log4j-api:2.17.2")
implementation("org.apache.logging.log4j:log4j-core:2.17.2")
implementation("org.apache.logging.log4j:log4j-slf4j-impl:2.17.2")
<Configuration status="WARN">
<Appenders>
<Console name="STDOUT" target="SYSTEM_OUT">
<PatternLayout pattern="%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="debug">
<AppenderRef ref="STDOUT" level="DEBUG"/>
</Root>
<Logger name="io.jhdf" level="off" additivity="true">
<appender-ref ref="STDOUT" />
</Logger>
</Loggers>
</Configuration>
如果你希望使用 Logback,需要将以下依赖项和配置文件 logback.xml
添加到项目的 src/resource
文件夹中:
// build.gradle
implementation 'ch.qos.logback:logback-classic:1.4.5'
// build.gradle.kts
implementation("ch.qos.logback:logback-classic:1.4.5")
<configuration>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
</encoder>
</appender>
<root level="info">
<appender-ref ref="STDOUT"/>
</root>
</configuration>
这些配置文件可以在 examples
模块中找到。
Fat Jar 问题
有一个已知的 Stack Overflow 问题 和 TensorFlow 问题 与在 Amazon EC2 实例上创建和执行 Fat Jar 相关。
java.lang.UnsatisfiedLinkError: /tmp/tensorflow_native_libraries-1562914806051-0/libtensorflow_jni.so: libtensorflow_framework.so.1: cannot open shared object file: No such file or directory
尽管描述此问题的 bug 已在 TensorFlow 1.14 版发布时关闭, 但它并没有完全修复,需要在构建脚本中添加额外的行。
一个简单的 解决方案 是将 TensorFlow 版本规范添加到 Jar 的 Manifest 中。 下面是创建 Fat Jar 的 Gradle 构建任务示例。
// build.gradle
task fatJar(type: Jar) {
manifest {
attributes 'Implementation-Version': '1.15'
}
classifier = 'all'
from { configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } }
with jar
}
// build.gradle.kts
plugins {
kotlin("jvm") version "1.5.31"
id("com.github.johnrengelman.shadow") version "7.0.0"
}
tasks{
shadowJar {
manifest {
attributes(Pair("Main-Class", "MainKt"))
attributes(Pair("Implementation-Version", "1.15"))
}
}
}
限制
目前,仅支持一部分深度学习架构。以下是可用层的列表:
- 核心层:
Input
、Dense
、Flatten
、Reshape
、Dropout
、BatchNorm
。
- 卷积层:
Conv1D
、Conv2D
、Conv3D
;Conv1DTranspose
、Conv2DTranspose
、Conv3DTranspose
;DepthwiseConv2D
;SeparableConv2D
。
- 池化层:
MaxPool1D
、MaxPool2D
、MaxPooling3D
;AvgPool1D
、AvgPool2D
、AvgPool3D
;GlobalMaxPool1D
、GlobalMaxPool2D
、GlobalMaxPool3D
;GlobalAvgPool1D
、GlobalAvgPool2D
、GlobalAvgPool3D
。
- 合并层:
Add
、Subtract
、Multiply
;Average
、Maximum
、Minimum
;Dot
;Concatenate
。
- 激活层:
ELU
、LeakyReLU
、PReLU
、ReLU
、Softmax
、ThresholdedReLU
;ActivationLayer
。
- 裁剪层:
Cropping1D
、Cropping2D
、Cropping3D
。
- 上采样层:
UpSampling1D
、UpSampling2D
、UpSampling3D
。
- 零填充层:
ZeroPadding1D
、ZeroPadding2D
、ZeroPadding3D
。
- 其他层:
Permute
、RepeatVector
。
当前用于层实现的 TensorFlow 版本是 1.15 Java API,但本项目将在不久的将来切换到 TensorFlow 2.+。 然而,这并不会影响高级 API。目前仅支持在桌面上使用 TensorFlow 模型进行推理。
贡献
阅读 贡献指南。
报告问题/支持
请使用 GitHub 问题 来提交功能请求和错误报告。 你也可以加入 Kotlin Slack 中的 #kotlindl 频道。
行为准则
此项目和相应的社区受 JetBrains 开源和社区行为准则 管理。请确保你已阅读。
许可证
KotlinDL 采用 Apache 2.0 许可证 授权。