Skip to content

Commit 0e9ad27

Browse files
author
Yuwei Guo
committed
add dummy key
1 parent 57e7d14 commit 0e9ad27

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

animatediff/utils/util.py

+3
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def load_weights(
111111
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
112112
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
113113
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
114+
unet_state_dict.pop("animatediff_config", "")
114115

115116
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
116117
assert len(unexpected) == 0
@@ -154,6 +155,7 @@ def load_weights(
154155
print(f"load domain lora from {adapter_lora_path}")
155156
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
156157
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
158+
domain_lora_state_dict.pop("animatediff_config", "")
157159

158160
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
159161

@@ -163,6 +165,7 @@ def load_weights(
163165
print(f"load motion LoRA from {path}")
164166
motion_lora_state_dict = torch.load(path, map_location="cpu")
165167
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
168+
motion_lora_state_dict.pop("animatediff_config", "")
166169

167170
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
168171

scripts/animate.py

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def main(args):
6969
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
7070
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
7171
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
72+
controlnet_state_dict.pop("animatediff_config", "")
7273
controlnet.load_state_dict(controlnet_state_dict)
7374
controlnet.cuda()
7475

0 commit comments

Comments
 (0)