forked from huggingface/diffusers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* initial * initial * added initial convert script for paella vqmodel * initial wuerstchen pipeline * add LayerNorm2d * added modules * fix typo * use model_v2 * embed clip caption amd negative_caption * fixed name of var * initial modules in one place * WuerstchenPriorPipeline * inital shape * initial denoising prior loop * fix output * add WuerstchenPriorPipeline to __init__.py * use the noise ratio in the Prior * try to save pipeline * save_pretrained working * Few additions * add _execution_device * shape is int * fix batch size * fix shape of ratio * fix shape of ratio * fix output dataclass * tests folder * fix formatting * fix float16 + started with generator * Update pipeline_wuerstchen.py * removed vqgan code * add WuerstchenGeneratorPipeline * fix WuerstchenGeneratorPipeline * fix docstrings * fix imports * convert generator pipeline * fix convert * Work on Generator Pipeline. WIP * Pipeline works with our diffuzz code * apply scale factor * removed vqgan.py * use cosine schedule * redo the denoising loop * Update src/diffusers/models/resnet.py Co-authored-by: Patrick von Platen <[email protected]> * use torch.lerp * use warp-diffusion org * clip_sample=False, * some refactoring * use model_v3_stage_c * c_cond size * use clip-bigG * allow stage b clip to be None * add dummy * würstchen scheduler * minor changes * set clip=None in the pipeline * fix attention mask * add attention_masks to text_encoder * make fix-copies * add back clip * add text_encoder * gen_text_encoder and tokenizer * fix import * updated pipeline test * undo changes to pipeline test * nip * fix typo * fix output name * set guidance_scale=0 and remove diffuze * fix doc strings * make style * nip * removed unused * initial docs * rename * toc * cleanup * remvoe test script * fix-copies * fix multi images * remove dup * remove unused modules * undo changes for debugging * no new line * remove dup conversion script * fix doc string * cleanup * pass default args * dup permute * fix some tests * fix prepare_latents * move Prior class to modules * offload only the text encoder and vqgan * fix resolution calculation for prior * nip * removed testing script * fix shape * fix argument to set_timesteps * do not change .gitignore * fix resolution calculations + readme * resolution calculation fix + readme * small fixes * Add combined pipeline * rename generator -> decoder * Update .gitignore Co-authored-by: Patrick von Platen <[email protected]> * removed efficient_net * create combined WuerstchenPipeline * make arguments consistent with VQ model * fix var names * no need to return text_encoder_hidden_states * add latent_dim_scale to config * split model into its own file * add WuerschenPipeline to docs * remove unused latent_size * register latent_dim_scale * update script * update docstring * use Attention preprocessor * concat with normed input * fix-copies * add docs * fix test * fix style * add to cpu_offloaded_model * updated type * remove 1-line func * updated type * initial decoder test * formatting * formatting * fix autodoc link * num_inference_steps is int * remove comments * fix example in docs * Update src/diffusers/pipelines/wuerstchen/diffnext.py Co-authored-by: Patrick von Platen <[email protected]> * rename layernorm to WuerstchenLayerNorm * rename DiffNext to WuerstchenDiffNeXt * added comment about MixingResidualBlock * move paella vq-vae to pipelines' folder * initial decoder test * increased test_float16_inference expected diff * self_attn is always true * more passing decoder tests * batch image_embeds * fix failing tests * set the correct dtype * relax inference test * update prior * added combined pipeline test * faster test * faster test * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * fix issues from review * update wuerstchen.md + change generator name * resolve issues * fix copied from usage and add back batch_size * fix API * fix arguments * fix combined test * Added timesteps argument + fixes * Update tests/pipelines/test_pipelines_common.py Co-authored-by: Patrick von Platen <[email protected]> * Update tests/pipelines/wuerstchen/test_wuerstchen_prior.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py Co-authored-by: Patrick von Platen <[email protected]> * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py * up * Fix more * failing tests * up * up * correct naming * correct docs * correct docs * fix test params * correct docs * fix classifier free guidance * fix classifier free guidance * fix more * fix all * make tests faster --------- Co-authored-by: Dominic Rampas <[email protected]> Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Dominic Rampas <[email protected]>
- Loading branch information
1 parent
b76274c
commit 541bb6e
Showing
26 changed files
with
2,838 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
# Würstchen | ||
|
||
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/0617c863-165a-43ee-9303-2a17299a0cf9"> | ||
|
||
[Würstchen: Efficient Pretraining of Text-to-Image Models](https://huggingface.co/papers/2306.00637) is by Pablo Pernias, Dominic Rampas, and Marc Aubreville. | ||
|
||
The abstract from the paper is: | ||
|
||
*We introduce Würstchen, a novel technique for text-to-image synthesis that unites competitive performance with unprecedented cost-effectiveness and ease of training on constrained hardware. Building on recent advancements in machine learning, our approach, which utilizes latent diffusion strategies at strong latent image compression rates, significantly reduces the computational burden, typically associated with state-of-the-art models, while preserving, if not enhancing, the quality of generated images. Wuerstchen achieves notable speed improvements at inference time, thereby rendering real-time applications more viable. One of the key advantages of our method lies in its modest training requirements of only 9,200 GPU hours, slashing the usual costs significantly without compromising the end performance. In a comparison against the state-of-the-art, we found the approach to yield strong competitiveness. This paper opens the door to a new line of research that prioritizes both performance and computational accessibility, hence democratizing the use of sophisticated AI technologies. Through Wuerstchen, we demonstrate a compelling stride forward in the realm of text-to-image synthesis, offering an innovative path to explore in future research.* | ||
|
||
## Würstchen v2 comes to Diffusers | ||
|
||
After the initial paper release, we have improved numerous things in the architecture, training and sampling, making Würstchen competetive to current state-of-the-art models in many ways. We are excited to release this new version together with Diffusers. Here is a list of the improvements. | ||
|
||
- Higher resolution (1024x1024 up to 2048x2048) | ||
- Faster inference | ||
- Multi Aspect Resolution Sampling | ||
- Better quality | ||
|
||
We are releasing 3 checkpoints for the text-conditional image generation model (Stage C). Those are: | ||
- v2-base | ||
- v2-aesthetic | ||
- v2-interpolated (50% interpolation between v2-base and v2-aesthetic) | ||
|
||
We recommend to use v2-interpolated, as it has a nice touch of both photorealism and aesthetic. Use v2-base for finetunings as it does not have a style bias and use v2-aesthetic for very artistic generations. | ||
A comparison can be seen here: | ||
|
||
<img src="https://github.com/dome272/Wuerstchen/assets/61938694/2914830f-cbd3-461c-be64-d50734f4b49d" width=500> | ||
|
||
## Text-to-Image Generation | ||
|
||
For the sake of usability Würstchen can be used with a single pipeline. This pipeline is called `WuerstchenCombinedPipeline` and can be used as follows: | ||
|
||
```python | ||
import torch | ||
from diffusers import AutoPipelineForText2Image | ||
|
||
device = "cuda" | ||
dtype = torch.float16 | ||
num_images_per_prompt = 2 | ||
|
||
pipeline = AutoPipelineForText2Image.from_pretrained( | ||
"warp-diffusion/wuerstchen", torch_dtype=dtype | ||
).to(device) | ||
|
||
caption = "Anthropomorphic cat dressed as a fire fighter" | ||
negative_prompt = "" | ||
|
||
output = pipeline( | ||
prompt=caption, | ||
height=1024, | ||
width=1024, | ||
negative_prompt=negative_prompt, | ||
prior_guidance_scale=4.0, | ||
decoder_guidance_scale=0.0, | ||
num_images_per_prompt=num_images_per_prompt, | ||
output_type="pil", | ||
).images | ||
``` | ||
|
||
For explanation purposes, we can also initialize the two main pipelines of Würstchen individually. Würstchen consists of 3 stages: Stage C, Stage B, Stage A. They all have different jobs and work only together. When generating text-conditional images, Stage C will first generate the latents in a very compressed latent space. This is what happens in the `prior_pipeline`. Afterwards, the generated latents will be passed to Stage B, which decompresses the latents into a bigger latent space of a VQGAN. These latents can then be decoded by Stage A, which is a VQGAN, into the pixel-space. Stage B & Stage A are both encapsulated in the `decoder_pipeline`. For more details, take a look the [paper](https://huggingface.co/papers/2306.00637). | ||
|
||
```python | ||
import torch | ||
from diffusers import WuerstchenDecoderPipeline, WuerstchenPriorPipeline | ||
|
||
device = "cuda" | ||
dtype = torch.float16 | ||
num_images_per_prompt = 2 | ||
|
||
prior_pipeline = WuerstchenPriorPipeline.from_pretrained( | ||
"warp-diffusion/wuerstchen-prior", torch_dtype=dtype | ||
).to(device) | ||
decoder_pipeline = WuerstchenDecoderPipeline.from_pretrained( | ||
"warp-diffusion/wuerstchen", torch_dtype=dtype | ||
).to(device) | ||
|
||
caption = "A captivating artwork of a mysterious stone golem" | ||
negative_prompt = "" | ||
|
||
prior_output = prior_pipeline( | ||
prompt=caption, | ||
height=1024, | ||
width=1024, | ||
negative_prompt=negative_prompt, | ||
guidance_scale=4.0, | ||
num_images_per_prompt=num_images_per_prompt, | ||
) | ||
decoder_output = decoder_pipeline( | ||
image_embeddings=prior_output.image_embeddings, | ||
prompt=caption, | ||
negative_prompt=negative_prompt, | ||
num_images_per_prompt=num_images_per_prompt, | ||
guidance_scale=0.0, | ||
output_type="pil", | ||
).images | ||
``` | ||
|
||
## Speed-Up Inference | ||
You can make use of ``torch.compile`` function and gain a speed-up of about 2-3x: | ||
|
||
```python | ||
pipeline.prior = torch.compile(pipeline.prior, mode="reduce-overhead", fullgraph=True) | ||
pipeline.decoder = torch.compile(pipeline.decoder, mode="reduce-overhead", fullgraph=True) | ||
``` | ||
|
||
## Limitations | ||
- Due to the high compression employed by Würstchen, generations can lack a good amount | ||
of detail. To our human eye, this is especially noticeable in faces, hands etc. | ||
- **Images can only be generated in 128-pixel steps**, e.g. the next higher resolution | ||
after 1024x1024 is 1152x1152 | ||
- The model lacks the ability to render correct text in images | ||
- The model often does not achieve photorealism | ||
- Difficult compositional prompts are hard for the model | ||
|
||
|
||
The original codebase, as well as experimental ideas, can be found at [dome272/Wuerstchen](https://github.com/dome272/Wuerstchen). | ||
|
||
## WuerschenPipeline | ||
|
||
[[autodoc]] WuerstchenCombinedPipeline | ||
- all | ||
- __call__ | ||
|
||
## WuerstchenPriorPipeline | ||
|
||
[[autodoc]] WuerstchenDecoderPipeline | ||
|
||
- all | ||
- __call__ | ||
|
||
## WuerstchenPriorPipelineOutput | ||
|
||
[[autodoc]] pipelines.wuerstchen.pipeline_wuerstchen_prior.WuerstchenPriorPipelineOutput | ||
|
||
## WuerstchenDecoderPipeline | ||
|
||
[[autodoc]] WuerstchenDecoderPipeline | ||
- all | ||
- __call__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Run inside root directory of official source code: https://github.com/dome272/wuerstchen/ | ||
import os | ||
|
||
import torch | ||
from transformers import AutoTokenizer, CLIPTextModel | ||
from vqgan import VQModel | ||
|
||
from diffusers import ( | ||
DDPMWuerstchenScheduler, | ||
WuerstchenCombinedPipeline, | ||
WuerstchenDecoderPipeline, | ||
WuerstchenPriorPipeline, | ||
) | ||
from diffusers.pipelines.wuerstchen import PaellaVQModel, WuerstchenDiffNeXt, WuerstchenPrior | ||
|
||
|
||
model_path = "models/" | ||
device = "cpu" | ||
|
||
paella_vqmodel = VQModel() | ||
state_dict = torch.load(os.path.join(model_path, "vqgan_f4_v1_500k.pt"), map_location=device)["state_dict"] | ||
paella_vqmodel.load_state_dict(state_dict) | ||
|
||
state_dict["vquantizer.embedding.weight"] = state_dict["vquantizer.codebook.weight"] | ||
state_dict.pop("vquantizer.codebook.weight") | ||
vqmodel = PaellaVQModel(num_vq_embeddings=paella_vqmodel.codebook_size, latent_channels=paella_vqmodel.c_latent) | ||
vqmodel.load_state_dict(state_dict) | ||
|
||
# Clip Text encoder and tokenizer | ||
text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k") | ||
|
||
# Generator | ||
gen_text_encoder = CLIPTextModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K").to("cpu") | ||
gen_tokenizer = AutoTokenizer.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") | ||
|
||
orig_state_dict = torch.load(os.path.join(model_path, "model_v2_stage_b.pt"), map_location=device)["state_dict"] | ||
state_dict = {} | ||
for key in orig_state_dict.keys(): | ||
if key.endswith("in_proj_weight"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
elif key.endswith("in_proj_bias"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
elif key.endswith("out_proj.weight"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
elif key.endswith("out_proj.bias"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
else: | ||
state_dict[key] = orig_state_dict[key] | ||
deocder = WuerstchenDiffNeXt() | ||
deocder.load_state_dict(state_dict) | ||
|
||
# Prior | ||
orig_state_dict = torch.load(os.path.join(model_path, "model_v3_stage_c.pt"), map_location=device)["ema_state_dict"] | ||
state_dict = {} | ||
for key in orig_state_dict.keys(): | ||
if key.endswith("in_proj_weight"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_weight", "to_q.weight")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_weight", "to_k.weight")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_weight", "to_v.weight")] = weights[2] | ||
elif key.endswith("in_proj_bias"): | ||
weights = orig_state_dict[key].chunk(3, 0) | ||
state_dict[key.replace("attn.in_proj_bias", "to_q.bias")] = weights[0] | ||
state_dict[key.replace("attn.in_proj_bias", "to_k.bias")] = weights[1] | ||
state_dict[key.replace("attn.in_proj_bias", "to_v.bias")] = weights[2] | ||
elif key.endswith("out_proj.weight"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.weight", "to_out.0.weight")] = weights | ||
elif key.endswith("out_proj.bias"): | ||
weights = orig_state_dict[key] | ||
state_dict[key.replace("attn.out_proj.bias", "to_out.0.bias")] = weights | ||
else: | ||
state_dict[key] = orig_state_dict[key] | ||
prior_model = WuerstchenPrior(c_in=16, c=1536, c_cond=1280, c_r=64, depth=32, nhead=24).to(device) | ||
prior_model.load_state_dict(state_dict) | ||
|
||
# scheduler | ||
scheduler = DDPMWuerstchenScheduler() | ||
|
||
# Prior pipeline | ||
prior_pipeline = WuerstchenPriorPipeline( | ||
prior=prior_model, text_encoder=text_encoder, tokenizer=tokenizer, scheduler=scheduler | ||
) | ||
|
||
prior_pipeline.save_pretrained("warp-diffusion/wuerstchen-prior") | ||
|
||
decoder_pipeline = WuerstchenDecoderPipeline( | ||
text_encoder=gen_text_encoder, tokenizer=gen_tokenizer, vqgan=vqmodel, decoder=deocder, scheduler=scheduler | ||
) | ||
decoder_pipeline.save_pretrained("warp-diffusion/wuerstchen") | ||
|
||
# Wuerstchen pipeline | ||
wuerstchen_pipeline = WuerstchenCombinedPipeline( | ||
# Decoder | ||
text_encoder=gen_text_encoder, | ||
tokenizer=gen_tokenizer, | ||
decoder=deocder, | ||
scheduler=scheduler, | ||
vqgan=vqmodel, | ||
# Prior | ||
prior_tokenizer=tokenizer, | ||
prior_text_encoder=text_encoder, | ||
prior=prior_model, | ||
prior_scheduler=scheduler, | ||
) | ||
wuerstchen_pipeline.save_pretrained("warp-diffusion/WuerstchenCombinedPipeline") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.