Skip to content

Commit

Permalink
Allow training up to 1024 for both v1 and v2 on T4
Browse files Browse the repository at this point in the history
  • Loading branch information
TheLastBen authored Apr 25, 2023
1 parent 1df8f11 commit b5e9131
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions fast-DreamBooth.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -961,11 +961,6 @@
"\n",
"precision=prec\n",
"\n",
"s = getoutput('nvidia-smi')\n",
"GCUNET=\"--gradient_checkpointing\"\n",
"if 'A100' in s or Res<=768:\n",
" GCUNET=\"\"\n",
"\n",
"resuming=\"\"\n",
"if Resume_Training and os.path.exists(OUTPUT_DIR+'/unet/diffusion_pytorch_model.bin'):\n",
" MODELT_NAME=OUTPUT_DIR\n",
Expand All @@ -978,6 +973,23 @@
" print('\u001b[1;31mNo model found, use the \"Model Download\" cell to download a model.')\n",
" time.sleep(5)\n",
"\n",
"V2=False\n",
"if os.path.getsize(MODELT_NAME+\"/text_encoder/pytorch_model.bin\") > 670901463:\n",
" V2=True\n",
"\n",
"s = getoutput('nvidia-smi')\n",
"GCUNET=\"--gradient_checkpointing\"\n",
"TexRes=Res\n",
"if 'A100' in s or Res<=768:\n",
" GCUNET=\"\"\n",
"\n",
"if V2: \n",
" if Res>704:\n",
" GCUNET=\"--gradient_checkpointing\"\n",
" if Res>576:\n",
" TexRes=576\n",
"\n",
"\n",
"Enable_text_encoder_training= True\n",
"Enable_Text_Encoder_Concept_Training= True\n",
"\n",
Expand Down Expand Up @@ -1027,7 +1039,7 @@
" --captions_dir=\"$CAPTIONS_DIR\" \\\n",
" --instance_prompt=\"$PT\" \\\n",
" --seed=$Seed \\\n",
" --resolution=$Res \\\n",
" --resolution=$TexRes \\\n",
" --mixed_precision=$precision \\\n",
" --train_batch_size=1 \\\n",
" --gradient_accumulation_steps=1 --gradient_checkpointing \\\n",
Expand Down

0 comments on commit b5e9131

Please sign in to comment.