Skip to content

Commit

Permalink
replace name deit-s with vit-s
Browse files Browse the repository at this point in the history
replace name deit-s with vit-s
  • Loading branch information
Mathilde Caron committed May 22, 2021
1 parent a618f0b commit 4b96393
Show file tree
Hide file tree
Showing 10 changed files with 47 additions and 44 deletions.
22 changes: 10 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ PyTorch implementation and pretrained models for DINO. For details, see **Emergi
</div>

## Pretrained models
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in `onnx` format, as well as detailed arguments and training/evaluation logs.
You can choose to download only the weights of the pretrained backbone used for downstream tasks, or the full checkpoint which contains backbone and projection head weights for both student and teacher networks. We also provide the backbone in `onnx` format, as well as detailed arguments and training/evaluation logs. Note that `DeiT-S` and `ViT-S` names refer exactly to the same architecture.

<table>
<tr>
Expand All @@ -19,7 +19,7 @@ You can choose to download only the weights of the pretrained backbone used for
<th colspan="6">download</th>
</tr>
<tr>
<td>DeiT-S/16</td>
<td>ViT-S/16</td>
<td>21M</td>
<td>74.5%</td>
<td>77.0%</td>
Expand All @@ -31,7 +31,7 @@ You can choose to download only the weights of the pretrained backbone used for
<td><a href="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain_eval_linear_log.txt">eval logs</a></td>
</tr>
<tr>
<td>DeiT-S/8</td>
<td>ViT-S/8</td>
<td>21M</td>
<td>78.3%</td>
<td>79.7%</td>
Expand Down Expand Up @@ -83,8 +83,8 @@ You can choose to download only the weights of the pretrained backbone used for
The pretrained models are available on PyTorch Hub.
```python
import torch
deits16 = torch.hub.load('facebookresearch/dino:main', 'dino_deits16')
deits8 = torch.hub.load('facebookresearch/dino:main', 'dino_deits8')
vits16 = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
vits8 = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
vitb16 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
resnet50 = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')
Expand All @@ -99,15 +99,15 @@ python main_dino.py --help
```

### Vanilla DINO training :sauropod:
Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
Run DINO with ViT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
```
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
```

### Multi-node training
We use Slurm and [submitit](https://github.com/facebookincubator/submitit) (`pip install submitit`). To train on 2 nodes with 8 GPUs each (total 16 GPUs):
```
python run_with_submitit.py --nodes 2 --ngpus 8 --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
python run_with_submitit.py --nodes 2 --ngpus 8 --arch vit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
```

<details>
Expand All @@ -125,15 +125,15 @@ python run_with_submitit.py --nodes 2 --ngpus 8 --use_volta32 --arch vit_base -
You can improve the performance of the vanilla run by:
- training for more epochs: `--epochs 300`,
- increasing the teacher temperature: `--teacher_temp 0.07 --warmup_teacher_temp_epochs 30`.
- removing last layer normalization (only safe with `--arch deit_small`): `--norm_last_layer false`,
- removing last layer normalization (only safe with `--arch vit_small`): `--norm_last_layer false`,

<details>
<summary>
Full command.
</summary>

```
python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
python run_with_submitit.py --arch vit_small --epochs 300 --teacher_temp 0.07 --warmup_teacher_temp_epochs 30 --norm_last_layer false --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
```

</details>
Expand Down Expand Up @@ -185,8 +185,6 @@ python video_generation.py --input_path output/attention \
--video_format avi
```

Also, check out [this colab](https://gist.github.com/aquadzn/32ac53aa6e485e7c3e09b1a0914f7422) for a video inference notebook.


## Evaluation: k-NN classification on ImageNet
To evaluate a simple k-NN classifier with a single GPU on a pre-trained model, run:
Expand Down
4 changes: 2 additions & 2 deletions eval_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,8 @@ def __getitem__(self, idx):
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument('--use_cuda', default=True, type=utils.bool_flag,
help="Should we store the features on GPU? We recommend setting this to False if you encounter OOM")
parser.add_argument('--arch', default='deit_small', type=str,
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument("--checkpoint_key", default="teacher", type=str,
help='Key to use in the checkpoint (example: "teacher")')
Expand Down
8 changes: 4 additions & 4 deletions eval_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def forward(self, x):
if __name__ == '__main__':
parser = argparse.ArgumentParser('Evaluation with linear classification on ImageNet')
parser.add_argument('--n_last_blocks', default=4, type=int, help="""Concatenate [CLS] tokens
for the `n` last blocks. We use `n=4` when evaluating DeiT-Small and `n=1` with ViT-Base.""")
for the `n` last blocks. We use `n=4` when evaluating ViT-Small and `n=1` with ViT-Base.""")
parser.add_argument('--avgpool_patchtokens', default=False, type=utils.bool_flag,
help="""Whether ot not to concatenate the global average pooled features to the [CLS] token.
We typically set this to False for DeiT-Small and to True with ViT-Base.""")
parser.add_argument('--arch', default='deit_small', type=str,
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
We typically set this to False for ViT-Small and to True with ViT-Base.""")
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
Expand Down
4 changes: 2 additions & 2 deletions eval_video_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,8 +251,8 @@ def color_normalize(x, mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225]):
if __name__ == '__main__':
parser = argparse.ArgumentParser('Evaluation with video object segmentation on DAVIS 2017')
parser.add_argument('--pretrained_weights', default='', type=str, help="Path to pretrained weights to evaluate.")
parser.add_argument('--arch', default='deit_small', type=str,
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--patch_size', default=16, type=int, help='Patch resolution of the model.')
parser.add_argument("--checkpoint_key", default="teacher", type=str, help='Key to use in the checkpoint (example: "teacher")')
parser.add_argument('--output_dir', default=".", help='Path where to save segmentations')
Expand Down
12 changes: 6 additions & 6 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
dependencies = ["torch", "torchvision"]


def dino_deits16(pretrained=True, **kwargs):
def dino_vits16(pretrained=True, **kwargs):
"""
DeiT-Small/16x16 pre-trained with DINO.
ViT-Small/16x16 pre-trained with DINO.
Achieves 74.5% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["deit_small"](patch_size=16, num_classes=0, **kwargs)
model = vits.__dict__["vit_small"](patch_size=16, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth",
Expand All @@ -34,12 +34,12 @@ def dino_deits16(pretrained=True, **kwargs):
return model


def dino_deits8(pretrained=True, **kwargs):
def dino_vits8(pretrained=True, **kwargs):
"""
DeiT-Small/8x8 pre-trained with DINO.
ViT-Small/8x8 pre-trained with DINO.
Achieves 78.3% top-1 accuracy on ImageNet with k-NN classification.
"""
model = vits.__dict__["deit_small"](patch_size=8, num_classes=0, **kwargs)
model = vits.__dict__["vit_small"](patch_size=8, num_classes=0, **kwargs)
if pretrained:
state_dict = torch.hub.load_state_dict_from_url(
url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth",
Expand Down
14 changes: 8 additions & 6 deletions main_dino.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,21 +42,21 @@ def get_args_parser():
parser = argparse.ArgumentParser('DINO', add_help=False)

# Model parameters
parser.add_argument('--arch', default='deit_small', type=str,
choices=['deit_tiny', 'deit_small', 'vit_base'] + torchvision_archs,
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base', 'deit_tiny', 'deit_small'] + torchvision_archs,
help="""Name of architecture to train. For quick experiments with ViTs,
we recommend using deit_tiny or deit_small.""")
we recommend using vit_tiny or vit_small.""")
parser.add_argument('--patch_size', default=16, type=int, help="""Size in pixels
of input square patches - default 16 (for 16x16 patches). Using smaller
values leads to better performance but requires more memory. Applies only
for ViTs (deit_tiny, deit_small and vit_base). If <16, we recommend disabling
for ViTs (vit_tiny, vit_small and vit_base). If <16, we recommend disabling
mixed precision training (--use_fp16 false) to avoid unstabilities.""")
parser.add_argument('--out_dim', default=65536, type=int, help="""Dimensionality of
the DINO head output. For complex and large datasets large values (like 65k) work well.""")
parser.add_argument('--norm_last_layer', default=True, type=utils.bool_flag,
help="""Whether or not to weight normalize the last layer of the DINO head.
Not normalizing leads to better performance but can make the training unstable.
In our experiments, we typically set this paramater to False with deit_small and True with vit_base.""")
In our experiments, we typically set this paramater to False with vit_small and True with vit_base.""")
parser.add_argument('--momentum_teacher', default=0.996, type=float, help="""Base EMA
parameter for teacher update. The value is increased to 1 during training with cosine schedule.
We recommend setting a higher value with small batches: for example use 0.9995 with batch size of 256.""")
Expand Down Expand Up @@ -153,7 +153,9 @@ def train_dino(args):
print(f"Data loaded: there are {len(dataset)} images.")

# ============ building student and teacher networks ... ============
# if the network is a vision transformer (i.e. deit_tiny, deit_small, vit_base)
# we changed the name DeiT-S for ViT-S to avoid confusions
args.arch = args.arch.replace("deit", "vit")
# if the network is a vision transformer (i.e. vit_tiny, vit_small, vit_base)
if args.arch in vits.__dict__.keys():
student = vits.__dict__[args.arch](
patch_size=args.patch_size,
Expand Down
4 changes: 2 additions & 2 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def load_pretrained_weights(model, pretrained_weights, checkpoint_key, model_nam
else:
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
url = None
if model_name == "deit_small" and patch_size == 16:
if model_name == "vit_small" and patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif model_name == "deit_small" and patch_size == 8:
elif model_name == "vit_small" and patch_size == 8:
url = "dino_deitsmall8_pretrain/dino_deitsmall8_pretrain.pth"
elif model_name == "vit_base" and patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
Expand Down
8 changes: 4 additions & 4 deletions video_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,9 +269,9 @@ def __load_model(self):
"Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate."
)
url = None
if self.args.arch == "deit_small" and self.args.patch_size == 16:
if self.args.arch == "vit_small" and self.args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif self.args.arch == "deit_small" and self.args.patch_size == 8:
elif self.args.arch == "vit_small" and self.args.patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif self.args.arch == "vit_base" and self.args.patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
Expand All @@ -296,9 +296,9 @@ def parse_args():
parser = argparse.ArgumentParser("Generation self-attention video")
parser.add_argument(
"--arch",
default="deit_small",
default="vit_small",
type=str,
choices=["deit_tiny", "deit_small", "vit_base"],
choices=["vit_tiny", "vit_small", "vit_base"],
help="Architecture (support only ViT atm).",
)
parser.add_argument(
Expand Down
4 changes: 2 additions & 2 deletions vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,14 +233,14 @@ def get_intermediate_layers(self, x, n=1):
return output


def deit_tiny(patch_size=16, **kwargs):
def vit_tiny(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
return model


def deit_small(patch_size=16, **kwargs):
def vit_small(patch_size=16, **kwargs):
model = VisionTransformer(
patch_size=patch_size, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4,
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
Expand Down
11 changes: 7 additions & 4 deletions visualize_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con

if __name__ == '__main__':
parser = argparse.ArgumentParser('Visualize Self-Attention maps')
parser.add_argument('--arch', default='deit_small', type=str,
choices=['deit_tiny', 'deit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--arch', default='vit_small', type=str,
choices=['vit_tiny', 'vit_small', 'vit_base'], help='Architecture (support only ViT atm).')
parser.add_argument('--patch_size', default=8, type=int, help='Patch resolution of the model.')
parser.add_argument('--pretrained_weights', default='', type=str,
help="Path to pretrained weights to load.")
Expand All @@ -122,15 +122,18 @@ def display_instances(image, mask, fname="test", figsize=(5, 5), blur=False, con
if args.checkpoint_key is not None and args.checkpoint_key in state_dict:
print(f"Take key {args.checkpoint_key} in provided checkpoint dict")
state_dict = state_dict[args.checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
msg = model.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(args.pretrained_weights, msg))
else:
print("Please use the `--pretrained_weights` argument to indicate the path of the checkpoint to evaluate.")
url = None
if args.arch == "deit_small" and args.patch_size == 16:
if args.arch == "vit_small" and args.patch_size == 16:
url = "dino_deitsmall16_pretrain/dino_deitsmall16_pretrain.pth"
elif args.arch == "deit_small" and args.patch_size == 8:
elif args.arch == "vit_small" and args.patch_size == 8:
url = "dino_deitsmall8_300ep_pretrain/dino_deitsmall8_300ep_pretrain.pth" # model used for visualizations in our paper
elif args.arch == "vit_base" and args.patch_size == 16:
url = "dino_vitbase16_pretrain/dino_vitbase16_pretrain.pth"
Expand Down

0 comments on commit 4b96393

Please sign in to comment.