Skip to content

Commit

Permalink
fix stft and istft in pyotrch 2.0.0
Browse files Browse the repository at this point in the history
fix stft and istft in pyotrch 2.0.0
in pytorch 2.0.0 not support real output(stft)and real input(istft)
  • Loading branch information
233lol authored Apr 6, 2023
1 parent 5cef5ee commit 5352251
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions separate.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,7 +392,8 @@ def run_model(self, mix, is_ckpt=False, is_match_mix=False):

def stft(self, x):
x = x.reshape([-1, self.chunk_size])
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
x = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True,return_complex=True)
x=torch.view_as_real(x)
x = x.permute([0,3,1,2])
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,self.dim_c,self.n_bins,self.dim_t])
return x[:,:,:self.dim_f]
Expand All @@ -402,6 +403,8 @@ def istft(self, x, freq_pad=None):
x = torch.cat([x, freq_pad], -2)
x = x.reshape([-1,2,2,self.n_bins,self.dim_t]).reshape([-1,2,self.n_bins,self.dim_t])
x = x.permute([0,2,3,1])
x=x.contiguous()
x=torch.view_as_complex(x)
x = torch.istft(x, n_fft=self.n_fft, hop_length=self.hop, window=self.window, center=True)
return x.reshape([-1,2,self.chunk_size])

Expand Down Expand Up @@ -936,4 +939,4 @@ def save_format(audio_path, save_format, mp3_bit_set):
try:
os.remove(audio_path)
except Exception as e:
print(e)
print(e)

0 comments on commit 5352251

Please sign in to comment.