diff --git a/asteroid/losses/pit_wrapper.py b/asteroid/losses/pit_wrapper.py index 434d46e..b100247 100644 --- a/asteroid/losses/pit_wrapper.py +++ b/asteroid/losses/pit_wrapper.py @@ -131,7 +131,7 @@ def get_pw_losses(loss_func, est_targets, targets, **kwargs): losses using broadcasting. """ batch_size, n_src, *_ = targets.shape - pair_wise_losses = torch.empty(batch_size, n_src, n_src) + pair_wise_losses = targets.new_empty(batch_size, n_src, n_src) for est_idx, est_src in enumerate(est_targets.transpose(0, 1)): for target_idx, target_src in enumerate(targets.transpose(0, 1)): pair_wise_losses[:, est_idx, target_idx] = loss_func(