Skip to content

Commit

Permalink
add callbacks to denoising step (huggingface#5427)
Browse files Browse the repository at this point in the history
* draft1

* update

* style

* move to the end of loop

* update

* update callbak_on_step_end_inputs

* Revert "update"

This reverts commit 5f9b153.

* Revert "update callbak_on_step_end_inputs"

This reverts commit 44889f4.

* update

* update test required_optional_params

* remove self.lora_scale

* img2img

* inpaint

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

* fix

* apply feedbacks on img2img + inpaint: keep only important pipeline attributes

* depth

* pix2pix

* make _callback_tensor_inputs an class variable so that we can use it for testing

* add a basic tst for callback

* add a read-only tensor input timesteps + fix tests

* add second test for callback cfg

* sdxl

* sdxl img2img

* sdxl inpaint

* kandinsky prior

* kandinsky decoder

* kandinsky img2img + combined

* kandinsky inpaint

* fix copies

* fix

* consistent default inputs

* fix copies

* wuerstchen_prior prior

* test_wuerstchen_decoder + fix test for prior

* wuerstchen_combined pipeline + skip tests

* skip test for kandinsky combined

* lcm

* remove timesteps etc

* add doc string

* copies

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Co-authored-by: Sayak Paul <[email protected]>

* Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

Co-authored-by: Sayak Paul <[email protected]>

* make style and improve tests

* up

* up

* fix more

* fix cfg test

* tests for callbacks

* fix for real

* update

* lcm img2img

* add doc

* add doc page to index

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
Co-authored-by: Patrick von Platen <[email protected]>
Co-authored-by: Sayak Paul <[email protected]>
Co-authored-by: Dhruv Nair <[email protected]>
  • Loading branch information
5 people authored Nov 5, 2023
1 parent 080081b commit 2b23ec8
Show file tree
Hide file tree
Showing 62 changed files with 2,514 additions and 582 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@
title: Kandinsky
- local: using-diffusers/controlnet
title: ControlNet
- local: using-diffusers/callback
title: Callback
- local: using-diffusers/shap-e
title: Shap-E
- local: using-diffusers/diffedit
Expand Down
60 changes: 60 additions & 0 deletions docs/source/en/using-diffusers/callback.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Using callback

[[open-in-colab]]

Most 🤗 Diffusers pipeline now accept a `callback_on_step_end` argument that allows you to change the default behavior of denoising loop with custom defined functions. Here is an example of a callback function we can write to disable classifier free guidance after 40% of inference steps to save compute with minimum tradeoff in performance.

```python
def callback_dynamic_cfg(pipe, step_index, timestep, callback_kwargs):
# adjust the batch_size of prompt_embeds according to guidance_scale
if step_index == int(pipe.num_timestep * 0.4):
prompt_embeds = callback_kwargs["prompt_embeds"]
prompt_embeds =prompt_embeds.chunk(2)[-1]

# update guidance_scale and prompt_embeds
pipe._guidance_scale = 0.0
callback_kwargs["prompt_embeds"] = prompt_embeds
return callback_kwargs
```

Your callback function has below arguments:
* `pipe` is the pipeline instance, which provides access to useful properties such as `num_timestep` and `guidance_scale`. You can modify these properties by updating the underlying attributes. In this example, we disable CFG by setting `pipe._guidance_scale` to be `0`.
* `step_index` and `timestep` tell you where you are in the denoising loop. In our example, we use `step_index` to decide when to turn off CFG.
* `callback_kwargs` is a dict that contains tensor variables you can modify during the denoising loop. It only includes variables specified in the `callback_on_step_end_tensor_inputs` argument passed to the pipeline's `__call__` method. Different pipelines may use different sets of variables so please check the pipeline class's `_callback_tensor_inputs` attribute for the list of variables that you can modify. Common variables include `latents` and `prompt_embeds`. In our example, we need to adjust the batch size of `prompt_embeds` after setting `guidance_scale` to be `0` in order for it to work properly.

You can pass the callback function as `callback_on_step_end` argument to the pipeline along with `callback_on_step_end_tensor_inputs`.

```
import torch
from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe = pipe.to("cuda")
prompt = "a photo of an astronaut riding a horse on mars"
generator = torch.Generator(device="cuda").manual_seed(1)
out= pipe(prompt, generator=generator, callback_on_step_end = callback_custom_cfg, callback_on_step_end_tensor_inputs=['prompt_embeds'])
out.images[0].save("out_custom_cfg.png")
```

Your callback function will be executed at the end of each denoising step and modify pipeline attributes and tensor variables for the next denoising step. We successfully added the "dynamic CFG" feature to the stable diffusion pipeline without having to modify the code at all.

<Tip>

Currently we only support `callback_on_step_end`. If you have a solid use case and require a callback function with a different execution point, please open an [feature request](https://github.com/huggingface/diffusers/issues/new/choose) so we can add it!

</Tip>
130 changes: 104 additions & 26 deletions src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class AltDiffusionPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraL
model_cpu_offload_seq = "text_encoder->unet->vae"
_optional_components = ["safety_checker", "feature_extractor"]
_exclude_from_cpu_offload = ["safety_checker"]
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]

def __init__(
self,
Expand Down Expand Up @@ -500,17 +501,23 @@ def check_inputs(
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found"
f" {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -581,6 +588,33 @@ def disable_freeu(self):
"""Disables the FreeU mechanism if enabled."""
self.unet.disable_freeu()

@property
def guidance_scale(self):
return self._guidance_scale

@property
def guidance_rescale(self):
return self._guidance_rescale

@property
def clip_skip(self):
return self._clip_skip

# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1

@property
def cross_attention_kwargs(self):
return self._cross_attention_kwargs

@property
def num_timesteps(self):
return self._num_timesteps

@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
Expand All @@ -599,11 +633,12 @@ def __call__(
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
**kwargs,
):
r"""
The call function to the pipeline for generation.
Expand Down Expand Up @@ -647,12 +682,6 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that calls every `callback_steps` steps during inference. The function is called with the
following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function is called. If not specified, the callback is called at
every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
Expand All @@ -663,6 +692,15 @@ def __call__(
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeine class.
Examples:
Expand All @@ -673,16 +711,47 @@ def __call__(
second element is a list of `bool`s indicating whether the corresponding generated image contains
"not-safe-for-work" (nsfw) content.
"""

callback = kwargs.pop("callback", None)
callback_steps = kwargs.pop("callback_steps", None)

if callback is not None:
deprecate(
"callback",
"1.0.0",
"Passing `callback` as an input argument to `__call__` is deprecated, consider using"
" `callback_on_step_end`",
)
if callback_steps is not None:
deprecate(
"callback_steps",
"1.0.0",
"Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using"
" `callback_on_step_end`",
)

# 0. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# to deal with lora scaling and other possible forward hooks

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
prompt,
height,
width,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
callback_on_step_end_tensor_inputs,
)

self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs

# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
Expand All @@ -692,29 +761,27 @@ def __call__(
batch_size = prompt_embeds.shape[0]

device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

# 3. Encode input prompt
lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
lora_scale = (
self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
)

prompt_embeds, negative_prompt_embeds = self.encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
self.do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=lora_scale,
clip_skip=clip_skip,
clip_skip=self.clip_skip,
)
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])

# 4. Prepare timesteps
Expand All @@ -739,33 +806,44 @@ def __call__(

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
cross_attention_kwargs=self.cross_attention_kwargs,
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if do_classifier_free_guidance and guidance_rescale > 0.0:
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
Expand Down
Loading

0 comments on commit 2b23ec8

Please sign in to comment.