Pix2Struct
本仓库包含了Pix2Struct: 截图解析作为视觉语言理解的预训练的代码。
我们发布了Base和Large模型的预训练检查点,以及在论文中讨论的九个下游任务上微调它们的代码。我们无法发布预训练数据,但可以使用C4数据集中公开发布的URL复制这些数据。
入门
克隆GitHub仓库,安装pix2struct
包,并运行测试以确保所有依赖项都成功安装。
git clone https://github.com/google-research/pix2struct.git
cd pix2struct
conda create -n pix2struct python=3.9
conda activate pix2struct
pip install -e ."[dev]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pytest
如果尚未安装,你可能需要先安装Java(sudo apt install default-jre
)和conda。
我们将使用Google Cloud Storage (GCS)进行数据和模型存储。在接下来的文档中,我们假设你自己的存储桶和目录的路径存储在PIX2STRUCT_DIR
环境变量中:
export PIX2STRUCT_DIR="gs://<your_bucket>/<path_to_pix2struct_dir>"
运行实验的代码在查找预处理数据时假定使用此环境变量。
数据预处理
我们的数据预处理脚本默认使用Dataflow运行,利用Apache Beam库。通过关闭--
后出现的标志,也可以在本地运行。
在接下来的文档中,我们假设GCP项目信息存储在以下环境变量中:
export GCP_PROJECT=<your_project_id>
export GCP_REGION=<your_region>
以下是预处理每个数据集所需的命令。结果将写入$PIX2STRUCT_DIR/data/<task_name>/preprocessed/
,这是tasks.py
中假定的文件结构。
TextCaps
mkdir -p data/textcaps
cd data/textcaps
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json
curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
unzip train_val_images.zip
rm train_val_images.zip
unzip test_images.zip
rm test_images.zip
cd ..
gsutil -m cp -r textcaps_data $PIX2STRUCT_DIR/data/textcaps
python -m pix2struct.preprocessing.convert_textcaps \
--textcaps_dir=$PIX2STRUCT_DIR/data/textcaps \
--output_dir=$PIX2STRUCT_DIR/data/textcaps/processed \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
ChartQA
mkdir -p data/chartqa
cd data/chartqa
git clone https://github.com/vis-nlp/ChartQA.git
cp -r ChartQA/ChartQA\ Dataset/* ./
rm -rf ChartQA
cd ..
gsutil -m cp -r chartqa $PIX2STRUCT_DIR/data/chartqa
python -m pix2struct.preprocessing.convert_chartqa \
--data_dir=$PIX2STRUCT_DIR/data/chartqa \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
RICO图像
Screen2Words、RefExp和Widget Captioning都需要RICO数据集中的图像。如果你想使用这些数据集中的任何一个,请在继续之前处理RICO图像。
cd data
wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
tar xvfz unique_uis.tar.gz
rm unique_uis.tar.gz
gsutil -m cp -r combined $PIX2STRUCT_DIR/data/rico_images
Widget Captioning
如果你还没有设置RICO,请在继续之前先进行设置。
mkdir -p data/widget_captioning
cd data/widget_captioning
git clone https://github.com/google-research-datasets/widget-caption.git
cp widget-caption/widget_captions.csv ./
cp widget-caption/split/*.txt ./
mv dev.txt val.txt
rm -rf widget-caption
cd ..
gsutil -m cp -r widget_captioning $PIX2STRUCT_DIR/data/widget_captioning
python -m pix2struct.preprocessing.convert_widget_captioning \
--data_dir=$PIX2STRUCT_DIR/data/widget_captioning \
--image_dir=$PIX2STRUCT_DIR/data/rico_images \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
Screen2Words
如果你还没有设置RICO,请在继续之前先进行设置。
cd data
git clone https://github.com/google-research-datasets/screen2words.git
gsutil -m cp -r screen2words $PIX2STRUCT_DIR/data/screen2words
python -m pix2struct.preprocessing.convert_screen2words \
--screen2words_dir=$PIX2STRUCT_DIR/data/screen2words \
--rico_dir=$PIX2STRUCT_DIR/data/rico_images \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
RefExp
如果你还没有设置RICO,请在继续之前先进行设置。
mkdir -p data/refexp
cd data/refexp
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/train.tfrecord
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/dev.tfrecord
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/test.tfrecord
mv dev.tfrecord val.tfrecord
cd ..
gsutil -m cp -r refexp $PIX2STRUCT_DIR/data/refexp
python -m pix2struct.preprocessing.convert_refexp \
--data_dir=$PIX2STRUCT_DIR/data/refexp \
--image_dir=$PIX2STRUCT_DIR/data/rico_images \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
DocVQA
mkdir -p data/docvqa
cd data/docvqa
从官方来源下载DocVQA(单文档视觉问答)(需要注册)。以下步骤假设train/val/test.tar.gz文件位于data/docvqa
中。
tar xvf train.tar.gz
tar xvf val.tar.gz
tar xvf test.tar.gz
rm -r *.tar.gz */ocr_results
cd ..
gsutil -m cp -r docvqa $PIX2STRUCT_DIR/data/docvqa
python -m pix2struct.preprocessing.convert_docvqa \
--data_dir=$PIX2STRUCT_DIR/data/docvqa \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
InfographicVQA
mkdir -p data/infographicvqa
cd data/infographicvqa
从此网站下载InfographicVQA任务1(需要注册)。以下步骤假设train/val/test.json
和zip
文件位于data/infographicvqa
中。
for split in train val test
do
unzip infographicVQA_${split}_v1.0_images.zip
mv infographicVQA_${split}_v1.0_images $split
mv infographicVQA_${split}_v1.0.json $split/${split}_v1.0.json
done
rm *.zip
cd ..
gsutil -m cp -r infographicvqa $PIX2STRUCT_DIR/data/infographicvqa
python -m pix2struct.preprocessing.convert_docvqa \
--data_dir=$PIX2STRUCT_DIR/data/infographicvqa \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
OCR-VQA
mkdir -p data/ocrvqa
cd data/ocrvqa
按照OCR-VQA网站上的说明将数据下载到data/ocrvqa
(需要爬取)。以下步骤假设data/ocrvqa
包含一个名为images
的目录和一个名为dataset.json
的文件。
cd ..
gsutil -m cp -r ocrvqa $PIX2STRUCT_DIR/data/ocrvqa
python -m pix2struct.preprocessing.convert_ocrvqa \
--data_dir=$PIX2STRUCT_DIR/data/ocrvqa \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
AI2D
mkdir -p data/
cd data/
wget https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip
unzip ai2d-all.zip
rm ai2d-all.zip
gsutil -m cp -r ai2d $PIX2STRUCT_DIR/data/ai2d
python -m pix2struct.preprocessing.convert_ai2d \
--data_dir=$PIX2STRUCT_DIR/data/ai2d \
--test_ids_path=gs://pix2struct-data/ai2d_test_ids.csv \
-- \
--runner=DataflowRunner \
--save_main_session \
--project=$GCP_PROJECT \
--region=$GCP_REGION \
--temp_location=$PIX2STRUCT_DIR/data/temp \
--staging_location=$PIX2STRUCT_DIR/data/staging \
--setup_file=./setup.py
运行实验
主要实验是作为T5X库的轻量级包装实现的。为简洁起见,我们演示了在Screen2Words数据集上微调预训练的基础Pix2Struct模型的示例工作流程。要扩展到更大的设置,请参阅T5X文档。
设置TPU
按照官方说明在Cloud TPU VM上运行JAX,这允许您直接ssh
到TPU主机。
在此示例中,我们使用的是v3-8
TPU:
TPU_TYPE=v3-8
TPU_NAME=pix2struct-$TPU_TYPE
TPU_ZONE=europe-west4-a
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$TPU_ZONE \
--accelerator-type=$TPU_TYPE \
--version=tpu-vm-base
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$TPU_ZONE
一旦您ssh
到TPU主机后,按照"入门"说明安装pix2struct
包。
训练
以下命令将启动训练循环,该循环包括训练步骤与验证集上的评估交替进行。
python -m t5x.train \
--gin_search_paths="pix2struct/configs" \
--gin_file="models/pix2struct.gin" \
--gin_file="runs/train.gin" \
--gin_file="sizes/base.gin" \
--gin_file="optimizers/adafactor.gin" \
--gin_file="schedules/screen2words.gin" \
--gin_file="init/pix2struct_base_init.gin" \
--gin.MIXTURE_OR_TASK_NAME="'screen2words'" \
--gin.MODEL_DIR="'$PIX2STRUCT_DIR/experiments/screen2words_base'" \
--gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \
--gin.BATCH_SIZE=32
评估
以下命令在测试集上评估模型。你需要将检查点路径替换为根据验证性能实际选择的路径。
python -m t5x.eval \
--gin_search_paths="pix2struct/configs" \
--gin_file="models/pix2struct.gin" \
--gin_file="runs/eval.gin" \
--gin_file="sizes/base.gin" \
--gin.MIXTURE_OR_TASK_NAME="'screen2words'" \
--gin.CHECKPOINT_PATH="'$PIX2STRUCT_DIR/experiments/screen2words_base/checkpoint_286600'" \
--gin.EVAL_OUTPUT_DIR="'$PIX2STRUCT_DIR/experiments/test_exp/test_eval'" \
--gin.EVAL_SPLIT="'test'" \
--gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \
--gin.BATCH_SIZE=32
微调后的检查点
除了在configs/init
目录中指定和发布的预训练检查点外,我们还发布了所有任务的微调模型检查点。
任务 | GCS 路径 (Base) | GCS 路径 (Large) |
---|---|---|
TextCaps | gs://pix2struct-data/textcaps_base/checkpoint_280400 | gs://pix2struct-data/textcaps_large/checkpoint_180600 |
ChartQA | gs://pix2struct-data/chartqa_base/checkpoint_287600 | gs://pix2struct-data/charqa_large/checkpoint_182600 |
WidgetCaptioning | gs://pix2struct-data/widget_captioning_base/checkpoint_281600 | gs://pix2struct-data/widget_captioning_large/checkpoint_181600 |
Screen2Words | gs://pix2struct-data/screen2words_base/checkpoint_282600 | gs://pix2struct-data/screen2words_large/checkpoint_183000 |
RefExp | gs://pix2struct-data/refexp_base/checkpoint_290000 | gs://pix2struct-data/refexp_large/checkpoint_187800 |
DocVQA | gs://pix2struct-data/docvqa_base/checkpoint_284400 | gs://pix2struct-data/docvqa_large/checkpoint_184000 |
InfographicVQA | gs://pix2struct-data/infographicvqa_base/checkpoint_284000 | gs://pix2struct-data/infographicvqa_large/checkpoint_182000 |
OCR-VQA | gs://pix2struct-data/ocrvqa_base/checkpoint_290000 | gs://pix2struct-data/ocrvqa_large/checkpoint_188400 |
AI2D | gs://pix2struct-data/ai2d_base/checkpoint_284400 | gs://pix2struct-data/ai2d_large/checkpoint_184000 |
这些检查点与上述文档中的评估命令以及下面提到的两种推理方法兼容。请确保configs/sizes
下的配置文件与检查点保持一致。
推理
我们提供了两种进行推理的方法。出于测试和演示目的,这些方法可以在CPU上运行。在这种情况下,请将JAX_PLATFORMS
环境变量设置为cpu
。
命令行示例
我们提供了一个用于对单个示例进行推理的最小脚本。这个路径仅在极小规模下进行了测试,不适用于大规模推理。对于大规模推理,我们建议设置一个带有占位符标签的自定义任务,并运行上述文档中的评估脚本(t5x.eval
)。
在以下示例中,我们展示了使用在TextCaps任务上微调的基础大小检查点来预测图像标题的命令。对于也接受文本提示(如VQA中的问题)的任务,你还可以通过text
标志提供问题(除了用image
标志指定图像)。
python -m pix2struct.example_inference \
--gin_search_paths="pix2struct/configs" \
--gin_file=models/pix2struct.gin \
--gin_file=runs/inference.gin \
--gin_file=sizes/base.gin \
--gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \
--gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \
--gin.BATCH_SIZE=1 \
--gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'" \
--image=$HOME/test_image.jpg
Web演示
为了提供更加用户友好的演示,我们还提供了上述推理脚本的基于Web的替代方案。运行此命令时,假设你在本地运行演示,可以通过localhost:8080
(或通过port
标志指定的任何端口)访问Web演示。然后,你可以上传自定义图像和可选提示,而不是通过命令行指定它们。
python -m pix2struct.demo \
--gin_search_paths="pix2struct/configs" \
--gin_file=models/pix2struct.gin \
--gin_file=runs/inference.gin \
--gin_file=sizes/base.gin \
--gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \
--gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \
--gin.BATCH_SIZE=1 \
--gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'"
清理
当你完成TPU VM的使用后,记得删除实例:
gcloud compute tpus tpu-vm delete $TPU_NAME --zone=$TPU_ZONE
注意
这不是官方支持的Google产品。