Skip to content

Commit

Permalink
fix rn50 hubconf + visu on cpu + ref to YK video
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed May 2, 2021
1 parent a15f6af commit 1d06521
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Self-Supervised Vision Transformers with DINO

PyTorch implementation and pretrained models for DINO. For details, see **Emerging Properties in Self-Supervised Vision Transformers**.
[[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)]
[[`blogpost`](https://ai.facebook.com/blog/dino-paws-computer-vision-with-self-supervised-transformers-and-10x-more-efficient-training)] [[`arXiv`](https://arxiv.org/abs/2104.14294)] [[`Yannic Kilcher's video`](https://www.youtube.com/watch?v=h3ij3F3cPIk)]

<div align="center">
<img width="100%" alt="DINO illustration" src=".github/dino.gif">
Expand Down
4 changes: 2 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def dino_vitb8(pretrained=True, **kwargs):
def dino_resnet50(pretrained=True, **kwargs):
"""
ResNet-50 pre-trained with DINO.
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark.
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
Note that `fc.weight` and `fc.bias` are randomly initialized.
"""
model = resnet50(pretrained=False, **kwargs)
Expand All @@ -79,5 +79,5 @@ def dino_resnet50(pretrained=True, **kwargs):
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=True)
model.load_state_dict(state_dict, strict=False)
return model
5 changes: 3 additions & 2 deletions visualize_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,13 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
obtained by thresholding the self-attention maps to keep xx% of the mass.""")
args = parser.parse_args()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# build model
model = vits.__dict__[args.arch](patch_size=args.patch_size, num_classes=0)
for p in model.parameters():
p.requires_grad = False
model.eval()
model.cuda()
model.to(device)
if os.path.isfile(args.pretrained_weights):
state_dict = torch.load(args.pretrained_weights, map_location="cpu")
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
Expand Down Expand Up @@ -158,7 +159,7 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
w_featmap = img.shape[-2] // args.patch_size
h_featmap = img.shape[-1] // args.patch_size

attentions = model.forward_selfattention(img.cuda())
attentions = model.forward_selfattention(img.to(device))

nh = attentions.shape[1] # number of head

Expand Down

0 comments on commit 1d06521

Please sign in to comment.