Skip to content

Commit

Permalink
support concat in recast (apache#8028)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored May 13, 2021
1 parent 76fb2af commit ed283b8
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 4 deletions.
18 changes: 14 additions & 4 deletions python/tvm/relay/transform/recast.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,20 @@ def visit_call(self, call):
# Downcast this op if its the correct type and not skipped.
if call.op in self.valid_ops and current_layer not in self.skip_layers:
# Recast inputs to specified type.
args = [self.visit(arg) for arg in call.args]
new_args = list()
for arg in args:
new_args.append(relay.cast(arg, dtype=self.dtype))
if call.op == relay.op.get("concatenate"):
if len(call.args) != 1 or not isinstance(call.args[0], relay.expr.Tuple):
return Call(new_fn, args, call.attrs)

tuple_args = [self.visit(arg) for arg in call.args[0].fields]
new_args = list()
for arg in tuple_args:
new_args.append(relay.cast(arg, dtype=self.dtype))
new_args = [relay.expr.Tuple(new_args)]
else:
args = [self.visit(arg) for arg in call.args]
new_args = list()
for arg in args:
new_args.append(relay.cast(arg, dtype=self.dtype))

# If out_dtype is in the attributes, we need to update it.
orig_dtype = None
Expand Down
24 changes: 24 additions & 0 deletions tests/python/relay/test_recast.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ def expected():
assert tvm.ir.structural_equal(expected, post)


def test_recast_concat():
def before():
x = relay.var("x", shape=[1, 4])
y = relay.var("y", shape=[1, 4])
t = relay.Tuple([x, y])
c = relay.op.concatenate(t, axis=1)
return relay.Function([x, y], c)

def expected():
xv = relay.var("x", shape=[1, 4])
yv = relay.var("y", shape=[1, 4])
x = relay.cast(xv, "float16")
y = relay.cast(yv, "float16")
t = relay.Tuple([x, y])
c = relay.op.concatenate(t, axis=1)
c = relay.cast(c, "float32")
return relay.Function([xv, yv], c)

pre = before()
post = recast(pre, "float16", "float32", ops=["concatenate"])
expected = expected()
assert tvm.ir.structural_equal(expected, post)


if __name__ == "__main__":
test_recast_simple()
test_recast_medium()
Expand Down

0 comments on commit ed283b8

Please sign in to comment.