Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config Serilization No Deps #1875

Merged
merged 1 commit into from
Mar 14, 2025
Merged

Config Serilization No Deps #1875

merged 1 commit into from
Mar 14, 2025

Conversation

drisspg
Copy link
Contributor

@drisspg drisspg commented Mar 12, 2025

Summary

Similar to #1806 but without using Pydantic

This PR introduces functionality for serializing and deserializing AOBaseConfig objects to and from dictionaries

The new config_to_dict and config_from_dict functions allow these configurations to be:

  • Saved to disk
  • Shared between applications
  • Versioned for compatibility

Another main difference from that PR is that AOBaseConfig remains unchanged besides the overridable class var which is set to 1 by default

This makes it so that configs can opt-in, which hopefully is as cheap is adding your config to the test file, and ensuring that all types that are needed for reconstruction can be found in torchao.quantization import module

Implementation Details

Serialization Format

When an AOBaseConfig object is serialized, it creates a dictionary with:

{
    "_type": "ConfigClassName",  # Just the class name, not full module path
    "_version": 1,               # Version from the class's VERSION attribute
    "_data": {                   # Actual configuration parameters
        "param1": value1,
        "param2": value2,
        # Nested objects also get serialized with their types
    }
}

Special Cases

The ConfigJSONEncoder handles several types of objects:

  1. AOBaseConfig subclasses: Primary config objects
  2. NamedTuple types: Common for grouped parameters
  3. Dataclasses: Needed for our granularity objects
  4. torch.dtype objects: the 1 native torch object that so far shows up in our configs
  5. Layout objects: Layout objects
  6. Enum values: For options like MappingType.SYMMETRIC

Version Management

Each config class defines a VERSION class variable (defaulting to 1). During deserialization, we enforce:

if stored_version != current_version:
    raise VersionMismatchError(type_path, stored_version, current_version)

This prevents using outdated configurations with newer code that might expect different parameters.

Restricted Reconstruction

I only allow importing types from torchao.quantization module path or from torch

try:
    module = importlib.import_module("torchao.quantization")
    cls = getattr(module, type_path)
except (ImportError, AttributeError) as e:
    raise ValueError(f"Failed to find class {type_path} in torchao.quantization: {e}")

Example Usage

Serializing a Configuration

from torchao.quantization import Int4WeightOnlyConfig
from torchao.core.config import config_to_dict
import json

# Create a configuration
config = Int4WeightOnlyConfig(group_size=64)

# Convert to dictionary
config_dict = config_to_dict(config)

# Save to JSON file
with open("config.json", "w") as f:
    json.dump(config_dict, f)

This produces

{"_type": "Int4WeightOnlyConfig", "_version": 1, "_data": {"group_size": 64, "layout": {"_type": "TensorCoreTiledLayout", "_version": 1, "_data": {"inner_k_tiles": 8}}, "use_hqq": false, "zero_point_domain": {"_type": "ZeroPointDomain", "_data": "NONE"}}}

Deserializing a Configuration

from torchao.core.config import config_from_dict
import json

# Load from JSON file
with open("config.json", "r") as f:
    config_dict = json.load(f)

# Rebuild configuration object
config = config_from_dict(config_dict)

print(config)

Prints

Int4WeightOnlyConfig(group_size=64, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False, zero_point_domain=<ZeroPointDomain.NONE: 3>)

YAML Support

The serialized format uses basic types (strings, numbers, booleans, lists, dictionaries) that are compatible with YAML. Users can:

  1. Serialize to dict using config_to_dict
  2. Use PyYAML to convert the dict to YAML
  3. For deserialization, parse YAML to dict, then use config_from_dict

Example:

import yaml
from torchao.core.config import config_to_dict, config_from_dict

# Serialize to YAML
config_dict = config_to_dict(my_config)
yaml_str = yaml.dump(config_dict)

# Deserialize from YAML
loaded_dict = yaml.safe_load(yaml_str)
loaded_config = config_from_dict(loaded_dict)

The yaml string

_data:
  group_size: 64
  layout:
    _data:
      inner_k_tiles: 8
    _type: TensorCoreTiledLayout
    _version: 1
  use_hqq: false
  zero_point_domain:
    _data: NONE
    _type: ZeroPointDomain
_type: Int4WeightOnlyConfig
_version: 1

Future Improvements

  • Migration logic for handling version changes
  • We could add more machinery for making custom types in AOBaseconfigs easier to support

As well we should ensure that all quant-api's configs are added to the test files

Copy link

pytorch-bot bot commented Mar 12, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1875

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (1 Unrelated Failure)

As of commit 6f44671 with merge base be09c1d (image):

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 12, 2025
@drisspg drisspg requested review from vkuzo and andrewor14 March 12, 2025 05:29
@drisspg drisspg added the topic: new feature Use this tag if this PR adds a new feature label Mar 12, 2025


@pytest.mark.parametrize("config", configs, ids=get_config_ids)
def test_reconstructable_dict_file_round_trip(config):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks great!

@vkuzo
Copy link
Contributor

vkuzo commented Mar 12, 2025

ensuring that all types that are needed for reconstruction can be found in torchao.quantization import module. "I only allow importing types from torchao.quantization module path or from torch".

Can this be relaxed to allow prototype and sparsity configs to be serializable?

Thoughts on how handling of static scales would work with this?

Overall looks great!

@drisspg
Copy link
Contributor Author

drisspg commented Mar 12, 2025

@vkuzo

  1. Added a global allow lists that can be used for allowing more config locations

  2. For handling static scales thats interesting. It is not possible to save a tensor to json ( well you can but not comically so)
    So you would update the type specific handlers to add special logic for tensors. Probably saying "scale = STATIC" or idk up to us. And then at reconstruction time you need to figure out where we put that data, which realistically should live on the subclasses stored in the state dict. SO you need to essentially init the config in an invalid state (scale = PLACEHOLDER) and then after the fact inject the right scale, a little handwavy but I think its possible

Copy link
Contributor

@andrewor14 andrewor14 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Definitely prefer this to the Pydantic version and separating serde from the string API. Just a couple questions:

  1. How do we want versioning to work exactly? We definitely need to bump the version when we remove or modify an existing field, but do we necessarily need to bump the version when adding a new field with a default value? In this case, should we still fail on version mismatch or just let it load the default value?

  2. Do we need to handle nested configs / data structures? I feel inner configs (e.g. Float8MMConfig) don't need to be TorchAOBaseConfig, so technically we just need to handle inner dataclasses/named tuples?

processed_data = {}
for key, value in obj_data.items():
if isinstance(value, dict) and "_type" in value and "_data" in value:
# Recursively handle nested configs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any existing configs that have nested configs or fields with data structures? I feel our configs so far have been pretty simple?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Vasiliy mentioned there might be one for float8 training but I could find it for config wrapping configs

We do have configs that wrap objects that need to be handled specially, the float inference configs have this where MMconfig is internal to it and needs special care

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do have configs that wrap objects that need to be handled specially, the float inference configs have this where MMconfig is internal to it and needs special care

I don't think MMConfig should be in the BC surface / serializable / etc, that seems like a code smell.

Copy link
Contributor Author

@drisspg drisspg Mar 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I disagree - it is a convenient way for users to modulate mm behavior on a per instance level and unpacking this struct into three separate booleans doen't feel beneficial.

That being said there are alot of examples of these nested objects, Layouts, granularities, etc.. for which this behavior is necessary

@drisspg drisspg force-pushed the serialize-part2 branch 3 times, most recently from 0e5aab0 to 896cb53 Compare March 13, 2025 03:13
@@ -56,6 +58,8 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
PlainLayout,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this exposed in public API / on HF / etc?

I don't think this abstraction is clean and I don't want to expose it, this is not really a "layout" rather than a "layout with some other arguments mixed in for convenience". We should talk about what to do with this before exposing it.

@drisspg drisspg merged commit d258a11 into main Mar 14, 2025
17 of 18 checks passed
@drisspg drisspg mentioned this pull request Mar 19, 2025
2 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants