Skip to content

Commit

Permalink
eval linear logs + typo when padding pos encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
Mathilde Caron committed May 13, 2021
1 parent b0f5bb4 commit fc997c8
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ python main_dino.py --help
```

### Vanilla DINO training :sauropod:
Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and ~73.8% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility.
Run DINO with DeiT-small network on a single node with 8 GPUs for 100 epochs with the following command. Training time is 1.75 day and the resulting checkpoint should reach 69.3% on k-NN eval and 74.0% on linear eval. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_vanilla_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.
```
python -m torch.distributed.launch --nproc_per_node=8 main_dino.py --arch deit_small --data_path /path/to/imagenet/train --output_dir /path/to/saving_dir
```
Expand Down Expand Up @@ -133,7 +133,7 @@ python run_with_submitit.py --arch deit_small --epochs 300 --teacher_temp 0.07 -

</details>

The resulting pretrained model should reach 73.3% on k-NN eval and ~76.1% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](/to/do) logs for this run to help reproducibility.
The resulting pretrained model should reach 73.3% on k-NN eval and 76.0% on linear eval. Training time is 2.6 days with 16 GPUs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_log.txt) and [linear evaluation](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_boost_deitsmall16_eval.txt) logs (with batch size 256 at evaluation time) for this run to help reproducibility.

### ResNet-50 and other convnets trainings
This code also works for training DINO on convolutional networks, like ResNet-50 for example. We highly recommend to adapt some optimization arguments in this case. For example following is a command to train DINO on ResNet-50 on a single node with 8 GPUs for 100 epochs. We provide [training](https://dl.fbaipublicfiles.com/dino/example_runs_logs/dino_rn50_log.txt) logs for this run.
Expand Down
6 changes: 4 additions & 2 deletions vision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,17 @@ def forward_selfattention(self, x):
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
# sometimes there is a floating point error in the interpolation and so
# we need to pad the patch positional encoding.
if w0 != patch_pos_embed.shape[-2]:
helper = torch.zeros(h0)[None, None, None, :].repeat(1, dim, w0 - patch_pos_embed.shape[-2], 1).to(x.device)
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-2)
if h0 != patch_pos_embed.shape[-1]:
helper = torch.zeros(w0)[None, None, :, None].repeat(1, dim, 1, h0 - patch_pos_embed.shape[-1]).to(x.device)
pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)
patch_pos_embed = torch.cat((patch_pos_embed, helper), dim=-1)

patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
pos_embed = torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + pos_embed
Expand Down

0 comments on commit fc997c8

Please sign in to comment.