Skip to content

Commit

Permalink
Fix PITLossWrapper usage on GPU (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
JorisCos authored Mar 11, 2020
1 parent 1069ce1 commit c2d7663
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion asteroid/losses/pit_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_pw_losses(loss_func, est_targets, targets, **kwargs):
losses using broadcasting.
"""
batch_size, n_src, *_ = targets.shape
pair_wise_losses = torch.empty(batch_size, n_src, n_src)
pair_wise_losses = targets.new_empty(batch_size, n_src, n_src)
for est_idx, est_src in enumerate(est_targets.transpose(0, 1)):
for target_idx, target_src in enumerate(targets.transpose(0, 1)):
pair_wise_losses[:, est_idx, target_idx] = loss_func(
Expand Down

0 comments on commit c2d7663

Please sign in to comment.