Skip to content

Commit

Permalink
[LoRA] remove unnecessary components from lora peft test suite (huggi…
Browse files Browse the repository at this point in the history
…ngface#6401)

remove unnecessary components from lora peft suite/
  • Loading branch information
sayakpaul authored Dec 30, 2023
1 parent 9f283b0 commit 6a376ce
Showing 1 changed file with 22 additions and 52 deletions.
74 changes: 22 additions & 52 deletions tests/lora/test_lora_layers_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub.repocard import RepoCard
from packaging import version
Expand All @@ -41,8 +40,6 @@
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.utils.import_utils import is_accelerate_available, is_peft_available
from diffusers.utils.testing_utils import (
floats_tensor,
Expand Down Expand Up @@ -78,28 +75,6 @@ def state_dicts_almost_equal(sd1, sd2):
return models_are_equal


def create_unet_lora_layers(unet: nn.Module):
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_processor_class = (
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
)
lora_attn_procs[name] = lora_attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet_lora_layers = AttnProcsLayers(lora_attn_procs)
return lora_attn_procs, unet_lora_layers


@require_peft_backend
class PeftLoraLoaderMixinTests:
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
Expand Down Expand Up @@ -140,8 +115,6 @@ def get_dummy_components(self, scheduler_cls=None):
r=rank, lora_alpha=rank, target_modules=["to_q", "to_k", "to_v", "to_out.0"], init_lora_weights=False
)

unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)

if self.has_two_text_encoders:
pipeline_components = {
"unet": unet,
Expand All @@ -165,11 +138,8 @@ def get_dummy_components(self, scheduler_cls=None):
"feature_extractor": None,
"image_encoder": None,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components, text_lora_config, unet_lora_config

return pipeline_components, text_lora_config, unet_lora_config

def get_dummy_inputs(self, with_generator=True):
batch_size = 1
Expand Down Expand Up @@ -216,7 +186,7 @@ def test_simple_inference(self):
Tests a simple inference and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand All @@ -231,7 +201,7 @@ def test_simple_inference_with_text_lora(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -262,7 +232,7 @@ def test_simple_inference_with_text_lora_and_scale(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -309,7 +279,7 @@ def test_simple_inference_with_text_lora_fused(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -351,7 +321,7 @@ def test_simple_inference_with_text_lora_unloaded(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -394,7 +364,7 @@ def test_simple_inference_with_text_lora_save_load(self):
Tests a simple usecase where users could use saving utilities for LoRA.
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -459,7 +429,7 @@ def test_simple_inference_save_pretrained(self):
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
components, text_lora_config, _ = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -510,7 +480,7 @@ def test_simple_inference_with_text_unet_lora_save_load(self):
Tests a simple usecase where users could use saving utilities for LoRA for Unet + text encoder
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -583,7 +553,7 @@ def test_simple_inference_with_text_unet_lora_and_scale(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -637,7 +607,7 @@ def test_simple_inference_with_text_lora_unet_fused(self):
and makes sure it works as expected - with unet
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -683,7 +653,7 @@ def test_simple_inference_with_text_unet_lora_unloaded(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -730,7 +700,7 @@ def test_simple_inference_with_text_unet_lora_unfused(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -780,7 +750,7 @@ def test_simple_inference_with_text_unet_multi_adapter(self):
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -848,7 +818,7 @@ def test_simple_inference_with_text_unet_multi_adapter_delete_adapter(self):
multiple adapters and set/delete them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -938,7 +908,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):
multiple adapters and set them
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -1010,7 +980,7 @@ def test_simple_inference_with_text_unet_multi_adapter_weighted(self):

def test_lora_fuse_nan(self):
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -1048,7 +1018,7 @@ def test_get_adapters(self):
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand All @@ -1075,7 +1045,7 @@ def test_get_list_adapters(self):
are the expected results
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -1113,7 +1083,7 @@ def test_simple_inference_with_text_lora_unet_fused_multi(self):
and makes sure it works as expected - with unet and multi-adapter case
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -1175,7 +1145,7 @@ def test_simple_inference_with_text_unet_lora_unfused_torch_compile(self):
and makes sure it works as expected
"""
for scheduler_cls in [DDIMScheduler, LCMScheduler]:
components, _, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
components, text_lora_config, unet_lora_config = self.get_dummy_components(scheduler_cls)
pipe = self.pipeline_class(**components)
pipe = pipe.to(self.torch_device)
pipe.set_progress_bar_config(disable=None)
Expand Down

0 comments on commit 6a376ce

Please sign in to comment.