Skip to content

Commit

Permalink
[REFACTOR][PY] relay.op.Op -> tvm.ir.Op (apache#5705)
Browse files Browse the repository at this point in the history
* [REFACTOR][PY] relay.op.Op -> tvm.ir.Op

* Improve the error check
  • Loading branch information
tqchen authored Jun 1, 2020
1 parent f280883 commit afc239a
Showing 31 changed files with 215 additions and 198 deletions.
6 changes: 3 additions & 3 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
@@ -121,7 +121,7 @@ class OpNode : public RelayExprNode {
return is_primitive_ != 0;
}

static constexpr const char* _type_key = "relay.Op";
static constexpr const char* _type_key = "Op";
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);

private:
@@ -180,7 +180,7 @@ class Op : public RelayExpr {
* \tparam ValueType The type of the attribute.
*/
template <typename ValueType>
inline static OpAttrMap<ValueType> GetAttrMap(const std::string& attr_name);
inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
/*!
* \brief Checks if an attr map is present in the registry.
* \param attr_name The name of the attribute.
@@ -374,7 +374,7 @@ class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }

template <typename ValueType>
inline OpAttrMap<ValueType> Op::GetAttrMap(const std::string& key) {
inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) {
return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
}

2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
@@ -81,7 +81,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
target_ops : List of relay.op.Op
target_ops : List of tvm.ir.Op
Target tuning operators.
target : str or tvm.target
4 changes: 2 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
@@ -38,7 +38,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
expr : tvm.relay.Expr.Function
Input relay function expression.
target_ops: List of relay.op.Op
target_ops: List of tvm.ir.Op
List of target relay ops
node_dict : dictionary from tvm.relay.Expr to int
@@ -157,7 +157,7 @@ def _traverse_expr(node):
elif isinstance(node, Constant):
node_entry["name"] = "Constant_" + str(node_index)
node_entry["types"] = [node.checked_type]
elif isinstance(node, relay.op.op.Op):
elif isinstance(node, tvm.ir.Op):
return
else:
raise RuntimeError("Not supported relay node type in graph tuning: %s"
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
@@ -78,7 +78,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
The compilation target
target_host: tvm.target.Target
The host compilation target
ops: List[relay.op.Op] or None
ops: List[tvm.ir.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
@@ -105,7 +105,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
The compilation target
target_host: tvm.target.Target
The host compilation target
ops: List[relay.op.Op] or None
ops: List[tvm.ir.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/topi_integration.py
Original file line number Diff line number Diff line change
@@ -61,7 +61,7 @@ def reset(self, wanted_relay_ops=None):
Parameters
----------
wanted_relay_ops: List of relay.op.Op
wanted_relay_ops: List of tvm.ir.Op
The relay ops to be extracted
"""
self.task_collection = []
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Original file line number Diff line number Diff line change
@@ -23,6 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op_attr
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
2 changes: 1 addition & 1 deletion python/tvm/ir/json_compact.py
Original file line number Diff line number Diff line change
@@ -109,7 +109,7 @@ def _convert(item, nodes):
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.Op": [_update_global_key, _rename("Op")],
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Id": [_update_from_std_str("name_hint")],
"relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
114 changes: 114 additions & 0 deletions python/tvm/ir/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Primitive operators in the TVM IR."""
import tvm._ffi
from . expr import RelayExpr
from . import _ffi_api


@tvm._ffi.register_object("Op")
class Op(RelayExpr):
"""Primitive operator in the IR."""
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")

@staticmethod
def get(op_name):
"""Get the Op for a given name
Parameters
----------
op_name : str
The operator name
Returns
-------
op : Op
The op of the corresponding name
"""
return _ffi_api.GetOp(op_name)

def get_attr(self, attr_name):
"""Get additional attribute about the operator.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _ffi_api.OpGetAttr(self, attr_name)

def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_ffi_api.OpSetAttr(self, attr_name, value, plevel)

def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_ffi_api.OpResetAttr(self, attr_name)


def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
Parameters
----------
op_name : str
The name of operator
attr_key : str
The attribute name.
value : object, optional
The value to set
level : int, optional
The priority level
Returns
-------
fregister : function
Register function if value is not specified.
"""
def _register(v):
"""internal register function"""
_ffi_api.RegisterOpAttr(op_name, attr_key, v, level)
return v
return _register(value) if value is not None else _register
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
@@ -40,7 +40,6 @@
from .backend import vm

# Root operators
from .op import Op
from .op import nn
from .op import image
from .op import annotation
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Original file line number Diff line number Diff line change
@@ -378,7 +378,7 @@ def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]:
return self.module

# Exprs
def visitOpIdent(self, ctx) -> op.Op:
def visitOpIdent(self, ctx) -> tvm.ir.Op:
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
4 changes: 2 additions & 2 deletions python/tvm/relay/analysis/annotated_regions.py
Original file line number Diff line number Diff line change
@@ -31,9 +31,9 @@ def __init__(self, expr, region_begin_op, region_end_op):
----------
expr : tvm.relay.Expr
The expression from which to construct the regions.
region_begin_op : tvm.relay.Op
region_begin_op : tvm.ir.Op
The region begin annotation.
region_end_op : tvm.relay.Op
region_end_op : tvm.ir.Op
The region end annotation.
"""
7 changes: 3 additions & 4 deletions python/tvm/relay/backend/compile_engine.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,6 @@
from ... import target as _target
from ... import autotvm
from .. import function as _function
from .. import op as _op
from .. import ty as _ty
from . import _backend

@@ -98,7 +97,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
Parameters
----------
op : relay.op.Op
op : tvm.ir.Op
Relay operator.
attrs : object
@@ -157,7 +156,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
Parameters
----------
op : relay.op.Op
op : tvm.ir.Op
Relay operator.
attrs : object
@@ -215,7 +214,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
assert isinstance(call.op, _op.Op)
assert isinstance(call.op, tvm.ir.Op)
op = call.op

# Prepare the call_node->checked_type(). For the call node inputs, we ensure that
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Original file line number Diff line number Diff line change
@@ -234,7 +234,7 @@ class Call(ExprWithOp):
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
op: tvm.ir.Op or any tvm.relay.Expr with function type.
The operation to be called.
args: List[tvm.relay.Expr]
2 changes: 1 addition & 1 deletion python/tvm/relay/expr_functor.py
Original file line number Diff line number Diff line change
@@ -16,13 +16,13 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from tvm.ir import Op

from .function import Function
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
from .op import Op

class ExprFunctor:
"""
4 changes: 2 additions & 2 deletions python/tvm/relay/op/__init__.py
Original file line number Diff line number Diff line change
@@ -17,9 +17,9 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, register_compute, register_gradient, \
from .op import get, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
Op, OpPattern, OpStrategy, debug, register_external_compiler
OpPattern, OpStrategy, debug, register_external_compiler
from . import strategy

# Operators
4 changes: 2 additions & 2 deletions python/tvm/relay/op/contrib/dnnl.py
Original file line number Diff line number Diff line change
@@ -32,7 +32,7 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
from ... import op as _op
import tvm.ir
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

@@ -51,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True):
f : callable
A function that returns if the operator is supported by DNNL.
"""
@_op.register(op_name, "target.dnnl")
@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
return supported

Loading

0 comments on commit afc239a

Please sign in to comment.