Skip to content

Commit

Permalink
resize img in visu
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed Jun 8, 2021
1 parent 4b96393 commit 9417599
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
8 changes: 4 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,15 +137,15 @@ def restart_from_checkpoint(ckp_path, run_variables=None, **kwargs):
if key in checkpoint and value is not None:
try:
msg = value.load_state_dict(checkpoint[key], strict=False)
print("=> loaded {} from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
print("=> loaded '{}' from checkpoint '{}' with msg {}".format(key, ckp_path, msg))
except TypeError:
try:
msg = value.load_state_dict(checkpoint[key])
print("=> loaded {} from checkpoint '{}'".format(key, ckp_path))
print("=> loaded '{}' from checkpoint: '{}'".format(key, ckp_path))
except ValueError:
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
print("=> failed to load '{}' from checkpoint: '{}'".format(key, ckp_path))
else:
print("=> failed to load {} from checkpoint '{}'".format(key, ckp_path))
print("=> key '{}' not found in checkpoint: '{}'".format(key, ckp_path))

# re load variable important for the run
if run_variables is not None:
Expand Down
2 changes: 2 additions & 0 deletions video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def __load_model(self):
)
state_dict = state_dict[self.args.checkpoint_key]
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print(
"Pretrained weights found at {} and loaded with msg: {}".format(
Expand Down
34 changes: 19 additions & 15 deletions visualize_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument("--image_path", default=None, type=str, help="Path of the image to load.")
parser.add_argument("--image_size", default=(480, 480), type=int, nargs="+", help="Resize image.")
parser.add_argument('--output_dir', default='.', help='Path where to save visualizations.')
parser.add_argument("--threshold", type=float, default=0.6, help="""We visualize masks
parser.add_argument("--threshold", type=float, default=None, help="""We visualize masks
obtained by thresholding the self-attention maps to keep xx% of the mass.""")
args = parser.parse_args()

Expand Down Expand Up @@ -162,6 +163,7 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
print(f"Provided image path {args.image_path} is non valid.")
sys.exit(1)
transform = pth_transforms.Compose([
pth_transforms.Resize(args.image_size),
pth_transforms.ToTensor(),
pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
Expand All @@ -181,17 +183,18 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - args.threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# interpolate
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
if args.threshold is not None:
# we keep only a certain percentage of the mass
val, idx = torch.sort(attentions)
val /= torch.sum(val, dim=1, keepdim=True)
cumval = torch.cumsum(val, dim=1)
th_attn = cumval > (1 - args.threshold)
idx2 = torch.argsort(idx)
for head in range(nh):
th_attn[head] = th_attn[head][idx2[head]]
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
# interpolate
th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()

attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=args.patch_size, mode="nearest")[0].cpu().numpy()
Expand All @@ -204,6 +207,7 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
plt.imsave(fname=fname, arr=attentions[j], format='png')
print(f"{fname} saved.")

image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
for j in range(nh):
display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)
if args.threshold is not None:
image = skimage.io.imread(os.path.join(args.output_dir, "img.png"))
for j in range(nh):
display_instances(image, th_attn[j], fname=os.path.join(args.output_dir, "mask_th" + str(args.threshold) + "_head" + str(j) +".png"), blur=False)

0 comments on commit 9417599

Please sign in to comment.