Skip to content

Commit

Permalink
[Hunyuan] allow Hunyuan DiT to run under 6GB for GPU VRAM (huggingfac…
Browse files Browse the repository at this point in the history
…e#8399)

* allow hunyuan dit to run under 6GB for GPU VRAM

* add section in the docs/
  • Loading branch information
sayakpaul authored Jun 5, 2024
1 parent a0542c1 commit 2f6f426
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
4 changes: 4 additions & 0 deletions docs/source/en/api/pipelines/hunyuandit.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ HunyuanDiT has the following components:
* It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder


## Memory optimization

By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details.

## HunyuanDiTPipeline

[[autodoc]] HunyuanDiTPipeline
Expand Down
25 changes: 21 additions & 4 deletions src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,16 +228,22 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.register_to_config(requires_safety_checker=requires_safety_checker)
self.default_sample_size = self.transformer.config.sample_size
self.default_sample_size = (
self.transformer.config.sample_size
if hasattr(self, "transformer") and self.transformer is not None
else 128
)

def encode_prompt(
self,
prompt: str,
device: torch.device,
dtype: torch.dtype,
device: torch.device = None,
dtype: torch.dtype = None,
num_images_per_prompt: int = 1,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[str] = None,
Expand Down Expand Up @@ -281,6 +287,17 @@ def encode_prompt(
text_encoder_index (`int`, *optional*):
Index of the text encoder to use. `0` for clip and `1` for T5.
"""
if dtype is None:
if self.text_encoder_2 is not None:
dtype = self.text_encoder_2.dtype
elif self.transformer is not None:
dtype = self.transformer.dtype
else:
dtype = None

if device is None:
device = self._execution_device

tokenizers = [self.tokenizer, self.tokenizer_2]
text_encoders = [self.text_encoder, self.text_encoder_2]

Expand Down

0 comments on commit 2f6f426

Please sign in to comment.