Unsupervised domain adaptation (UDA) aims to transfer knowledge learned from a labeled source domain to a different unlabeled target domain. Most existing UDA methods focus on learning domain-invariant feature representation, either from the domain level or category level, using convolution neural networks (CNNs)-based frameworks. With the success of Transformer in various tasks, we find that the cross-attention in Transformer is robust to the noisy input pairs for better feature alignment, thus in this paper Transformer is adopted for the challenging UDA task. Specifically, to generate accurate input pairs, we design a two-way center-aware labeling algorithm to produce pseudo labels for target samples. Along with the pseudo labels, a weight-sharing triple-branch transformer framework is proposed to apply self-attention and cross-attention for source/target feature learning and source-target domain alignment, respectively. Such design explicitly enforces the framework to learn discriminative domain-specific and domain-invariant representations simultaneously. The proposed method is dubbed CDTrans (cross-domain transformer), and it provides one of the first attempts to solve UDA tasks with a pure transformer solution. Extensive experiments show that our proposed method achieves the best performance on all public UDA datasets including Office-Home, Office-31, VisDA-2017, and DomainNet.
Methods | Avg. | A->D | A->W | D->A | D->W | W->A | W->D |
Baseline(DeiT-S) | 86.7 | 87.6 | 86.9 | 74.9 | 97.7 | 73.5 | 99.6 |
model | model | model | |||||
CDTrans(DeiT-S) | 90.4 | 94.6 | 93.5 | 78.4 | 98.2 | 78 | 99.6 |
model | model | model | model | model | model | ||
Baseline(DeiT-B) | 88.8 | 90.8 | 90.4 | 76.8 | 98.2 | 76.4 | 100 |
model | model | model | |||||
CDTrans(DeiT-B) | 92.6 | 97 | 96.7 | 81.1 | 99 | 81.9 | 100 |
model | model | model | model | model | model |
Methods | Avg. | Ar->Cl | Ar->Pr | Ar->Re | Cl->Ar | Cl->Pr | Cl->Re | Pr->Ar | Pr->Cl | Pr->Re | Re->Ar | Re->Cl | Re->Pr |
Baseline(DeiT-S) | 69.8 | 55.6 | 73 | 79.4 | 70.6 | 72.9 | 76.3 | 67.5 | 51 | 81 | 74.5 | 53.2 | 82.7 |
model | model | model | model | ||||||||||
CDTrans(DeiT-S) | 74.7 | 60.6 | 79.5 | 82.4 | 75.6 | 81.0 | 82.3 | 72.5 | 56.7 | 84.4 | 77.0 | 59.1 | 85.5 |
model | model | model | model | model | model | model | model | model | model | model | model | ||
Baseline(DeiT-B) | 74.8 | 61.8 | 79.5 | 84.3 | 75.4 | 78.8 | 81.2 | 72.8 | 55.7 | 84.4 | 78.3 | 59.3 | 86 |
model | model | model | model | ||||||||||
CDTrans(DeiT-B) | 80.5 | 68.8 | 85 | 86.9 | 81.5 | 87.1 | 87.3 | 79.6 | 63.3 | 88.2 | 82 | 66 | 90.6 |
model | model | model | model | model | model | model | model | model | model | model | model |
Methods | Per-class | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck |
Baseline(DeiT-B) | 67.3 (model) | 98.1 | 48.1 | 84.6 | 65.2 | 76.3 | 59.4 | 94.5 | 11.8 | 89.5 | 52.2 | 94.5 | 34.1 |
CDTrans(DeiT-B) | 88.4 (model) | 97.7 | 86.39 | 86.87 | 83.33 | 97.76 | 97.16 | 95.93 | 84.08 | 97.93 | 83.47 | 94.59 | 55.3 |
Base-S | clp | info | pnt | qdr | rel | skt | Avg. | CDTrans-S | clp | info | pnt | qdr | rel | skt | Avg. |
clp | - | 21.2 | 44.2 | 15.3 | 59.9 | 46.0 | 37.3 | clp | - | 25.3 | 52.5 | 23.2 | 68.3 | 53.2 | 44.5 |
model | model | model | model | model | model | model | |||||||||
info | 36.8 | - | 39.4 | 5.4 | 52.1 | 32.6 | 33.3 | info | 47.6 | - | 48.3 | 9.9 | 62.8 | 41.1 | 41.9 |
model | model | model | model | model | model | model | |||||||||
pnt | 47.1 | 21.7 | - | 5.7 | 60.2 | 39.9 | 34.9 | pnt | 55.4 | 24.5 | - | 11.7 | 67.4 | 48.0 | 41.4 |
model | model | model | model | model | model | model | |||||||||
qdr | 25.0 | 3.3 | 10.4 | - | 18.8 | 14.0 | 14.3 | qdr | 36.6 | 5.3 | 19.3 | - | 33.8 | 22.7 | 23.5 |
model | model | model | model | model | model | model | |||||||||
rel | 54.8 | 23.9 | 52.6 | 7.4 | - | 40.1 | 35.8 | rel | 61.5 | 28.1 | 56.8 | 12.8 | - | 47.2 | 41.3 |
model | model | model | model | model | model | model | |||||||||
skt | 55.6 | 18.6 | 42.7 | 14.9 | 55.7 | - | 37.5 | skt | 64.3 | 26.1 | 53.2 | 23.9 | 66.2 | - | 46.7 |
model | model | model | model | model | model | model | |||||||||
Avg. | 43.9 | 17.7 | 37.9 | 9.7 | 49.3 | 34.5 | 32.2 | Avg. | 53.08 | 21.86 | 46.02 | 16.3 | 59.7 | 42.44 | 39.9 |
Base-B | clp | info | pnt | qdr | rel | skt | Avg. | CDTrans-B | clp | info | pnt | qdr | rel | skt | Avg. |
clp | - | 24.2 | 48.9 | 15.5 | 63.9 | 50.7 | 40.6 | clp | - | 29.4 | 57.2 | 26.0 | 72.6 | 58.1 | 48.7 |
model | model | model | model | model | model | model | |||||||||
info | 43.5 | - | 44.9 | 6.5 | 58.8 | 37.6 | 38.3 | info | 57.0 | - | 54.4 | 12.8 | 69.5 | 48.4 | 48.4 |
model | model | model | model | model | model | model | |||||||||
pnt | 52.8 | 23.3 | - | 6.6 | 64.6 | 44.5 | 38.4 | pnt | 62.9 | 27.4 | - | 15.8 | 72.1 | 53.9 | 46.4 |
model | model | model | model | model | model | model | |||||||||
qdr | 31.8 | 6.1 | 15.6 | - | 23.4 | 18.9 | 19.2 | qdr | 44.6 | 8.9 | 29.0 | - | 42.6 | 28.5 | 30.7 |
model | model | model | model | model | model | model | |||||||||
rel | 58.9 | 26.3 | 56.7 | 9.1 | - | 45.0 | 39.2 | rel | 66.2 | 31.0 | 61.5 | 16.2 | - | 52.9 | 45.6 |
model | model | model | model | model | model | model | |||||||||
skt | 60.0 | 21.1 | 48.4 | 16.6 | 61.7 | - | 41.6 | skt | 69.0 | 29.6 | 59.0 | 27.2 | 72.5 | - | 51.5 |
model | model | model | model | model | model | model | |||||||||
Avg. | 49.4 | 20.2 | 42.9 | 10.9 | 54.5 | 39.3 | 36.2 | Avg. | 59.9 | 25.3 | 52.2 | 19.6 | 65.9 | 48.4 | 45.2 |
pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)
Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet
Then unzip them and rename them under the directory like follow: (Note that each dataset floader needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project.)
data
├── OfficeHomeDataset
│ │── class_name
│ │ └── images
│ └── *.txt
├── domainnet
│ │── class_name
│ │ └── images
│ └── *.txt
├── office31
│ │── class_name
│ │ └── images
│ └── *.txt
├── visda
│ │── train
│ │ │── class_name
│ │ │ └── images
│ │ └── *.txt
│ └── validation
│ │── class_name
│ │ └── images
│ └── *.txt
For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT.
You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel
directory.
We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.
Command input paradigm
bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]
DeiT-Base scripts
# Office-31 Source: Amazon -> Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base
#Office-Home Source: Art -> Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base
# VisDA-2017 Source: train -> Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base
# DomainNet Source: Clipart -> Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base
DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:
# Office-31 Source: Amazon -> Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small
# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'
Codebase from TransReID