Skip to content

Commit

Permalink
[Dreambooth] flax fixes (huggingface#1765)
Browse files Browse the repository at this point in the history
* Fail if there are less images than the effective batch size.

* Remove lr-scheduler arg as it's currently ignored.

* Make guidance_scale work for batch_size > 1.
  • Loading branch information
pcuenca authored Dec 19, 2022
1 parent 8331da4 commit 9f8c915
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions examples/dreambooth/train_dreambooth_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,15 +142,6 @@ def parse_args():
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
Expand Down Expand Up @@ -429,6 +420,13 @@ def collate_fn(examples):
return batch

total_train_batch_size = args.train_batch_size * jax.local_device_count()
if len(train_dataset) < total_train_batch_size:
raise ValueError(
f"Training batch size is {total_train_batch_size}, but your dataset only contains"
f" {len(train_dataset)} images. Please, use a larger dataset or reduce the effective batch size. Note that"
f" there are {jax.local_device_count()} parallel devices, so your batch size can't be smaller than that."
)

train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=total_train_batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def __call__(
guidance_scale = jnp.array([guidance_scale] * prompt_ids.shape[0])
if len(prompt_ids.shape) > 2:
# Assume sharded
guidance_scale = guidance_scale.reshape(prompt_ids.shape[:2])
guidance_scale = guidance_scale[:, None]

if jit:
images = _p_generate(
Expand Down

0 comments on commit 9f8c915

Please sign in to comment.