安装
pip install auraloss
如果你想使用MelSTFTLoss()
或FIRFilter()
,你需要指定额外安装(librosa和scipy)。
pip install auraloss[all]
使用
import torch
import auraloss
mrstft = auraloss.freq.MultiResolutionSTFTLoss()
input = torch.rand(8,1,44100)
target = torch.rand(8,1,44100)
loss = mrstft(input, target)
新功能:使用梅尔频谱图进行感知加权。
bs = 8
chs = 1
seq_len = 131072
sample_rate = 44100
# 你想比较的一些音频
target = torch.rand(bs, chs, seq_len)
pred = torch.rand(bs, chs, seq_len)
# 定义损失函数
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
scale="mel",
n_bins=128,
sample_rate=sample_rate,
perceptual_weighting=True,
)
# 计算
loss = loss_fn(pred, target)
引用
如果你在你的工作中使用了这段代码,请考虑引用我们。
@inproceedings{steinmetz2020auraloss,
title={auraloss: {A}udio focused loss functions in {PyTorch}},
author={Steinmetz, Christian J. and Reiss, Joshua D.},
booktitle={Digital Music Research Network One-day Workshop (DMRN+15)},
year={2020}
}
损失函数
我们将损失函数分为时域方法和频域方法。 此外,我们还包括感知变换。
损失函数 | 接口 | 参考文献 |
---|---|---|
时域 | ||
误差信号比(ESR) | auraloss.time.ESRLoss() | Wright & Välimäki, 2019 |
直流误差(DC) | auraloss.time.DCLoss() | Wright & Välimäki, 2019 |
对数双曲余弦(Log-cosh) | auraloss.time.LogCoshLoss() | Chen et al., 2019 |
信噪比(SNR) | auraloss.time.SNRLoss() | |
尺度不变信号失真比(SI-SDR) | auraloss.time.SISDRLoss() | Le Roux et al., 2018 |
尺度相关信号失真比(SD-SDR) | auraloss.time.SDSDRLoss() | Le Roux et al., 2018 |
频域 | ||
聚合短时傅里叶变换 | auraloss.freq.STFTLoss() | Arik et al., 2018 |
聚合梅尔尺度短时傅里叶变换 | auraloss.freq.MelSTFTLoss(sample_rate) | |
多分辨率短时傅里叶变换 | auraloss.freq.MultiResolutionSTFTLoss() | Yamamoto et al., 2019* |
随机分辨率短时傅里叶变换 | auraloss.freq.RandomResolutionSTFTLoss() | Steinmetz & Reiss, 2020 |
和差短时傅里叶变换损失 | auraloss.freq.SumAndDifferenceSTFTLoss() | Steinmetz et al., 2020 |
感知变换 | ||
和差信号变换 | auraloss.perceptual.SumAndDifference() | |
FIR预加重滤波器 | auraloss.perceptual.FIRFilter() | Wright & Välimäki, 2019 |
- Wang et al., 2019也提出了一种多分辨率谱损失(Engel et al., 2020采用了这种方法),但他们没有包括对数幅度(L1距离)和谱收敛项,这两项最初由Arik et al., 2018引入,后来在Yamamoto et al., 2019的工作中扩展到多分辨率情况。
示例
目前我们包括了一个使用一组损失函数来训练TCN以模拟模拟动态范围压缩器的示例。详细信息请参阅examples/compressor
中的详细内容。我们提供了预训练模型、用于计算论文中指标的评估脚本,以及重新训练模型的脚本。
基于STFTLoss
类,您可以做一些更高级的事情。例如,您可以像Engel等人,2020那样计算线性和对数尺度的STFT误差。在这种情况下,我们不包括谱收敛项。
stft_loss = auraloss.freq.STFTLoss(
w_log_mag=1.0,
w_lin_mag=1.0,
w_sc=0.0,
)
还有一个Mel尺度的STFT损失,它有一些特殊要求。这个损失函数要求您设置采样率并指定正确的设备。
sample_rate = 44100
melstft_loss = auraloss.freq.MelSTFTLoss(sample_rate, device="cuda")
您还可以轻松构建一个具有64个频段的多分辨率Mel尺度STFT损失。确保传递正确的设备,即您要比较的张量所在的设备。
loss_fn = auraloss.freq.MultiResolutionSTFTLoss(
scale="mel",
n_bins=64,
sample_rate=sample_rate,
device="cuda"
)
如果您在立体声音频上计算损失,可能需要考虑和差(中侧)损失。下面我们展示了一个使用此损失函数的示例,结合感知权重和mel尺度以进一步提高感知相关性。
target = torch.rand(8, 2, 44100)
pred = torch.rand(8, 2, 44100)
loss_fn = auraloss.freq.SumAndDifferenceSTFTLoss(
fft_sizes=[1024, 2048, 8192],
hop_sizes=[256, 512, 2048],
win_lengths=[1024, 2048, 8192],
perceptual_weighting=True,
sample_rate=44100,
scale="mel",
n_bins=128,
)
loss = loss_fn(pred, target)
开发
使用pytest在本地运行测试。
python -m pytest