Keras 3:为人类设计的深度学习
Keras 3 是一个支持多种后端的深度学习框架,支持 JAX、TensorFlow 和 PyTorch。您可以轻松构建和训练计算机视觉、自然语言处理、音频处理、时间序列预测、推荐系统等模型。
- 加速模型开发:借助 Keras 的高级用户体验以及易于调试的运行时(如 PyTorch 或 JAX 急切执行),加快深度学习解决方案的发布。
- 最先进的性能:通过选择最适合您模型架构的最快后端(通常是 JAX!),相比其他框架,速度提升可达 20% 至 350%。基准测试请看这里。
- 数据中心级训练:从您的笔记本电脑到大规模 GPU 或 TPU 集群,轻松扩展。
加入近三百万开发者的行列,从初创公司到全球企业,一起利用 Keras 3 的强大功能。
安装
使用 pip 安装
Keras 3 可以在 PyPI 上以 keras
包的形式获得。请注意,Keras 2 仍然可以通过 tf-keras
包使用。
- 安装
keras
:
pip install keras --upgrade
- 安装后端包。
要使用 keras
,您还需要安装所选的后端:tensorflow
、jax
或 torch
。请注意,tensorflow
是使用某些 Keras 3 功能所必需的:某些预处理层以及 tf.data
管道。
本地安装
最小化安装
Keras 3 与 Linux 和 MacOS 系统兼容。对于 Windows 用户,我们建议使用 WSL2 运行 Keras。要安装本地开发版本:
- 安装依赖项:
pip install -r requirements.txt
- 从根目录运行安装命令。
python pip_build.py --install
- 当创建更新
keras_export
公共 API 的 PR 时,运行 API 生成脚本:
./shell/api_gen.sh
添加 GPU 支持
requirements.txt
文件将安装仅支持 CPU 的 TensorFlow、JAX 和 PyTorch 版本。对于 GPU 支持,我们还提供了单独的 requirements-{backend}-cuda.txt
文件,用于 TensorFlow、JAX 和 PyTorch。这些文件通过 pip
安装所有 CUDA 依赖项,并要求预先安装 NVIDIA 驱动程序。我们建议为每个后端创建一个干净的 Python 环境,以避免 CUDA 版本不匹配。例如,以下是在 conda
中创建 Jax GPU 环境的方法:
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
配置您的后端
您可以导出环境变量 KERAS_BACKEND
或编辑本地配置文件 ~/.keras/keras.json
来配置后端。可用的后端选项有:"tensorflow"
、"jax"
、"torch"
。例如:
export KERAS_BACKEND="jax"
在 Colab 中,您可以这样做:
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
注意:后端必须在导入 keras
之前配置,并且在包导入后不能更改后端。
向后兼容性
Keras 3 旨在作为 tf.keras
的直接替代品(使用 TensorFlow 后端时)。只需使用现有的 tf.keras
代码,确保对 model.save()
的调用使用最新的 .keras
格式,您就完成了。
如果您的 tf.keras
模型不包含自定义组件,您可以立即在 JAX 或 PyTorch 上运行它。
如果它包含自定义组件(例如,自定义层或自定义 train_step()
),通常可以在几分钟内将其转换为与后端无关的实现。
此外,无论您使用哪个后端,Keras 模型都可以使用任何格式的数据集:您可以使用现有的 tf.data.Dataset
管道或 PyTorch 的 DataLoaders
训练模型。
为什么使用 Keras 3?
- 在任何框架之上运行您的高级 Keras 工作流——随时利用每个框架的优势,例如 JAX 的可扩展性和性能或 TensorFlow 的生产生态系统选项。
- 编写可以在任何框架的低级工作流中使用的自定义组件(如层、模型、指标)。
- 您可以获取一个 Keras 模型,并在从头开始编写的原生 TF、JAX 或 PyTorch 训练循环中训练它。
- 您可以获取一个 Keras 模型,并将其作为 PyTorch 原生
Module
的一部分,或作为 JAX 原生模型函数的一部分使用。
- 通过避免框架锁定,使您的机器学习代码面向未来。
- 作为 PyTorch 用户:最终获得 Keras 的强大功能和可用性!
- 作为 JAX 用户:获得一个功能齐全、经过实战考验且文档完善的建模和训练库。
阅读更多内容,请参见 Keras 3 发布公告。