CLAP
本仓库通过对比语言-音频预训练(CLAP)提供音频和文本的表示
使用CLAP,您可以为自己的模型或不同的下游任务提取任何给定音频和文本的潜在表示。
所有代码均来自以下论文,该论文已被IEEE国际声学、语音和信号处理会议(ICASSP 2023)接受:
最新更新:
1. 我们发布了新的CLAP预训练检查点,这些检查点是在我们的数据集收集仓库中的音乐和语音数据集上预训练的。
2. CLAP模型已被HuggingFace Transformers整合和支持。非常感谢Younes Belkada和Arthur Zucker为HuggingFace支持做出的贡献。
关于本项目
本项目是LAION的一个项目,旨在学习更好的音频理解并获取更多音频数据。 这是一个开源项目。我们采用了open_clip的代码库作为本项目的基础。
非常感谢@cfoster0允许我们使用他的仓库名称。
架构
对比语言-音频预训练,简称CLAP。参考CLIP(对比语言-图像预训练)架构,CLAP架构如下所示。
快速开始
我们为CLAP模型提供了PyPI库:
pip install laion-clap
然后您可以按照以下用法进行操作,或参考unit_test.py。
关于API的文档,请参考hook.py。
import numpy as np
import librosa
import torch
import laion_clap
# 量化
def int16_to_float32(x):
return (x / 32767.0).astype(np.float32)
def float32_to_int16(x):
x = np.clip(x, a_min=-1., a_max=1.)
return (x * 32767.).astype(np.int16)
model = laion_clap.CLAP_Module(enable_fusion=False)
model.load_ckpt() # 下载默认的预训练检查点
# 直接从音频文件获取音频嵌入
audio_file = [
'/home/data/test_clap_short.wav',
'/home/data/test_clap_long.wav'
]
audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=False)
print(audio_embed[:,-20:])
print(audio_embed.shape)
# 从音频数据获取音频嵌入
audio_data, _ = librosa.load('/home/data/test_clap_short.wav', sr=48000) # 采样率应为48000
audio_data = audio_data.reshape(1, -1) # 将其转换为(1,T)或(N,T)的形状
audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=False)
print(audio_embed[:,-20:])
print(audio_embed.shape)
# 直接从音频文件获取音频嵌入,但返回torch张量
audio_file = [
'/home/data/test_clap_short.wav',
'/home/data/test_clap_long.wav'
]
audio_embed = model.get_audio_embedding_from_filelist(x = audio_file, use_tensor=True)
print(audio_embed[:,-20:])
print(audio_embed.shape)
# 从音频数据获取音频嵌入
audio_data, _ = librosa.load('/home/data/test_clap_short.wav', sr=48000) # 采样率应为48000
audio_data = audio_data.reshape(1, -1) # 将其转换为(1,T)或(N,T)的形状
audio_data = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() # 在将数据输入模型之前进行量化
audio_embed = model.get_audio_embedding_from_data(x = audio_data, use_tensor=True)
print(audio_embed[:,-20:])
print(audio_embed.shape)
# 从文本获取文本嵌入:
text_data = ["我喜欢对比学习", "我喜欢预训练模型"]
text_embed = model.get_text_embedding(text_data)
print(text_embed)
print(text_embed.shape)
# 从文本获取文本嵌入,但返回torch张量:
text_data = ["我喜欢对比学习", "我喜欢预训练模型"]
text_embed = model.get_text_embedding(text_data, use_tensor=True)
print(text_embed)
print(text_embed.shape)
预训练模型
预训练检查点可以在这里找到。 请参考上一节了解如何加载和运行检查点。 对于PyPI库,630k-audioset-best.pt和630k-audioset-fusion-best.pt是我们的默认模型(非融合和融合)
我们进一步根据您的使用情况提供以下预训练模型:
- 对于10秒以下的一般音频:630k-audioset-best.pt或630k-best.pt
- 对于可变长度的一般音频:630k-audioset-fusion-best.pt或630k-fusion-best.pt
- 对于音乐:music_audioset_epoch_15_esc_90.14.pt
- 对于音乐和语音:music_speech_epoch_15_esc_89.25.pt
- 对于语音、音乐和一般音频:music_speech_audioset_epoch_15_esc_89.98.pt
这里列出的每个模型设置的检查点是训练中平均mAP分数最高的。 平均mAP分数是通过对以下4个分数取平均值计算得出的:AudioCaps上的A-->T mAP@10,AudioCaps上的T-->A mAP@10,Clotho上的A-->T mAP@10,以及Clotho上的T-->A mAP@10。
要使用上述预训练模型,您需要自己加载检查点,如:
2023年4月7日更新:我们发布了3个更大的CLAP模型,这些模型除了在LAION-Audio-630k上训练外,还在音乐和语音数据集上进行了训练。以下是模型的描述及其性能:
music_speech_audioset_epoch_15_esc_89.98.pt
:在音乐+语音+Audioset+LAION-Audio-630k上训练。零样本ESC50性能为89.98%,GTZAN性能为51%。music_audioset_epoch_15_esc_90.14.pt
:在音乐+Audioset+LAION-Audio-630k上训练。零样本ESC50性能为90.14%,GTZAN性能为71%。music_speech_epoch_15_esc_89.25.pt
:在音乐+语音+LAION-Audio-630k上训练。零样本ESC50性能为89.25%,GTZAN性能为69%。
该模型使用了更大的音频编码器。使用pip API加载模型:
import laion_clap
model = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base')
model.load_ckpt('checkpoint_path/checkpoint_name.pt')
请注意,这是为正在进行大规模下游任务的人提供的临时发布版本。 我们将在未来发布一个更全面的模型版本,包含详细的实验。 使用此模型时请自行承担风险。
- 所有新的检查点都没有使用融合进行训练。
music_speech_audioset_epoch_15_esc_89.98.pt
的训练数据集大小约为400万个样本。零样本GTZAN分数是使用提示"This audio is asong."进行评估的。
环境安装
如果您想检查并在项目中重用我们的模型,而不是直接使用pip库,您需要安装与我们相同的环境,请运行以下命令:
conda create env -n clap python=3.10
conda activate clap
git clone https://github.com/LAION-AI/CLAP.git
cd CLAP
# 您也可以按照官方说明安装PyTorch (https://pytorch.org/get-started/locally/)
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/torch_stable.html
pip install -r requirements.txt
数据集格式
我们使用webdataset格式的训练数据。有关我们数据集的详细信息,请参见https://github.com/LAION-AI/audio-dataset。
由于版权原因,我们无法发布用于训练此模型的数据集。但是,我们发布了LAION-audio-630K,这是我们用来组成数据集的数据源,其中包含每个音频的链接及其说明。更多详情请参阅LAION-audio-630K。您可以下载数据集,自行预处理并在本地进行训练。要在本地数据集上进行训练,请将训练脚本中的--remotedata
(参见experiment_scripts文件夹)更改为--datasetpath <您的数据集目录>
。
您可以在这里找到我们数据集格式的示例。 它包含完整的ESC50数据集,按照第一个5折交叉验证进行拆分。
训练、微调和评估
请在experiment_scripts文件夹中查找训练、微调和评估(零样本和检索)的脚本。
其中包含的脚本是我们在SLURM集群上用于训练模型的脚本。
您需要更改脚本以适应自己的环境。
例如,在单机多GPU设置中,您可能想使用torchrun
而不是srun
来运行脚本。
要在单GPU机器上进行训练,请使用CUDA_VISIBLE_DEVICES=0 python -m ...
而不是srun
。
我们使用Weights and Biases进行实验记录。您需要在环境中配置weights and biases。
要在本地数据集上进行训练,请将训练脚本中的--remotedata
(参见experiment_scripts文件夹)更改为--datasetpath <您的数据集目录>
。
核心代码
请参考main.py、train.py、data.py和model.py以快速熟悉我们的模型。
可复现性
Clotho数据集的预处理示例(webdataset格式)可以在这里下载(下载即表示您同意Clotho数据集中描述的许可条款)。使用48kHz AudioSet预训练的音频编码器可以在这里找到,其中HTSAT-fullset-imagenet-map=0.467.ckpt
是用于初始化我们的HTSAT音频编码器的检查点。通过加载音频编码器检查点并在相同数据集上训练,您应该能获得类似的结果。
在Clotho数据集上训练模型的脚本可以在这里找到。您需要将datasetpath
和pretrained-audio
替换为指向您自己的目录。您可以查看在单个A100 GPU上训练脚本的报告作为参考。
由于大多数数据集有版权限制,很遗憾我们无法直接分享其他预处理过的数据集。由关键词到描述模型生成的Audioset描述可以在这里找到。
使用ESC50官方分割进行零样本分类
以下是使用pip API在ESC50官方分割的第一个子集上进行零样本分类的示例代码:
import laion_clap
import glob
import json
import torch
import numpy as np
device = torch.device('cuda:0')
# 下载 https://drive.google.com/drive/folders/1scyH43eQAcrBz-5fAw44C6RNBhC3ejvX?usp=sharing 并解压 ./ESC50_1/test/0.tar 到 ./ESC50_1/test/
esc50_test_dir = './ESC50_1/test/*/'
class_index_dict_path = './class_labels/ESC50_class_labels_indices_space.json'
# 加载模型
model = laion_clap.CLAP_Module(enable_fusion=False, device=device)
model.load_ckpt()
# 获取类别索引字典
class_index_dict = {v: k for v, k in json.load(open(class_index_dict_path)).items()}
# 获取所有数据
audio_files = sorted(glob.glob(esc50_test_dir + '**/*.flac', recursive=True))
json_files = sorted(glob.glob(esc50_test_dir + '**/*.json', recursive=True))
ground_truth_idx = [class_index_dict[json.load(open(jf))['tag'][0]] for jf in json_files]
with torch.no_grad():
ground_truth = torch.tensor(ground_truth_idx).view(-1, 1)
# 获取文本特征
all_texts = ["This is a sound of " + t for t in class_index_dict.keys()]
text_embed = model.get_text_embedding(all_texts)
audio_embed = model.get_audio_embedding_from_filelist(x=audio_files)
ranking = torch.argsort(torch.tensor(audio_embed) @ torch.tensor(text_embed).t(), descending=True)
preds = torch.where(ranking == ground_truth)[1]
preds = preds.cpu().numpy()
metrics = {}
metrics[f"mean_rank"] = preds.mean() + 1
metrics[f"median_rank"] = np.floor(np.median(preds)) + 1
for k in [1, 5, 10]:
metrics[f"R@{k}"] = np.mean(preds < k)
# map@10
metrics[f"mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0))
print(
f"零样本分类结果: "
+ "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in metrics.items()])
)
对于ESC50数据集,您可以从这里下载我们处理好的webdataset格式的ESC50,并将./test/0.tar
解压到./test/
。或者您可以下载原始ESC50数据集,并自行将标签处理成class_labels/ESC50_class_labels_indices_space.json
的格式(将_
替换为空格)。
结果应该与以下相同:
对于model = laion_clap.CLAP_Module(enable_fusion=True, device=device)
:mean_rank: 1.2425 median_rank: 1.0000 R@1: 0.9050 R@5: 0.9900 R@10: 0.9925 mAP@10: 0.9407
对于model = laion_clap.CLAP_Module(enable_fusion=False, device=device)
:mean_rank: 1.1450 median_rank: 1.0000 R@1: 0.9275 R@5: 0.9975 R@10: 1.0000 mAP@10: 0.9556
注意,这些结果略高于论文中报告的结果,因为我们使用了ESC50的训练+测试数据,并删除了其他训练数据集(主要是freesound)中的重叠数据。
引用
如果您发现这个项目和LAION-Audio-630K数据集有用,请引用我们的论文:
@inproceedings{laionclap2023,
title = {Large-scale Contrastive Language-Audio Pretraining with Feature Fusion and Keyword-to-Caption Augmentation},
author = {Wu*, Yusong and Chen*, Ke and Zhang*, Tianyu and Hui*, Yuchen and Berg-Kirkpatrick, Taylor and Dubnov, Shlomo},
booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing, ICASSP},
year = {2023}
}
@inproceedings{htsatke2022,
author = {Ke Chen and Xingjian Du and Bilei Zhu and Zejun Ma and Taylor Berg-Kirkpatrick and Shlomo Dubnov},
title = {HTS-AT: A Hierarchical Token-Semantic Audio Transformer for Sound Classification and Detection},
booktitle={IEEE International Conference on Acoustics, Speech and Signal Processing, ICASSP},
year = {2022}
}
致谢
这个项目正在进行中,因此代码库和模型可能不完美或存在bug。 我们非常感谢任何形式的贡献或提出的问题。 如果您发现bug或有任何建议,请随时开启一个issue或联系我们。 如果您想积极地为这个项目做出贡献,请加入LAION的Discord。