Skip to content

Commit 6ba2231

Browse files
Reproducibility 3/3 (huggingface#1924)
* make tests deterministic * run slow tests * prepare for testing * finish * refactor * add print statements * finish more * correct some test failures * more fixes * set up to correct tests * more corrections * up * fix more * more prints * add * up * up * up * uP * uP * more fixes * uP * up * up * up * up * fix more * up * up * clean tests * up * up * up * more fixes * Apply suggestions from code review Co-authored-by: Suraj Patil <[email protected]> * make * correct * finish * finish Co-authored-by: Suraj Patil <[email protected]>
1 parent 008c22d commit 6ba2231

File tree

47 files changed

+673
-448
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+673
-448
lines changed

docs/source/en/_toctree.yml

+2
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
title: Text-Guided Depth-to-Image
3333
- local: using-diffusers/reusing_seeds
3434
title: Reusing seeds for deterministic generation
35+
- local: using-diffusers/reproducibility
36+
title: Reproducibility
3537
- local: using-diffusers/custom_pipeline_examples
3638
title: Community Pipelines
3739
- local: using-diffusers/contribute_pipeline
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Reproducibility
14+
15+
Before reading about reproducibility for Diffusers, it is strongly recommended to take a look at
16+
[PyTorch's statement about reproducibility](https://pytorch.org/docs/stable/notes/randomness.html).
17+
18+
PyTorch states that
19+
> *completely reproducible results are not guaranteed across PyTorch releases, individual commits, or different platforms.*
20+
While one can never expect the same results across platforms, one can expect results to be reproducible
21+
across releases, platforms, etc... within a certain tolerance. However, this tolerance strongly varies
22+
depending on the diffusion pipeline and checkpoint.
23+
24+
In the following, we show how to best control sources of randomness for diffusion models.
25+
26+
## Inference
27+
28+
During inference, diffusion pipelines heavily rely on random sampling operations, such as the creating the
29+
gaussian noise tensors to be denoised and adding noise to the scheduling step.
30+
31+
Let's have a look at an example. We run the [DDIM pipeline](./api/pipelines/ddim.mdx)
32+
for just two inference steps and return a numpy tensor to look into the numerical values of the output.
33+
34+
```python
35+
from diffusers import DDIMPipeline
36+
import numpy as np
37+
38+
model_id = "google/ddpm-cifar10-32"
39+
40+
# load model and scheduler
41+
ddim = DDIMPipeline.from_pretrained(model_id)
42+
43+
# run pipeline for just two steps and return numpy tensor
44+
image = ddim(num_inference_steps=2, output_type="np").images
45+
print(np.abs(image).sum())
46+
```
47+
48+
Running the above prints a value of 1464.2076, but running it again prints a different
49+
value of 1495.1768. What is going on here? Every time the pipeline is run, gaussian noise
50+
is created and step-wise denoised. To create the gaussian noise with [`torch.randn`](https://pytorch.org/docs/stable/generated/torch.randn.html), a different random seed is taken every time, thus leading to a different result.
51+
This is a desired property of diffusion pipelines, as it means that the pipeline can create a different random image every time it is run. In many cases, one would like to generate the exact same image of a certain
52+
run, for which case an instance of a [PyTorch generator](https://pytorch.org/docs/stable/generated/torch.randn.html) has to be passed:
53+
54+
```python
55+
import torch
56+
from diffusers import DDIMPipeline
57+
import numpy as np
58+
59+
model_id = "google/ddpm-cifar10-32"
60+
61+
# load model and scheduler
62+
ddim = DDIMPipeline.from_pretrained(model_id)
63+
64+
# create a generator for reproducibility
65+
generator = torch.Generator(device="cpu").manual_seed(0)
66+
67+
# run pipeline for just two steps and return numpy tensor
68+
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
69+
print(np.abs(image).sum())
70+
```
71+
72+
Running the above always prints a value of 1491.1711 - also upon running it again because we
73+
define the generator object to be passed to all random functions of the pipeline.
74+
75+
If you run this code snippet on your specific hardware and version, you should get a similar, if not the same, result.
76+
77+
<Tip>
78+
79+
It might be a bit unintuitive at first to pass `generator` objects to the pipelines instead of
80+
just integer values representing the seed, but this is the recommended design when dealing with
81+
probabilistic models in PyTorch as generators are *random states* that are advanced and can thus be
82+
passed to multiple pipelines in a sequence.
83+
84+
</Tip>
85+
86+
Great! Now, we know how to write reproducible pipelines, but it gets a bit trickier since the above example only runs on the CPU. How do we also achieve reproducibility on GPU?
87+
In short, one should not expect full reproducibility across different hardware when running pipelines on GPU
88+
as matrix multiplications are less deterministic on GPU than on CPU and diffusion pipelines tend to require
89+
a lot of matrix multiplications. Let's see what we can do to keep the randomness within limits across
90+
different GPU hardware.
91+
92+
To achieve maximum speed performance, it is recommended to create the generator directly on GPU when running
93+
the pipeline on GPU:
94+
95+
```python
96+
import torch
97+
from diffusers import DDIMPipeline
98+
import numpy as np
99+
100+
model_id = "google/ddpm-cifar10-32"
101+
102+
# load model and scheduler
103+
ddim = DDIMPipeline.from_pretrained(model_id)
104+
ddim.to("cuda")
105+
106+
# create a generator for reproducibility
107+
generator = torch.Generator(device="cuda").manual_seed(0)
108+
109+
# run pipeline for just two steps and return numpy tensor
110+
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
111+
print(np.abs(image).sum())
112+
```
113+
114+
Running the above now prints a value of 1389.8634 - even though we're using the exact same seed!
115+
This is unfortunate as it means we cannot reproduce the results we achieved on GPU, also on CPU.
116+
Nevertheless, it should be expected since the GPU uses a different random number generator than the CPU.
117+
118+
To circumvent this problem, we created a [`randn_tensor`](#diffusers.utils.randn_tensor) function, which can create random noise
119+
on the CPU and then move the tensor to GPU if necessary. The function is used everywhere inside the pipelines allowing the user to **always** pass a CPU generator even if the pipeline is run on GPU:
120+
121+
```python
122+
import torch
123+
from diffusers import DDIMPipeline
124+
import numpy as np
125+
126+
model_id = "google/ddpm-cifar10-32"
127+
128+
# load model and scheduler
129+
ddim = DDIMPipeline.from_pretrained(model_id)
130+
ddim.to("cuda")
131+
132+
# create a generator for reproducibility
133+
generator = torch.manual_seed(0)
134+
135+
# run pipeline for just two steps and return numpy tensor
136+
image = ddim(num_inference_steps=2, output_type="np", generator=generator).images
137+
print(np.abs(image).sum())
138+
```
139+
140+
Running the above now prints a value of 1491.1713, much closer to the value of 1491.1711 when
141+
the pipeline is fully run on the CPU.
142+
143+
<Tip>
144+
145+
As a consequence, we recommend always passing a CPU generator if Reproducibility is important.
146+
The loss of performance is often neglectable, but one can be sure to generate much more similar
147+
values than if the pipeline would have been run on CPU.
148+
149+
</Tip>
150+
151+
Finally, we noticed that more complex pipelines, such as [`UnCLIPPipeline`] are often extremely
152+
susceptible to precision error propagation and thus one cannot expect even similar results across
153+
different GPU hardware or PyTorch versions. In such cases, one has to make sure to run
154+
exactly the same hardware and PyTorch version for full Reproducibility.
155+
156+
## Randomness utilities
157+
158+
### randn_tensor
159+
[[autodoc]] diffusers.utils.randn_tensor

src/diffusers/pipelines/ddim/pipeline_ddim.py

+1-19
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import torch
1818

1919
from ...schedulers import DDIMScheduler
20-
from ...utils import deprecate, randn_tensor
20+
from ...utils import randn_tensor
2121
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2222

2323

@@ -78,24 +78,6 @@ def __call__(
7878
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
7979
"""
8080

81-
if (
82-
generator is not None
83-
and isinstance(generator, torch.Generator)
84-
and generator.device.type != self.device.type
85-
and self.device.type != "mps"
86-
):
87-
message = (
88-
f"The `generator` device is `{generator.device}` and does not match the pipeline "
89-
f"device `{self.device}`, so the `generator` will be ignored. "
90-
f'Please use `generator=torch.Generator(device="{self.device}")` instead.'
91-
)
92-
deprecate(
93-
"generator.device == 'cpu'",
94-
"0.13.0",
95-
message,
96-
)
97-
generator = None
98-
9981
# Sample gaussian noise to begin loop
10082
if isinstance(self.unet.sample_size, int):
10183
image_shape = (batch_size, self.unet.in_channels, self.unet.sample_size, self.unet.sample_size)

src/diffusers/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
load_numpy,
7777
nightly,
7878
parse_flag_from_env,
79+
print_tensor_test,
7980
require_torch_gpu,
8081
slow,
8182
torch_all_close,

src/diffusers/utils/testing_utils.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from distutils.util import strtobool
99
from io import BytesIO, StringIO
1010
from pathlib import Path
11-
from typing import Union
11+
from typing import Optional, Union
1212

1313
import numpy as np
1414

@@ -45,6 +45,21 @@ def torch_all_close(a, b, *args, **kwargs):
4545
return True
4646

4747

48+
def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
49+
test_name = os.environ.get("PYTEST_CURRENT_TEST")
50+
if not torch.is_tensor(tensor):
51+
tensor = torch.from_numpy(tensor)
52+
53+
tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
54+
# format is usually:
55+
# expected_slice = np.array([-0.5713, -0.3018, -0.9814, 0.04663, -0.879, 0.76, -1.734, 0.1044, 1.161])
56+
output_str = tensor_str.replace("tensor", f"{expected_tensor_name} = np.array")
57+
test_file, test_class, test_fn = test_name.split("::")
58+
test_fn = test_fn.split()[0]
59+
with open(filename, "a") as f:
60+
print(";".join([test_file, test_class, test_fn, output_str]), file=f)
61+
62+
4863
def get_tests_dir(append_path=None):
4964
"""
5065
Args:
@@ -150,9 +165,13 @@ def require_onnxruntime(test_case):
150165
return unittest.skipUnless(is_onnx_available(), "test requires onnxruntime")(test_case)
151166

152167

153-
def load_numpy(arry: Union[str, np.ndarray]) -> np.ndarray:
168+
def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -> np.ndarray:
154169
if isinstance(arry, str):
155-
if arry.startswith("http://") or arry.startswith("https://"):
170+
# local_path = "/home/patrick_huggingface_co/"
171+
if local_path is not None:
172+
# local_path can be passed to correct images of tests
173+
return os.path.join(local_path, "/".join([arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]]))
174+
elif arry.startswith("http://") or arry.startswith("https://"):
156175
response = requests.get(arry)
157176
response.raise_for_status()
158177
arry = np.load(BytesIO(response.content))

tests/models/test_models_vae.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def get_sd_vae_model(self, model_id="CompVis/stable-diffusion-v1-4", fp16=False)
166166

167167
def get_generator(self, seed=0):
168168
if torch_device == "mps":
169-
return torch.Generator().manual_seed(seed)
169+
return torch.manual_seed(seed)
170170
return torch.Generator(device=torch_device).manual_seed(seed)
171171

172172
@parameterized.expand(

tests/pipelines/altdiffusion/test_alt_diffusion.py

+9-42
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def test_alt_diffusion_pndm(self):
188188
expected_slice = np.array(
189189
[0.51605093, 0.5707241, 0.47365507, 0.50578886, 0.5633877, 0.4642503, 0.5182081, 0.48763484, 0.49084237]
190190
)
191+
191192
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
192193

193194

@@ -207,20 +208,16 @@ def test_alt_diffusion(self):
207208
alt_pipe.set_progress_bar_config(disable=None)
208209

209210
prompt = "A painting of a squirrel eating a burger"
210-
generator = torch.Generator(device=torch_device).manual_seed(0)
211-
with torch.autocast("cuda"):
212-
output = alt_pipe(
213-
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np"
214-
)
211+
generator = torch.manual_seed(0)
212+
output = alt_pipe([prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="np")
215213

216214
image = output.images
217215

218216
image_slice = image[0, -3:, -3:, -1]
219217

220218
assert image.shape == (1, 512, 512, 3)
221-
expected_slice = np.array(
222-
[0.8720703, 0.87109375, 0.87402344, 0.87109375, 0.8779297, 0.8925781, 0.8823242, 0.8808594, 0.8613281]
223-
)
219+
expected_slice = np.array([0.1010, 0.0800, 0.0794, 0.0885, 0.0843, 0.0762, 0.0769, 0.0729, 0.0586])
220+
224221
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
225222

226223
def test_alt_diffusion_fast_ddim(self):
@@ -231,44 +228,14 @@ def test_alt_diffusion_fast_ddim(self):
231228
alt_pipe.set_progress_bar_config(disable=None)
232229

233230
prompt = "A painting of a squirrel eating a burger"
234-
generator = torch.Generator(device=torch_device).manual_seed(0)
231+
generator = torch.manual_seed(0)
235232

236-
with torch.autocast("cuda"):
237-
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
233+
output = alt_pipe([prompt], generator=generator, num_inference_steps=2, output_type="numpy")
238234
image = output.images
239235

240236
image_slice = image[0, -3:, -3:, -1]
241237

242238
assert image.shape == (1, 512, 512, 3)
243-
expected_slice = np.array(
244-
[0.9267578, 0.9301758, 0.9013672, 0.9345703, 0.92578125, 0.94433594, 0.9423828, 0.9423828, 0.9160156]
245-
)
246-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
247-
248-
def test_alt_diffusion_text2img_pipeline_fp16(self):
249-
torch.cuda.reset_peak_memory_stats()
250-
model_id = "BAAI/AltDiffusion"
251-
pipe = AltDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, safety_checker=None)
252-
pipe = pipe.to(torch_device)
253-
pipe.set_progress_bar_config(disable=None)
254-
255-
prompt = "a photograph of an astronaut riding a horse"
239+
expected_slice = np.array([0.4019, 0.4052, 0.3810, 0.4119, 0.3916, 0.3982, 0.4651, 0.4195, 0.5323])
256240

257-
generator = torch.Generator(device=torch_device).manual_seed(0)
258-
output_chunked = pipe(
259-
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
260-
)
261-
image_chunked = output_chunked.images
262-
263-
generator = torch.Generator(device=torch_device).manual_seed(0)
264-
with torch.autocast(torch_device):
265-
output = pipe(
266-
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=10, output_type="numpy"
267-
)
268-
image = output.images
269-
270-
# Make sure results are close enough
271-
diff = np.abs(image_chunked.flatten() - image.flatten())
272-
# They ARE different since ops are not run always at the same precision
273-
# however, they should be extremely close.
274-
assert diff.mean() < 2e-2
241+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

tests/pipelines/altdiffusion/test_alt_diffusion_img2img.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def test_stable_diffusion_img2img_default_case(self):
162162
expected_slice = np.array(
163163
[0.41293705, 0.38656747, 0.40876025, 0.4782187, 0.4656803, 0.41394007, 0.4142093, 0.47150758, 0.4570448]
164164
)
165+
165166
assert np.abs(image_slice.flatten() - expected_slice).max() < 1.5e-3
166167
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1.5e-3
167168

@@ -196,7 +197,7 @@ def test_stable_diffusion_img2img_fp16(self):
196197
alt_pipe.set_progress_bar_config(disable=None)
197198

198199
prompt = "A painting of a squirrel eating a burger"
199-
generator = torch.Generator(device=torch_device).manual_seed(0)
200+
generator = torch.manual_seed(0)
200201
image = alt_pipe(
201202
[prompt],
202203
generator=generator,
@@ -227,7 +228,7 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
227228

228229
prompt = "A fantasy landscape, trending on artstation"
229230

230-
generator = torch.Generator(device=torch_device).manual_seed(0)
231+
generator = torch.manual_seed(0)
231232
output = pipe(
232233
prompt=prompt,
233234
image=init_image,
@@ -241,7 +242,8 @@ def test_stable_diffusion_img2img_pipeline_multiple_of_8(self):
241242
image_slice = image[255:258, 383:386, -1]
242243

243244
assert image.shape == (504, 760, 3)
244-
expected_slice = np.array([0.3252, 0.3340, 0.3418, 0.3263, 0.3346, 0.3300, 0.3163, 0.3470, 0.3427])
245+
expected_slice = np.array([0.9358, 0.9397, 0.9599, 0.9901, 1.0000, 1.0000, 0.9882, 1.0000, 1.0000])
246+
245247
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
246248

247249

@@ -275,7 +277,7 @@ def test_stable_diffusion_img2img_pipeline_default(self):
275277

276278
prompt = "A fantasy landscape, trending on artstation"
277279

278-
generator = torch.Generator(device=torch_device).manual_seed(0)
280+
generator = torch.manual_seed(0)
279281
output = pipe(
280282
prompt=prompt,
281283
image=init_image,

0 commit comments

Comments
 (0)