Skip to content

Commit

Permalink
Partially revert google#4192 which sets back a bunch of previous merg…
Browse files Browse the repository at this point in the history
…ed pushes.

PiperOrigin-RevId: 675337465
  • Loading branch information
IvyZX authored and Flax Authors committed Sep 16, 2024
1 parent 03e034d commit 9eb0a61
Show file tree
Hide file tree
Showing 17 changed files with 274 additions and 158 deletions.
14 changes: 14 additions & 0 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""

import abc
import dataclasses
import functools
from typing import Any, Generic, TypeVar
from collections.abc import Callable
Expand Down Expand Up @@ -287,6 +288,19 @@ def get_sharding(self, mesh: jax.sharding.Mesh) -> jax.sharding.Sharding:
"""Returns the ``NamedSharding`` for this partitioned value."""
return jax.sharding.NamedSharding(mesh, self.get_partition_spec())

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.Partitioned`."""
metadata['names'] = metadata.pop('sharding')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_partitioning(
fn: Callable[..., Any],
Expand Down
9 changes: 9 additions & 0 deletions flax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ def __reduce__(self):
return (FlaxError, (str(self),))


#################################################
# NNX errors #
#################################################


class TraceContextError(FlaxError):
pass


#################################################
# lazy_init.py errors #
#################################################
Expand Down
15 changes: 15 additions & 0 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,21 @@ def unbox(self, apply_constraint=True) -> Any:
else:
return self.value

def to_nnx_metadata(self) -> dict[str, Any]:
"""Return a dict of metadata that can translate into an `nnx.Variable`."""
metadata = vars(self)
metadata['sharding'] = metadata.pop('names')
metadata['sharding_rules'] = metadata.pop('rules')
return metadata

@classmethod
def from_nnx_metadata(cls, metadata: dict[str, Any]):
"""Given a dict of `nnx.Variable` format metadata, create a `nn.LogicallyPartitioned`."""
metadata['names'] = metadata.pop('sharding')
metadata['rules'] = metadata.pop('sharding_rules')
fields = {x.name for x in dataclasses.fields(cls)}
return cls(**{k: v for k, v in metadata.items() if k in fields})


def with_logical_partitioning(
fn: Callable[..., Any],
Expand Down
30 changes: 15 additions & 15 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def variable_type_name(typ: tp.Type[variableslib.Variable[tp.Any]]) -> str:


def register_variable_name_type_pair(name, typ, overwrite = False):
"""Register a pair of variable type name (like Linen collections) and its NNX type."""
"""Register a pair of Linen collection name and its NNX type."""
if not overwrite and name in VariableTypeCache:
raise ValueError(f'Name {name} already mapped to type {VariableTypeCache[name]}. '
'To overwrite, call with `overwrite=True`.')
'To overwrite, call register_variable_name_type_pair() with `overwrite=True`.')
VariableTypeCache[name] = typ


Expand All @@ -85,8 +85,7 @@ def _variable_parents_count(t: type):


class NNXMeta(struct.PyTreeNode, meta.AxisMetadata[A]):
"""Default Flax metadata class for `nnx.VariableState`.
"""
"""Default Flax metadata class for `nnx.VariableState`."""

var_type: type[variableslib.Variable[tp.Any]] = struct.field(pytree_node=False)
value: Any = struct.field(pytree_node=True)
Expand All @@ -110,10 +109,11 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
if 'linen_meta_type' in metadata:
if metadata['linen_meta_type'] is not meta.Partitioned:
raise ValueError('Not supporting Linen metadata types other than nn.Partitioned')
return meta.Partitioned(vs.value, names=metadata['sharding'], mesh=metadata['mesh'])
return NNXMeta(vs.type, vs.value, vs.get_metadata())
linen_type = metadata['linen_meta_type']
if hasattr(linen_type, 'from_nnx_metadata'):
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
return linen_type(vs.value, **metadata)
return NNXMeta(vs.type, vs.value, metadata)


def get_col_name(keypath: tp.Sequence[Any]) -> str:
Expand All @@ -124,15 +124,15 @@ def get_col_name(keypath: tp.Sequence[Any]) -> str:


def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
"""Convert a Linen variable to an NNX variable.
This process needs the collection name,
"""
"""Convert a Linen variable to an NNX variable."""
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
if isinstance(x, meta.AxisMetadata):
if isinstance(x, meta.Partitioned):
return vtype(x.value, sharding=x.names, mesh=x.mesh, linen_meta_type=meta.Partitioned)
raise ValueError('Not yet supporting metadata types other than nn.Partitioned and NNXMeta')
return vtype(x)
x_metadata = vars(x)
if hasattr(x, 'to_nnx_metadata'):
x_metadata = x.to_nnx_metadata()
assert hasattr(x, 'value')
return vtype(**x_metadata, linen_meta_type=type(x))
return vtype(x)
53 changes: 27 additions & 26 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def lazy_init(fn: Module | tp.Callable[..., tp.Any], *args, **kwargs):
module = fn
assert callable(fn)
else:
if not (hasattr(fn, '__self__') and isinstance(fn.__self__, Module)):
if not hasattr(fn, '__self__') and isinstance(fn.__self__, Module):
raise ValueError(f'{fn = } needs to be a method of an NNX Module.')
module = fn.__self__
_set_initializing(module, True)
Expand Down Expand Up @@ -124,6 +124,7 @@ def __init__(
self.linen_collections: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
return lazy_init(self, *args, **kwargs)

def __call__(
Expand Down Expand Up @@ -224,28 +225,6 @@ class ToLinen(linen.Module):
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

def update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)

@linen.compact
def __call__(self, *args, **kwargs):
# init codepath
Expand All @@ -255,7 +234,7 @@ def __call__(self, *args, **kwargs):
module_kwargs |= dict(rngs=nnx.Rngs(**linen_rngs_dict(self)))
module = self.nnx_class(*self.args, **module_kwargs)
# TODO: add lazy_init here in case there's an `ToNNX` submodule under `module`.
self.update_variables(module)
self._update_variables(module)
return module(*args, **kwargs)

# apply codepath
Expand All @@ -270,11 +249,33 @@ def __call__(self, *args, **kwargs):
module = nnx.merge(gdef, nnx_state)
nnx.reseed(module, **linen_rngs_dict(self)) # reseed with keys from linen apply call.
out = module(*args, **kwargs)
self.update_variables(module)
self._update_variables(module)
return out

def _update_variables(self, module):
"""Store the NNX module's graph def and state inside Linen module variables."""
gdef, state = nnx.split(module)
# Save the graph def.
if self.is_mutable_collection('nnx'):
self.put_variable('nnx', 'graphdef', gdef)
# Sort all the variable types.
types = set(jax.tree.leaves(
jax.tree.map(lambda x: x.type, state,
is_leaf=lambda x: isinstance(x, nnx.VariableState))))
types = bv.sort_variable_types(types)
_, *state_by_types = nnx.split(module, *types)
# Each variable type goes to its own linen collection, and
# each attribute goes to its own linen variable
for typ, state in zip(types, state_by_types):
collection = bv.variable_type_name(typ)
if self.is_mutable_collection(collection):
for k, v in state.raw_mapping.items():
v = jax.tree.map(bv.to_linen_var, v,
is_leaf=lambda x: isinstance(x, nnx.VariableState))
self.put_variable(collection, k, v)


def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `ToLinen` if user is not changing any of `ToLinen` default fields."""
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
17 changes: 0 additions & 17 deletions flax/nnx/errors.py

This file was deleted.

18 changes: 10 additions & 8 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from flax import struct
from flax.nnx.object import Object
from flax.typing import MISSING, PathParts
from flax.typing import Missing, PathParts
from flax.nnx import graph


Expand Down Expand Up @@ -59,7 +59,7 @@ def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> (
tuple[A, tuple[tp.Any, ...]]
Expand Down Expand Up @@ -101,7 +101,7 @@ def extract_graph_nodes(

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
if prefix is Missing:
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
else:
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]
Expand Down Expand Up @@ -330,12 +330,13 @@ def to_tree(
tree,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
split_fn: tp.Callable[
[graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any
] = default_split_fn,
map_non_graph_nodes: bool = False,
ctxtag: str | None = None,
check_aliasing: bool = True,
) -> tp.Any:
leaf_prefixes = broadcast_prefix(
prefix,
Expand All @@ -351,9 +352,10 @@ def to_tree(
with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if graph.is_graph_node(leaf):
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
if check_aliasing:
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(tree_node)
else:
Expand Down Expand Up @@ -381,7 +383,7 @@ def from_tree(
tree: tp.Any,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
merge_fn: tp.Callable[
[graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any
] = merge_tree_node,
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@
import numpy as np

from flax.nnx import (
errors,
reprlib,
tracers,
)
from flax.nnx import graph
from flax.nnx.variables import Variable, VariableState
from flax.typing import Key
from flax import errors

G = tp.TypeVar('G', bound='Object')

Expand Down
10 changes: 8 additions & 2 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _add_axis(x: tp.Any):
sharding.insert(index, axis_name)
x.sharding = tuple(sharding) # type: ignore

x.add_axis(axis_name, index)
x.add_axis(index, axis_name)
return x

return jax.tree.map(
Expand All @@ -61,7 +61,7 @@ def _remove_axis(x: tp.Any):
sharding = list(x.sharding)
assert sharding.pop(index) == axis_name
x.sharding = tuple(sharding)
x.remove_axis(axis_name, index)
x.remove_axis(index, axis_name)
return x

return jax.tree.map(
Expand Down Expand Up @@ -89,9 +89,15 @@ def _maybe_replicate(x):
else:
return None

def from_rules(sharding, sharding_rules):
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else s for s in sharding)

def f(x):
if isinstance(x, (variables.VariableState, variables.Variable)):
if hasattr(x, 'sharding') and x.sharding:
if hasattr(x, 'sharding_rules') and x.sharding_rules:
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
return x.replace(PartitionSpec(*x.sharding))
else:
return x.replace(_maybe_replicate(x.value))
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs):
(args, kwargs),
prefix=(in_shardings, kwarg_shardings),
split_fn=_jit_split_fn,
check_aliasing=in_shardings is not None,
ctxtag='jit',
)
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
Expand Down
Loading

0 comments on commit 9eb0a61

Please sign in to comment.