forked from facebookresearch/dino
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhubconf.py
95 lines (83 loc) · 3.47 KB
/
hubconf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torchvision.models.resnet import resnet50
import vision_transformer as vits
dependencies = ["torch", "torchvision"]
def dino_deits16(pretrained=True, **kwargs):
"""
DeiT-Small/16x16 pre-trained with DINO.
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["deit_small"](patch_size=16, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=True)
return model
def dino_deits8(pretrained=True, **kwargs):
"""
DeiT-Small/8x8 pre-trained with DINO.
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=True)
return model
def dino_vitb16(pretrained=True, **kwargs):
"""
ViT-Base/16x16 pre-trained with DINO.
Achieves 76.1% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["vit_base"](patch_size=16, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=True)
return model
def dino_vitb8(pretrained=True, **kwargs):
"""
ViT-Base/8x8 pre-trained with DINO.
Achieves 77.4% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["vit_base"](patch_size=8, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=True)
return model
def dino_resnet50(pretrained=True, **kwargs):
"""
ResNet-50 pre-trained with DINO.
Achieves 75.3% top-1 accuracy on ImageNet linear evaluation benchmark (requires to train `fc`).
Note that `fc.weight` and `fc.bias` are randomly initialized.
"""
model = resnet50(pretrained=False, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_resnet50_pretrain/dino_resnet50_pretrain.pth",
map_location="cpu",
)
model.load_state_dict(state_dict, strict=False)
return model