diff --git a/docs/source/en/api/pipelines/hunyuandit.md b/docs/source/en/api/pipelines/hunyuandit.md index f777d42e9bde..29604ac9f814 100644 --- a/docs/source/en/api/pipelines/hunyuandit.md +++ b/docs/source/en/api/pipelines/hunyuandit.md @@ -29,6 +29,10 @@ HunyuanDiT has the following components: * It combines two text encoders, a bilingual CLIP and a multilingual T5 encoder +## Memory optimization + +By loading the T5 text encoder in 8 bits, you can run the pipeline in just under 6 GBs of GPU VRAM. Refer to [this script](https://gist.github.com/sayakpaul/3154605f6af05b98a41081aaba5ca43e) for details. + ## HunyuanDiTPipeline [[autodoc]] HunyuanDiTPipeline diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index 46e19f79cc80..86089abc07b4 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -228,16 +228,22 @@ def __init__( " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.vae_scale_factor = ( + 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 + ) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) self.register_to_config(requires_safety_checker=requires_safety_checker) - self.default_sample_size = self.transformer.config.sample_size + self.default_sample_size = ( + self.transformer.config.sample_size + if hasattr(self, "transformer") and self.transformer is not None + else 128 + ) def encode_prompt( self, prompt: str, - device: torch.device, - dtype: torch.dtype, + device: torch.device = None, + dtype: torch.dtype = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, negative_prompt: Optional[str] = None, @@ -281,6 +287,17 @@ def encode_prompt( text_encoder_index (`int`, *optional*): Index of the text encoder to use. `0` for clip and `1` for T5. """ + if dtype is None: + if self.text_encoder_2 is not None: + dtype = self.text_encoder_2.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + if device is None: + device = self._execution_device + tokenizers = [self.tokenizer, self.tokenizer_2] text_encoders = [self.text_encoder, self.text_encoder_2]