Skip to content

Commit

Permalink
fix OOM for test_vae_tiling (huggingface#7510)
Browse files Browse the repository at this point in the history
use float16 and add torch.no_grad()
  • Loading branch information
yiyixuxu authored Mar 29, 2024
1 parent e49c04d commit 34c90db
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
13 changes: 8 additions & 5 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,8 +1118,10 @@ def test_sd_f16(self):
assert torch_all_close(actual_output, expected_output, atol=5e-3)

def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16)
pipe = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None, torch_dtype=torch.float16
)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

Expand All @@ -1143,6 +1145,7 @@ def test_vae_tiling(self):

# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
with torch.no_grad():
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)
7 changes: 4 additions & 3 deletions tests/pipelines/test_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_vae_tiling(self):

# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)
with torch.no_grad():
for shape in shapes:
zeros = torch.zeros(shape).to(torch_device)
pipe.vae.decode(zeros)

def test_freeu_enabled(self):
components = self.get_dummy_components()
Expand Down

0 comments on commit 34c90db

Please sign in to comment.