Skip to content

Commit

Permalink
Allow resolutions that are not multiples of 64 (huggingface#505)
Browse files Browse the repository at this point in the history
* Allow resolutions that are not multiples of 64

* ran black

* fix bug

* add test

* more explanation

* more comments

Co-authored-by: Patrick von Platen <[email protected]>
  • Loading branch information
jachiam and patrickvonplaten authored Sep 30, 2022
1 parent 9ebaea5 commit a784be2
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 12 deletions.
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def run(self):
else:
extras["flax"] = deps_list("jax", "jaxlib", "flax")

extras["dev"] = extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
extras["dev"] = (
extras["quality"] + extras["test"] + extras["training"] + extras["docs"] + extras["torch"] + extras["flax"]
)

install_requires = [
deps["importlib_metadata"],
Expand Down
10 changes: 8 additions & 2 deletions src/diffusers/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,18 @@ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_chann
else:
self.Conv2d_0 = conv

def forward(self, hidden_states):
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels

if self.use_conv_transpose:
return self.conv(hidden_states)

hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")

# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
if self.use_conv:
Expand Down
48 changes: 42 additions & 6 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from ..utils import BaseOutput, logging
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock2D,
Expand All @@ -20,6 +20,9 @@
)


logger = logging.get_logger(__name__) # pylint: disable=invalid-name


@dataclass
class UNet2DConditionOutput(BaseOutput):
"""
Expand Down Expand Up @@ -145,15 +148,25 @@ def __init__(
resnet_groups=norm_num_groups,
)

# count how many layers upsample the images
self.num_upsamplers = 0

# up
reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
is_final_block = i == len(block_out_channels) - 1

prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]

is_final_block = i == len(block_out_channels) - 1
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False

up_block = get_up_block(
up_block_type,
Expand All @@ -162,7 +175,7 @@ def __init__(
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=not is_final_block,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
Expand Down Expand Up @@ -223,6 +236,20 @@ def forward(
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers

# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None

if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True

# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
Expand Down Expand Up @@ -262,20 +289,29 @@ def forward(
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)

# 5. up
for upsample_block in self.up_blocks:
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1

res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]

# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]

if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
)
else:
sample = upsample_block(hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples)

sample = upsample_block(
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
)
# 6. post-process
# make sure hidden states is in float32
# when running in half-precision
Expand Down
7 changes: 4 additions & 3 deletions src/diffusers/models/unet_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,7 @@ def forward(
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
Expand All @@ -1151,7 +1152,7 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states

Expand Down Expand Up @@ -1204,7 +1205,7 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
Expand All @@ -1225,7 +1226,7 @@ def custom_forward(*inputs):

if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states)
hidden_states = upsampler(hidden_states, upsample_size)

return hidden_states

Expand Down
49 changes: 49 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,55 @@ def test_stable_diffusion_ddim(self):
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_ddim_factor_8(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
scheduler = DDIMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)

vae = self.dummy_vae
bert = self.dummy_text_encoder
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

# make sure here that pndm scheduler skips prk
sd_pipe = StableDiffusionPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
text_encoder=bert,
tokenizer=tokenizer,
safety_checker=self.dummy_safety_checker,
feature_extractor=self.dummy_extractor,
)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

prompt = "A painting of a squirrel eating a burger"

generator = torch.Generator(device=device).manual_seed(0)
output = sd_pipe(
[prompt],
generator=generator,
guidance_scale=6.0,
height=536,
width=536,
num_inference_steps=2,
output_type="np",
)
image = output.images

image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 134, 134, 3)
expected_slice = np.array([0.7834, 0.5488, 0.5781, 0.46, 0.3609, 0.5369, 0.542, 0.4855, 0.5557])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

def test_stable_diffusion_pndm(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down

0 comments on commit a784be2

Please sign in to comment.