@@ -111,6 +111,7 @@ def load_weights(
111
111
motion_module_state_dict = torch .load (motion_module_path , map_location = "cpu" )
112
112
motion_module_state_dict = motion_module_state_dict ["state_dict" ] if "state_dict" in motion_module_state_dict else motion_module_state_dict
113
113
unet_state_dict .update ({name : param for name , param in motion_module_state_dict .items () if "motion_modules." in name })
114
+ unet_state_dict .pop ("animatediff_config" , "" )
114
115
115
116
missing , unexpected = animation_pipeline .unet .load_state_dict (unet_state_dict , strict = False )
116
117
assert len (unexpected ) == 0
@@ -154,6 +155,7 @@ def load_weights(
154
155
print (f"load domain lora from { adapter_lora_path } " )
155
156
domain_lora_state_dict = torch .load (adapter_lora_path , map_location = "cpu" )
156
157
domain_lora_state_dict = domain_lora_state_dict ["state_dict" ] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
158
+ domain_lora_state_dict .pop ("animatediff_config" , "" )
157
159
158
160
animation_pipeline = load_diffusers_lora (animation_pipeline , domain_lora_state_dict , alpha = adapter_lora_scale )
159
161
@@ -163,6 +165,7 @@ def load_weights(
163
165
print (f"load motion LoRA from { path } " )
164
166
motion_lora_state_dict = torch .load (path , map_location = "cpu" )
165
167
motion_lora_state_dict = motion_lora_state_dict ["state_dict" ] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
168
+ motion_lora_state_dict .pop ("animatediff_config" , "" )
166
169
167
170
animation_pipeline = load_diffusers_lora (animation_pipeline , motion_lora_state_dict , alpha )
168
171
0 commit comments