Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/diffusers into …
Browse files Browse the repository at this point in the history
…main
  • Loading branch information
patrickvonplaten committed Nov 15, 2022
2 parents d5ab55e + a052019 commit 554b374
Show file tree
Hide file tree
Showing 76 changed files with 2,328 additions and 473 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,4 +163,6 @@ tags
*.lock

# DS_Store (MacOS)
.DS_Store
.DS_Store
# RL pipelines may produce mp4 outputs
*.mp4
10 changes: 1 addition & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`.
```python
from diffusers import LMSDiscreteScheduler

lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")

pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5",
revision="fp16",
torch_dtype=torch.float16,
scheduler=lms,
)
pipe = pipe.to("cuda")
pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config)

prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt).images[0]
Expand Down
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
- sections:
- local: using-diffusers/loading
title: "Loading Pipelines, Models, and Schedulers"
- local: using-diffusers/schedulers
title: "Using different Schedulers"
- local: using-diffusers/configuration
title: "Configuring Pipelines, Models, and Schedulers"
- local: using-diffusers/custom_pipeline_overview
Expand Down
4 changes: 2 additions & 2 deletions docs/source/api/configuration.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License.
In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are
passed to the respective `__init__` methods in a JSON-configuration file.

TODO(PVP) - add example and better info here

## ConfigMixin

[[autodoc]] ConfigMixin
- load_config
- from_config
- save_config
9 changes: 6 additions & 3 deletions docs/source/api/models.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## UNet2DOutput
[[autodoc]] models.unet_2d.UNet2DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DModel
[[autodoc]] UNet2DModel

## UNet1DOutput
[[autodoc]] models.unet_1d.UNet1DOutput

## UNet1DModel
[[autodoc]] UNet1DModel

## UNet2DConditionOutput
[[autodoc]] models.unet_2d_condition.UNet2DConditionOutput

Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/pipelines/cycle_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler
# load the pipeline
# make sure you're logged in with `huggingface-cli login`
model_id_or_path = "CompVis/stable-diffusion-v1-4"
scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler")
scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler")
pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda")

# let's download an initial image
Expand Down
2 changes: 1 addition & 1 deletion docs/source/api/pipelines/repaint.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256))
mask_image = download_image(mask_url).resize((256, 256))

# Load the RePaint scheduler and pipeline based on a pretrained DDPM model
scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256")
scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256")
pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler)
pipe = pipe.to("cuda")

Expand Down
12 changes: 8 additions & 4 deletions docs/source/api/pipelines/stable_diffusion.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba
### How to load and use different schedulers.

The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc.
To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:
To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following:

```python
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler

euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4")
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
>>> # or
>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler)
```
Expand Down
25 changes: 13 additions & 12 deletions docs/source/quicktour.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat
```python
>>> from diffusers import DiffusionPipeline

>>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256")
```
The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components.
Because the model consists of roughly 1.4 billion parameters, we strongly recommend running it on GPU.
You can move the generator object to GPU, just like you would in PyTorch.
```python
>>> generator.to("cuda")
>>> pipeline.to("cuda")
```

Now you can use the `generator` on your text prompt:
Now you can use the `pipeline` on your text prompt:

```python
>>> image = generator("An image of a squirrel in Picasso style").images[0]
>>> image = pipeline("An image of a squirrel in Picasso style").images[0]
```

The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class).
Expand All @@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`:
```python
>>> from diffusers import DiffusionPipeline

>>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)
```

If you do not pass your authentication token you will see that the diffusion system will not be correctly
Expand All @@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned
you can also load the pipeline as follows:

```python
>>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5")
```

Running the pipeline is then identical to the code above as it's the same model architecture.
Expand All @@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model

Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their
pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to
use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler,
use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler,
you could use it as follows:

```python
>>> from diffusers import LMSDiscreteScheduler
>>> from diffusers import EulerDiscreteScheduler

>>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler")
>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN)

>>> generator = StableDiffusionPipeline.from_pretrained(
... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN
... )
>>> # change scheduler to Euler
>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config)
```

For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide.

[Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model
and can do much more than just generating images from text. We have dedicated a whole documentation page,
just for Stable Diffusion [here](./conceptual/stable_diffusion).
Expand Down
26 changes: 13 additions & 13 deletions docs/source/using-diffusers/loading.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load:

- *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`]
- *Diffusion Models* via [`ModelMixin.from_pretrained`]
- *Schedulers* via [`ConfigMixin.from_config`]
- *Schedulers* via [`SchedulerMixin.from_pretrained`]

## Loading pipelines

Expand Down Expand Up @@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis

repo_id = "runwayml/stable-diffusion-v1-5"

scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
# or
# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")

stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler)
```

Three things are worth paying attention to here.
- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights
- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`]
- Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler)
- Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`]

Expand Down Expand Up @@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id)

## Loading schedulers

Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
Therefore the loading method was given a different name here.
Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file.
For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case.

In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers.
For example, all of:
Expand Down Expand Up @@ -367,13 +367,13 @@ from diffusers import (

repo_id = "runwayml/stable-diffusion-v1-5"

ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler")
ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler")
ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler")
pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler")
lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler")
dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler")

# replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc`
pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm)
Expand Down
Loading

0 comments on commit 554b374

Please sign in to comment.