forked from pytorch/executorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmemory.py
50 lines (37 loc) · 1.26 KB
/
memory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import List, Tuple, Union
import torch
from executorch.exir.sym_util import eval_shape
from executorch.exir.tensor import TensorSpec
from torch.utils import _pytree as pytree
from typing_extensions import TypeAlias
TensorAllocSpec: TypeAlias = Tuple[Tuple[int], torch.dtype]
AllocSpec: TypeAlias = Union[
TensorAllocSpec,
List[TensorAllocSpec],
]
def alloc(spec: AllocSpec) -> pytree.PyTree:
if isinstance(spec, list):
return [alloc(s) for s in spec]
shape, dtype = spec
# evaluate the shape to int so we can run the traced module
# in python for testing
shape = eval_shape(shape)
return torch.empty(shape, dtype=dtype)
def free(spec: TensorSpec) -> None:
"""
The function is nop. The major purpose is to put it in the Fx IR.
E.g., it can be the target of call_function node.
"""
pass
def view(base: torch.Tensor, size: List[int]) -> torch.Tensor:
"""
This function mimics torch.ops.aten.view.default.
It is used to elide view_copy nodes.
"""
return base.view(size)