Nougat-Latex-Base项目介绍
Nougat-Latex-Base是一个基于Donut模型的图像到文本转换工具,它专注于从图像生成LaTeX代码。该项目从facebook的nougat-base模型精调而来,并利用了im2latex-100k数据集以提高其转换LaTeX代码的能力。针对初始编码器输入图像尺寸不适合用于公式图像这一问题,Nougat-Latex-Base对输入分辨率进行了调整,并采用自适应填充的方法,确保在任意环境中的公式图像片段,其缩放后的分辨率能与训练数据尽量匹配,从而避免因重新缩放而导致的质量下降。
模型评估
Nougat-Latex-Base的性能在由Wikipedia、arXiv以及im2latex-100k收集的图像-方程配对数据集中进行评估,这些数据是由lukas-blecher整理的。对比结果如下表所示:
模型 | 语义词准确率(token_acc) ↑ | 标准编辑距离(normed edit distance) ↓ |
---|---|---|
pix2tex | 0.5346 | 0.10312 |
pix2tex* | 0.60 | 0.10 |
nougat-latex-based | 0.623850 | 0.06180 |
其中,pix2tex是一个由ResNet、ViT和文本解码器组成的架构,初次出现在LaTeX-OCR项目中。
pix2tex*为LaTeX-OCR项目中报告的数据;pix2tex为作者根据发布的检查点自己评估的结果;nougat-latex-based采用beam-search策略生成结果进行评估。
使用指南
环境需求
在项目中运行Nougat-Latex-Base模型需要安装以下环境:
pip install transformers >= 4.34.0
模型应用
由于部分API接口可能会导致响应结果被截断,建议用户在本地运行模型以获得完整的转译结果。使用步骤如下:
- 下载仓库
git clone git@github.com:NormXU/nougat-latex-ocr.git
cd ./nougat-latex-ocr
- 推理示例
import torch
from PIL import Image
from transformers import VisionEncoderDecoderModel
from transformers.models.nougat import NougatTokenizerFast
from nougat_latex import NougatLaTexProcessor
model_name = "Norm/nougat-latex-base"
device = "cuda" if torch.cuda.is_available() else "cpu"
# 初始化模型
model = VisionEncoderDecoderModel.from_pretrained(model_name).to(device)
# 初始化处理器
tokenizer = NougatTokenizerFast.from_pretrained(model_name)
latex_processor = NougatLaTexProcessor.from_pretrained(model_name)
# 测试
image = Image.open("path/to/latex/image.png")
if not image.mode == "RGB":
image = image.convert('RGB')
pixel_values = latex_processor(image, return_tensors="pt").pixel_values
decoder_input_ids = tokenizer(tokenizer.bos_token, add_special_tokens=False,
return_tensors="pt").input_ids
with torch.no_grad():
outputs = model.generate(
pixel_values.to(device),
decoder_input_ids=decoder_input_ids.to(device),
max_length=model.decoder.config.max_length,
early_stopping=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
use_cache=True,
num_beams=5,
bad_words_ids=[[tokenizer.unk_token_id]],
return_dict_in_generate=True,
)
sequence = tokenizer.batch_decode(outputs.sequences)[0]
sequence = sequence.replace(tokenizer.eos_token, "").replace(tokenizer.pad_token, "").replace(tokenizer.bos_token, "")
print(sequence)
这段代码示例展示了如何从一张LaTeX图像生成LaTeX代码,用户可以根据需要修改图像路径并观察输出结果。该项目能够为有LaTeX需求的用户提供强大且高效的图像转码解决方案。