Skip to content

Commit

Permalink
Fix some performance issues with weight loading and unloading.
Browse files Browse the repository at this point in the history
Lower peak memory usage when changing model.

Fix case where model weights would be unloaded and reloaded.
  • Loading branch information
comfyanonymous committed Mar 28, 2024
1 parent 327ca13 commit 5d8898c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
16 changes: 11 additions & 5 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(self, model):
self.model = model
self.device = model.load_device
self.weights_loaded = False
self.real_model = None

def model_memory(self):
return self.model.model_size()
Expand Down Expand Up @@ -312,6 +313,7 @@ def model_unload(self, unpatch_weights=True):
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
self.model.model_patches_to(self.model.offload_device)
self.weights_loaded = self.weights_loaded and not unpatch_weights
self.real_model = None

def __eq__(self, other):
return self.model is other.model
Expand All @@ -326,7 +328,7 @@ def unload_model_clones(model, unload_weights_only=True, force_unload=True):
to_unload = [i] + to_unload

if len(to_unload) == 0:
return None
return True

same_weights = 0
for i in to_unload:
Expand Down Expand Up @@ -408,8 +410,8 @@ def load_models_gpu(models, memory_required=0):

total_memory_required = {}
for loaded_model in models_to_load:
unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) #unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
if unload_model_clones(loaded_model.model, unload_weights_only=True, force_unload=False) == True:#unload clones where the weights are different
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)

for device in total_memory_required:
if device != torch.device("cpu"):
Expand Down Expand Up @@ -448,11 +450,15 @@ def load_models_gpu(models, memory_required=0):
def load_model_gpu(model):
return load_models_gpu([model])

def cleanup_models():
def cleanup_models(keep_clone_weights_loaded=False):
to_delete = []
for i in range(len(current_loaded_models)):
if sys.getrefcount(current_loaded_models[i].model) <= 2:
to_delete = [i] + to_delete
if not keep_clone_weights_loaded:
to_delete = [i] + to_delete
#TODO: find a less fragile way to do this.
elif sys.getrefcount(current_loaded_models[i].real_model) <= 3: #references from .real_model + the .model
to_delete = [i] + to_delete

for i in to_delete:
x = current_loaded_models.pop(i)
Expand Down
1 change: 1 addition & 0 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,7 @@ def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
d = self.outputs_ui.pop(x)
del d

comfy.model_management.cleanup_models(keep_clone_weights_loaded=True)
self.add_message("execution_cached",
{ "nodes": list(current_outputs) , "prompt_id": prompt_id},
broadcast=False)
Expand Down

0 comments on commit 5d8898c

Please sign in to comment.