Skip to content

Commit

Permalink
[Relay] Remove in-place modification of attributes in layout transform (
Browse files Browse the repository at this point in the history
apache#8309)

* stub

* mnist test working

* porting InferCorrectLayout

* compiles with new infer layout

* remove log

* fix qnn concat

* do not run dense pack alter op test on gpu targets

* cleanup

* add test

* cpplint

* CHECK -> ICHECK

* doc update

* restore try catch

* split inferred_layout into seperate fields

* Update InferCorrectLayout functions following struct field change

* fix cpplint
  • Loading branch information
masahi authored Jun 26, 2021
1 parent 5177729 commit c25b8fa
Show file tree
Hide file tree
Showing 24 changed files with 359 additions and 260 deletions.
17 changes: 10 additions & 7 deletions src/relay/op/dyn/nn/upsampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ namespace relay {
namespace dyn {

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
InferCorrectLayoutOutput UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

if (new_in_layouts.defined()) {
ICHECK_GT(new_in_layouts.size(), 0);

Expand All @@ -59,7 +61,8 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,

Layout inferred_layout(params->layout);
Layout param_layout("NCHW");
return Array<Array<Layout> >{{inferred_layout, param_layout, param_layout}, {inferred_layout}};
return InferCorrectLayoutOutput({inferred_layout, param_layout, param_layout}, {inferred_layout},
Attrs(params));
}

} // namespace dyn
Expand Down
12 changes: 6 additions & 6 deletions src/relay/op/image/dilation2d.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(Dilation2DAttrs);

template <typename T>
Array<Array<Layout> > Dilation2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
InferCorrectLayoutOutput Dilation2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();

return Array<Array<Layout> >{{params->data_layout, params->kernel_layout}, {params->data_layout}};
return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout},
{params->data_layout}, attrs);
}

// Positional relay function to create dilation2d operator
Expand Down
16 changes: 8 additions & 8 deletions src/relay/op/image/resize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(ResizeAttrs);

template <typename T>
Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
InferCorrectLayoutOutput ResizeInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<T>();
CHECK(attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

if (new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 1);
Expand All @@ -54,8 +55,7 @@ Array<Array<Layout> > ResizeInferCorrectLayout(const Attrs& attrs,
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params));
}

bool ResizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
10 changes: 5 additions & 5 deletions src/relay/op/nn/bitserial.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ namespace relay {
TVM_REGISTER_NODE_TYPE(BitPackAttrs);

template <typename T>
Array<Array<Layout>> BinaryConv2DInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
InferCorrectLayoutOutput BinaryConv2DInferCorrectLayout(
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();

// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout>>{{params->data_layout, params->kernel_layout}, {params->data_layout}};
return InferCorrectLayoutOutput({params->data_layout, params->kernel_layout},
{params->data_layout}, attrs);
}

bool BitPackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
21 changes: 9 additions & 12 deletions src/relay/op/nn/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -1243,29 +1243,26 @@ bool DeformableConv2DRel(const Array<Type>& types, int num_inputs, const Attrs&
}

template <typename AttrType>
Array<Array<Layout> > DeformableConvInferCorrectLayout(
InferCorrectLayoutOutput DeformableConvInferCorrectLayout(
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const AttrType* params = attrs.as<AttrType>();

// Layout of {data, offet, kernel}, {out}
return Array<Array<Layout> >{
return InferCorrectLayoutOutput(
{params->data_layout, params->data_layout, params->kernel_layout},
{params->out_layout == "" ? params->data_layout : params->out_layout}};
{params->out_layout == "" ? params->data_layout : params->out_layout}, attrs);
}

template <typename T>
Array<Array<Layout> > ConvInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
InferCorrectLayoutOutput ConvInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const T* params = attrs.as<T>();

// We always make other operators to fit the layouts of convolution layers
// So this inference ignores all inputs
return Array<Array<Layout> >{
return InferCorrectLayoutOutput(
{params->data_layout, params->kernel_layout},
{params->out_layout == "" ? params->data_layout : params->out_layout}};
{params->out_layout == "" ? params->data_layout : params->out_layout}, attrs);
}

} // namespace relay
Expand Down
9 changes: 4 additions & 5 deletions src/relay/op/nn/correlation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,12 @@ namespace relay {
// relay.nn.correlation
TVM_REGISTER_NODE_TYPE(CorrelationAttrs);

Array<Array<Layout>> CorrelationInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
InferCorrectLayoutOutput CorrelationInferCorrectLayout(
const Attrs& attrs, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* params = attrs.as<CorrelationAttrs>();
Layout layout{params->layout};
return Array<Array<Layout>>{{layout, layout}, {layout}};
return InferCorrectLayoutOutput({layout, layout}, {layout}, attrs);
}

// Positional relay function to create correlation operator
Expand Down
27 changes: 14 additions & 13 deletions src/relay/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ bool PReluRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
}

template <typename T>
Array<Array<Layout>> PReluInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
InferCorrectLayoutOutput PReluInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
ICHECK_EQ(old_in_layouts.size(), 2U);
ICHECK_EQ(old_in_types.size(), 2U);
Layout data_layout = old_in_layouts[0];
if (new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 2U);
}
return Array<Array<Layout>>{{data_layout, Layout("C")}, {data_layout}};
return InferCorrectLayoutOutput({data_layout, Layout("C")}, {data_layout}, attrs);
}

// Positional relay function to create prelu operator used by frontend FFI.
Expand Down Expand Up @@ -598,11 +598,13 @@ The whole array is rescaled by ``1/(1-p)`` to keep the expected sum of the input
// batch_norm
TVM_REGISTER_NODE_TYPE(BatchNormAttrs);

Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
BatchNormAttrs* param = const_cast<BatchNormAttrs*>(attrs.as<BatchNormAttrs>());
InferCorrectLayoutOutput BatchNormInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<BatchNormAttrs>();
ICHECK(attrs_ptr);
ObjectPtr<BatchNormAttrs> param = make_object<BatchNormAttrs>(*attrs_ptr);

Array<Array<IndexExpr>> old_in_shapes;
for (auto old_in_t : old_in_types) {
Expand All @@ -627,9 +629,8 @@ Array<Array<Layout>> BatchNormInferCorrectLayout(const Attrs& attrs,
}
// BN has 5 inputs, 3 outputs. The last 4 inputs and last 2 outputs have "C" layout.
Layout c_layout = Layout("C");

return Array<Array<Layout>>{{ret, c_layout, c_layout, c_layout, c_layout},
{ret, c_layout, c_layout}};
return InferCorrectLayoutOutput({ret, c_layout, c_layout, c_layout, c_layout},
{ret, c_layout, c_layout}, Attrs(param));
}

bool BatchNormRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
14 changes: 8 additions & 6 deletions src/relay/op/nn/pad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ namespace relay {
// relay.nn.pad
TVM_REGISTER_NODE_TYPE(PadAttrs);

Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
PadAttrs* params = const_cast<PadAttrs*>(attrs.as<PadAttrs>());
InferCorrectLayoutOutput PadInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<PadAttrs>();
CHECK(attrs_ptr);
ObjectPtr<PadAttrs> params = make_object<PadAttrs>(*attrs_ptr);

Layout ret_data;
// If new_in_layouts are defined, this code tries to modify the layout.
Expand Down Expand Up @@ -112,7 +114,7 @@ Array<Array<Layout>> PadInferCorrectLayout(const Attrs& attrs, const Array<Layou

// The pad value is always a scalar
Layout ret_pad_value = Layout("1");
return Array<Array<Layout>>{{ret_data, ret_pad_value}, {ret_data}};
return InferCorrectLayoutOutput({ret_data, ret_pad_value}, {ret_data}, Attrs(params));
}

bool PadRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
Expand Down
16 changes: 8 additions & 8 deletions src/relay/op/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ TVM_REGISTER_NODE_TYPE(MaxPool2DAttrs);
TVM_REGISTER_NODE_TYPE(AvgPool2DAttrs);

template <typename T>
Array<Array<Layout> > PoolInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
InferCorrectLayoutOutput PoolInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

if (new_in_layouts.defined()) {
// Set the pool with the new layout.
ICHECK_EQ(new_in_layouts.size(), 1);
params->layout = new_in_layouts[0].name();
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params));
}

IndexExpr calculate_pool_dimension(IndexExpr in_dimension, IndexExpr pad_amount,
Expand Down
16 changes: 8 additions & 8 deletions src/relay/op/nn/upsampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,13 @@ namespace tvm {
namespace relay {

template <typename T>
Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
T* params = const_cast<T*>(attrs.as<T>());
InferCorrectLayoutOutput UpsamplingInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<T>();
ICHECK(attrs_ptr);
ObjectPtr<T> params = make_object<T>(*attrs_ptr);

if (new_in_layouts.defined()) {
ICHECK_EQ(new_in_layouts.size(), 1);
Expand All @@ -57,8 +58,7 @@ Array<Array<Layout> > UpsamplingInferCorrectLayout(const Attrs& attrs,
}
}

Layout inferred_layout(params->layout);
return Array<Array<Layout> >{{inferred_layout}, {inferred_layout}};
return InferCorrectLayoutOutput({params->layout}, {params->layout}, Attrs(params));
}

} // namespace relay
Expand Down
15 changes: 8 additions & 7 deletions src/relay/op/tensor/reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,13 @@ Array<Integer> GetExcludeAxes(size_t indim, const Array<Integer>& inaxis) {
}

// Return the modified layout for AlterOpLayout pass.
Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
// NOTE: Discard "const" qualifier here.
ReduceAttrs* params = const_cast<ReduceAttrs*>(attrs.as<ReduceAttrs>());
InferCorrectLayoutOutput ReduceInferCorrectLayout(const Attrs& attrs,
const Array<Layout>& new_in_layouts,
const Array<Layout>& old_in_layouts,
const Array<tvm::relay::Type>& old_in_types) {
const auto* attrs_ptr = attrs.as<ReduceAttrs>();
ICHECK(attrs_ptr);
ObjectPtr<ReduceAttrs> params = make_object<ReduceAttrs>(*attrs_ptr);

// Get the reduce axes.
Array<Array<IndexExpr>> old_in_shapes;
Expand Down Expand Up @@ -188,7 +189,7 @@ Array<Array<Layout>> ReduceInferCorrectLayout(const Attrs& attrs,
}
}

return Array<Array<Layout>>{{inferred_in}, {inferred_out}};
return InferCorrectLayoutOutput({inferred_in}, {inferred_out}, Attrs(params));
}

template <typename F>
Expand Down
Loading

0 comments on commit c25b8fa

Please sign in to comment.