-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9aa6da2
commit add1e5c
Showing
60 changed files
with
9,489 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,47 @@ | ||
# Dynamic-Vision-Transformer-MindSpore | ||
# Dynamic-Vision-Transformer-MindSpore | ||
|
||
## Requirements | ||
|
||
* Mindspore 1.0 (https://www.mindspore.cn/install/en) | ||
* jinja2 (https://anaconda.org/anaconda/jinja2) | ||
* tqdm (for GPU only) | ||
* mpi4py (for GPU only) | ||
|
||
## Training | ||
You have to execute script from "src" directory. It will create directory "../results/{DATETIME}__{EXPERIMENT_NAME}" and place results there. | ||
|
||
``` | ||
bash scripts/train_ascend.sh {0-7} EXPERIMENT_NAME --config=CONFIG_PATH --device {Ascend (default)|GPU} [TRAIN.PY_ARGUMENTS] | ||
# training for feature reuse and releation reuse | ||
bash scripts/train_ascend.sh 0-7 deit_dvt_12_49_196_w_f_w_r_adamw_originhead_dataaug_mixup --config=configs/local/vit_dvt/deit_dvt_12_49_196_w_f_w_r_adamw_originhead_dataaug_mixup.yml.j2 | ||
# training for feature reuse and w/o releation reuse | ||
bash scripts/train_ascend.sh 0-7 deit_dvt_12_49_196_w_f_n_r_adamw_originhead_dataaug_mixup --config=configs/local/vit_dvt/deit_dvt_12_49_196_w_f_n_r_adamw_originhead_dataaug_mixup.yml.j2 | ||
# training for w/o feature reuse and releation reuse | ||
bash scripts/train_ascend.sh 0-7 deit_dvt_12_49_196_n_f_w_r_adamw_originhead_dataaug_mixup --config=configs/local/vit_dvt/deit_dvt_12_49_196_n_f_w_r_adamw_originhead_dataaug_mixup.yml.j2 | ||
# training for w/o feature reuse and w/o releation reuse | ||
bash scripts/train_ascend.sh 0-7 deit_dvt_12_49_196_n_f_n_r_adamw_originhead_dataaug_mixup --config=configs/local/vit_dvt/deit_dvt_12_49_196_n_f_n_r_adamw_originhead_dataaug_mixup.yml.j2 | ||
# inference for feature reuse and releation reuse | ||
bash scripts/inference_ascend.sh 0 deit_dvt_12_49_196_w_f_w_r_adamw_originhead_dataaug_mixup_inference --config=configs/local/vit_dvt/deit_dvt_12_49_196_w_f_w_r_adamw_originhead_dataaug_mixup_inference.yml.j2 | ||
``` | ||
|
||
|model|flops|acc| | ||
|-|-|-| | ||
|deit-s/16|4.608|78.67| | ||
|deit-s/32|1.145|72.116| | ||
|vit-b/16|17.58|79.1| | ||
|vit-b/32|4.41|73.972| | ||
|
||
Top-1 accuracy on ImageNet v.s. GFLOPs | ||
|
||
![](deit_dvt_vs_vit_inference.png) | ||
|
||
|
||
## Profiling (Ascend) | ||
Just add parameter --profile to training script | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
"""Some routines for executing scripts.""" | ||
|
||
# pylint: disable = import-outside-toplevel, c-extension-no-member | ||
|
||
import argparse | ||
import os | ||
import sys | ||
|
||
from utils.cfg_parser import ConfigObject, dump_yaml, parse_yaml, copy_data_to_cache | ||
|
||
VERSION = "0.1.0" | ||
|
||
def parse_args(): | ||
"""Function for parsing command line args and merging them with yaml.j2 config.""" | ||
|
||
parser = argparse.ArgumentParser(description='ISDS Mindspore research code.') | ||
|
||
parser.add_argument('--version', action='version', version=f'{VERSION}') | ||
parser.add_argument('--device', type=str, default='Ascend', choices=["CPU", "GPU", "Ascend"], | ||
help='Computing device.') | ||
parser.add_argument('--profile', type=int, default=0, help='Profiling mode.') | ||
parser.add_argument('--export_file', type=str, default='', | ||
help='Exporting mode. Path to save exported model') | ||
parser.add_argument('--config', type=str, default='', help='Configuration file') | ||
parser.add_argument('--seed', type=int, default=1, help="Random seed.") | ||
parser.add_argument('--pretrained', type=str, default='', help='Pretrained model') | ||
parser.add_argument('--start_epoch', type=int, default=0, | ||
help='Starting epoch for resumed training.') | ||
parser.add_argument('--num_epochs', type=int, default=300, help="Number of epochs for training.") | ||
parser.add_argument('-b', '--batch_size', type=int, default=None, | ||
help="Batch size for training.") | ||
parser.add_argument('--eval_batch_size', type=int, default=None, | ||
help="Batch size for eval.") | ||
parser.add_argument('--export_batch_size', type=int, nargs='+', default=32, | ||
help="Batch size for exported models.") | ||
parser.add_argument('-d', '--dataset', type=str, default="imagenet", help="Dataset name.") | ||
parser.add_argument('--stat', type=int, default=0, help="Save training statistics.") | ||
parser.add_argument('--train_url', type=str, default='', help='train_url') | ||
|
||
args = parser.parse_args() | ||
|
||
pretrained = args.pretrained | ||
if pretrained.startswith('s3://'): | ||
pretrained_cache = pretrained.replace('s3://', '/cache/') | ||
copy_data_to_cache(pretrained, pretrained_cache) | ||
pretrained = pretrained_cache | ||
args.pretrained = pretrained | ||
|
||
|
||
if os.path.basename(sys.argv[0]) in ["export.py"]: | ||
device_num = 1 | ||
else: | ||
device_num = int(os.getenv('RANK_SIZE')) | ||
print('........device_num={}'.format(device_num)) | ||
|
||
if args.device == "GPU": | ||
from mpi4py import MPI | ||
comm = MPI.COMM_WORLD | ||
rank = comm.Get_rank() | ||
os.environ["RANK_ID"] = str(rank) | ||
os.environ["DEVICE_ID"] = str(rank) | ||
|
||
data = { | ||
"DEVICE_NUM": device_num, | ||
"VERSION": VERSION, | ||
"NUM_EPOCHS": args.num_epochs, | ||
"START_EPOCH": args.start_epoch, | ||
"DEVICE": args.device, | ||
"TRAIN_BATCH_SIZE": args.batch_size, | ||
"EVAL_BATCH_SIZE": args.eval_batch_size, | ||
"DATASET": args.dataset, | ||
"STAT": args.stat | ||
} | ||
yaml = parse_yaml(args.config, data) | ||
# print('yaml:', yaml) | ||
if os.path.basename(sys.argv[0]) not in ["export.py"]: | ||
dump_yaml(yaml, "config.yaml") | ||
|
||
args = args.__dict__ | ||
args.update(yaml) | ||
args = ConfigObject(args) | ||
|
||
return args |
241 changes: 241 additions & 0 deletions
241
configs/cloud/vit/deit_base16_adamw_originhead_dataaug_mixup.yml.j2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
# ViT: | ||
# * base configuration | ||
# * patch = 16 | ||
|
||
version: {{VERSION}} | ||
|
||
{% set name = "deit_base16" %} | ||
name: {{name}} | ||
|
||
{% if DEVICE == "Ascend" %} | ||
{% if DATASET == "imagenet" %} | ||
{% set train_dataset_path = '/cache/bucket-d/data/imagenet/train' %} | ||
{% set eval_dataset_path = '/cache/bucket-d/data/imagenet/val' %} | ||
{% set num_classes = 1001 %} | ||
{% set train_len = 1281167 %} | ||
{% set val_len = 50000 %} | ||
{% elif DATASET == "imagenet100" %} | ||
{% set train_dataset_path = '/opt/npu/datasets/imagenet100/train' %} | ||
{% set eval_dataset_path = '/opt/npu/datasets/imagenet100/val' %} | ||
{% set num_classes = 101 %} | ||
{% set train_len = 130000 %} | ||
{% set val_len = 5000 %} | ||
{% endif %} | ||
{% else %} | ||
{% if DATASET == "imagenet" %} | ||
{% set train_dataset_path = '/ssd/ssd0/datasets/ImageNet/train' %} | ||
{% set eval_dataset_path = '/ssd/ssd0/datasets/ImageNet/val' %} | ||
{% set num_classes = 1001 %} | ||
{% set train_len = 1281167 %} | ||
{% set val_len = 50000 %} | ||
{% elif DATASET == "imagenet100" %} | ||
{% set train_dataset_path = '/ssd/ssd0/datasets/imagenet100/train' %} | ||
{% set eval_dataset_path = '/ssd/ssd0/datasets/imagenet100/val' %} | ||
{% set num_classes = 101 %} | ||
{% set train_len = 130000 %} | ||
{% set val_len = 5000 %} | ||
{% endif %} | ||
{% endif %} | ||
|
||
profiler: | ||
func: "mindspore.profiler.profiling.Profiler" | ||
output_path: "data" | ||
is_detail: True | ||
is_show_op_path: True | ||
|
||
{% set crop_size = 224 %} | ||
crop_size: {{crop_size}} | ||
|
||
{% if TRAIN_BATCH_SIZE is not none %} | ||
{% set train_batch_size = TRAIN_BATCH_SIZE %} | ||
{% else %} | ||
{% set train_batch_size = 256 %} | ||
{% endif %} | ||
train_batch_size: {{train_batch_size}} | ||
global_batch_size: {{train_batch_size * DEVICE_NUM}} | ||
{% set base_cfg = {'d_model': 384, 'depth': 12, 'heads': 6, 'mlp_dim': 1536, 'dim_head': 64} %} | ||
{% set normalized_shape = [base_cfg['d_model']] %} | ||
{% set patch_size = 16 %} | ||
network: | ||
func: 'networks.vit.ViT' | ||
d_model: {{base_cfg['d_model']}} | ||
image_size: {{crop_size}} | ||
patch_size: {{patch_size}} | ||
pool: cls | ||
dropout_rate: 0.1 | ||
initialization: | ||
func: mindspore.common.initializer.Normal | ||
sigma: 1.0 | ||
stem: | ||
func: 'networks.vit.VitStem' | ||
d_model: {{base_cfg['d_model']}} | ||
image_size: {{crop_size}} | ||
patch_size: {{patch_size}} | ||
initialization: | ||
func: mindspore.common.initializer.XavierUniform | ||
body: | ||
func: 'networks.transformer.Transformer' | ||
depth: {{base_cfg['depth']}} | ||
attention: | ||
_func: 'networks.transformer.Attention' | ||
size_cfg: {{base_cfg}} | ||
initialization: | ||
func: mindspore.common.initializer.XavierUniform | ||
activation: | ||
func: mindspore.nn.Softmax | ||
dropout_rate: 0.1 | ||
feedforward: | ||
_func: 'networks.transformer.FeedForward' | ||
size_cfg: {{base_cfg}} | ||
initialization: | ||
func: mindspore.common.initializer.XavierUniform | ||
activation: | ||
func: mindspore.nn.GELU | ||
dropout_rate: 0.1 | ||
normalization: | ||
_func: 'mindspore.nn.LayerNorm' | ||
normalized_shape: {{normalized_shape}} | ||
head: | ||
func: 'networks.vit.origin_head' | ||
size_cfg: {{base_cfg}} | ||
dropout_rate: 0.1 | ||
num_classes: {{num_classes}} | ||
activation: | ||
func: mindspore.nn.GELU | ||
initialization: | ||
func: mindspore.common.initializer.XavierUniform | ||
normalization: | ||
func: 'mindspore.nn.LayerNorm' | ||
normalized_shape: {{normalized_shape}} | ||
norm: | ||
_func: 'mindspore.nn.LayerNorm' | ||
normalized_shape: {{normalized_shape}} | ||
|
||
{% set resize_size = 256 %} | ||
|
||
|
||
train_dataset: | ||
func: "utils.dataset.create_dataset" | ||
dataset_path: {{train_dataset_path}} | ||
do_train: True | ||
batch_size: {{train_batch_size}} | ||
resize_size: {{resize_size}} | ||
crop_size: {{crop_size}} | ||
target: {{DEVICE}} | ||
autoaugment: True | ||
num_classes: {{num_classes}} | ||
mixup: 0.2 | ||
{% set train_batches_num = train_len // (train_batch_size * DEVICE_NUM)%} | ||
train_batches_num: {{train_batches_num}} | ||
train_len: {{train_len}} | ||
|
||
{% if EVAL_BATCH_SIZE is not none %} | ||
{% set eval_batch_size = EVAL_BATCH_SIZE %} | ||
{% else %} | ||
{% set eval_batch_size = 256 %} | ||
{% endif %} | ||
eval_batch_size: {{eval_batch_size}} | ||
eval_dataset: | ||
func: "utils.dataset.create_dataset" | ||
dataset_path: {{eval_dataset_path}} | ||
do_train: False | ||
batch_size: {{eval_batch_size}} | ||
resize_size: {{resize_size}} | ||
crop_size: {{crop_size}} | ||
target: {{DEVICE}} | ||
val_len: {{val_len}} | ||
|
||
{% set lr_per_bs256 = 0.00044375 %} | ||
lr_schedule: | ||
func: "utils.lr_generator.get_lr" | ||
start_epoch: {{START_EPOCH}} | ||
lr_init: 0.0 | ||
lr_end: 0.0 | ||
lr_max: {{lr_per_bs256 * DEVICE_NUM * train_batch_size / 256}} | ||
warmup_epochs: 40 | ||
total_epochs: {{NUM_EPOCHS}} | ||
steps_per_epoch: {{train_len // (train_batch_size * DEVICE_NUM)}} | ||
lr_decay_mode: 'cosine' | ||
poly_power: 2 | ||
|
||
{% set weight_decay = 0.05 %} | ||
{% set loss_scale = 1024 %} | ||
optimizer: | ||
func: "nn.optimizers.adamw_gcnorm.AdamW" | ||
params: | ||
func: "nn.optimizers.beta_bias_wd_filter" | ||
params: null | ||
weight_decay: {{weight_decay}} | ||
learning_rate: null | ||
loss_scale: {{loss_scale}} | ||
|
||
{% if STAT == 1 %} | ||
{% set sink_size = -1 %} | ||
{% set dataset_sink_mode = False %} | ||
{% else %} | ||
{% set sink_size = train_batches_num %} | ||
{% set dataset_sink_mode = True %} | ||
{% endif %} | ||
sink_size: {{sink_size}} | ||
dataset_sink_mode: {{dataset_sink_mode}} | ||
|
||
train_model: | ||
func: "mindspore.train.Model" | ||
network: null | ||
loss_fn: | ||
func: "nn.losses.cross_entropy.CrossEntropySmoothMixup" | ||
sparse: True | ||
reduction: "mean" | ||
smooth_factor: 0.1 | ||
num_classes: {{num_classes}} | ||
optimizer: null | ||
loss_scale_manager: | ||
func: "mindspore.train.loss_scale_manager.FixedLossScaleManager" | ||
loss_scale: {{loss_scale}} | ||
drop_overflow_update: False | ||
amp_level: "O0" | ||
keep_batchnorm_fp32: False | ||
metrics: | ||
acc: | ||
func: "nn.metrics.DistAccuracy" | ||
batch_size: {{eval_batch_size}} | ||
device_num: {{DEVICE_NUM}} | ||
val_len: {{val_len}} | ||
eval_network: | ||
func: "nn.metrics.ClassifyCorrectCell" | ||
network: null | ||
|
||
eval_model: | ||
func: "mindspore.train.Model" | ||
network: null | ||
loss_fn: | ||
func: "nn.losses.CrossEntropySmooth" | ||
sparse: True | ||
reduction: "mean" | ||
smooth_factor: 0.1 | ||
num_classes: {{num_classes}} | ||
amp_level: "O0" | ||
keep_batchnorm_fp32: False | ||
metrics: | ||
acc: | ||
func: "nn.metrics.DistAccuracy" | ||
batch_size: {{eval_batch_size}} | ||
device_num: {{DEVICE_NUM}} | ||
val_len: {{val_len}} | ||
eval_network: | ||
func: "nn.metrics.ClassifyCorrectCell" | ||
network: null | ||
|
||
{% set save_checkpoint_epochs = 5 %} | ||
checkpoint_callback: | ||
func: "mindspore.train.callback.ModelCheckpoint" | ||
prefix: {{name}} | ||
directory: "/cache/checkpoints" | ||
config: | ||
func: "mindspore.train.callback.CheckpointConfig" | ||
save_checkpoint_steps: {{save_checkpoint_epochs * train_batches_num}} | ||
keep_checkpoint_max: 10 | ||
|
||
eval: | ||
offset: 0 | ||
interval: 1 |
Oops, something went wrong.