Skip to content

Commit

Permalink
Remove insecure torch.load calls (huggingface#7393)
Browse files Browse the repository at this point in the history
update
  • Loading branch information
DN6 authored Mar 19, 2024
1 parent 161c6e1 commit 4da810b
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open

from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
_get_model_file,
is_accelerate_available,
Expand Down Expand Up @@ -182,7 +182,7 @@ def load_ip_adapter(
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
else:
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch import nn

from .. import __version__
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -281,7 +281,7 @@ def lora_state_dict(
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/loaders/textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from ..models.modeling_utils import load_state_dict
from ..utils import _get_model_file, is_accelerate_available, is_transformers_available, logging


Expand Down Expand Up @@ -100,7 +101,7 @@ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path

Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
IPAdapterPlusImageProjection,
MultiIPAdapterImageProjection,
)
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
_get_model_file,
Expand Down Expand Up @@ -214,7 +214,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
state_dict = load_state_dict(model_file)
else:
state_dict = pretrained_model_name_or_path_or_dict

Expand Down
7 changes: 6 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,12 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
if file_extension == SAFETENSORS_FILE_EXTENSION:
return safetensors.torch.load_file(checkpoint_file, device="cpu")
else:
return torch.load(checkpoint_file, map_location="cpu")
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
return torch.load(
checkpoint_file,
map_location="cpu",
**weights_only_kwarg,
)
except Exception as e:
try:
with open(checkpoint_file) as f:
Expand Down

0 comments on commit 4da810b

Please sign in to comment.