Whisper JAX:让你的语音转文字功能速度快70倍!

Ray

项目简介

该存储库包含 OpenAI 的 Whisper 模型的优化 JAX代码,主要基于 Hugging Face Transformers Whisper 实现。与 OpenAI的 PyTorch 代码相比,Whisper JAX 的运行速度快了70 倍以上,使其成为可用的最快的 Whisper 运行。

JAX 代码在 CPU、GPU 和 TPU 上兼容,并且可以独立运行或作为推理端点。

安装

Whisper JAX 使用 Python 3.9 和 JAX 版本 0.4.5 进行了测试。安装假定您的设备上已安装最新版本的 JAX 包。您可以使用官方 JAX 安装指南来执行此操作:

https://github.com/google/jax#installation

一旦安装了适当版本的 JAX,就可以通过 pip 安装 Whisper JAX

pip install git+https://github.com/sanchit-gandhi/whisper-jax.git

要将 Whisper JAX 包更新到最新版本,只需运行:

pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git

管道使用

运行 Whisper JAX 的推荐方式是通过FlaxWhisperPipline抽象类。此类处理所有必要的预处理和后处理,并包装生成方法以实现跨加速器设备的数据并行性。

Whisper JAX 利用 JAX 的pmap功能实现跨 GPU/TPU 设备的数据并行性。该函数在第一次调用时进行即时 (JIT)编译。此后,该函数将被缓存,使其能够以超快的速度运行:

from whisper_jax import FlaxWhisperPipline
# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-large-v2")
# JIT compile the forward call - slow, but we only do once
text = pipeline("audio.mp3")
# used cached function thereafter - super fast!!
text = pipeline("audio.mp3")

半精度

通过在实例化管道时传递 dtype 参数,可以以半精度运行模型计算。通过以半精度存储中间张量,这将大大加快计算速度。模型权重的精度没有变化。

对于大多数 GPU,dtype 应设置为jnp.float16. 对于 A100 GPU 或 TPU,dtype 应设置为jnp.bfloat16:

from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
# instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", dtype=jnp.bfloat16)

批处理

Whisper JAX 还提供跨加速器设备批量处理单个音频输入的选项。音频首先被分成 30 秒的片段,然后将片段分派到模型进行并行转录。所得到的转录在边界处缝合在一起以给出单个、统一的转录。实际上,如果选择的批处理大小足够大,则与顺序转录音频样本相比,批处理提供了 10 倍的加速,并且 WER 1的损失不到 1%。

要启用批处理,请batch_size在实例化管道时传递参数:

from whisper_jax import FlaxWhisperPipline
# instantiate pipeline with batching
pipeline = FlaxWhisperPipline("openai/whisper-large-v2", batch_size=16)

任务

默认情况下,管道以所说的语言转录音频文件。对于语音翻译,请将参数设置 task为"translate":

# translate
text = pipeline("audio.mp3", task="translate")

时间戳

FlaxWhisperPipline还支持时间戳预测。请注意,启用时间戳将需要对前向调用进行第二次 JIT 编译,包括时间戳输出:

# transcribe and return timestamps
outputs = pipeline("audio.mp3",  task="transcribe", return_timestamps=True)
text = outputs["text"]  # transcription
chunks = outputs["chunks"]  # transcription + timestamps

高级用法

更高级的用户可能希望探索不同的并行化技术。Whisper JAX 代码构建在T5x 代码库之上,这意味着它可以使用 T5x 分区约定的模型、激活和数据并行性来运行。要使用 T5x 分区,必须定义逻辑轴规则和模型分区数量。更多详细信息,用户可参考官方T5x分区指南:https://github.com/google-research/t5x/blob/main/docs/usage/partitioning.md

基准测试

我们将 Whisper JAX 与官方OpenAI 实现和🤗 Transformers 实现进行比较。我们在长度不断增加的音频样本上对模型进行基准测试,并报告 10 次重复运行的平均推理时间(以秒为单位)。对于所有三个系统,我们将预加载的音频文件传递给模型并测量前向传递的时间。将加载音频文件的任务留给系统会增加所有基准时间的相等偏移,因此加载和转录音频文件的实际时间将高于报告的数字。

OpenAI 和 Transformers 都在 GPU 上的 PyTorch 中运行。Whisper JAX 在 GPU 和 TPU 上的 JAX 中运行。OpenAI 按照说话的顺序依次转录音频。Transformers 和 Whisper JAX 都使用批处理算法,其中音频块被一起批处理并并行转录(请参阅批处理部分)。

表 :长度增加的音频文件的平均推理时间(以秒为单位)。GPU设备是单个A100 40GB GPU。TPU 设备是单个 TPU v4-8。

benchmark

项目链接

https://github.com/sanchit-gandhi/whisper-jax

avatar
0
0
0
最新项目
Project Cover

豆包MarsCode

豆包 MarsCode 是一款革命性的编程助手,通过AI技术提供代码补全、单测生成、代码解释和智能问答等功能,支持100+编程语言,与主流编辑器无缝集成,显著提升开发效率和代码质量。

Project Cover

AI写歌

Suno AI是一个革命性的AI音乐创作平台,能在短短30秒内帮助用户创作出一首完整的歌曲。无论是寻找创作灵感还是需要快速制作音乐,Suno AI都是音乐爱好者和专业人士的理想选择。

Project Cover

白日梦AI

白日梦AI提供专注于AI视频生成的多样化功能,包括文生视频、动态画面和形象生成等,帮助用户快速上手,创造专业级内容。

Project Cover

有言AI

有言平台提供一站式AIGC视频创作解决方案,通过智能技术简化视频制作流程。无论是企业宣传还是个人分享,有言都能帮助用户快速、轻松地制作出专业级别的视频内容。

Project Cover

Kimi

Kimi AI助手提供多语言对话支持,能够阅读和理解用户上传的文件内容,解析网页信息,并结合搜索结果为用户提供详尽的答案。无论是日常咨询还是专业问题,Kimi都能以友好、专业的方式提供帮助。

Project Cover

讯飞绘镜

讯飞绘镜是一个支持从创意到完整视频创作的智能平台,用户可以快速生成视频素材并创作独特的音乐视频和故事。平台提供多样化的主题和精选作品,帮助用户探索创意灵感。

Project Cover

讯飞文书

讯飞文书依托讯飞星火大模型,为文书写作者提供从素材筹备到稿件撰写及审稿的全程支持。通过录音智记和以稿写稿等功能,满足事务性工作的高频需求,帮助撰稿人节省精力,提高效率,优化工作与生活。

Project Cover

阿里绘蛙

绘蛙是阿里巴巴集团推出的革命性AI电商营销平台。利用尖端人工智能技术,为商家提供一键生成商品图和营销文案的服务,显著提升内容创作效率和营销效果。适用于淘宝、天猫等电商平台,让商品第一时间被种草。

Project Cover

AIWritePaper论文写作

AIWritePaper论文写作是一站式AI论文写作辅助工具,简化了选题、文献检索至论文撰写的整个过程。通过简单设定,平台可快速生成高质量论文大纲和全文,配合图表、参考文献等一应俱全,同时提供开题报告和答辩PPT等增值服务,保障数据安全,有效提升写作效率和论文质量。

投诉举报邮箱: service@vectorlightyear.com
@2024 懂AI·鲁ICP备2024100362号-6·鲁公网安备37021002001498号