Project Icon

pix2struct

基于截图解析的视觉语言预训练模型

Pix2Struct是一个基于截图解析的视觉语言预训练模型。该模型可处理图像描述、图表问答和界面元素理解等多种任务。项目提供预训练的Base和Large模型检查点,以及9个下游任务的微调代码。Pix2Struct在多个视觉语言任务中表现优异,为相关研究提供了有力支持。

Pix2Struct

本仓库包含了Pix2Struct: 截图解析作为视觉语言理解的预训练的代码。

我们发布了Base和Large模型的预训练检查点,以及在论文中讨论的九个下游任务上微调它们的代码。我们无法发布预训练数据,但可以使用C4数据集中公开发布的URL复制这些数据。

入门

克隆GitHub仓库,安装pix2struct包,并运行测试以确保所有依赖项都成功安装。

git clone https://github.com/google-research/pix2struct.git
cd pix2struct
conda create -n pix2struct python=3.9
conda activate pix2struct
pip install -e ."[dev]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pytest

如果尚未安装,你可能需要先安装Java(sudo apt install default-jre)和conda

我们将使用Google Cloud Storage (GCS)进行数据和模型存储。在接下来的文档中,我们假设你自己的存储桶和目录的路径存储在PIX2STRUCT_DIR环境变量中:

export PIX2STRUCT_DIR="gs://<your_bucket>/<path_to_pix2struct_dir>"

运行实验的代码在查找预处理数据时假定使用此环境变量。

数据预处理

我们的数据预处理脚本默认使用Dataflow运行,利用Apache Beam库。通过关闭--后出现的标志,也可以在本地运行。

在接下来的文档中,我们假设GCP项目信息存储在以下环境变量中:

export GCP_PROJECT=<your_project_id>
export GCP_REGION=<your_region>

以下是预处理每个数据集所需的命令。结果将写入$PIX2STRUCT_DIR/data/<task_name>/preprocessed/,这是tasks.py中假定的文件结构。

TextCaps

mkdir -p data/textcaps
cd data/textcaps
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_train.json
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_val.json
curl -O https://dl.fbaipublicfiles.com/textvqa/data/textcaps/TextCaps_0.1_test.json
curl -O https://dl.fbaipublicfiles.com/textvqa/images/train_val_images.zip
curl -O https://dl.fbaipublicfiles.com/textvqa/images/test_images.zip
unzip train_val_images.zip
rm train_val_images.zip
unzip test_images.zip
rm test_images.zip
cd ..
gsutil -m cp -r textcaps_data $PIX2STRUCT_DIR/data/textcaps
python -m pix2struct.preprocessing.convert_textcaps \
  --textcaps_dir=$PIX2STRUCT_DIR/data/textcaps \
  --output_dir=$PIX2STRUCT_DIR/data/textcaps/processed \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

ChartQA

mkdir -p data/chartqa
cd data/chartqa
git clone https://github.com/vis-nlp/ChartQA.git
cp -r ChartQA/ChartQA\ Dataset/* ./
rm -rf ChartQA
cd ..
gsutil -m cp -r chartqa $PIX2STRUCT_DIR/data/chartqa
python -m pix2struct.preprocessing.convert_chartqa \
  --data_dir=$PIX2STRUCT_DIR/data/chartqa \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

RICO图像

Screen2Words、RefExp和Widget Captioning都需要RICO数据集中的图像。如果你想使用这些数据集中的任何一个,请在继续之前处理RICO图像。

cd data
wget https://storage.googleapis.com/crowdstf-rico-uiuc-4540/rico_dataset_v0.1/unique_uis.tar.gz
tar xvfz unique_uis.tar.gz
rm unique_uis.tar.gz
gsutil -m cp -r combined $PIX2STRUCT_DIR/data/rico_images

Widget Captioning

如果你还没有设置RICO,请在继续之前先进行设置。

mkdir -p data/widget_captioning
cd data/widget_captioning
git clone https://github.com/google-research-datasets/widget-caption.git
cp widget-caption/widget_captions.csv ./
cp widget-caption/split/*.txt ./
mv dev.txt val.txt
rm -rf widget-caption
cd ..
gsutil -m cp -r widget_captioning $PIX2STRUCT_DIR/data/widget_captioning
python -m pix2struct.preprocessing.convert_widget_captioning \
  --data_dir=$PIX2STRUCT_DIR/data/widget_captioning \
  --image_dir=$PIX2STRUCT_DIR/data/rico_images \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

Screen2Words

如果你还没有设置RICO,请在继续之前先进行设置。

cd data
git clone https://github.com/google-research-datasets/screen2words.git
gsutil -m cp -r screen2words $PIX2STRUCT_DIR/data/screen2words
python -m pix2struct.preprocessing.convert_screen2words \
  --screen2words_dir=$PIX2STRUCT_DIR/data/screen2words \
  --rico_dir=$PIX2STRUCT_DIR/data/rico_images \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

RefExp

如果你还没有设置RICO,请在继续之前先进行设置。

mkdir -p data/refexp
cd data/refexp
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/train.tfrecord
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/dev.tfrecord
wget https://github.com/google-research-datasets/uibert/raw/main/ref_exp/test.tfrecord
mv dev.tfrecord val.tfrecord
cd ..
gsutil -m cp -r refexp $PIX2STRUCT_DIR/data/refexp
python -m pix2struct.preprocessing.convert_refexp \
  --data_dir=$PIX2STRUCT_DIR/data/refexp \
  --image_dir=$PIX2STRUCT_DIR/data/rico_images \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

DocVQA

mkdir -p data/docvqa
cd data/docvqa

官方来源下载DocVQA(单文档视觉问答)(需要注册)。以下步骤假设train/val/test.tar.gz文件位于data/docvqa中。

tar xvf train.tar.gz
tar xvf val.tar.gz
tar xvf test.tar.gz
rm -r *.tar.gz */ocr_results

cd ..
gsutil -m cp -r docvqa $PIX2STRUCT_DIR/data/docvqa
python -m pix2struct.preprocessing.convert_docvqa \
  --data_dir=$PIX2STRUCT_DIR/data/docvqa \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

InfographicVQA

mkdir -p data/infographicvqa
cd data/infographicvqa

此网站下载InfographicVQA任务1(需要注册)。以下步骤假设train/val/test.jsonzip文件位于data/infographicvqa中。

for split in train val test
do
  unzip infographicVQA_${split}_v1.0_images.zip
  mv infographicVQA_${split}_v1.0_images $split
  mv infographicVQA_${split}_v1.0.json $split/${split}_v1.0.json
done
rm *.zip

cd ..
gsutil -m cp -r infographicvqa $PIX2STRUCT_DIR/data/infographicvqa
python -m pix2struct.preprocessing.convert_docvqa \
  --data_dir=$PIX2STRUCT_DIR/data/infographicvqa \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

OCR-VQA

mkdir -p data/ocrvqa
cd data/ocrvqa

按照OCR-VQA网站上的说明将数据下载到data/ocrvqa(需要爬取)。以下步骤假设data/ocrvqa包含一个名为images的目录和一个名为dataset.json的文件。

cd ..
gsutil -m cp -r ocrvqa $PIX2STRUCT_DIR/data/ocrvqa
python -m pix2struct.preprocessing.convert_ocrvqa \
  --data_dir=$PIX2STRUCT_DIR/data/ocrvqa \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

AI2D

mkdir -p data/
cd data/
wget https://ai2-public-datasets.s3.amazonaws.com/diagrams/ai2d-all.zip
unzip ai2d-all.zip
rm ai2d-all.zip
gsutil -m cp -r ai2d $PIX2STRUCT_DIR/data/ai2d
python -m pix2struct.preprocessing.convert_ai2d \
  --data_dir=$PIX2STRUCT_DIR/data/ai2d \
  --test_ids_path=gs://pix2struct-data/ai2d_test_ids.csv \
  -- \
  --runner=DataflowRunner \
  --save_main_session \
  --project=$GCP_PROJECT \
  --region=$GCP_REGION \
  --temp_location=$PIX2STRUCT_DIR/data/temp \
  --staging_location=$PIX2STRUCT_DIR/data/staging \
  --setup_file=./setup.py

运行实验

主要实验是作为T5X库的轻量级包装实现的。为简洁起见,我们演示了在Screen2Words数据集上微调预训练的基础Pix2Struct模型的示例工作流程。要扩展到更大的设置,请参阅T5X文档。

设置TPU

按照官方说明在Cloud TPU VM上运行JAX,这允许您直接ssh到TPU主机。

在此示例中,我们使用的是v3-8 TPU:

TPU_TYPE=v3-8
TPU_NAME=pix2struct-$TPU_TYPE
TPU_ZONE=europe-west4-a
gcloud compute tpus tpu-vm create $TPU_NAME \
  --zone=$TPU_ZONE \
  --accelerator-type=$TPU_TYPE \
  --version=tpu-vm-base
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$TPU_ZONE

一旦您ssh到TPU主机后,按照"入门"说明安装pix2struct包。

训练

以下命令将启动训练循环,该循环包括训练步骤与验证集上的评估交替进行。

python -m t5x.train \
  --gin_search_paths="pix2struct/configs" \
  --gin_file="models/pix2struct.gin" \
  --gin_file="runs/train.gin" \
  --gin_file="sizes/base.gin" \
  --gin_file="optimizers/adafactor.gin" \
  --gin_file="schedules/screen2words.gin" \
  --gin_file="init/pix2struct_base_init.gin" \
  --gin.MIXTURE_OR_TASK_NAME="'screen2words'" \
  --gin.MODEL_DIR="'$PIX2STRUCT_DIR/experiments/screen2words_base'" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \
  --gin.BATCH_SIZE=32

评估

以下命令在测试集上评估模型。你需要将检查点路径替换为根据验证性能实际选择的路径。

python -m t5x.eval \
  --gin_search_paths="pix2struct/configs" \
  --gin_file="models/pix2struct.gin" \
  --gin_file="runs/eval.gin" \
  --gin_file="sizes/base.gin" \
  --gin.MIXTURE_OR_TASK_NAME="'screen2words'" \
  --gin.CHECKPOINT_PATH="'$PIX2STRUCT_DIR/experiments/screen2words_base/checkpoint_286600'" \
  --gin.EVAL_OUTPUT_DIR="'$PIX2STRUCT_DIR/experiments/test_exp/test_eval'" \
  --gin.EVAL_SPLIT="'test'" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 4096, 'targets': 128}" \
  --gin.BATCH_SIZE=32

微调后的检查点

除了在configs/init目录中指定和发布的预训练检查点外,我们还发布了所有任务的微调模型检查点。

任务GCS 路径 (Base)GCS 路径 (Large)
TextCapsgs://pix2struct-data/textcaps_base/checkpoint_280400gs://pix2struct-data/textcaps_large/checkpoint_180600
ChartQAgs://pix2struct-data/chartqa_base/checkpoint_287600gs://pix2struct-data/charqa_large/checkpoint_182600
WidgetCaptioninggs://pix2struct-data/widget_captioning_base/checkpoint_281600gs://pix2struct-data/widget_captioning_large/checkpoint_181600
Screen2Wordsgs://pix2struct-data/screen2words_base/checkpoint_282600gs://pix2struct-data/screen2words_large/checkpoint_183000
RefExpgs://pix2struct-data/refexp_base/checkpoint_290000gs://pix2struct-data/refexp_large/checkpoint_187800
DocVQAgs://pix2struct-data/docvqa_base/checkpoint_284400gs://pix2struct-data/docvqa_large/checkpoint_184000
InfographicVQAgs://pix2struct-data/infographicvqa_base/checkpoint_284000gs://pix2struct-data/infographicvqa_large/checkpoint_182000
OCR-VQAgs://pix2struct-data/ocrvqa_base/checkpoint_290000gs://pix2struct-data/ocrvqa_large/checkpoint_188400
AI2Dgs://pix2struct-data/ai2d_base/checkpoint_284400gs://pix2struct-data/ai2d_large/checkpoint_184000

这些检查点与上述文档中的评估命令以及下面提到的两种推理方法兼容。请确保configs/sizes下的配置文件与检查点保持一致。

推理

我们提供了两种进行推理的方法。出于测试和演示目的,这些方法可以在CPU上运行。在这种情况下,请将JAX_PLATFORMS环境变量设置为cpu

命令行示例

我们提供了一个用于对单个示例进行推理的最小脚本。这个路径仅在极小规模下进行了测试,不适用于大规模推理。对于大规模推理,我们建议设置一个带有占位符标签的自定义任务,并运行上述文档中的评估脚本(t5x.eval)。

在以下示例中,我们展示了使用在TextCaps任务上微调的基础大小检查点来预测图像标题的命令。对于也接受文本提示(如VQA中的问题)的任务,你还可以通过text标志提供问题(除了用image标志指定图像)。

python -m pix2struct.example_inference \
  --gin_search_paths="pix2struct/configs" \
  --gin_file=models/pix2struct.gin \
  --gin_file=runs/inference.gin \
  --gin_file=sizes/base.gin \
  --gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \
  --gin.BATCH_SIZE=1 \
  --gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'" \
  --image=$HOME/test_image.jpg

Web演示

为了提供更加用户友好的演示,我们还提供了上述推理脚本的基于Web的替代方案。运行此命令时,假设你在本地运行演示,可以通过localhost:8080(或通过port标志指定的任何端口)访问Web演示。然后,你可以上传自定义图像和可选提示,而不是通过命令行指定它们。

python -m pix2struct.demo \
  --gin_search_paths="pix2struct/configs" \
  --gin_file=models/pix2struct.gin \
  --gin_file=runs/inference.gin \
  --gin_file=sizes/base.gin \
  --gin.MIXTURE_OR_TASK_NAME="'placeholder_pix2struct'" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 2048, 'targets': 128}" \
  --gin.BATCH_SIZE=1 \
  --gin.CHECKPOINT_PATH="'gs://pix2struct-data/textcaps_base/checkpoint_280400'"

清理

当你完成TPU VM的使用后,记得删除实例:

gcloud compute tpus tpu-vm delete $TPU_NAME --zone=$TPU_ZONE

注意

这不是官方支持的Google产品。

项目侧边栏1项目侧边栏2
推荐项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号