From c345c33ad4425d77cac70db813dc1a8aa17de760 Mon Sep 17 00:00:00 2001 From: afiaka87 <3994972+afiaka87@users.noreply.github.com> Date: Sun, 13 Jun 2021 07:41:30 -0500 Subject: [PATCH 1/4] Expose flops_profiler, attn_dropout, ff_dropout --- train_dalle.py | 29 +++++++++++++++++++++++------ 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/train_dalle.py b/train_dalle.py index 81ce315a..7e663d9d 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -54,11 +54,8 @@ help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.') -parser.add_argument( - '--amp', - action='store_true', - help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.' -) +parser.add_argument('--amp', action='store_true', + help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.') parser.add_argument('--wandb_name', default='dalle_train_transformer', help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`') @@ -67,6 +64,8 @@ train_group = parser.add_argument_group('Training settings') +train_group.add_argument('--flops_profiler', dest = 'flops_profiler', action='store_true', help = 'Exits after printing detailed flops/runtime analysis of forward/backward') + train_group.add_argument('--epochs', default = 20, type = int, help = 'Number of epochs') train_group.add_argument('--save_every_n_steps', default = 1000, type = int, help = 'Save a checkpoint every n steps') @@ -95,6 +94,10 @@ model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension') +train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.') + +train_group.add_argument('--attn_dropout', default = 0.0, type = float, help = 'Feed forward dropout.') + model_group.add_argument('--reversible', dest = 'reversible', action='store_true') model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight') @@ -151,6 +154,8 @@ def cp_path_to_dir(cp_path, tag): DIM_HEAD = args.dim_head REVERSIBLE = args.reversible LOSS_IMG_WEIGHT = args.loss_img_weight +FF_DROPOUT = args.ff_dropout +ATTN_DROPOUT = args.attn_dropout ATTN_TYPES = tuple(args.attn_types.split(',')) @@ -233,6 +238,8 @@ def cp_path_to_dir(cp_path, tag): reversible=REVERSIBLE, loss_img_weight=LOSS_IMG_WEIGHT, attn_types=ATTN_TYPES, + ff_dropout=FF_DROPOUT, + attn_dropout=ATTN_DROPOUT, ) # configure OpenAI VAE for float16s @@ -342,6 +349,14 @@ def group_weight(model): 'enabled': args.amp, 'opt_level': 'O1', }, + "flops_profiler": { + "enabled": args.flops_profiler, + "profile_step": 200, + "module_depth": -1, + "top_modules": 1, + "detailed": True, + "output_file": None # TODO Can't get this to work. + }, } if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2: @@ -419,6 +434,7 @@ def save_model(path): if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): + if i % 10 == 0 and distr_backend.is_root_worker(): t = time.time() if args.fp16: @@ -472,7 +488,8 @@ def save_model(path): if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) - + if i == 201: + raise StopIteration("E if i % 10 == 9 and distr_backend.is_root_worker(): sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) log["sample_per_sec"] = sample_per_sec From 01370177b7b0b7dc8390012afe243f15170c55fa Mon Sep 17 00:00:00 2001 From: afiaka87 <3994972+afiaka87@users.noreply.github.com> Date: Mon, 14 Jun 2021 19:52:38 -0500 Subject: [PATCH 2/4] Fix syntax error --- train_dalle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_dalle.py b/train_dalle.py index 7e663d9d..20849aba 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -434,7 +434,9 @@ def save_model(path): if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): - + if i == 201 and args.flops_profiler: + raise StopIteration("Profiler has finished running. Stopping training early.") + if i % 10 == 0 and distr_backend.is_root_worker(): t = time.time() if args.fp16: @@ -488,8 +490,6 @@ def save_model(path): if not avoid_model_calls: log['image'] = wandb.Image(image, caption=decoded_text) - if i == 201: - raise StopIteration("E if i % 10 == 9 and distr_backend.is_root_worker(): sample_per_sec = BATCH_SIZE * 10 / (time.time() - t) log["sample_per_sec"] = sample_per_sec From 659b2d86a1295f6f97c0a83dd8b55c473bdc1e76 Mon Sep 17 00:00:00 2001 From: afiaka87 <3994972+afiaka87@users.noreply.github.com> Date: Mon, 14 Jun 2021 20:09:59 -0500 Subject: [PATCH 3/4] Expose profiler, ff dropout, attn dropout --- train_dalle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/train_dalle.py b/train_dalle.py index 20849aba..2d31cdeb 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -434,7 +434,7 @@ def save_model(path): if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): - if i == 201 and args.flops_profiler: + if i == 201 and args.flops_profiler: raise StopIteration("Profiler has finished running. Stopping training early.") if i % 10 == 0 and distr_backend.is_root_worker(): From 96a3286d03e4d1a8a792a1ae27b386d4b763851a Mon Sep 17 00:00:00 2001 From: afiaka87 <3994972+afiaka87@users.noreply.github.com> Date: Mon, 14 Jun 2021 20:27:08 -0500 Subject: [PATCH 4/4] `--flops_profiler`, `--attn_dropout`, `--ff_dropout` --- train_dalle.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/train_dalle.py b/train_dalle.py index 2d31cdeb..7399ecb1 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -434,9 +434,6 @@ def save_model(path): if data_sampler: data_sampler.set_epoch(epoch) for i, (text, images) in enumerate(distr_dl): - if i == 201 and args.flops_profiler: - raise StopIteration("Profiler has finished running. Stopping training early.") - if i % 10 == 0 and distr_backend.is_root_worker(): t = time.time() if args.fp16: @@ -495,6 +492,9 @@ def save_model(path): log["sample_per_sec"] = sample_per_sec print(epoch, i, f'sample_per_sec - {sample_per_sec}') + if i == 201 and args.flops_profiler: + raise StopIteration("Profiler has finished running. Stopping training early.") + if distr_backend.is_root_worker(): wandb.log(log)