-
Notifications
You must be signed in to change notification settings - Fork 4
/
criteria.py
61 lines (38 loc) · 1.79 KB
/
criteria.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from conv_stft import ConvSTFT
class stftm_loss(object):
def __init__(self, frame_size=512, frame_shift=256, loss_type='mae'):
self.stft = ConvSTFT(frame_size, frame_shift, frame_size, 'hanning', 'complex', fix=True).cuda()
self.fft_len = 512
def __call__(self, outputs, labels):
out_real, out_imag = self.get_stftm(outputs)
lab_real, lab_imag = self.get_stftm(labels)
if self.loss_type == 'mae':
loss = torch.mean(torch.abs(out_real-lab_real)+torch.abs(out_imag-lab_imag))
elif self.loss_type == 'char':
loss = self.char_loss(out_real, lab_real) + self.char_loss(out_imag, lab_imag)
elif self.loss_type == 'hybrid':
loss = (self.edge_loss(out_real, lab_real) + self.edge_loss(out_imag, lab_imag)) * 0.05 +\
self.char_loss(out_real, lab_real) + self.char_loss(out_imag, lab_imag)
return loss
def get_stftm(self, ipt):
specs = self.stft(ipt)
real = specs[:,:self.fft_len//2+1]
imag = specs[:,self.fft_len//2+1:]
return real, imag
class mag_loss(object):
def __init__(self, frame_size=512, frame_shift=256, loss_type='mae'):
self.stft = ConvSTFT(frame_size, frame_shift, frame_size, 'hanning', 'complex', fix=True).cuda()
self.fft_len = 512
def __call__(self, outputs, labels):
out_mags = self.get_mag(outputs)
lab_mags = self.get_mag(labels)
if self.loss_type == 'mae':
loss = torch.mean(torch.abs(out_mags-lab_mags))
return loss
def get_mag(self, ipt):
specs = self.stft(ipt)
real = specs[:,:self.fft_len//2+1]
imag = specs[:,self.fft_len//2+1:]
spec_mags = torch.sqrt(real**2+imag**2+1e-8)
return spec_mags