diff --git a/train_dalle.py b/train_dalle.py index 2b3894d8..f86be516 100644 --- a/train_dalle.py +++ b/train_dalle.py @@ -106,13 +106,13 @@ model_group.add_argument('--dim', default = 512, type = int, help = 'Model dimension') -model_group.add_argument('--text_seq_len', default = 128, type = int, help = 'Text sequence length') +model_group.add_argument('--text_seq_len', default = 256, type = int, help = 'Text sequence length') model_group.add_argument('--depth', default = 2, type = int, help = 'Model depth') -model_group.add_argument('--heads', default = 4, type = int, help = 'Model number of heads') +model_group.add_argument('--heads', default = 8, type = int, help = 'Model number of heads') -model_group.add_argument('--dim_head', default = 16, type = int, help = 'Model head dimension') +model_group.add_argument('--dim_head', default = 64, type = int, help = 'Model head dimension') train_group.add_argument('--ff_dropout', default = 0.0, type = float, help = 'Feed forward dropout.')