Skip to content

Commit

Permalink
hot fix
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Oct 28, 2022
1 parent c4ef1ef commit cbbb293
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions tests/models/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

from diffusers import AutoencoderKL
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import floats_tensor, require_torch_gpu, slow, torch_all_close, torch_device
from diffusers.utils import floats_tensor, load_numpy, require_torch_gpu, slow, torch_all_close, torch_device
from parameterized import parameterized

from ..test_modeling_common import ModelTesterMixin
Expand Down Expand Up @@ -136,18 +136,18 @@ def test_output_pretrained(self):

@slow
class AutoencoderKLIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"

def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_sd_image(self, seed=0, shape=(4, 3, 512, 512), fp16=False):
batch_size, channels, height, width = shape
generator = torch.Generator(device=torch_device).manual_seed(seed)
dtype = torch.float16 if fp16 else torch.float32
image = torch.randn(batch_size, channels, height, width, device=torch_device, generator=generator, dtype=dtype)

image = torch.from_numpy(load_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image

def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False):
Expand Down

0 comments on commit cbbb293

Please sign in to comment.