Paxml (又名 Pax)
Pax 是一个配置和运行基于 Jax 的机器学习实验的框架。
快速开始
设置云 TPU VM
我们参考 此页面 以获取有关启动云 TPU 项目更详尽的文档。以下命令足以从公司机器创建一个具有 8 个核心的云 TPU VM。
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml
# 创建 TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR
如果您使用的是 TPU Pod 切片,请参考 此指南。使用 gcloud 并添加 --worker=all
选项,在本地机器上运行所有命令:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \
--worker=all --command="<commmands>"
以下快速入门部分假设您在单主机 TPU 上运行,因此可以 ssh 到 VM 并在那里运行命令。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
安装 Pax
在 ssh 到 VM 后,您可以从 PyPI 安装稳定版本的 paxml,或从 GitHub 安装开发版本。
从 PyPI(https://pypi.org/project/paxml/)安装稳定版本:
python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
如果遇到传递依赖问题并且您使用的是原生云 TPU VM 环境,请导航到相应的发布分支 rX.Y.Z 并下载 paxml/pip_package/requirements.txt
。此文件包含在原生云 TPU VM 环境中构建/测试相应版本所需的所有传递依赖项的确切版本。
git clone -b rX.Y.Z https://github.com/google/paxml
pip install --no-deps -r paxml/paxml/pip_package/requirements.txt
从 GitHub 安装开发版本,并便于编辑代码:
# 首先安装 praxis 的开发版本
git clone https://github.com/google/praxis
pip install -e praxis
git clone https://github.com/google/paxml
pip install -e paxml
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
运行测试模型
# 使用 pjit (SPMD) 的示例模型
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>
# 使用 pmap 的示例模型
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \
--job_log_dir=gs://<your-bucket> \
--pmap_use_tensorstore=True
文档
请访问我们的 docs 文件夹 以获取文档和 Jupyter Notebook 教程。请参见以下部分有关在云 TPU VM 上运行 Jupyter Notebook 的说明。
运行 Notebook
您可以在刚刚安装 paxml 的 TPU VM 中运行 示例 notebooks。
在 v4-8
中启用 notebook 的步骤
-
使用端口转发 ssh 到 TPU VM
gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME --zone=$ZONE --ssh-flag="-4 -L 8080:localhost:8080"
-
在 TPU VM 上安装 jupyter notebook 并降级 markupsafe
pip install notebook
pip install markupsafe==2.0.1
-
导出 jupyter 路径
export PATH=/home/$USER/.local/bin:$PATH
-
将 示例 notebooks 拷贝到您的 TPU VM
gcloud compute tpus tpu-vm scp $TPU_NAME:<TPU 内路径> <notebooks 的本地路径> --zone=$ZONE --project=$PROJECT
-
从 TPU VM 中启动 jupyter notebook 并记录 jupyter notebook 生成的令牌
jupyter notebook --no-browser --port=8080
-
然后在您的本地浏览器中访问:http://localhost:8080/ 并输入提供的令牌
注意:如果需要在第一个 notebook 仍占用 TPU 时启动第二个 notebook,可以运行
pkill -9 python3
以释放 TPU。
在 GPU 上运行
注意:NVIDIA 发布了支持 H100 FP8 和广泛 GPU 性能改进的 Pax 更新版本。请访问 NVIDIA Rosetta 仓库以获取更多详细信息和使用说明。
在 GPU 上运行 PGLE 工作流程
Profile Guided Latency Estimator (PGLE) 工作流程测量计算和集体操作的实际运行时间,并将配置文件信息反馈给 XLA 编译器以便做出更好的调度决策。
在 XLA/GPU 中使用 Profile Guided Latency Estimator 工作流程如下:
- 运行工作负载一次,启用异步集体操作和延迟隐藏调度器。
可以通过设置以下环境变量来实现:
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true"
- 使用 JAX profiler 收集和后处理配置文件,将提取的指令延迟保存到二进制 protobuf 文件中。
import os
from etils import epath
import jax
from jax.experimental import profiler as exp_profiler
# 定义您的配置文件目录
profile_dir = 'gs://my_bucket/profile'
jax.profiler.start_trace(profile_dir)
# 运行您的工作流程
# for i in range(10):
# train_step()
# 停止追踪
jax.profiler.stop_trace()
profile_dir = epath.Path(profile_dir)
directories = profile_dir.glob('plugins/profile/*/')
directories = [d for d in directories if d.is_dir()]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
# 后处理配置文件
fdo_profile = exp_profiler.get_profiled_instructions_proto(os.fspath(rundir))
# 将配置文件 proto 保存到文件。
dump_dir = rundir / 'profile.pb'
dump_dir.parent.mkdir(parents=True, exist_ok=True)
dump_dir.write_bytes(fdo_profile)
完成此步骤后,您将在代码中打印出的 rundir
目录下获得一个 profile.pb
文件。
- 再次运行工作负载并将该文件传递给编译。
您需要将 profile.pb
文件传递给 --xla_gpu_pgle_profile_file_or_directory_path
标志。
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_pgle_profile_file_or_directory_path=/path/to/profile/profile.pb"
要在 XLA 中启用日志记录并检查配置文件是否正常,请将日志级别设置为包括 INFO
:
export TF_CPP_MIN_LOG_LEVEL=0
运行实际工作流程,如果在运行日志中找到这些日志,表示在延迟隐藏调度器中使用了 profiler:
2023-07-21 16:09:43.551600: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:478] Using PGLE profile from /tmp/profile/plugins/profile/2023_07_20_18_29_30/profile.pb
2023-07-21 16:09:43.551741: I external/xla/xla/service/gpu/gpu_hlo_schedule.cc:573] Found profile, using profile guided latency estimator
常见问题
-
Pax 在 Jax 上运行,您可以在 这里 找到有关在云 TPU 上运行 Jax 作业的详细信息,也可以在 这里 找到有关在云 TPU pod 上运行 Jax 作业的详细信息。
-
如果遇到依赖错误,请参考您正在安装的稳定版本对应的分支中的
requirements.txt
文件。 例如,对于 稳定版本 0.4.0,使用分支r0.4.0
并参考 requirements.txt 以获取稳定版本使用的依赖项的确切版本。
示例收敛运行
以下是一些在 c4 数据集 上的示例收敛运行。
1B 模型在 c4 数据集上
您可以在 TPU v4-8
上使用 c4.py 中的配置 C4Spmd1BAdam4Replicas
运行一个 1B
参数模型在 c4 数据集上:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket>
您可以观察到损失曲线和 log perplexity
图如下:
<img src=paxml/docs/images/1B-loss.png width="400" height="300"> <img src=paxml/docs/images/1B-pplx.png width="400" height="300">
16B 模型在 c4 数据集上
您可以在 TPU v4-64
上使用 c4.py 中的配置 C4Spmd16BAdam32Replicas
运行一个 16B
参数模型在 c4 数据集上:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \
--job_log_dir=gs://<your-bucket>
您可以观察到损失曲线和 log perplexity
图如下:
<img src=paxml/docs/images/16B-loss.png width="400" height="300"> <img src=paxml/docs/images/16B-pplx.png width="400" height="300">
GPT3-XL 模型在 c4 数据集上
您可以在 TPU v4-128
上使用 c4.py 中的配置 C4SpmdPipelineGpt3SmallAdam64Replicas
运行 GPT3-XL 模型在 c4 数据集上:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \
--job_log_dir=gs://<your-bucket>
你可以观察到损失曲线和log perplexity
图如下:
<img src=paxml/docs/images/GPT3-XL-loss.png width="400" height="300"> <img src=paxml/docs/images/GPT3-XL-pplx.png width="400" height="300">
Cloud TPU v4 上的基准测试
PaLM 论文介绍了一种称为模型 FLOPs 利用率 (MFU) 的效率指标。MFU 衡量的方式是实际吞吐量(例如语言模型的每秒标记数)与系统利用 100% 峰值 FLOPs 的理论最大吞吐量之比。这与其他计算利用率的方式不同,因为它不包括在反向传递过程中花费在激活重新物化上的 FLOPs,这意味着通过 MFU 测量的效率直接转化为端到端的训练速度。
为了评估 TPU v4 Pods 上关键工作负载类别的 MFU,我们在一系列只包含解码器的 Transformer 语言模型 (GPT) 配置上进行了深入的基准测试,这些配置的参数规模从数十亿到数万亿不等,使用了 c4 数据集。下图显示了使用“弱扩展”模式进行训练的效率,其中我们按照使用芯片数量的比例增加模型规模。
<img src=paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png width="500" height="300">
Pax 在多片上
此存储库中的多片配置参考 1. 单片配置 语法和模型架构 和 2. MaxText 库 配置值。
我们提供了在 c4_multislice.py 下的示例运行,作为在多片上运行 Pax 的起点。
使用排队资源设置 Cloud TPU VMs
我们参考 此页面 获取有关使用排队资源进行多片 Cloud TPU 项目的更详尽的文档。以下步骤展示了设置 TPU 以运行该存储库中的示例配置。
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-128 # 或 v4-384,取决于你运行哪个配置
例如,要在 2 片 v4-128 上运行 C4Spmd22BAdam2xv4_128
,你需要如下方式设置 TPU:
export TPU_PREFIX=<your-prefix> # 新的 TPU 将基于该前缀创建
export QR_ID=$TPU_PREFIX
export NODE_COUNT=<number-of-slices> # 1, 2 或 4,取决于你运行哪个配置
# 创建一个 TPU VM
gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX
安装 Pax
前面描述的设置命令需要在所有切片中的所有工作节点上运行。你可以 1)分别 ssh 到每个工作节点和每个切片;或 2)使用 for 循环加 --worker=all
参数,如下命令。
for ((i=0; i<$NODE_COUNT; i++))
do
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command="pip install paxml && pip install orbax==0.1.1 && pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
done
运行一个测试多片模型
为了运行多片配置,打开与 $NODE_COUNT 数量相同的终端。对于我们在 2 片上的实验(C4Spmd22BAdam2xv4_128
),打开两个终端。然后,在每个终端中分别运行以下命令。
从终端 0 开始,运行用于切片 0 的训练命令,如下所示:
export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS JAX_USE_PJRT_C_API_ON_TPU=1 \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"
从终端 1 开始,同时运行用于切片 1 的训练命令,如下所示:
export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS JAX_USE_PJRT_C_API_ON_TPU=1 \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"
MaxText 到 Pax
此表格涵盖了 MaxText 变量名如何转换为 Pax 的详细信息。
注意 MaxText 有一个“比例”,可乘以几个参数(base_num_decoder_layers, base_emb_dim, base_mlp_dim, base_num_heads)以得到最终值。
另一个需要提及的是,虽然 Pax 在 DCN 和 ICN MESH_SHAPE 中覆盖了一个数组,但在 MaxText 中,DCN 和 ICI 分别有单独的变量 data_parallelism, fsdp_parallelism 和 tensor_parallelism。由于这些值默认设置为 1,仅记录了大于 1 的变量在此表中。
也就是说,ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism]
和 DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]
Pax C4Spmd22BAdam2xv4_128 | MaxText 2xv4-128.sh | (应用比例后) | ||
---|---|---|---|---|
scale (应用到接下来的 4 个变量) | 3 | |||
NUM_LAYERS | 48 | base_num_decoder_layers | 16 | 48 |
MODEL_DIMS | 6144 | base_emb_dim | 2048 | 6144 |
HIDDEN_DIMS | 24576 | MODEL_DIMS * 4 (= base_mlp_dim) | 8192 | 24576 |
NUM_HEADS | 24 | base_num_heads | 8 | 24 |
DIMS_PER_HEAD | 256 | head_dim | 256 | |
PERCORE_BATCH_SIZE | 16 | per_device_batch_size | 16 | |
MAX_SEQ_LEN | 1024 | max_target_length | 1024 | |
VOCAB_SIZE | 32768 | vocab_size | 32768 | |
FPROP_DTYPE | jnp.bfloat16 | dtype | bfloat16 | |
USE_REPEATED_LAYER | TRUE | |||
SUMMARY_INTERVAL_STEPS | 10 | |||
ICI_MESH_SHAPE | [1, 64, 1] | ici_fsdp_parallelism | 64 | |
DCN_MESH_SHAPE | [2, 1, 1] | dcn_data_parallelism | 2 |
数据输入
介绍
输入是 BaseInput
类的一个实例,用于将数据输入模型进行训练/评估/解码。
class BaseInput:
def get_next(self):
pass
def reset(self):
pass
它的行为类似于一个迭代器:get_next()
返回一个 NestedMap
,其中每个字段是一个数值数组,批量大小为其领先维度。
每个输入由 BaseInput.HParams
的子类配置。
在此页面中,我们使用 p
表示一个 BaseInput.Params
的实例,并实例化为 input
。
多主机进料
在 Pax 中,数据始终是多主机的:每个 Jax 进程都会有一个单独的、独立的 input
实例化。它们的参数将拥有不同的 p.infeed_host_index
,由 Pax 自动设置。
因此,每个主机上看到的本地批量大小是 p.batch_size
,全局批量大小是 (p.batch_size * p.num_infeed_hosts)
。通常会看到 p.batch_size
设置为 jax.local_device_count() * PERCORE_BATCH_SIZE
。
由于这种多主机性质,input
必须被正确分片。
对于训练,每个 input
必须从不发出相同的批次,对于评估一个有限的数据集,每个 input
必须在相同数量的批次后终止。最佳解决方案是正确分片的输入实现,以使不同主机上的每个 input
不重叠。否则,也可以使用不同的随机种子来避免训练期间的重复批次。
评估数据的输入
input.reset()
从不在训练数据上调用,但可以用于评估(或解码)数据。
对于每次评估(或解码)运行,Pax 通过调用 input.get_next()
N
次从 input
获取 N
批次。使用的批次数,N
,可以是用户通过 p.eval_loop_num_batches
指定的固定数;或者 N
也可以是动态的(p.eval_loop_num_batches=None
),即我们调用 input.get_next()
直到耗尽其所有数据(通过引发 StopIteration
或 tf.errors.OutOfRange
)。
如果 p.reset_for_eval=True
,则忽略 p.eval_loop_num_batches
,并且 N
是由耗尽数据的批次数动态确定的。在这种情况下,应将 p.repeat
设置为 False,否则会导致无限次解码/评估。
如果 p.reset_for_eval=False
,则 Pax 将获取 p.eval_loop_num_batches
的批次。应设置为 p.repeat=True
,以确保数据不会过早耗尽。
请注意,LingvoEvalAdaptor 输入需要 p.reset_for_eval=True
。
| | `N`: 静态 | `N`: 动态 |
| ------------------------ | ----------------------- | ----------------------- |
| `p.reset_for_eval=True` | 每次评估运行使用 | 每次评估运行一个周期。 |
: : 前 `N` 批次。尚未 : `eval_loop_num_batches` :
: : 支持。 : 被忽略。输入必须是 :
: : : 有限的 :
: : : (`p.repeat=False`) :
| `p.reset_for_eval=False` | 每次评估运行使用 | 不支持。 |
: : 非重叠 `N` 批次在滚动 : :
: : 基础上,根据 : :
: : `eval_loop_num_batches` : :
: : 。输入必须重复 : :
: : 无限 : :
: : (`p.repeat=True`) 或 : :
: : 否则可能会引发 : :
: : 异常 : :
如果在一个周期内精确运行解码/评估(即 `p.reset_for_eval=True`),输入必须正确处理分片,这样每个分片在生产相同数量的批次后在相同的步骤上引发。这通常意味着输入必须填充评估数据。`SeqIOInput` 和 `LingvoEvalAdaptor` 会自动完成此操作(参见下文)。
### 评估指标
对于大多数输入,我们只调用 `get_next()` 获取批次数据。有一种评估数据例外,其中“如何计算指标”也在输入对象上定义。
这仅适用于定义一些规范评估基准的 `SeqIOInput`。具体来说,Pax 使用 `predict_metric_fns` 和 `score_metric_fns()` 由 SeqIO 任务定义来计算评估指标(尽管 Pax 不直接依赖 SeqIO 评估器)。
## 最佳实践
当一个模型使用多个输入时,无论是在训练/评估之间还是在预训练/微调之间使用不同的训练数据,用户必须确保这些输入使用的分词器是相同的,尤其是在导入由他人实现的不同输入时。
用户可以通过 `input.ids_to_strings()` 解码一些 ID 来检查分词器的正确性。
通过检查一些批数据来验证数据的正确性总是一个好主意。用户可以在 colab 中轻松重现参数并检查数据:
```python
p = ... # 指定预期的输入参数
inp = p.Instantiate()
b = inp.get_next()
print(b)
训练数据通常不应该使用固定的随机种子。这是因为如果训练任务被抢占,训练数据将开始重复。特别是,对于 Lingvo 输入,我们建议将 p.input.file_random_seed = 0
用于训练数据。
要测试是否正确处理了分片,用户可以手动设置不同的 p.num_infeed_hosts, p.infeed_host_index
值,并查看实例化的输入是否发出不同的批次。
输入类型
Pax 支持 3 种类型的输入:SeqIO、Lingvo 和自定义。
SeqIO
可以使用 SeqIOInput
导入数据集。
SeqIO 输入会自动处理评估数据的正确分片和填充。
Lingvo
可以使用 LingvoInputAdaptor
导入数据集。
输入完全委托给 Lingvo 实现,可能会或可能不会自动处理分片。
对于使用固定 packing_factor
的基于 GenericInput 的 Lingvo 输入实现,我们建议使用 LingvoInputAdaptorNewBatchSize
为内部 Lingvo 输入指定更大的批量大小,并将所需的(通常更小)的批量大小放在 p.batch_size
上。
对于评估数据,我们建议使用 LingvoEvalAdaptor
处理分片和填充以便在一个周期内运行评估。
自定义
BaseInput
的自定义子类。用户实现自己的子类,通常使用 tf.data
或 SeqIO。
用户还可以继承现有的输入类,仅定制批次的后处理。例如:
class MyInput(base_input.LingvoInputAdaptor):
def get_next(self):
batch = super().get_next()
# 修改批次:batch.new_field = ...
return batch
关键的 Pax 组件
超参数
超参数是定义模型和配置实验的重要组成部分。
为了更好地与 Python 工具集成,Pax/Praxis 使用基于 pythonic 数据类的配置风格来配置超参数。
class Linear(base_layer.BaseLayer):
"""没有偏置的线性层。"""
class HParams(BaseHParams):
"""此层类的相关超参数。
属性:
input_dims: 输入的深度。
output_dims: 输出的深度。
"""
input_dims: int = 0
output_dims: int = 0
嵌套
也可以嵌套 HParams 数据类,在下面的示例中,linear_tpl 属性是一个嵌套的 Linear.HParams。
class FeedForward(base_layer.BaseLayer):
"""带激活的前馈层。"""
class HParams(BaseHParams):
"""此层类的相关超参数。
属性:
input_dims: 输入的深度。
output_dims: 输出的深度。
has_bias: 是否添加偏置权重。
linear_tpl: 线性层参数。
activation_tpl: 激活层参数。
"""
input_dims: int = 0
output_dims: int = 0
has_bias: bool = True
linear_tpl: BaseHParams = sub_config_field(Linear.HParams)
activation_tpl: activations.BaseActivation.HParams = sub_config_field(
ReLU.HParams)
层
一个层表示一个可能带有可训练参数的任意函数。一个层可以包含其他层作为子层。层是模型的基本构建块。层从 Flax nn.Module 继承。
通常层定义两种方法:
setup
此方法创建可训练的权重和子层。
fprop
此方法定义前向传播函数,根据输入计算一些输出。此外,fprop 可能会添加摘要或跟踪辅助损失。
Fiddle 和共享层
Fiddle 是一个开源的 Python 优先配置库,设计用于 ML 应用。Pax/Praxis 支持与 Fiddle Config/Partial 互操作和一些高级功能,如快速检测错误和共享参数。
fdl_config = Linear.HParams.config(input_dims=1, output_dims=1)
# 拼写错误。
fdl_config.input_dimz = 31337 # 立即引发异常以快速捕捉拼写错误!
fdl_partial = Linear.HParams.partial(input_dims=1)
使用 Fiddle,层可以配置为共享(例如:仅实例化一次并共享可训练的权重)。
模型
模型仅定义网络,通常是一组层,并定义与模型交互的接口,如解码等。
一些示例基本模型包括:
- 语言模型
- 序列模型
- 分类模型
任务
一个任务包含一个或多个模型和学习者/优化器。最简单的任务子类是 SingleTask
,它需要以下 Hparams:
class HParams(base_task.BaseTask.HParams):
"""任务参数。
属性:
name: 此任务对象的名称,必须是一个有效的标识符。
model: 包含所有层的底层 JAX 模型。
train: 控制此任务如何训练的 HParams。
metrics: 基本指标聚合类,用于确定如何计算指标。
loss_aggregator: 损失聚合器类,用于确定如何聚合损失(例如,单一或多重损失)
vn: 控制变异噪声的 HParams。
发布
PyPI 版本 | 提交 |
---|---|
0.1.0 | 546370f5323ef8b27d38ddc32445d7d3d1e4da9a |
版权所有 2022 Google LLC
根据 Apache 许可证 2.0 版(“许可证”)获得许可;除非符合许可证,否则不得使用该文件。
您可以在以下网址获得许可证副本:
https://www.apache.org/licenses/LICENSE-2.0
除非适用法律要求或书面同意,否则根据许可证分发的软件按“原样”分发,无任何明示或暗示的担保。
参见许可证以了解管理权限和限制的特定语言。