Skip to content

Commit

Permalink
[pytorch][export] Move is_param and get_param out of exir and into ex…
Browse files Browse the repository at this point in the history
…port (pytorch#107264)

Summary: These doesn't feel edge specific so moving out of exir.

Test Plan: ci

Differential Revision: D48361384

Pull Request resolved: pytorch#107264
Approved by: https://github.com/angelayi
  • Loading branch information
JacobSzwejbka authored and pytorchmergebot committed Aug 22, 2023
1 parent 8fb6416 commit c14f4d6
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 1 deletion.
23 changes: 22 additions & 1 deletion test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch import Tensor
from torch._export import DEFAULT_EXPORT_DYNAMO_CONFIG, dynamic_dim, export
from torch._export.constraints import constrain_as_size, constrain_as_value
from torch._export.utils import register_dataclass_as_pytree_node
from torch._export.utils import register_dataclass_as_pytree_node, is_param, get_param
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing import FileCheck
from torch.testing._internal.common_utils import run_tests, TestCase
Expand Down Expand Up @@ -437,6 +437,27 @@ class Outer:
unflat = tree_unflatten(flat, spec)
self.assertEqual(unflat, inp)

def test_param_util(self):
class Basic(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(10, 1)

def forward(self, x):
return self.lin(x)

ep = export(Basic(), (torch.randn(5, 10),))
num_params = 0
params = []
for node in ep.graph.nodes:
if is_param(ep, node):
num_params += 1
params.append(get_param(ep, node))
self.assertEqual(num_params, 2)
self.assertEqual(params[0].shape, [1, 10]) # weight
self.assertEqual(params[1].shape, [1]) # bias


def test_export_dynamo_config(self):
class MyModule(torch.nn.Module):
def __init__(self):
Expand Down
28 changes: 28 additions & 0 deletions torch/_export/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

from typing import Any, List, Optional, Tuple

import torch

from torch._export import ExportedProgram

from torch.utils._pytree import (
_register_pytree_node,
Context,
Expand Down Expand Up @@ -52,3 +56,27 @@ def default_unflatten_fn(values: List[Any], context: Context) -> Any:
None,
None,
)


def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Checks if the given node is a parameter within the exported program
"""

return node.name in program.graph_signature.inputs_to_parameters


def get_param(
program: ExportedProgram,
node: torch.fx.Node,
) -> Optional[torch.nn.Parameter]:
"""
Returns the parameter associated with the given node in the exported program.
Returns None if the node is not a parameter within the exported program
"""

if is_param(program, node):
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
return program.state_dict[parameter_name]

return None

0 comments on commit c14f4d6

Please sign in to comment.