Skip to content

Commit

Permalink
[REFACTOR] Separate ArgTypeCode from DLDataTypeCode (apache#5730)
Browse files Browse the repository at this point in the history
We use a single enum(TypeCode) to represent ArgTypeCode and DLDataTypeCode.
However, as we start to expand more data types, it is clear that argument
type code(in the FFI convention) and data type code needs to evolve separately.
So that we can add first class for data types without having changing the FFI ABI.

This PR makes the distinction clear and refactored the code to separate the two.

- [PY] Separate ArgTypeCode from DataTypeCode
- [WEB] Separate ArgTypeCode from DataTypeCode
- [JAVA] Separate ArgTypeCode from DataTypeCode
  • Loading branch information
tqchen authored Jun 4, 2020
1 parent 34c95a8 commit 8a98782
Show file tree
Hide file tree
Showing 38 changed files with 284 additions and 236 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/dlpack
8 changes: 4 additions & 4 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,14 @@ typedef enum {
} TVMDeviceExtType;

/*!
* \brief The type code in used in the TVM FFI.
* \brief The type code in used in the TVM FFI for argument passing.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kTVMArgInt = kDLInt,
kTVMArgFloat = kDLFloat,
kTVMOpaqueHandle = 3U,
kTVMNullptr = 4U,
kTVMDataType = 5U,
Expand All @@ -115,9 +117,7 @@ typedef enum {
// The following section of code is used for non-reserved types.
kTVMExtReserveEnd = 64U,
kTVMExtEnd = 128U,
// The rest of the space is used for custom, user-supplied datatypes
kTVMCustomBegin = 129U,
} TVMTypeCode;
} TVMArgTypeCode;

/*!
* \brief The Device information, abstract away common device types.
Expand Down
37 changes: 8 additions & 29 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class DataType {
kInt = kDLInt,
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMTypeCode::kTVMOpaqueHandle,
kHandle = TVMArgTypeCode::kTVMOpaqueHandle,
kCustomBegin = 129
};
/*! \brief default constructor */
DataType() {}
Expand Down Expand Up @@ -248,7 +249,7 @@ TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan);
* \param type_code The type code .
* \return The name of type code.
*/
inline const char* TypeCode2Str(int type_code);
inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code);

/*!
* \brief convert a string to TVM type.
Expand All @@ -265,38 +266,16 @@ inline DLDataType String2DLDataType(std::string s);
inline std::string DLDataType2String(DLDataType t);

// implementation details
inline const char* TypeCode2Str(int type_code) {
switch (type_code) {
inline const char* DLDataTypeCode2Str(DLDataTypeCode type_code) {
switch (static_cast<int>(type_code)) {
case kDLInt:
return "int";
case kDLUInt:
return "uint";
case kDLFloat:
return "float";
case kTVMStr:
return "str";
case kTVMBytes:
return "bytes";
case kTVMOpaqueHandle:
case DataType::kHandle:
return "handle";
case kTVMNullptr:
return "NULL";
case kTVMDLTensorHandle:
return "ArrayHandle";
case kTVMDataType:
return "DLDataType";
case kTVMContext:
return "TVMContext";
case kTVMPackedFuncHandle:
return "FunctionHandle";
case kTVMModuleHandle:
return "ModuleHandle";
case kTVMNDArrayHandle:
return "NDArrayContainer";
case kTVMObjectHandle:
return "Object";
case kTVMObjectRValueRefArg:
return "ObjectRValueRefArg";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
Expand All @@ -311,8 +290,8 @@ inline std::ostream& operator<<(std::ostream& os, DLDataType t) { // NOLINT(*)
if (DataType(t).is_void()) {
return os << "void";
}
if (t.code < kTVMCustomBegin) {
os << TypeCode2Str(t.code);
if (t.code < DataType::kCustomBegin) {
os << DLDataTypeCode2Str(static_cast<DLDataTypeCode>(t.code));
} else {
os << "custom[" << GetCustomTypeName(t.code) << "]";
}
Expand Down
49 changes: 47 additions & 2 deletions include/tvm/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,16 @@ class TVMArgs {
inline TVMArgValue operator[](int i) const;
};

/*!
* \brief Convert argument type code to string.
* \param type_code The input type code.
* \return The corresponding string repr.
*/
inline const char* ArgTypeCode2Str(int type_code);

// macro to check type code.
#define TVM_CHECK_TYPE_CODE(CODE, T) \
CHECK_EQ(CODE, T) << " expected " << TypeCode2Str(T) << " but get " << TypeCode2Str(CODE)
CHECK_EQ(CODE, T) << " expected " << ArgTypeCode2Str(T) << " but get " << ArgTypeCode2Str(CODE)

/*!
* \brief Type traits for runtime type check during FFI conversion.
Expand Down Expand Up @@ -394,7 +401,7 @@ class TVMPODValue_ {
} else {
if (type_code_ == kTVMNullptr) return nullptr;
LOG(FATAL) << "Expect "
<< "DLTensor* or NDArray but get " << TypeCode2Str(type_code_);
<< "DLTensor* or NDArray but get " << ArgTypeCode2Str(type_code_);
return nullptr;
}
}
Expand Down Expand Up @@ -982,6 +989,44 @@ inline void PackedFunc::CallPacked(TVMArgs args, TVMRetValue* rv) const { body_(
inline PackedFunc::FType PackedFunc::body() const { return body_; }

// internal namespace
inline const char* ArgTypeCode2Str(int type_code) {
switch (type_code) {
case kDLInt:
return "int";
case kDLUInt:
return "uint";
case kDLFloat:
return "float";
case kTVMStr:
return "str";
case kTVMBytes:
return "bytes";
case kTVMOpaqueHandle:
return "handle";
case kTVMNullptr:
return "NULL";
case kTVMDLTensorHandle:
return "ArrayHandle";
case kTVMDataType:
return "DLDataType";
case kTVMContext:
return "TVMContext";
case kTVMPackedFuncHandle:
return "FunctionHandle";
case kTVMModuleHandle:
return "ModuleHandle";
case kTVMNDArrayHandle:
return "NDArrayContainer";
case kTVMObjectHandle:
return "Object";
case kTVMObjectRValueRefArg:
return "ObjectRValueRefArg";
default:
LOG(FATAL) << "unknown type_code=" << static_cast<int>(type_code);
return "";
}
}

namespace detail {

template <bool stop, std::size_t I, typename F>
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/tir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(DataType::kCustomBegin)) {
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
package org.apache.tvm;

// Type code used in API calls
public enum TypeCode {
public enum ArgTypeCode {
INT(0), UINT(1), FLOAT(2), HANDLE(3), NULL(4), TVM_TYPE(5),
TVM_CONTEXT(6), ARRAY_HANDLE(7), NODE_HANDLE(8), MODULE_HANDLE(9),
FUNC_HANDLE(10), STR(11), BYTES(12), NDARRAY_CONTAINER(13);

public final int id;

private TypeCode(int id) {
private ArgTypeCode(int id) {
this.id = id;
}

Expand Down
14 changes: 7 additions & 7 deletions jvm/core/src/main/java/org/apache/tvm/Function.java
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ private static Function getGlobalFunc(String name, boolean isResident, boolean a
* @param isResident Whether this is a resident function in jvm
*/
Function(long handle, boolean isResident) {
super(TypeCode.FUNC_HANDLE);
super(ArgTypeCode.FUNC_HANDLE);
this.handle = handle;
this.isResident = isResident;
}
Expand Down Expand Up @@ -187,7 +187,7 @@ public Function pushArg(String arg) {
* @return this
*/
public Function pushArg(NDArrayBase arg) {
int id = arg.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
int id = arg.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(arg.handle, id);
return this;
}
Expand All @@ -198,7 +198,7 @@ public Function pushArg(NDArrayBase arg) {
* @return this
*/
public Function pushArg(Module arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.MODULE_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.MODULE_HANDLE.id);
return this;
}

Expand All @@ -208,7 +208,7 @@ public Function pushArg(Module arg) {
* @return this
*/
public Function pushArg(Function arg) {
Base._LIB.tvmFuncPushArgHandle(arg.handle, TypeCode.FUNC_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(arg.handle, ArgTypeCode.FUNC_HANDLE.id);
return this;
}

Expand Down Expand Up @@ -249,12 +249,12 @@ private static void pushArgToStack(Object arg) {
Base._LIB.tvmFuncPushArgBytes((byte[]) arg);
} else if (arg instanceof NDArrayBase) {
NDArrayBase nd = (NDArrayBase) arg;
int id = nd.isView ? TypeCode.ARRAY_HANDLE.id : TypeCode.NDARRAY_CONTAINER.id;
int id = nd.isView ? ArgTypeCode.ARRAY_HANDLE.id : ArgTypeCode.NDARRAY_CONTAINER.id;
Base._LIB.tvmFuncPushArgHandle(nd.handle, id);
} else if (arg instanceof Module) {
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, TypeCode.MODULE_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(((Module) arg).handle, ArgTypeCode.MODULE_HANDLE.id);
} else if (arg instanceof Function) {
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, TypeCode.FUNC_HANDLE.id);
Base._LIB.tvmFuncPushArgHandle(((Function) arg).handle, ArgTypeCode.FUNC_HANDLE.id);
} else if (arg instanceof TVMValue) {
TVMValue tvmArg = (TVMValue) arg;
switch (tvmArg.typeCode) {
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/Module.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ private static Function getApi(String name) {
}

Module(long handle) {
super(TypeCode.MODULE_HANDLE);
super(ArgTypeCode.MODULE_HANDLE);
this.handle = handle;
}

Expand Down Expand Up @@ -138,7 +138,7 @@ public String typeKey() {
*/
public static Module load(String path, String fmt) {
TVMValue ret = getApi("ModuleLoadFromFile").pushArg(path).pushArg(fmt).invoke();
assert ret.typeCode == TypeCode.MODULE_HANDLE;
assert ret.typeCode == ArgTypeCode.MODULE_HANDLE;
return ret.asModule();
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/NDArrayBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class NDArrayBase extends TVMValue {
private boolean isReleased = false;

NDArrayBase(long handle, boolean isView) {
super(TypeCode.ARRAY_HANDLE);
super(ArgTypeCode.ARRAY_HANDLE);
this.handle = handle;
this.isView = isView;
}
Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/TVMValue.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@
package org.apache.tvm;

public class TVMValue {
public final TypeCode typeCode;
public final ArgTypeCode typeCode;

public TVMValue(TypeCode tc) {
public TVMValue(ArgTypeCode tc) {
typeCode = tc;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueBytes.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class TVMValueBytes extends TVMValue {
public final byte[] value;

public TVMValueBytes(byte[] value) {
super(TypeCode.BYTES);
super(ArgTypeCode.BYTES);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueDouble.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class TVMValueDouble extends TVMValue {
public final double value;

public TVMValueDouble(double value) {
super(TypeCode.FLOAT);
super(ArgTypeCode.FLOAT);
this.value = value;
}

Expand Down
4 changes: 2 additions & 2 deletions jvm/core/src/main/java/org/apache/tvm/TVMValueHandle.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
package org.apache.tvm;

/**
* Java class related to TVM handles (TypeCode.HANDLE)
* Java class related to TVM handles (ArgTypeCode.HANDLE)
*/
public class TVMValueHandle extends TVMValue {
public final long value;

public TVMValueHandle(long value) {
super(TypeCode.HANDLE);
super(ArgTypeCode.HANDLE);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueLong.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class TVMValueLong extends TVMValue {
public final long value;

public TVMValueLong(long value) {
super(TypeCode.INT);
super(ArgTypeCode.INT);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueNull.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@

public class TVMValueNull extends TVMValue {
public TVMValueNull() {
super(TypeCode.NULL);
super(ArgTypeCode.NULL);
}
}
2 changes: 1 addition & 1 deletion jvm/core/src/main/java/org/apache/tvm/TVMValueString.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ public class TVMValueString extends TVMValue {
public final String value;

public TVMValueString(String value) {
super(TypeCode.STR);
super(ArgTypeCode.STR);
this.value = value;
}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# top-level alias
# tvm._ffi
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, DataType
from ._ffi.runtime_ctypes import DataTypeCode, DataType
from ._ffi import register_object, register_func, register_extension, get_global_func

# top-level alias
Expand Down
12 changes: 6 additions & 6 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""Runtime Object api"""
import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .types import ArgTypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
from .ndarray import _register_ndarray, NDArrayBase


Expand Down Expand Up @@ -60,12 +60,12 @@ def _return_object(x):
obj.handle = handle
return obj

RETURN_SWITCH[TypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_HANDLE)
RETURN_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _return_object
C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_HANDLE] = _wrap_arg_func(
_return_object, ArgTypeCode.OBJECT_HANDLE)

C_TO_PY_ARG_SWITCH[TypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, TypeCode.OBJECT_RVALUE_REF_ARG)
C_TO_PY_ARG_SWITCH[ArgTypeCode.OBJECT_RVALUE_REF_ARG] = _wrap_arg_func(
_return_object, ArgTypeCode.OBJECT_RVALUE_REF_ARG)


class PyNativeObject:
Expand Down
Loading

0 comments on commit 8a98782

Please sign in to comment.