Project Icon

kotlindl

高层次的深度学习API,用Kotlin编写,适用于JVM和安卓环境

KotlinDL是一种高层次的深度学习API,用Kotlin编写,适用于JVM和安卓环境。它利用TensorFlow和ONNX Runtime,为开发者提供从零训练深度学习模型、导入Keras和ONNX模型进行推理,以及迁移学习功能。KotlinDL旨在简化深度学习的部署,是生产环境的理想选择。提供详尽的文档、教程和丰富的代码示例,帮助开发者轻松上手并优化深度学习应用。

KotlinDL: Kotlin 中的高层深度学习 API 官方 JetBrains 项目

Kotlin Slack 频道

KotlinDL 是一个用 Kotlin 编写的高层深度学习 API,受 Keras 启发。它的底层使用 TensorFlow Java API 和 ONNX Runtime API for Java。KotlinDL 提供了简单的 API,用于从头开始训练深度学习模型,导入现有的 Keras 和 ONNX 模型进行推理,以及利用迁移学习为你的任务调整现有的预训练模型。

该项目旨在使 JVM 和 Android 开发人员更容易进行深度学习,并简化深度学习模型在生产环境中的部署。

以下是 KotlinDL 中经典卷积神经网络 LeNet 的示例:

private const val EPOCHS = 3
private const val TRAINING_BATCH_SIZE = 1000
private const val NUM_CHANNELS = 1L
private const val IMAGE_SIZE = 28L
private const val SEED = 12L
private const val TEST_BATCH_SIZE = 1000

private val lenet5Classic = Sequential.of(
    Input(
        IMAGE_SIZE,
        IMAGE_SIZE,
        NUM_CHANNELS
    ),
    Conv2D(
        filters = 6,
        kernelSize = intArrayOf(5, 5),
        strides = intArrayOf(1, 1, 1, 1),
        activation = Activations.Tanh,
        kernelInitializer = GlorotNormal(SEED),
        biasInitializer = Zeros(),
        padding = ConvPadding.SAME
    ),
    AvgPool2D(
        poolSize = intArrayOf(1, 2, 2, 1),
        strides = intArrayOf(1, 2, 2, 1),
        padding = ConvPadding.VALID
    ),
    Conv2D(
        filters = 16,
        kernelSize = intArrayOf(5, 5),
        strides = intArrayOf(1, 1, 1, 1),
        activation = Activations.Tanh,
        kernelInitializer = GlorotNormal(SEED),
        biasInitializer = Zeros(),
        padding = ConvPadding.SAME
    ),
    AvgPool2D(
        poolSize = intArrayOf(1, 2, 2, 1),
        strides = intArrayOf(1, 2, 2, 1),
        padding = ConvPadding.VALID
    ),
    Flatten(), // 3136
    Dense(
        outputSize = 120,
        activation = Activations.Tanh,
        kernelInitializer = GlorotNormal(SEED),
        biasInitializer = Constant(0.1f)
    ),
    Dense(
        outputSize = 84,
        activation = Activations.Tanh,
        kernelInitializer = GlorotNormal(SEED),
        biasInitializer = Constant(0.1f)
    ),
    Dense(
        outputSize = 10,
        activation = Activations.Linear,
        kernelInitializer = GlorotNormal(SEED),
        biasInitializer = Constant(0.1f)
    )
)


fun main() {
    val (train, test) = mnist()
    
    lenet5Classic.use {
        it.compile(
            optimizer = Adam(clipGradient = ClipGradientByValue(0.1f)),
            loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
            metric = Metrics.ACCURACY
        )
    
        it.logSummary()
    
        it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE)
    
        val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
    
        println("Accuracy: $accuracy")
    }
}

目录

库结构

KotlinDL 由几个模块组成:

  • kotlin-deeplearning-api API 接口和类
  • kotlin-deeplearning-impl 实现类和工具
  • kotlin-deeplearning-onnx 使用 ONNX Runtime 进行推理
  • kotlin-deeplearning-tensorflow 使用 TensorFlow 进行学习和推理
  • kotlin-deeplearning-visualization 可视化工具
  • kotlin-deeplearning-dataset 数据集类

模块 kotlin-deeplearning-tensorflowkotlin-deeplearning-dataset 仅适用于桌面 JVM,而其他制品也可以在 Android 上使用。

如何在项目中配置 KotlinDL

要在项目中使用 KotlinDL,请确保 mavenCentral 已添加到存储库列表中:

repositories {
    mavenCentral()
}

然后将必要的依赖项添加到你的 build.gradle 文件中。

要开始创建简单的神经网络或下载预训练模型,只需添加以下依赖项:

// build.gradle
dependencies {
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
    implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")
}

使用 kotlin-deeplearning-onnx 模块进行 ONNX Runtime 推理:

// build.gradle
dependencies {
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
  implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")
}

要在你的 JVM 项目中充分发挥 KotlinDL 的作用,请将以下依赖项添加到你的 build.gradle 文件中:

// build.gradle
dependencies {
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]'
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
    implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-visualization:[KOTLIN-DL-VERSION]'
}
// build.gradle.kts
dependencies {
  implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")
  implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")
  implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-visualization:[KOTLIN-DL-VERSION]")
}

最新的稳定版 KotlinDL 版本是 0.5.2,最新的不稳定版本是 0.6.0-alpha-1

更多详细信息以及 pom.xmlbuild.gradle.kts 示例,请参考 快速入门指南

在 Jupyter Notebook 中使用 KotlinDL

你可以在 Jupyter Notebook 中使用 Kotlin 内核来交互地使用 KotlinDL。为此,请在你的笔记本中添加所需的依赖项:

@file:DependsOn("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:[KOTLIN-DL-VERSION]")

有关安装 Jupyter Notebook 和添加 Kotlin 内核的更多详细信息,请查看 快速入门指南

在 Android 项目中使用 KotlinDL

KotlinDL 支持在 Android 平台上进行 ONNX 模型的推理。 要在你的 Android 项目中使用 KotlinDL,请将以下依赖项添加到你的 build.gradle 文件中:

// build.gradle
implementation 'org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]'
// build.gradle.kts
implementation ("org.jetbrains.kotlinx:kotlin-deeplearning-onnx:[KOTLIN-DL-VERSION]")

更多详细信息,请参阅 快速入门指南

KotlinDL、ONNX Runtime、Android 和 JDK 版本

下表显示了 KotlinDL、TensorFlow、ONNX Runtime、Android 编译 SDK 和最低支持的 Java 版本之间的映射。

KotlinDL 版本最低 Java 版本ONNX Runtime 版本TensorFlow 版本Android: 编译 SDK 版本
0.1.*81.15
0.2.081.15
0.3.081.8.11.15
0.4.081.11.01.15
0.5.0-0.5.1111.12.11.1531
0.5.2111.14.01.1531
0.6.*111.16.01.1531

文档

示例和教程

你不需要有深度学习的先验经验来使用 KotlinDL。

我们正在努力包含详尽的文档,以帮助你入门。 目前,请随意查看我们准备的以下教程:

获取更多灵感,请查看本存储库中的 代码示例Sample Android App

在 GPU 上运行 KotlinDL

为了启用 GPU 上的训练和推理,请阅读这个 TensorFlow GPU 支持页面 并安装 CUDA 框架,以便在 GPU 设备上进行计算。

注意,只有 NVIDIA 设备得到支持。

如果你希望利用 GPU,还需要在项目中添加以下依赖项:

// build.gradle
implementation 'org.tensorflow:libtensorflow:1.15.0'
implementation 'org.tensorflow:libtensorflow_jni_gpu:1.15.0'
// build.gradle.kts
implementation ("org.tensorflow:libtensorflow:1.15.0")
implementation ("org.tensorflow:libtensorflow_jni_gpu:1.15.0")

在 Windows 上,需要以下分发版:

为了在 CUDA 设备上进行 ONNX 模型的推理,你还需要在项目中添加以下依赖项:

// build.gradle
api 'com.microsoft.onnxruntime:onnxruntime_gpu:1.16.0'
// build.gradle.kts
api("com.microsoft.onnxruntime:onnxruntime_gpu:1.16.0")

欲了解更多有关 ONNXRuntime 和 CUDA 版本兼容性的详细信息,请参阅 ONNXRuntime CUDA 执行提供者页面

日志记录

默认情况下,API 模块使用 kotlin-logging 库来组织与特定日志记录实现分离的日志记录过程。 你可以使用任何广泛知名的 JVM 日志库,比如 Simple Logging Facade for Java (SLF4J) 实现库,例如 Logback 或 Log4j/Log4j2。

如果你希望使用 log4j2,你还需要将以下依赖项和配置文件 log4j2.xml 添加到项目的 src/resource 文件夹中:

// build.gradle
implementation 'org.apache.logging.log4j:log4j-api:2.17.2'
implementation 'org.apache.logging.log4j:log4j-core:2.17.2'
implementation 'org.apache.logging.log4j:log4j-slf4j-impl:2.17.2'
// build.gradle.kts
implementation("org.apache.logging.log4j:log4j-api:2.17.2")
implementation("org.apache.logging.log4j:log4j-core:2.17.2")
implementation("org.apache.logging.log4j:log4j-slf4j-impl:2.17.2")
<Configuration status="WARN">
    <Appenders>
        <Console name="STDOUT" target="SYSTEM_OUT">
            <PatternLayout pattern="%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n"/>
        </Console>
    </Appenders>

    <Loggers>
        <Root level="debug">
            <AppenderRef ref="STDOUT" level="DEBUG"/>
        </Root>
        <Logger name="io.jhdf" level="off" additivity="true">
            <appender-ref ref="STDOUT" />
        </Logger>
    </Loggers>
</Configuration>

如果你希望使用 Logback,需要将以下依赖项和配置文件 logback.xml 添加到项目的 src/resource 文件夹中:

// build.gradle
implementation 'ch.qos.logback:logback-classic:1.4.5'
// build.gradle.kts
implementation("ch.qos.logback:logback-classic:1.4.5")
<configuration>
    <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
        <encoder>
            <pattern>%d{HH:mm:ss.SSS} [%thread] %-5level %logger{36} - %msg%n</pattern>
        </encoder>
    </appender>

    <root level="info">
        <appender-ref ref="STDOUT"/>
    </root>
</configuration>

这些配置文件可以在 examples 模块中找到。

Fat Jar 问题

有一个已知的 Stack Overflow 问题 和 TensorFlow 问题 与在 Amazon EC2 实例上创建和执行 Fat Jar 相关。

java.lang.UnsatisfiedLinkError: /tmp/tensorflow_native_libraries-1562914806051-0/libtensorflow_jni.so: libtensorflow_framework.so.1: cannot open shared object file: No such file or directory

尽管描述此问题的 bug 已在 TensorFlow 1.14 版发布时关闭, 但它并没有完全修复,需要在构建脚本中添加额外的行。

一个简单的 解决方案 是将 TensorFlow 版本规范添加到 Jar 的 Manifest 中。 下面是创建 Fat Jar 的 Gradle 构建任务示例。

// build.gradle

task fatJar(type: Jar) {
    manifest {
        attributes 'Implementation-Version': '1.15'
    }
    classifier = 'all'
    from { configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } }
    with jar
}
// build.gradle.kts

plugins {
    kotlin("jvm") version "1.5.31"
    id("com.github.johnrengelman.shadow") version "7.0.0"
}

tasks{
    shadowJar {
        manifest {
            attributes(Pair("Main-Class", "MainKt"))
            attributes(Pair("Implementation-Version", "1.15"))
        }
    }
}

限制

目前,仅支持一部分深度学习架构。以下是可用层的列表:

  • 核心层:
    • InputDenseFlattenReshapeDropoutBatchNorm
  • 卷积层:
    • Conv1DConv2DConv3D
    • Conv1DTransposeConv2DTransposeConv3DTranspose
    • DepthwiseConv2D
    • SeparableConv2D
  • 池化层:
    • MaxPool1DMaxPool2DMaxPooling3D
    • AvgPool1DAvgPool2DAvgPool3D
    • GlobalMaxPool1DGlobalMaxPool2DGlobalMaxPool3D
    • GlobalAvgPool1DGlobalAvgPool2DGlobalAvgPool3D
  • 合并层:
    • AddSubtractMultiply
    • AverageMaximumMinimum
    • Dot
    • Concatenate
  • 激活层:
    • ELULeakyReLUPReLUReLUSoftmaxThresholdedReLU
    • ActivationLayer
  • 裁剪层:
    • Cropping1DCropping2DCropping3D
  • 上采样层:
    • UpSampling1DUpSampling2DUpSampling3D
  • 零填充层:
    • ZeroPadding1DZeroPadding2DZeroPadding3D
  • 其他层:
    • PermuteRepeatVector

当前用于层实现的 TensorFlow 版本是 1.15 Java API,但本项目将在不久的将来切换到 TensorFlow 2.+。 然而,这并不会影响高级 API。目前仅支持在桌面上使用 TensorFlow 模型进行推理。

贡献

阅读 贡献指南

报告问题/支持

请使用 GitHub 问题 来提交功能请求和错误报告。 你也可以加入 Kotlin Slack 中的 #kotlindl 频道

行为准则

此项目和相应的社区受 JetBrains 开源和社区行为准则 管理。请确保你已阅读。

许可证

KotlinDL 采用 Apache 2.0 许可证 授权。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

吐司

探索Tensor.Art平台的独特AI模型,免费访问各种图像生成与AI训练工具,从Stable Diffusion等基础模型开始,轻松实现创新图像生成。体验前沿的AI技术,推动个人和企业的创新发展。

Project Cover

SubCat字幕猫

SubCat字幕猫APP是一款创新的视频播放器,它将改变您观看视频的方式!SubCat结合了先进的人工智能技术,为您提供即时视频字幕翻译,无论是本地视频还是网络流媒体,让您轻松享受各种语言的内容。

Project Cover

美间AI

美间AI创意设计平台,利用前沿AI技术,为设计师和营销人员提供一站式设计解决方案。从智能海报到3D效果图,再到文案生成,美间让创意设计更简单、更高效。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号