Skip to content

Commit

Permalink
controlnet training resize inputs to multiple of 8 (huggingface#3135)
Browse files Browse the repository at this point in the history
controlnet training center crop input images to multiple of 8

The pipeline code resizes inputs to multiples of 8.
Not doing this resizing in the training script is causing
the encoded image to have different height/width dimensions
than the encoded conditioning image (which uses a separate
encoder that's part of the controlnet model).

We resize and center crop the inputs to make sure they're the
same size (as well as all other images in the batch). We also
check that the initial resolution is a multiple of 8.
  • Loading branch information
williamberman authored Apr 19, 2023
1 parent a4c91be commit 7e6886f
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions examples/controlnet/train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,11 @@ def parse_args(input_args=None):
" or the same number of `--validation_prompt`s and `--validation_image`s"
)

if args.resolution % 8 != 0:
raise ValueError(
"`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder."
)

return args


Expand Down Expand Up @@ -607,6 +612,7 @@ def tokenize_captions(examples, is_train=True):
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
Expand All @@ -615,6 +621,7 @@ def tokenize_captions(examples, is_train=True):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
Expand Down

0 comments on commit 7e6886f

Please sign in to comment.