Skip to content

Commit

Permalink
Update leaderboard.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bhkim94 authored Dec 4, 2020
1 parent bdde72a commit 21e7047
Showing 1 changed file with 33 additions and 8 deletions.
41 changes: 33 additions & 8 deletions models/eval/leaderboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, splits, seen
maskrcnn.load_state_dict(torch.load('weight_maskrcnn.pt'))
maskrcnn = maskrcnn.cuda()

prev_image = None
prev_action = None
nav_actions = ['MoveAhead_25', 'RotateLeft_90', 'RotateRight_90', 'LookDown_15', 'LookUp_15']

prev_class = 0
prev_center = torch.zeros(2)

done, success = False, False
actions = list()
fails = 0
Expand All @@ -92,13 +99,19 @@ def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, splits, seen
m_pred = model.extract_preds(m_out, [(traj_data, False)], feat, clean_special_tokens=False)
m_pred = list(m_pred.values())[0]

# get action and mask
action = m_pred['action_low']
if prev_image == curr_image and prev_action == action and prev_action in nav_actions and action in nav_actions and action == 'MoveAhead_25':
dist_action = m_out['out_action_low'][0][0].detach().cpu()
idx_rotateR = model.vocab['action_low'].word2index('RotateRight_90')
idx_rotateL = model.vocab['action_low'].word2index('RotateLeft_90')
action = 'RotateLeft_90' if dist_action[idx_rotateL] > dist_action[idx_rotateR] else 'RotateRight_90'

# check if <<stop>> was predicted
if m_pred['action_low'] == cls.STOP_TOKEN:
print("\tpredicted STOP")
break

# get action and mask
action = m_pred['action_low']

mask = None
if model.has_interaction(action):
class_dist = m_pred['action_low_mask'][0]
Expand All @@ -114,9 +127,19 @@ def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, splits, seen
else:
masks = out['masks'][out['labels'] == pred_class].detach().cpu()
scores = out['scores'][out['labels'] == pred_class].detach().cpu()

if prev_class != pred_class:
scores, indices = scores.sort(descending=True)
masks = masks[indices]
prev_class = pred_class
prev_center = masks[0].squeeze(dim=0).nonzero().double().mean(dim=0)
else:
cur_centers = torch.stack([m.nonzero().double().mean(dim=0) for m in masks.squeeze(dim=1)])
distances = ((cur_centers - prev_center)**2).sum(dim=1)
distances, indices = distances.sort()
masks = masks[indices]
prev_center = cur_centers[0]

scores, indices = scores.sort(descending=True)
masks = masks[indices]
mask = np.squeeze(masks[0].numpy(), axis=0)

# use predicted action and mask (if available) to interact with the env
Expand All @@ -136,6 +159,9 @@ def evaluate(cls, env, model, r_idx, resnet, traj_data, args, lock, splits, seen
# next time-step
t += 1

prev_image = curr_image
prev_action = action

# actseq
seen_ids = [t['task'] for t in splits['tests_seen']]
actseq = {traj_data['task_id']: actions}
Expand Down Expand Up @@ -224,7 +250,7 @@ def save_results(self):
'tests_unseen': list(self.unseen_actseqs)}

save_path = os.path.dirname(self.args.model_path)
save_path = os.path.join(save_path, self.args.dout + '_tests_actseqs_dump_' + datetime.now().strftime("%Y%m%d_%H%M%S_%f") + '.json')
save_path = os.path.join(save_path, 'tests_actseqs_dump_' + datetime.now().strftime("%Y%m%d_%H%M%S_%f") + '.json')
with open(save_path, 'w') as r:
json.dump(results, r, indent=4, sort_keys=True)

Expand All @@ -240,8 +266,7 @@ def save_results(self):
# settings
parser.add_argument('--splits', type=str, default="data/splits/oct21.json")
parser.add_argument('--data', type=str, default="data/json_feat_2.1.0")
parser.add_argument('--dout', type=str, default="15")
parser.add_argument('--model_path', type=str, default="exp/weight_batch4_50epoch/net_epoch_15.pth")
parser.add_argument('--model_path', type=str, default="exp/pretrained/pretrained.pth")
parser.add_argument('--model', type=str, default='models.model.seq2seq_im_mask')
parser.add_argument('--preprocess', dest='preprocess', action='store_true')
parser.add_argument('--gpu', dest='gpu', action='store_true', default=True)
Expand Down

0 comments on commit 21e7047

Please sign in to comment.