TensorFlow.js
TensorFlow.js 是一个开源的硬件加速 JavaScript 库,用于训练和部署机器学习模型。
在浏览器中开发机器学习
使用灵活且直观的 API,从零开始使用低级 JavaScript 线性代数库或高级层 API 构建模型。
在 Node.js 中开发机器学习
在 Node.js 运行时使用与 TensorFlow.js 相同的 API 执行原生 TensorFlow。
运行现有模型
使用 TensorFlow.js 模型转换器在浏览器中运行预先存在的 TensorFlow 模型。
重新训练现有模型
使用连接到浏览器的传感器数据或其他客户端数据重新训练现有的机器学习模型。
关于此代码库
该代码库包含组合多个软件包的逻辑和脚本。
API:
- TensorFlow.js Core,一个灵活的低级 API,用于神经网络和数值计算。
- TensorFlow.js Layers,一个高级 API,实现类似 Keras 的功能。
- TensorFlow.js Data,一个简单的 API,用于加载和准备数据,类似于 tf.data。
- TensorFlow.js Converter,用于将 TensorFlow SavedModel 导入到 TensorFlow.js 的工具。
- TensorFlow.js Vis,用于 TensorFlow.js 模型的浏览器可视化工具。
- TensorFlow.js AutoML,一组用于加载和运行 AutoML Edge 生成的模型的 API。
后端/平台:
- TensorFlow.js CPU Backend,Node.js 和浏览器的纯 JavaScript 后端。
- TensorFlow.js WebGL Backend,浏览器的 WebGL 后端。
- TensorFlow.js WASM Backend,浏览器的 WebAssembly 后端。
- TensorFlow.js WebGPU,浏览器的 WebGPU 后端。
- TensorFlow.js Node,通过 TensorFlow C++ 适配器实现的 Node.js 平台。
- TensorFlow.js React Native,通过 expo-gl 适配器实现的 React Native 平台。
如果您关心包大小,可以单独导入这些包。
如果您正在寻找 Node.js 的支持,请查看 TensorFlow.js Node 目录。
示例
画廊
请务必查看 TensorFlow.js 相关项目的画廊。
预训练模型
请务必查看我们的模型库,我们在 NPM 上托管了预训练模型。
基准测试
- 本地基准测试工具。使用此网页工具收集 您本地设备 上使用 CPU、WebGL 或 WASM 后端的 TensorFlow.js 模型和内核的性能相关指标(速度、内存等)。您可以通过遵循指南 来基准测试自定义模型。
- 多设备基准测试工具。使用此工具在 一组远程设备 上收集相同的性能相关指标。
入门
在 JavaScript 项目中获取 TensorFlow.js 有两种主要方式: 通过 脚本标签 或 从 NPM 安装并使用构建工具如 Parcel、WebPack 或 Rollup。
通过脚本标签
将以下代码添加到 HTML 文件中:
<html>
<head>
<!-- 加载 TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs/dist/tf.min.js"> </script>
<!-- 将您的代码放置在下面的 script 标签中。您也可以使用外部 .js 文件 -->
<script>
// 注意这里没有 'import' 语句。'tf' 在首页可用
// 因为上面的 script 标签。
// 为线性回归定义一个模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// 为训练准备模型:指定损失函数和优化器。
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 生成一些用于训练的合成数据。
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// 使用数据训练模型。
model.fit(xs, ys).then(() => {
// 使用模型对模型未见过的数据点进行推理:
// 打开浏览器开发工具查看输出
model.predict(tf.tensor2d([5], [1, 1])).print();
});
</script>
</head>
<body>
</body>
</html>
在浏览器中打开该 HTML 文件,代码应能运行!
通过 NPM
使用 yarn 或 npm 将 TensorFlow.js 添加到您的项目中。注意:由于我们使用 ES2017 语法(如 import
),此工作流程假定您正在使用现代浏览器或使用打包工具/转译器将代码转换为旧浏览器可以理解的形式。请参阅我们的示例,了解我们如何使用 Parcel 构建代码。不过,您可以自由选择任何您喜欢的构建工具。
import * as tf from '@tensorflow/tfjs';
// 为线性回归定义一个模型。
const model = tf.sequential();
model.add(tf.layers.dense({units: 1, inputShape: [1]}));
// 为训练准备模型:指定损失函数和优化器。
model.compile({loss: 'meanSquaredError', optimizer: 'sgd'});
// 生成一些用于训练的合成数据。
const xs = tf.tensor2d([1, 2, 3, 4], [4, 1]);
const ys = tf.tensor2d([1, 3, 5, 7], [4, 1]);
// 使用数据训练模型。
model.fit(xs, ys).then(() => {
// 使用模型对模型未见过的数据点进行推理:
model.predict(tf.tensor2d([5], [1, 1])).print();
});
导入预训练模型
我们支持从以下途径导入预训练模型:
不同后端支持的各种操作
请参阅以下内容:
了解更多
TensorFlow.js 是 TensorFlow 生态系统的一部分。了解更多信息:
- 想要获取社区帮助,请在 TensorFlow 论坛 使用
tfjs
标签。 - [TensorFlow.js 网站](https://