From 6f4467188167a01d00e2eecaa4e624915116d678 Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 11 Mar 2025 22:16:49 -0700 Subject: [PATCH] hand-rolled --- .../quantization/test_config_serialization.py | 172 +++++++++++++ torchao/core/config.py | 229 +++++++++++++++++- torchao/quantization/__init__.py | 9 + 3 files changed, 409 insertions(+), 1 deletion(-) create mode 100644 test/quantization/test_config_serialization.py diff --git a/test/quantization/test_config_serialization.py b/test/quantization/test_config_serialization.py new file mode 100644 index 0000000000..694c0cdafc --- /dev/null +++ b/test/quantization/test_config_serialization.py @@ -0,0 +1,172 @@ +import json +import os +import tempfile +from dataclasses import dataclass +from unittest import mock + +import pytest +import torch + +from torchao.core.config import ( + AOBaseConfig, + VersionMismatchError, + config_from_dict, + config_to_dict, +) +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, + Float8WeightOnlyConfig, + FPXWeightOnlyConfig, + GemliteUIntXWeightOnlyConfig, + Int4DynamicActivationInt4WeightConfig, + Int4WeightOnlyConfig, + Int8DynamicActivationInt4WeightConfig, + Int8DynamicActivationInt8WeightConfig, + Int8WeightOnlyConfig, + PerRow, + UIntXWeightOnlyConfig, +) +from torchao.sparsity.sparse_api import BlockSparseWeightConfig, SemiSparseWeightConfig + +# Define test configurations as fixtures +configs = [ + Float8DynamicActivationFloat8WeightConfig(), + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + Float8WeightOnlyConfig( + weight_dtype=torch.float8_e4m3fn, + ), + UIntXWeightOnlyConfig(dtype=torch.uint1), + Int4DynamicActivationInt4WeightConfig(), + Int4WeightOnlyConfig( + group_size=32, + ), + Int8DynamicActivationInt4WeightConfig( + group_size=64, + ), + Int8DynamicActivationInt8WeightConfig(), + # Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()), + Int8WeightOnlyConfig( + group_size=128, + ), + UIntXWeightOnlyConfig( + dtype=torch.uint3, + group_size=32, + use_hqq=True, + ), + GemliteUIntXWeightOnlyConfig( + group_size=128, # Optional, has default of 64 + bit_width=8, # Optional, has default of 4 + packing_bitwidth=8, # Optional, has default of 32 + contiguous=True, # Optional, has default of None + ), + FPXWeightOnlyConfig(ebits=4, mbits=8), + # Sparsity configs + SemiSparseWeightConfig(), + BlockSparseWeightConfig(blocksize=128), +] + + +# Create ids for better test naming +def get_config_ids(configs): + if not isinstance(configs, list): + configs = [configs] + return [config.__class__.__name__ for config in configs] + + +@pytest.mark.parametrize("config", configs, ids=get_config_ids) +def test_reconstructable_dict_file_round_trip(config): + """Test saving and loading reconstructable dicts to/from JSON files.""" + # Get a reconstructable dict + reconstructable = config_to_dict(config) + + # Create a temporary file to save the JSON + with tempfile.NamedTemporaryFile( + mode="w+", suffix=".json", delete=False + ) as temp_file: + # Write the reconstructable dict as JSON + json.dump(reconstructable, temp_file) + temp_file_path = temp_file.name + + try: + # Read back the JSON file + with open(temp_file_path, "r") as file: + loaded_dict = json.load(file) + + # Reconstruct from the loaded dict + reconstructed = config_from_dict(loaded_dict) + + # Check it's the right class + assert isinstance(reconstructed, config.__class__) + + # Verify attributes match + for attr_name in config.__dict__: + if not attr_name.startswith("_"): # Skip private attributes + original_value = getattr(config, attr_name) + reconstructed_value = getattr(reconstructed, attr_name) + + # Special handling for torch dtypes + if ( + hasattr(original_value, "__module__") + and original_value.__module__ == "torch" + ): + assert ( + str(original_value) == str(reconstructed_value) + ), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}" + else: + assert ( + original_value == reconstructed_value + ), f"Attribute {attr_name} mismatch after file round trip for {config.__class__.__name__}" + + finally: + # Clean up the temporary file + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + + +# Define a dummy config in a non-allowed module +@dataclass +class DummyNonAllowedConfig(AOBaseConfig): + VERSION = 2 + value: int = 42 + + +def test_disallowed_modules(): + """Test that configs from non-allowed modules are rejected during reconstruction.""" + # Create a config from a non-allowed module + dummy_config = DummyNonAllowedConfig() + reconstructable = config_to_dict(dummy_config) + + with pytest.raises( + ValueError, + match="Failed to find class DummyNonAllowedConfig in any of the allowed modules", + ): + config_from_dict(reconstructable) + + # Use mock.patch as a context manager + with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}): + reconstructed = config_from_dict(reconstructable) + assert isinstance(reconstructed, DummyNonAllowedConfig) + assert reconstructed.value == 42 + assert reconstructed.VERSION == 2 + + +def test_version_mismatch(): + """Test that version mismatch raises an error during reconstruction.""" + # Create a config + dummy_config = DummyNonAllowedConfig() + reconstructable = config_to_dict(dummy_config) + + # Modify the version in the dict to create a mismatch + reconstructable["_version"] = 1 + + # Patch to allow the module but should still fail due to version mismatch + with mock.patch("torchao.core.config.ALLOWED_AO_MODULES", {__name__}): + with pytest.raises( + VersionMismatchError, + match="Version mismatch for DummyNonAllowedConfig: stored version 1 != current version 2", + ): + config_from_dict(reconstructable) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/torchao/core/config.py b/torchao/core/config.py index 14a7b8dc66..d97ae5987c 100644 --- a/torchao/core/config.py +++ b/torchao/core/config.py @@ -1,4 +1,11 @@ import abc +import dataclasses +import enum +import importlib +import json +from typing import Any, ClassVar, Dict + +import torch class AOBaseConfig(abc.ABC): @@ -26,4 +33,224 @@ def _transform( """ - pass + # Base Version of a config + VERSION: ClassVar[int] = 1 + + +class VersionMismatchError(Exception): + """Raised when trying to deserialize a config with a different version""" + + def __init__(self, type_path, stored_version, current_version): + self.type_path = type_path + self.stored_version = stored_version + self.current_version = current_version + message = ( + f"Version mismatch for {type_path}: " + f"stored version {stored_version} != current version {current_version}" + ) + super().__init__(message) + + +class ConfigJSONEncoder(json.JSONEncoder): + """Custom JSON encoder for AOBaseConfig objects""" + + def default(self, o): + # Handle AOBaseConfig subclasses first (most specific case) + if isinstance(o, AOBaseConfig): + data_dict = {} + # Process each attribute to handle nested objects + for k, v in o.__dict__.items(): + if not k.startswith("_") and k != "VERSION": + # Recursively encode each value (important for nested objects) + data_dict[k] = self.encode_value(v) + + return { + # Only store the class name, not the full module path + "_type": o.__class__.__name__, + "_version": getattr(o.__class__, "VERSION", 1), + "_data": data_dict, + } + + # Handle NamedTuple types + if hasattr(o, "_fields") and hasattr( + o, "_asdict" + ): # Check for NamedTuple characteristics + asdict_data = o._asdict() + # Process each field to handle nested objects + processed_data = {k: self.encode_value(v) for k, v in asdict_data.items()} + + return { + "_type": o.__class__.__name__, + "_version": getattr(o.__class__, "VERSION", 1), + "_data": processed_data, + } + + # Handle dataclasses + if dataclasses.is_dataclass(o) and not isinstance(o, type): + data_dict = {} + # Process each field to handle nested objects + for f in dataclasses.fields(o): + if f.name != "VERSION": + data_dict[f.name] = self.encode_value(getattr(o, f.name)) + + return { + # Only store the class name for dataclasses too + "_type": o.__class__.__name__, + "_version": getattr(o.__class__, "VERSION", 1), + "_data": data_dict, + } + + # Handle torch.dtype + if hasattr(o, "__module__") and o.__module__ == "torch" and isinstance(o, type): + return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]} + + # Handle Layout objects + if hasattr(o, "__class__") and "Layout" in o.__class__.__name__: + return { + "_type": o.__class__.__name__, + "_data": { + k: self.encode_value(v) + for k, v in o.__dict__.items() + if not k.startswith("_") + }, + } + + # Handle enum values + if isinstance(o, enum.Enum): + # Store the full path for enums to ensure uniqueness + return {"_type": f"{o.__class__.__name__}", "_data": o.name} + + if isinstance(o, torch.dtype): + return {"_type": "torch.dtype", "_data": str(o).split(".")[-1]} + + # For lists and dictionaries, recursively process their items + if isinstance(o, list): + return [self.encode_value(item) for item in o] + + if isinstance(o, dict): + return {k: self.encode_value(v) for k, v in o.items()} + + # Default case - let the parent class handle it + return super().default(o) + + def encode_value(self, value): + """Helper method to recursively encode a value""" + # Try to use default for custom type + try: + # This will handle all our special cases and raise TypeError + # if it can't handle the type + result = self.default(value) + return result + except TypeError: + pass + + # Default case - return as is + # (This will be processed by standard JSON encoder later) + return value + + +def config_to_dict(config: AOBaseConfig) -> Dict[str, Any]: + """ + Convert an AOBaseConfig instance to a dictionary suitable for serialization. + + Args: + config: An instance of AOBaseConfig subclass + + Returns: + Dict representation of the config + """ + if not isinstance(config, AOBaseConfig): + raise TypeError(f"Expected AOBaseConfig instance, got {type(config)}") + + # Use the existing JSON encoder but return the dict directly + return json.loads(json.dumps(config, cls=ConfigJSONEncoder)) + + +ALLOWED_AO_MODULES = {"torchao.quantization", "torchao.sparsity.sparse_api"} + + +def config_from_dict(data: Dict[str, Any]) -> AOBaseConfig: + """ + Create an AOBaseConfig subclass instance from a dictionary. + + Args: + data: Dictionary containing serialized config data + + Returns: + An instance of the appropriate AOBaseConfig subclass + + Raises: + VersionMismatchError: If the stored version doesn't match the class version + ValueError: If deserialization fails for other reasons + """ + if not isinstance(data, dict): + raise TypeError(f"Expected dictionary, got {type(data)}") + + if "_type" not in data or "_data" not in data: + raise ValueError("Input dictionary missing required '_type' or '_data' fields") + + type_path = data["_type"] + stored_version = data.get("_version", 1) + obj_data = data["_data"] + + # Handle torch.dtype + if type_path == "torch.dtype": + import torch + + return getattr(torch, obj_data) + # Try to find the class in any of the allowed modules + cls = None + for module_path in ALLOWED_AO_MODULES: + try: + module = importlib.import_module(module_path) + cls = getattr(module, type_path) + break # Found the class, exit the loop + except (ImportError, AttributeError): + continue # Try the next module + + # If we couldn't find the class in any allowed module, raise an error + if cls is None: + allowed_modules_str = ", ".join(ALLOWED_AO_MODULES) + raise ValueError( + f"Failed to find class {type_path} in any of the allowed modules: {allowed_modules_str}" + ) + + # Check version - require exact match + current_version = getattr(cls, "VERSION", 1) + if stored_version != current_version: + raise VersionMismatchError(type_path, stored_version, current_version) + + # Handle the case where obj_data is not a dictionary + if not isinstance(obj_data, dict): + if issubclass(cls, enum.Enum): + # For enums, convert string to enum value + return getattr(cls, obj_data) + else: + # For other primitive types, create an instance with the value + try: + return cls(obj_data) + except: + return obj_data + + # Process nested structures for dictionary obj_data + processed_data = {} + for key, value in obj_data.items(): + if isinstance(value, dict) and "_type" in value and "_data" in value: + # Recursively handle nested configs + processed_data[key] = config_from_dict(value) + elif isinstance(value, list): + # Handle lists of possible configs + processed_data[key] = [ + config_from_dict(item) + if isinstance(item, dict) and "_type" in item and "_data" in item + else item + for item in value + ] + else: + processed_data[key] = value + + # Create and return the instance + try: + return cls(**processed_data) + except Exception as e: + raise ValueError(f"Failed to create instance of {cls.__name__}: {e}") diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b816eb585e..c986084af4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -46,8 +46,10 @@ AffineQuantizedObserverBase, ) from .quant_api import ( + CutlassInt4PackedLayout, Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, + Float8MMConfig, Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, @@ -57,6 +59,8 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + PlainLayout, + TensorCoreTiledLayout, UIntXWeightOnlyConfig, float8_dynamic_activation_float8_weight, float8_static_activation_float8_weight, @@ -190,4 +194,9 @@ "WeightOnlyInt8QuantLinear", "TwoStepQuantizer", "Quantizer", + # Layouts for quant_api + "PlainLayout", + "TensorCoreTiledLayout", + "CutlassInt4PackedLayout", + "Float8MMConfig", ]