Skip to content

Commit

Permalink
chore: add a cleaning utility to be useful during training. (huggingf…
Browse files Browse the repository at this point in the history
  • Loading branch information
sayakpaul authored Sep 3, 2024
1 parent 9d49b45 commit 8ba90aa
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
24 changes: 9 additions & 15 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import argparse
import copy
import gc
import itertools
import logging
import math
Expand Down Expand Up @@ -56,6 +55,7 @@
from diffusers.training_utils import (
_set_state_dict_into_text_encoder,
cast_training_params,
clear_objs_and_retain_memory,
compute_density_for_timestep_sampling,
compute_loss_weighting_for_sd3,
)
Expand Down Expand Up @@ -210,9 +210,7 @@ def log_validation(
}
)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])

return images

Expand Down Expand Up @@ -1107,9 +1105,7 @@ def main(args):
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)

del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(objs=[pipeline])

# Handle the repository creation
if accelerator.is_main_process:
Expand Down Expand Up @@ -1455,12 +1451,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):

# Clear the memory here
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
del tokenizers, text_encoders
# Explicitly delete the objects as well, otherwise only the lists are deleted and the original references remain, preventing garbage collection
del text_encoder_one, text_encoder_two, text_encoder_three
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
clear_objs_and_retain_memory(
objs=[tokenizers, text_encoders, text_encoder_one, text_encoder_two, text_encoder_three]
)

# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
# pack the statically computed variables appropriately here. This is so that we don't
Expand Down Expand Up @@ -1795,11 +1789,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
pipeline_args=pipeline_args,
epoch=epoch,
)
objs = []
if not args.train_text_encoder:
del text_encoder_one, text_encoder_two, text_encoder_three
objs.extend([text_encoder_one, text_encoder_two, text_encoder_three])

torch.cuda.empty_cache()
gc.collect()
clear_objs_and_retain_memory(objs=objs)

# Save the lora layers
accelerator.wait_for_everyone()
Expand Down
17 changes: 17 additions & 0 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import copy
import gc
import math
import random
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -259,6 +260,22 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
return weighting


def clear_objs_and_retain_memory(objs: List[Any]):
"""Deletes `objs` and runs garbage collection. Then clears the cache of the available accelerator."""
if len(objs) >= 1:
for obj in objs:
del obj

gc.collect()

if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.backends.mps.is_available():
torch.mps.empty_cache()
elif is_torch_npu_available():
torch_npu.empty_cache()


# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Expand Down

0 comments on commit 8ba90aa

Please sign in to comment.