Skip to content

Commit

Permalink
force dtype of masks to be the same as stft_repr in mel band roformer
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 18, 2023
1 parent 4056daf commit 59fcd67
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
2 changes: 2 additions & 0 deletions bs_roformer/mel_band_roformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,8 @@ def forward(
stft_repr = torch.view_as_complex(stft_repr)
masks = torch.view_as_complex(masks)

masks = masks.type(stft_repr.dtype)

# need to average the estimated mask for the overlapped frequencies

scatter_indices = repeat(freq_indices, 'f -> b 1 f t', b = batch, t = stft_repr.shape[-1])
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'BS-RoFormer',
packages = find_packages(exclude=[]),
version = '0.2.2',
version = '0.2.3',
license='MIT',
description = 'BS-RoFormer - Band-Split Rotary Transformer for SOTA Music Source Separation',
author = 'Phil Wang',
Expand Down

0 comments on commit 59fcd67

Please sign in to comment.