项目介绍:paxml
paxml,又称Pax,是一个基于Jax的框架,用于配置和运行机器学习实验。它提供了一个强大的工具集合,帮助研究人员和工程师在云端和本地环境中高效地进行大规模机器学习模型的训练与评估。
快速入门
设置Cloud TPU虚拟机
使用Cloud TPU进行机器学习计算是paxml的一个重要应用。你可以通过以下命令在项目中创建一个8核的Cloud TPU虚拟机:
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml
#create a TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR
如果你使用TPU Pod切片,可以参考官方指南对其进行更多设置。
安装Pax
在成功连接到TPU虚拟机之后,你可以从PyPI安装paxml的稳定版本,也可以从GitHub安装开发版本:
从PyPI安装稳定版本:
python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
运行测试模型
一旦安装完成,你可以通过以下命令运行一个简单的测试模型:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>
文档资源
paxml还提供了广泛的文档及Jupyter Notebook教程。你可以访问这里获取更多信息。
在TPU上运行Notebook
要在安装好的TPU虚拟机中运行Jupyter Notebook,以下是启动步骤:
- 使用SSH并进行端口转发进入TPU虚拟机。
- 在TPU虚拟机中安装Jupyter Notebook并调整相关配置。
- 从本地浏览器访问并运行Notebook。
在GPU上运行
paxml同样支持在NVIDIA GPU上运行,并提供了一些性能改进和FP8支持的功能。详细信息可以参考NVIDIA Rosetta仓库。
常见问题解答
paxml依赖于Jax运行,如果你在过程中遇到依赖问题,可以查阅requirements.txt
以获取解决方案。对于特定版本的问题解决,可以参考对应发布分支的requirements.txt
文件。
案例实验
paxml在各种模型和数据集上运行效果良好。以下是一些在c4数据集上的示例,包括1B参数模型、16B参数模型和GPT3-XL模型的运行结果,对应的命令及其产生的损失曲线和复杂度图表都可以帮助用户直观了解模型的性能。
性能基准
paxml还在TPU Pods上进行了一系列性能基准测试,展现了在弱缩放模式下大语言模型在TPU上的训练效率。
高级功能:多切片配置与MaxText集成
paxml支持多切片配置,这种设置可以显著提高跨多个TPU设备的训练效率。同时,paxml与MaxText仓库进行了良好的集成,共享了一些参数配置方式。
通过利用paxml,开发者能够显著简化模型训练的配置过程,并在多种硬件环境中高效复用这些流程。这使得paxml成为一个对于想要在不同硬件环境中进行大规模机器学习实验的研究人员来说不可或缺的工具。