diff --git a/tests/lora/test_lora_layers_peft.py b/tests/lora/test_lora_layers_peft.py index c139e0d6ea2c..0015a12012cd 100644 --- a/tests/lora/test_lora_layers_peft.py +++ b/tests/lora/test_lora_layers_peft.py @@ -22,7 +22,6 @@ import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from huggingface_hub import hf_hub_download from huggingface_hub.repocard import RepoCard from packaging import version @@ -41,8 +40,6 @@ StableDiffusionXLPipeline, UNet2DConditionModel, ) -from diffusers.loaders import AttnProcsLayers -from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0 from diffusers.utils.import_utils import is_accelerate_available, is_peft_available from diffusers.utils.testing_utils import ( floats_tensor, @@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2): return models_are_equal -def create_unet_lora_layers(unet: nn.Module): - lora_attn_procs = {} - for name in unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = unet.config.block_out_channels[block_id] - lora_attn_processor_class = ( - LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor - ) - lora_attn_procs[name] = lora_attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim - ) - unet_lora_layers = AttnProcsLayers(lora_attn_procs) - return lora_attn_procs, unet_lora_layers - - @require_peft_backend class PeftLoraLoaderMixinTests: torch_device = "cuda" if torch.cuda.is_available() else "cpu" @@ -140,8 +115,6 @@ def get_dummy_components(self, scheduler_cls=None): r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False ) - unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet) - if self.has_two_text_encoders: pipeline_components = { "unet": unet, @@ -165,11 +138,8 @@ def get_dummy_components(self, scheduler_cls=None): "feature_extractor": None, "image_encoder": None, } - lora_components = { - "unet_lora_layers": unet_lora_layers, - "unet_lora_attn_procs": unet_lora_attn_procs, - } - return pipeline_components, lora_components, text_lora_config, unet_lora_config + + return pipeline_components, text_lora_config, unet_lora_config def get_dummy_inputs(self, with_generator=True): batch_size = 1 @@ -216,7 +186,7 @@ def test_simple_inference(self): Tests a simple inference and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -231,7 +201,7 @@ def test_simple_inference_with_text_lora(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -262,7 +232,7 @@ def test_simple_inference_with_text_lora_and_scale(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -309,7 +279,7 @@ def test_simple_inference_with_text_lora_fused(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -351,7 +321,7 @@ def test_simple_inference_with_text_lora_unloaded(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -394,7 +364,7 @@ def test_simple_inference_with_text_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA. """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -459,7 +429,7 @@ def test_simple_inference_save_pretrained(self): Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls) + components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -510,7 +480,7 @@ def test_simple_inference_with_text_unet_lora_save_load(self): Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -583,7 +553,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -637,7 +607,7 @@ def test_simple_inference_with_text_lora_unet_fused(self): and makes sure it works as expected - with unet """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -683,7 +653,7 @@ def test_simple_inference_with_text_unet_lora_unloaded(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -730,7 +700,7 @@ def test_simple_inference_with_text_unet_lora_unfused(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -780,7 +750,7 @@ def test_simple_inference_with_text_unet_multi_adapter(self): multiple adapters and set them """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -848,7 +818,7 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self): multiple adapters and set/delete them """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -938,7 +908,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self): multiple adapters and set them """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -1010,7 +980,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self): def test_lora_fuse_nan(self): for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -1048,7 +1018,7 @@ def test_get_adapters(self): are the expected results """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -1075,7 +1045,7 @@ def test_get_list_adapters(self): are the expected results """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -1113,7 +1083,7 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self): and makes sure it works as expected - with unet and multi-adapter case """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None) @@ -1175,7 +1145,7 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self): and makes sure it works as expected """ for scheduler_cls in [DDIMScheduler, LCMScheduler]: - components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) + components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls) pipe = self.pipeline_class(**components) pipe = pipe.to(self.torch_device) pipe.set_progress_bar_config(disable=None)