Skip to content

Commit

Permalink
fix a bug in trainers/pvnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
pengsida committed Mar 9, 2020
1 parent b9ef443 commit 2ec747f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion lib/datasets/linemod/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __getitem__(self, index_tuple):
inp, kpt_2d, mask = self._transforms(inp, kpt_2d, mask)

vertex = pvnet_data_utils.compute_vertex(mask, kpt_2d).transpose(2, 0, 1)
ret = {'inp': inp, 'mask': mask.astype(np.uint8), 'vertex': vertex, 'img_id': img_id, 'meta': ''}
ret = {'inp': inp, 'mask': mask.astype(np.uint8), 'vertex': vertex, 'img_id': img_id, 'meta': {}}
# visualize_utils.visualize_linemod_ann(torch.tensor(inp), kpt_2d, mask, True)

return ret
Expand Down
1 change: 1 addition & 0 deletions lib/train/trainers/pvnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def forward(self, batch):
loss = 0

if 'pose_test' in batch['meta'].keys():
loss = torch.tensor(0).to(batch['inp'].device)
return output, loss, {}, {}

weight = batch['mask'][:, None].float()
Expand Down
2 changes: 1 addition & 1 deletion lib/train/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def val(self, epoch, data_loader, evaluator=None, recorder=None):
batch[k] = batch[k].cuda()

with torch.no_grad():
output, loss, loss_stats, image_stats = self.network(batch)
output, loss, loss_stats, image_stats = self.network.module(batch)
if evaluator is not None:
evaluator.evaluate(output, batch)

Expand Down

0 comments on commit 2ec747f

Please sign in to comment.