Skip to content

Commit

Permalink
Docs: short section on changing the scheduler in Flax (huggingface#2181)
Browse files Browse the repository at this point in the history
* Short doc on changing the scheduler in Flax.

* Apply fix from @patil-suraj

Co-authored-by: Suraj Patil <[email protected]>

---------

Co-authored-by: Suraj Patil <[email protected]>
  • Loading branch information
pcuenca and patil-suraj authored Feb 2, 2023
1 parent 68ef066 commit 2bbd532
Showing 1 changed file with 52 additions and 0 deletions.
52 changes: 52 additions & 0 deletions docs/source/en/using-diffusers/schedulers.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ image
<br>
</p>

If you are a JAX/Flax user, please check [this section](#changing-the-scheduler-in-flax) instead.

## Compare schedulers

Expand Down Expand Up @@ -260,3 +261,54 @@ image

As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different
schedulers to compare results.

## Changing the Scheduler in Flax

If you are a JAX/Flax user, you can also change the default pipeline scheduler. This is a complete example of how to run inference using the Flax Stable Diffusion pipeline and the super-fast [DDPM-Solver++ scheduler](../api/schedulers/multistep_dpm_solver):

```Python
import jax
import numpy as np
from flax.jax_utils import replicate
from flax.training.common_utils import shard

from diffusers import FlaxStableDiffusionPipeline, FlaxDPMSolverMultistepScheduler

model_id = "runwayml/stable-diffusion-v1-5"
scheduler, scheduler_state = FlaxDPMSolverMultistepScheduler.from_pretrained(
model_id,
subfolder="scheduler"
)
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained(
model_id,
scheduler=scheduler,
revision="bf16",
dtype=jax.numpy.bfloat16,
)
params["scheduler"] = scheduler_state

# Generate 1 image per parallel device (8 on TPUv2-8 or TPUv3-8)
prompt = "a photo of an astronaut riding a horse on mars"
num_samples = jax.device_count()
prompt_ids = pipeline.prepare_inputs([prompt] * num_samples)

prng_seed = jax.random.PRNGKey(0)
num_inference_steps = 25

# shard inputs and rng
params = replicate(params)
prng_seed = jax.random.split(prng_seed, jax.device_count())
prompt_ids = shard(prompt_ids)

images = pipeline(prompt_ids, params, prng_seed, num_inference_steps, jit=True).images
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
```

<Tip warning={true}>

The following Flax schedulers are _not yet compatible_ with the Flax Stable Diffusion Pipeline:

- `FlaxLMSDiscreteScheduler`
- `FlaxDDPMScheduler`

</Tip>

0 comments on commit 2bbd532

Please sign in to comment.