Skip to content

Commit

Permalink
fix target shape when batch size equals to 1 (pytorch#99)
Browse files Browse the repository at this point in the history
When target only contains one element, the shape of its numpy array will be `()`,  `assert target.shape[0] == output.shape[0]` will report errors. `np.atleast_1d` fix it.
  • Loading branch information
wandering007 authored and Sasha Sax committed Jul 26, 2018
1 parent fd3236b commit ad81dab
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion torchnet/meter/classerrormeter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def add(self, output, target):
if torch.is_tensor(output):
output = output.cpu().squeeze().numpy()
if torch.is_tensor(target):
target = target.cpu().squeeze().numpy()
target = np.atleast_1d(target.cpu().squeeze().numpy())
elif isinstance(target, numbers.Number):
target = np.asarray([target])
if np.ndim(output) == 1:
Expand Down

0 comments on commit ad81dab

Please sign in to comment.