Skip to content

Commit

Permalink
[Lora] correct lora saving & loading (huggingface#2655)
Browse files Browse the repository at this point in the history
* [Lora] correct lora saving & loading

* fix final

* Apply suggestions from code review
  • Loading branch information
patrickvonplaten authored Mar 14, 2023
1 parent 7c1b347 commit d185c0d
Showing 1 changed file with 23 additions and 25 deletions.
48 changes: 23 additions & 25 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, is_safetensors_available, logging
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, deprecate, is_safetensors_available, logging


if is_safetensors_available():
Expand Down Expand Up @@ -150,13 +150,14 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict

model_file = None
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
if (is_safetensors_available() and weight_name is None) or weight_name.endswith(".safetensors"):
if weight_name is None:
weight_name = LORA_WEIGHT_NAME_SAFE
# Let's first try to load .safetensors weights
if (is_safetensors_available() and weight_name is None) or (
weight_name is not None and weight_name.endswith(".safetensors")
):
try:
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand All @@ -169,14 +170,13 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
)
state_dict = safetensors.torch.load_file(model_file, device="cpu")
except EnvironmentError:
if weight_name == LORA_WEIGHT_NAME_SAFE:
weight_name = None
# try loading non-safetensors weights
pass

if model_file is None:
if weight_name is None:
weight_name = LORA_WEIGHT_NAME
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
weights_name=weight_name or LORA_WEIGHT_NAME,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
Expand Down Expand Up @@ -225,9 +225,10 @@ def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = None,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
**kwargs,
):
r"""
Save an attention processor to a directory, so that it can be re-loaded using the
Expand All @@ -245,6 +246,12 @@ def save_attn_procs(
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
weight_name = weight_name or deprecate(
"weights_name",
"0.18.0",
"`weights_name` is deprecated, please use `weight_name` instead.",
take_from=kwargs,
)
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
Expand All @@ -265,22 +272,13 @@ def save_function(weights, filename):
# Save the model
state_dict = model_to_save.state_dict()

# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)

if weights_name is None:
if weight_name is None:
if safe_serialization:
weights_name = LORA_WEIGHT_NAME_SAFE
weight_name = LORA_WEIGHT_NAME_SAFE
else:
weights_name = LORA_WEIGHT_NAME
weight_name = LORA_WEIGHT_NAME

# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
save_function(state_dict, os.path.join(save_directory, weight_name))

logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

0 comments on commit d185c0d

Please sign in to comment.