Skip to content

Commit

Permalink
avoid nested closures in module (ml-explore#759)
Browse files Browse the repository at this point in the history
  • Loading branch information
awni authored Feb 29, 2024
1 parent 776c3d2 commit 4494970
Showing 1 changed file with 46 additions and 27 deletions.
73 changes: 46 additions & 27 deletions python/mlx/nn/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,42 @@
from mlx.utils import tree_flatten, tree_unflatten


def _unwrap(model, value_key, value, filter_fn, map_fn, is_leaf_fn):
if is_leaf_fn(model, value_key, value):
return map_fn(value)

elif isinstance(value, Module):
return {
k: _unwrap(value, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in value.items()
if filter_fn(value, k, v)
}

elif isinstance(value, dict):
nd = {}
for k, v in v.items():
tk = f"{value_key}.{k}"
nd[k] = (
_unwrap(model, tk, v, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, v)
else {}
)
return nd

elif isinstance(value, list):
nl = []
for i, vi in enumerate(value):
tk = f"{value_key}.{i}"
nl.append(
_unwrap(model, tk, vi, filter_fn, map_fn, is_leaf_fn)
if filter_fn(model, tk, vi)
else {}
)
return nl

raise RuntimeError("Unexpected leaf found while traversing the module")


class Module(dict):
"""Base class for building neural networks with MLX.
Expand Down Expand Up @@ -98,10 +134,13 @@ def __getattr__(self, key: str):
if key in self:
return self[key]
else:
raise AttributeError(f"{type(self)!r} has no attribute {key!r}")
super(Module, self).__getattr__(key, val)

def __setattr__(self, key: str, val: Any):
self[key] = val
if isinstance(val, (mx.array, dict, list, tuple)):
self[key] = val
else:
super(Module, self).__setattr__(key, val)

def load_weights(
self,
Expand Down Expand Up @@ -245,31 +284,11 @@ def filter_and_map(
is_leaf_fn = is_leaf_fn or (
lambda m, k, v: not isinstance(v, (Module, dict, list))
)

def unwrap(vk, v):
if is_leaf_fn(self, vk, v):
return map_fn(v)

if isinstance(v, Module):
return v.filter_and_map(filter_fn, map_fn, is_leaf_fn)

if isinstance(v, dict):
nd = {}
for k, v in v.items():
tk = f"{vk}.{k}"
nd[k] = unwrap(tk, v) if filter_fn(self, tk, v) else {}
return nd

if isinstance(v, list):
nl = []
for i, vi in enumerate(v):
tk = f"{vk}.{i}"
nl.append(unwrap(tk, vi) if filter_fn(self, tk, vi) else {})
return nl

raise RuntimeError("Unexpected leaf found while traversing the module")

return {k: unwrap(k, v) for k, v in self.items() if filter_fn(self, k, v)}
return {
k: _unwrap(self, k, v, filter_fn, map_fn, is_leaf_fn)
for k, v in self.items()
if filter_fn(self, k, v)
}

def parameters(self):
"""Recursively return all the :class:`mlx.core.array` members of this Module
Expand Down

0 comments on commit 4494970

Please sign in to comment.