-
Notifications
You must be signed in to change notification settings - Fork 509
/
Copy pathop_rescale.py
72 lines (61 loc) · 2.29 KB
/
op_rescale.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# pyre-unsafe
from typing import cast, List
import executorch.backends.arm.tosa_quant_utils as tosa_quant_utils
import serializer.tosa_serializer as ts # type: ignore
import torch
import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
from torch.fx import Node
@register_node_visitor
class RescaleVisitor(NodeVisitor):
target = "_rescale.default"
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
input_dtype = inputs[0].dtype
output_dtype = cast(torch.dtype, node.args[1])
scale = cast(float, node.args[2])
input_zp = cast(int, node.args[3])
output_zp = cast(int, node.args[4])
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
raise ValueError(
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
)
if output_dtype != torch.int8 and output_zp != 0:
raise ValueError(
f"If output dtype is not int8, output_zp must be 0. Got {output_dtype=}, {output_zp=}"
)
# scale32 gives higher accuracy but for a higher HW cost.
# For now, always go for scale32.
scale_32 = True
scale_width = 32 if scale_32 else 16
multiplier, shift = tosa_quant_utils.compute_multiplier_and_shift(
[scale], scale_width
)
attr_rescale = ts.TosaSerializerAttribute()
attr_rescale.RescaleAttribute(
input_zp=input_zp,
output_zp=output_zp,
multiplier=multiplier,
shift=shift,
scale32=scale_32,
double_round=False,
per_channel=False,
input_unsigned=False,
output_unsigned=False,
)
tosa_graph.addOperator(
TosaOp.Op().RESCALE, [inputs[0].name], [output.name], attr_rescale
)