Skip to content

Commit

Permalink
Merge pull request lucidrains#302 from afiaka87/patch-19
Browse files Browse the repository at this point in the history
Expose flops_profiler, attn_dropout, ff_dropout
  • Loading branch information
lucidrains authored Jun 15, 2021
2 parents dc147ca + 96a3286 commit cec0797
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions train_dalle.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,8 @@
help='(experimental) - Enable DeepSpeed 16 bit precision. Reduces VRAM.')


parser.add_argument(
'--amp',
action='store_true',
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.'
)
parser.add_argument('--amp', action='store_true',
help='Apex "O1" automatic mixed precision. More stable than 16 bit precision. Can\'t be used in conjunction with deepspeed zero stages 1-3.')

parser.add_argument('--wandb_name', default='dalle_train_transformer',
help='Name W&B will use when saving results.\ne.g. `--wandb_name "coco2017-full-sparse"`')
Expand All @@ -67,6 +64,8 @@

train_group = parser.add_argument_group('Training settings')

train_group.add_argument('--flops_profiler', dest = 'flops_profiler', action='store_true', help = 'Exits after printing detailed flops/runtime analysis of forward/backward')

train_group.add_argument('--epochs', default = 20, type = int, help = 'Number of epochs')

train_group.add_argument('--save_every_n_steps', default = 1000, type = int, help = 'Save a checkpoint every n steps')
Expand Down Expand Up @@ -95,6 +94,10 @@

model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension')

train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')

train_group.add_argument('--attn_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')

model_group.add_argument('--reversible', dest = 'reversible', action='store_true')

model_group.add_argument('--loss_img_weight', default = 7, type = int, help = 'Image loss weight')
Expand Down Expand Up @@ -151,6 +154,8 @@ def cp_path_to_dir(cp_path, tag):
DIM_HEAD = args.dim_head
REVERSIBLE = args.reversible
LOSS_IMG_WEIGHT = args.loss_img_weight
FF_DROPOUT = args.ff_dropout
ATTN_DROPOUT = args.attn_dropout

ATTN_TYPES = tuple(args.attn_types.split(','))

Expand Down Expand Up @@ -233,6 +238,8 @@ def cp_path_to_dir(cp_path, tag):
reversible=REVERSIBLE,
loss_img_weight=LOSS_IMG_WEIGHT,
attn_types=ATTN_TYPES,
ff_dropout=FF_DROPOUT,
attn_dropout=ATTN_DROPOUT,
)

# configure OpenAI VAE for float16s
Expand Down Expand Up @@ -342,6 +349,14 @@ def group_weight(model):
'enabled': args.amp,
'opt_level': 'O1',
},
"flops_profiler": {
"enabled": args.flops_profiler,
"profile_step": 200,
"module_depth": -1,
"top_modules": 1,
"detailed": True,
"output_file": None # TODO Can't get this to work.
},
}

if deepspeed_config.get('zero_optimization', {}).get('stage', 0) >= 2:
Expand Down Expand Up @@ -477,12 +492,14 @@ def save_model(path):
if not avoid_model_calls:
log['image'] = wandb.Image(image, caption=decoded_text)


if i % 10 == 9 and distr_backend.is_root_worker():
sample_per_sec = BATCH_SIZE * 10 / (time.time() - t)
log["sample_per_sec"] = sample_per_sec
print(epoch, i, f'sample_per_sec - {sample_per_sec}')

if i == 201 and args.flops_profiler:
raise StopIteration("Profiler has finished running. Stopping training early.")

if distr_backend.is_root_worker():
wandb.log(log)

Expand Down

0 comments on commit cec0797

Please sign in to comment.