Skip to content

Commit cd991d1

Browse files
a-r-r-o-wsayakpaul
andauthored
Fix TorchAO related bugs; revert device_map changes (huggingface#10371)
* Revert "Add support for sharded models when TorchAO quantization is enabled (huggingface#10256)" This reverts commit 41ba8c0. * update tests * udpate * update * update * update device map tests * apply review suggestions * update * make style * fix * update docs * update tests * update workflow * update * improve tests * allclose tolerance * Update src/diffusers/models/modeling_utils.py Co-authored-by: Sayak Paul <[email protected]> * Update tests/quantization/torchao/test_torchao.py Co-authored-by: Sayak Paul <[email protected]> * improve tests * fix * update correct slices --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 825979d commit cd991d1

File tree

5 files changed

+350
-125
lines changed

5 files changed

+350
-125
lines changed

.github/workflows/nightly_tests.yml

+2
Original file line numberDiff line numberDiff line change
@@ -359,6 +359,8 @@ jobs:
359359
test_location: "bnb"
360360
- backend: "gguf"
361361
test_location: "gguf"
362+
- backend: "torchao"
363+
test_location: "torchao"
362364
runs-on:
363365
group: aws-g6e-xlarge-plus
364366
container:

docs/source/en/quantization/torchao.md

+62
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Quantize a model by passing [`TorchAoConfig`] to [`~ModelMixin.from_pretrained`]
2525
The example below only quantizes the weights to int8.
2626

2727
```python
28+
import torch
2829
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
2930

3031
model_id = "black-forest-labs/FLUX.1-dev"
@@ -44,6 +45,10 @@ pipe = FluxPipeline.from_pretrained(
4445
)
4546
pipe.to("cuda")
4647

48+
# Without quantization: ~31.447 GB
49+
# With quantization: ~20.40 GB
50+
print(f"Pipeline memory usage: {torch.cuda.max_memory_reserved() / 1024**3:.3f} GB")
51+
4752
prompt = "A cat holding a sign that says hello world"
4853
image = pipe(
4954
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
@@ -88,6 +93,63 @@ Some quantization methods are aliases (for example, `int8wo` is the commonly use
8893

8994
Refer to the official torchao documentation for a better understanding of the available quantization methods and the exhaustive list of configuration options available.
9095

96+
## Serializing and Deserializing quantized models
97+
98+
To serialize a quantized model in a given dtype, first load the model with the desired quantization dtype and then save it using the [`~ModelMixin.save_pretrained`] method.
99+
100+
```python
101+
import torch
102+
from diffusers import FluxTransformer2DModel, TorchAoConfig
103+
104+
quantization_config = TorchAoConfig("int8wo")
105+
transformer = FluxTransformer2DModel.from_pretrained(
106+
"black-forest-labs/Flux.1-Dev",
107+
subfolder="transformer",
108+
quantization_config=quantization_config,
109+
torch_dtype=torch.bfloat16,
110+
)
111+
transformer.save_pretrained("/path/to/flux_int8wo", safe_serialization=False)
112+
```
113+
114+
To load a serialized quantized model, use the [`~ModelMixin.from_pretrained`] method.
115+
116+
```python
117+
import torch
118+
from diffusers import FluxPipeline, FluxTransformer2DModel
119+
120+
transformer = FluxTransformer2DModel.from_pretrained("/path/to/flux_int8wo", torch_dtype=torch.bfloat16, use_safetensors=False)
121+
pipe = FluxPipeline.from_pretrained("black-forest-labs/Flux.1-Dev", transformer=transformer, torch_dtype=torch.bfloat16)
122+
pipe.to("cuda")
123+
124+
prompt = "A cat holding a sign that says hello world"
125+
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.0).images[0]
126+
image.save("output.png")
127+
```
128+
129+
Some quantization methods, such as `uint4wo`, cannot be loaded directly and may result in an `UnpicklingError` when trying to load the models, but work as expected when saving them. In order to work around this, one can load the state dict manually into the model. Note, however, that this requires using `weights_only=False` in `torch.load`, so it should be run only if the weights were obtained from a trustable source.
130+
131+
```python
132+
import torch
133+
from accelerate import init_empty_weights
134+
from diffusers import FluxPipeline, FluxTransformer2DModel, TorchAoConfig
135+
136+
# Serialize the model
137+
transformer = FluxTransformer2DModel.from_pretrained(
138+
"black-forest-labs/Flux.1-Dev",
139+
subfolder="transformer",
140+
quantization_config=TorchAoConfig("uint4wo"),
141+
torch_dtype=torch.bfloat16,
142+
)
143+
transformer.save_pretrained("/path/to/flux_uint4wo", safe_serialization=False, max_shard_size="50GB")
144+
# ...
145+
146+
# Load the model
147+
state_dict = torch.load("/path/to/flux_uint4wo/diffusion_pytorch_model.bin", weights_only=False, map_location="cpu")
148+
with init_empty_weights():
149+
transformer = FluxTransformer2DModel.from_config("/path/to/flux_uint4wo/config.json")
150+
transformer.load_state_dict(state_dict, strict=True, assign=True)
151+
```
152+
91153
## Resources
92154

93155
- [TorchAO Quantization API](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md)

src/diffusers/models/modeling_utils.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
718718
hf_quantizer = None
719719

720720
if hf_quantizer is not None:
721-
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
722-
if is_bnb_quantization_method and device_map is not None:
721+
if device_map is not None:
723722
raise NotImplementedError(
724-
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
723+
"Currently, providing `device_map` is not supported for quantized models. Providing `device_map` as an input will be added in the future."
725724
)
726725

727726
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
@@ -820,7 +819,8 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
820819
revision=revision,
821820
subfolder=subfolder or "",
822821
)
823-
if hf_quantizer is not None and is_bnb_quantization_method:
822+
# TODO: https://github.com/huggingface/diffusers/issues/10013
823+
if hf_quantizer is not None:
824824
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
825825
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
826826
is_sharded = False

src/diffusers/quantizers/torchao/torchao_quantizer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def validate_environment(self, *args, **kwargs):
132132
def update_torch_dtype(self, torch_dtype):
133133
quant_type = self.quantization_config.quant_type
134134

135-
if quant_type.startswith("int"):
135+
if quant_type.startswith("int") or quant_type.startswith("uint"):
136136
if torch_dtype is not None and torch_dtype != torch.bfloat16:
137137
logger.warning(
138138
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but "

0 commit comments

Comments
 (0)