node-mlx
一个基于 MLX 的 Node.js 机器学习框架。
该项目与苹果公司无关,你可以通过赞助我来支持开发。
支持的平台
GPU 支持:
- 搭载 Apple Silicon 的 Mac
CPU 支持:
- x64 架构的 Mac
- x64/arm64 架构的 Linux
(目前尚不支持 Windows,但我将来会尝试让 MLX 在其上运行)
请注意,目前 MLX 没有计划支持 Apple Silicon 以外的 GPU。对于 NVIDIA GPU 的计算,你必须使用 TensorFlow.js,或者等待有人将 PyTorch 移植到 Node.js(这应该不会太难)。
示例
- llama3.js - Llama 3 的 JavaScript 实现。
- llm.js - 使用 JavaScript 在本地加载语言模型。
- train-model-with-js - 使用 JavaScript 训练简单的文本生成模型。
- train-llama3-js - 使用 parquet 数据集训练小型 Llama3 模型。
- train-japanese-llama3-js - 训练日语语言模型。
- fine-tune-decoder-js - 微调仅解码器模型。
使用方法
import mlx from '@frost-beta/mlx';
const {core: mx, nn} = mlx;
const model = new nn.Sequential(
new nn.Sequential(new nn.Linear(2, 10), nn.relu),
new nn.Sequential(new nn.Linear(10, 10), new nn.ReLU()),
new nn.Linear(10, 1),
mx.sigmoid,
);
const y = model.forward(mx.random.normal([32, 2]));
console.log(y);
API
目前还没有 JavaScript API 的文档,请查看 TypeScript 定义以了解可用的 API,并参考 MLX 官方网站的文档。
JavaScript API 基本上是通过将 API 名称从 snake_case 转换为 camelCase 来复制官方 Python API。例如,Python 中的 mx.not_equal
API 在 JavaScript 中被重命名为 mx.notEqual
。
由于 JavaScript 的限制,有一些例外情况:
- JavaScript 数字始终是浮点值,因此
mx.array(42)
的默认 dtype 是mx.float32
而不是mx.int32
。 mx.var
API 被重命名为mx.variance
。- 运算符重载不起作用,使用
mx.add(a, b)
而不是a + b
。 - 通过
[]
运算符进行索引不起作用,请使用array.item
和array.itemPut_
方法代替(_
后缀表示原地操作)。 delete array
不起作用,你必须等待垃圾回收来释放数组的内存。Module
实例不能作为函数使用,必须使用forward
方法。
未实现的功能
一些功能尚未支持,将在未来实现:
distributed
模块尚未实现。mx.custom_function
API 尚未实现。- 不支持使用 JavaScript 数组作为索引。
- 传递给
mx.vmap
的函数必须有所有参数都是mx.array
。 mx.compile
的捕获inputs
/outputs
参数尚未实现。- 从 JavaScript 数组创建
mx.array
时,该数组必须只包含原始值。 - API 只接受普通参数,例如
mx.uniform(0, 1, [2, 2])
。尚未实现像mx.uniform({shape: [2, 2]})
这样的命名参数调用。 - 尚不支持
.npz
张量格式。
仅限 JavaScript 的 API
node-mlx 中有一些新的 API,用于解决 JavaScript 特有的问题。
mx.tidy
这与 TensorFlow.js 的 tf.tidy
API 相同,它会清理传递函数中分配的所有中间张量,除了返回的张量。
let result = mx.tidy(() => {
return model.forward(x);
});
mx.dispose
这与 TensorFlow.js 的 tf.dispose
API 相同,它会清理对象中找到的所有张量。
mx.dispose({ a: mx.array([1, 2, 3, 4]) })
复数
JavaScript 中没有内置的复数,我们使用对象来表示它们:
interface Complex {
re: number;
im: number;
}
你也可以使用 mx.Complex(real, imag?)
辅助函数来创建复数。
索引
JavaScript 中的切片表示为对象:
interface Slice {
start: number | null;
stop: number | null;
step: number | null;
}
你也可以使用 mx.Slice(start?, stop?, step?)
辅助函数来创建切片。
JavaScript 标准不允许使用 ...
作为值。要使用省略号作为索引,请使用字符串 "..."
。
使用数组作为索引时,请确保指定了整数 dtype,因为默认 dtype 是 float32
,例如 a.index(mx.array([ 1, 2, 3 ], mx.uint32))
。
以下是将 Python 索引代码翻译为 JavaScript 的一些示例:
获取器
Python | JavaScript |
---|---|
array[None] | array.index(null) |
array[Ellipsis, ...] | array.index('...', '...') |
array[1, 2] | array.index(1, 2) |
array[True, False] | array.index(true, false) |
array[1::2] | array.index(mx.Slice(1, None, 2)) |
array[mx.array([1, 2])] | array.index(mx.array([1, 2], mx.int32)) |
array[..., 0, True, 1::2] | array.index('...', 0, true, mx.Slice(1, null, 2) |
设置器
Python | JavaScript |
---|---|
array[None] = 1 | array.indexPut_(null, 1) |
array[Ellipsis, ...] = 1 | array.indexPut_(['...', '...'], 1) |
array[1, 2] = 1 | array.indexPut_([1, 2], 1) |
array[True, False] = 1 | array.indexPut_([true, false], 1) |
array[1::2] = 1 | array.indexPut_(mx.Slice(1, null, 2), 1) |
array[mx.array([1, 2])] = 1 | array.indexPut_(mx.array([1, 2], mx.int32), 1) |
array[..., 0, True, 1::2] = 1 | array.indexPut_(['...', 0, true, mx.Slice(1, null, 2)], 1) |
Python与JavaScript索引类型之间的转换
Python | JavaScript |
---|---|
None | null |
Ellipsis | "..." |
... | "..." |
123 | 123 |
True | true |
False | false |
: 或 :: | mx.Slice() |
1: 或 1:: | mx.Slice(1) |
:3 或 :3: | mx.Slice(null, 3) |
::2 | mx.Slice(null, null, 2) |
1:3 | mx.Slice(1, 3) |
1::2 | mx.Slice(1, null, 2) |
:3:2 | mx.Slice(null, 3, 2) |
1:3:2 | mx.Slice(1, 3, 2) |
mx.array([1, 2]) | mx.array([1, 2], mx.int32) |
构建
对于非苹果芯片的Mac平台,你必须安装blas。
# Linux
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
# x64 Mac
brew install openblas
这个项目混合了C++和TypeScript代码,并使用cmake-js来构建原生代码。
git clone --recursive https://github.com/frost-beta/node-mlx.git
cd node-mlx
npm install
npm run build -p 8
npm run test
发布
预构建的二进制文件会上传到GitHub Releases,当从npm仓库安装node-mlx时,总是会下载预构建的二进制文件,没有从源代码构建的备选方案。
package.json
中的版本号始终是0.0.1-dev
,表示本地开发版本,npm包只能通过GitHub工作流在推送新标签时发布。
版本控制
在达到官方Python API的功能和稳定性之前,这个项目的npm版本将保持在0.0.x
。
与上游同步
测试和大部分TypeScript源代码是从官方MLX项目的Python代码转换而来的。当更新deps/mlx
子模块时,需要审查每个新的提交,确保Python API、测试和实现的变更也反映在这个仓库中。