Skip to content

Commit

Permalink
minimal stable diffusion GPU memory usage with accelerate hooks (hugg…
Browse files Browse the repository at this point in the history
…ingface#850)

* add method to enable cuda with minimal gpu usage to stable diffusion

* add test to minimal cuda memory usage

* ensure all models but unet are onn torch.float32

* move to cpu_offload along with minor internal changes to make it work

* make it test against accelerate master branch

* coming back, its official: I don't know how to make it test againt the master branch from accelerate

* make it install accelerate from master on tests

* go back to accelerate>=0.11

* undo prettier formatting on yml files

* undo prettier formatting on yml files againn
  • Loading branch information
piEsposito authored Oct 26, 2022
1 parent 2f0fcf4 commit b2e2d14
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 2 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/pr_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment
run: |
Expand Down Expand Up @@ -80,6 +81,7 @@ jobs:
${CONDA_RUN} python -m pip install --upgrade pip
${CONDA_RUN} python -m pip install -e .[quality,test]
${CONDA_RUN} python -m pip install --pre torch==${MPS_TORCH_VERSION} --extra-index-url https://download.pytorch.org/whl/test/cpu
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment
shell: arch -arch arm64 bash {0}
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/push_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ jobs:
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -e .[quality,test]
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment
run: |
Expand All @@ -58,8 +59,6 @@ jobs:
name: torch_test_reports
path: reports



run_examples_single_gpu:
name: Examples tests
runs-on: [ self-hosted, docker-gpu, single-gpu ]
Expand All @@ -83,6 +82,7 @@ jobs:
python -m pip uninstall -y torch torchvision torchtext
python -m pip install torch --extra-index-url https://download.pytorch.org/whl/cu116
python -m pip install -e .[quality,test,training]
python -m pip install git+https://github.com/huggingface/accelerate
- name: Environment
run: |
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,8 @@ def device(self) -> torch.device:
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
if module.device == torch.device("meta"):
return torch.device("cpu")
return module.device
return torch.device("cpu")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from diffusers.utils import is_accelerate_available
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ...configuration_utils import FrozenDict
Expand Down Expand Up @@ -118,6 +119,18 @@ def disable_attention_slicing(self):
# set slice_size = `None` to disable `attention slicing`
self.enable_attention_slicing(None)

def cuda_with_minimal_gpu_usage(self):
if is_accelerate_available():
from accelerate import cpu_offload
else:
raise ImportError("Please install accelerate via `pip install accelerate`")

device = torch.device("cuda")
self.enable_attention_slicing(1)

for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
cpu_offload(cpu_offloaded_model, device)

@torch.no_grad()
def __call__(
self,
Expand Down
20 changes: 20 additions & 0 deletions tests/test_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,3 +535,23 @@ def test_stable_diffusion_accelerate_load_reduces_memory_footprint(self):
tracemalloc.stop()

assert peak_accelerate < peak_normal

@slow
@unittest.skipIf(torch_device == "cpu", "This test is supposed to run on GPU")
def test_stable_diffusion_pipeline_with_unet_on_gpu_only(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()

pipeline_id = "CompVis/stable-diffusion-v1-4"
prompt = "Andromeda galaxy in a bottle"

pipeline = StableDiffusionPipeline.from_pretrained(
pipeline_id, revision="fp16", torch_dtype=torch.float32, use_auth_token=True
)
pipeline.cuda_with_minimal_gpu_usage()

_ = pipeline(prompt)

mem_bytes = torch.cuda.max_memory_allocated()
# make sure that less than 0.8 GB is allocated
assert mem_bytes < 0.8 * 10**9

0 comments on commit b2e2d14

Please sign in to comment.