Skip to content

Commit

Permalink
[Frontend][PaddlePaddle] Support conv2d when data_format is NHWC (apa…
Browse files Browse the repository at this point in the history
…che#16616)

* support conv2d when data_format is NHWC

* modify the annotation
  • Loading branch information
Zheng-Bicheng authored Feb 22, 2024
1 parent ad3dfb4 commit 9fd3461
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@ def convert_conv2d(g, op, block):
strides = op.attr("strides")

kernel = g.get_node(op.input("Filter")[0])
kernel_layout = "OIHW"
input_x = g.get_node(op.input("Input")[0])
data_layout = op.attr("data_format")
out_channels, _, k_h, k_w = infer_shape(kernel)
Expand All @@ -335,6 +336,16 @@ def convert_conv2d(g, op, block):
msg = f'Value {padding_algorithm} in attribute "padding" of operator Conv is not "valid."'
raise tvm.error.OpAttributeInvalid(msg)

if data_layout == "NHWC":
kernel_layout = "HWIO"
# PaddlePaddle wieght layout is "OIHW", tvm need "HWIO" when op data_format is "NHWC".
kernel_data = g.get_params(op.input("Filter")[0])
kernel_data = kernel_data.asnumpy()
kernel_data = kernel_data.transpose((2, 3, 1, 0))
kernel_data = _nd.array(kernel_data)
g.modify_node(op.input("Filter")[0], kernel_data)
kernel = g.get_node(op.input("Filter")[0])

out = _op.nn.conv2d(
input_x,
kernel,
Expand All @@ -345,6 +356,7 @@ def convert_conv2d(g, op, block):
channels=out_channels,
kernel_size=[k_h, k_w],
data_layout=data_layout,
kernel_layout=kernel_layout,
)
g.add_node(op.output("Output")[0], out)

Expand Down Expand Up @@ -2915,6 +2927,12 @@ def add_node(self, name, node):

self.nodes[name] = fold_constant(node)

def modify_node(self, name, params):
"""modify node from graph"""

self.params[name] = params
self.nodes[name] = new_var(name, shape=params.shape, dtype=params.dtype)

def get_params(self, name=None):
"""Get params from graph."""

Expand Down

0 comments on commit 9fd3461

Please sign in to comment.