Skip to content

Commit

Permalink
Speed up the peft lora unload (huggingface#5741)
Browse files Browse the repository at this point in the history
* Update peft_utils.py

* fix bug

* make the util backwards compatible.

Co-Authored-By: Younes Belkada <[email protected]>

* fix import issue

* refactor the backward compatibilty condition

* rename the conditional variable

* address comments

Co-Authored-By: Benjamin Bossan <[email protected]>

* address comment

---------

Co-authored-by: Younes Belkada <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent c6f90da commit 6f14353
Showing 1 changed file with 66 additions and 44 deletions.
110 changes: 66 additions & 44 deletions src/diffusers/utils/peft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,55 +23,77 @@
from .import_utils import is_peft_available, is_torch_available


def recurse_remove_peft_layers(model):
if is_torch_available():
import torch
if is_torch_available():
import torch


def recurse_remove_peft_layers(model):
r"""
Recursively replace all instances of `LoraLayer` with corresponding new layers in `model`.
"""
from peft.tuners.lora import LoraLayer

for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)

module_replaced = False

if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias

module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
).to(module.weight.device)

new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias

module_replaced = True

if module_replaced:
setattr(model, name, new_module)
del module

if torch.cuda.is_available():
torch.cuda.empty_cache()
from peft.tuners.tuners_utils import BaseTunerLayer

has_base_layer_pattern = False
for module in model.modules():
if isinstance(module, BaseTunerLayer):
has_base_layer_pattern = hasattr(module, "base_layer")
break

if has_base_layer_pattern:
from peft.utils import _get_submodules

key_list = [key for key, _ in model.named_modules() if "lora" not in key]
for key in key_list:
try:
parent, target, target_name = _get_submodules(model, key)
except AttributeError:
continue
if hasattr(target, "base_layer"):
setattr(parent, target_name, target.get_base_layer())
else:
# This is for backwards compatibility with PEFT <= 0.6.2.
# TODO can be removed once that PEFT version is no longer supported.
from peft.tuners.lora import LoraLayer

for name, module in model.named_children():
if len(list(module.children())) > 0:
## compound module, go inside it
recurse_remove_peft_layers(module)

module_replaced = False

if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
module.weight.device
)
new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias

module_replaced = True
elif isinstance(module, LoraLayer) and isinstance(module, torch.nn.Conv2d):
new_module = torch.nn.Conv2d(
module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.groups,
).to(module.weight.device)

new_module.weight = module.weight
if module.bias is not None:
new_module.bias = module.bias

module_replaced = True

if module_replaced:
setattr(model, name, new_module)
del module

if torch.cuda.is_available():
torch.cuda.empty_cache()
return model


Expand Down

0 comments on commit 6f14353

Please sign in to comment.