Skip to content

Commit

Permalink
Update seq2seq_im_mask.py
Browse files Browse the repository at this point in the history
  • Loading branch information
bhkim94 authored Jan 30, 2024
1 parent 2c2f8fd commit a0942af
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions models/model/seq2seq_im_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self, args, vocab):

# paths
self.root_path = os.getcwd()
self.feat_pt = 'feat_conv.pt'
self.feat_pt = 'feat_conv_panoramic.pt'

# params
self.max_subgoals = 25
Expand Down Expand Up @@ -143,9 +143,11 @@ def featurize(self, batch, load_mask=True, load_frames=True):
if load_frames and not self.test_mode:
root = self.get_task_root(ex)
if not swapColor:
im = torch.load(os.path.join(root, self.feat_pt))
else:
im = torch.load(os.path.join(root, 'feat_conv_colorSwap{}.pt'.format(swapColor)))
im = torch.load(os.path.join(root, self.feat_pt))[2]
elif swapColor in [1, 2]:
im = torch.load(os.path.join(root, 'feat_conv_colorSwap{}_panoramic.pt'.format(swapColor)))[2]
elif swapColor in [3, 4, 5, 6]:
im = torch.load(os.path.join(root, 'feat_conv_onlyAutoAug{}_panoramic.pt'.format(swapColor - 2)))[2]
feat['frames'].append(im)

# tensorization and padding
Expand Down

0 comments on commit a0942af

Please sign in to comment.