Data2Vec-Audio-Base-960h 项目介绍
项目背景
Data2Vec-Audio-Base-960h 是一个由 Facebook 开发的数据2vec模型,该模型针对语音识别任务进行了特别训练和微调,使用了 Librispeech 数据集的960小时语音数据。该模型的特点是使用16kHz采样率的语音输入。
自监督学习框架
Data2Vec 是一个通用框架,旨在增强自监督学习方法的通用性。在传统的自监督学习中,语音、自然语言处理和计算机视觉都有各自的专用算法和目标。数据2vec的核心理念是通过自提炼过程,预测完整输入数据的潜在表示,而这种表示是基于输入的掩码视图生成的。
不同于预测特定模态的目标(如单词、视觉代币或人类语音单位),数据2vec 预测的表示包含完整输入的信息。这种方法在语音识别、图像分类和自然语言理解的主要基准测试上表现出新的技术水平或与主流方法竞争的能力。
预训练方法
数据2vec在结构上采用了标准的Transformer架构。与其他模型不同的是,它更关注于生成语境化的潜在表示,使其不仅限于某一特定模态。此方法帮助模型在训练期间识别和学习更复杂的特征。
模型使用
Data2Vec-Audio-Base-960h 可以作为独立的声学模型来转录音频文件,代码示例如下:
from transformers import Wav2Vec2Processor, Data2VecForCTC
from datasets import load_dataset
import torch
# 加载模型和处理器
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h")
model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
# 加载样本数据集并读取音频文件
ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
# 进行分词
input_values = processor(ds[0]["audio"]["array"], return_tensors="pt", padding="longest").input_values # batch size为1
# 获取logits
logits = model(input_values).logits
# 取得argmax并解码
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
模型评估
可以利用该模型对LibriSpeech数据集中的"clean"和"other"分支进行评估,以计算单词错误率(WER):
from transformers import Wav2Vec2Processor, Data2VecForCTC
from datasets import load_dataset
import torch
from jiwer import wer
# 加载模型和处理器
processor = Wav2Vec2Processor.from_pretrained("facebook/data2vec-audio-base-960h").to("cuda")
model = Data2VecForCTC.from_pretrained("facebook/data2vec-audio-base-960h")
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
def map_to_pred(batch):
input_values = processor(batch["audio"]["array"], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
评估结果(WER):
- "clean":2.77
- "other":7.08
总结
Data2Vec-Audio-Base-960h 展现了在语音识别任务中使用通用自监督学习框架的潜力。该模型在不同的测试集上表现出显著的识别能力。这表明自监督学习框架在多模态任务中的广泛适用性,尤其是在语音识别领域的前景。