Skip to content

Commit

Permalink
[REFACTOR][PY] tvm._ffi (apache#4813)
Browse files Browse the repository at this point in the history
* [REFACTOR][PY] tvm._ffi

- Remove from __future__ import absolute_import in the related files as they are no longer needed if the code only runs in python3
- Remove reverse dependency of _ctypes _cython to object_generic.
- function.py -> packed_func.py
- Function -> PackedFunc
- all registry related logics goes to tvm._ffi.registry
- Use absolute references for FFI related calls.
  - tvm._ffi.register_object
  - tvm._ffi.register_func
  - tvm._ffi.get_global_func

* Move get global func to the ffi side
  • Loading branch information
tqchen authored Feb 5, 2020
1 parent 4a39e52 commit f9b46c4
Show file tree
Hide file tree
Showing 81 changed files with 705 additions and 740 deletions.
17 changes: 7 additions & 10 deletions python/tvm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
# under the License.
# pylint: disable=redefined-builtin, wildcard-import
"""TVM: Low level DSL/IR stack for tensor computation."""
from __future__ import absolute_import as _abs

import multiprocessing
import sys
import traceback

from . import _pyversion
# import ffi related features
from ._ffi.base import TVMError, __version__
from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.packed_func import PackedFunc as Function
from ._ffi.registry import register_object, register_func, register_extension
from ._ffi.object import Object

from . import tensor
from . import arith
Expand All @@ -34,7 +38,6 @@
from . import container
from . import schedule
from . import module
from . import object
from . import attrs
from . import ir_builder
from . import target
Expand All @@ -48,15 +51,9 @@
from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl
from .ndarray import vpi, rocm, opengl, ext_dev, micro_dev

from ._ffi.runtime_ctypes import TypeCode, TVMType
from ._ffi.ndarray import TVMContext
from ._ffi.function import Function
from ._ffi.base import TVMError, __version__
from .api import *
from .intrin import *
from .tensor_intrin import decl_tensor_intrin
from .object import register_object
from .ndarray import register_extension
from .schedule import create_schedule
from .build_module import build, lower, build_config
from .tag import tag_scope
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/_ffi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,7 @@
Some performance critical functions are implemented by cython
and have a ctypes fallback implementation.
"""
from . import _pyversion
from .base import register_error
from .registry import register_object, register_func, register_extension
from .registry import _init_api, get_global_func
2 changes: 0 additions & 2 deletions python/tvm/_ffi/_ctypes/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Runtime NDArray api"""
from __future__ import absolute_import

import ctypes
from ..base import _LIB, check_call, c_str
from ..runtime_ctypes import TVMArrayHandle
Expand Down
2 changes: 0 additions & 2 deletions python/tvm/_ffi/_ctypes/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
# pylint: disable=invalid-name
"""Runtime Object api"""
from __future__ import absolute_import

import ctypes
from ..base import _LIB, check_call
from .types import TypeCode, RETURN_SWITCH, C_TO_PY_ARG_SWITCH, _wrap_arg_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,12 @@
# coding: utf-8
# pylint: disable=invalid-name, protected-access, too-many-branches, global-statement, unused-import
"""Function configuration API."""
from __future__ import absolute_import

import ctypes
import traceback
from numbers import Number, Integral

from ..base import _LIB, get_last_ffi_error, py2cerror
from ..base import _LIB, get_last_ffi_error, py2cerror, check_call
from ..base import c_str, string_types
from ..object_generic import convert_to_object, ObjectGeneric
from ..runtime_ctypes import TVMType, TVMByteArray, TVMContext
from . import ndarray as _nd
from .ndarray import NDArrayBase, _make_array
Expand All @@ -35,7 +32,7 @@
from .object import ObjectBase, _set_class_object
from . import object as _object

FunctionHandle = ctypes.c_void_p
PackedFuncHandle = ctypes.c_void_p
ModuleHandle = ctypes.c_void_p
ObjectHandle = ctypes.c_void_p
TVMRetValueHandle = ctypes.c_void_p
Expand All @@ -49,6 +46,15 @@ def _ctypes_free_resource(rhandle):
TVM_FREE_PYOBJ = TVMCFuncFinalizer(_ctypes_free_resource)
ctypes.pythonapi.Py_IncRef(ctypes.py_object(TVM_FREE_PYOBJ))


def _make_packed_func(handle, is_global):
"""Make a packed function class"""
obj = _CLASS_PACKED_FUNC.__new__(_CLASS_PACKED_FUNC)
obj.is_global = is_global
obj.handle = handle
return obj


def convert_to_tvm_func(pyfunc):
"""Convert a python function to TVM function
Expand Down Expand Up @@ -89,7 +95,7 @@ def cfun(args, type_codes, num_args, ret, _):
_ = rv
return 0

handle = FunctionHandle()
handle = PackedFuncHandle()
f = TVMPackedCFunc(cfun)
# NOTE: We will need to use python-api to increase ref count of the f
# TVM_FREE_PYOBJ will be called after it is no longer needed.
Expand All @@ -98,7 +104,7 @@ def cfun(args, type_codes, num_args, ret, _):
if _LIB.TVMFuncCreateFromCFunc(
f, pyobj, TVM_FREE_PYOBJ, ctypes.byref(handle)) != 0:
raise get_last_ffi_error()
return _CLASS_FUNCTION(handle, False)
return _make_packed_func(handle, False)


def _make_tvm_args(args, temp_args):
Expand Down Expand Up @@ -144,15 +150,15 @@ def _make_tvm_args(args, temp_args):
elif isinstance(arg, string_types):
values[i].v_str = c_str(arg)
type_codes[i] = TypeCode.STR
elif isinstance(arg, (list, tuple, dict, ObjectGeneric)):
arg = convert_to_object(arg)
elif isinstance(arg, (list, tuple, dict, _CLASS_OBJECT_GENERIC)):
arg = _FUNC_CONVERT_TO_OBJECT(arg)
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.OBJECT_HANDLE
temp_args.append(arg)
elif isinstance(arg, _CLASS_MODULE):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.MODULE_HANDLE
elif isinstance(arg, FunctionBase):
elif isinstance(arg, PackedFuncBase):
values[i].v_handle = arg.handle
type_codes[i] = TypeCode.PACKED_FUNC_HANDLE
elif isinstance(arg, ctypes.c_void_p):
Expand All @@ -168,7 +174,7 @@ def _make_tvm_args(args, temp_args):
return values, type_codes, num_args


class FunctionBase(object):
class PackedFuncBase(object):
"""Function base."""
__slots__ = ["handle", "is_global"]
# pylint: disable=no-member
Expand All @@ -177,7 +183,7 @@ def __init__(self, handle, is_global):
Parameters
----------
handle : FunctionHandle
handle : PackedFuncHandle
the handle to the underlying function.
is_global : bool
Expand Down Expand Up @@ -238,9 +244,22 @@ def _return_module(x):
def _handle_return_func(x):
"""Return function"""
handle = x.v_handle
if not isinstance(handle, FunctionHandle):
handle = FunctionHandle(handle)
return _CLASS_FUNCTION(handle, False)
if not isinstance(handle, PackedFuncHandle):
handle = PackedFuncHandle(handle)
return _CLASS_PACKED_FUNC(handle, False)


def _get_global_func(name, allow_missing=False):
handle = PackedFuncHandle()
check_call(_LIB.TVMFuncGetGlobal(c_str(name), ctypes.byref(handle)))

if handle.value:
return _make_packed_func(handle, False)

if allow_missing:
return None

raise ValueError("Cannot find global function %s" % name)

# setup return handle for function type
_object.__init_by_constructor__ = __init_handle_by_constructor__
Expand All @@ -255,13 +274,22 @@ def _handle_return_func(x):
C_TO_PY_ARG_SWITCH[TypeCode.NDARRAY_HANDLE] = lambda x: _make_array(x.v_handle, False, True)

_CLASS_MODULE = None
_CLASS_FUNCTION = None
_CLASS_PACKED_FUNC = None
_CLASS_OBJECT_GENERIC = None
_FUNC_CONVERT_TO_OBJECT = None


def _set_class_module(module_class):
"""Initialize the module."""
global _CLASS_MODULE
_CLASS_MODULE = module_class

def _set_class_function(func_class):
global _CLASS_FUNCTION
_CLASS_FUNCTION = func_class
def _set_class_packed_func(packed_func_class):
global _CLASS_PACKED_FUNC
_CLASS_PACKED_FUNC = packed_func_class

def _set_class_object_generic(object_generic_class, func_convert_to_object):
global _CLASS_OBJECT_GENERIC
global _FUNC_CONVERT_TO_OBJECT
_CLASS_OBJECT_GENERIC = object_generic_class
_FUNC_CONVERT_TO_OBJECT = func_convert_to_object
2 changes: 0 additions & 2 deletions python/tvm/_ffi/_ctypes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
# under the License.
"""The C Types used in API."""
# pylint: disable=invalid-name
from __future__ import absolute_import as _abs

import ctypes
import struct
from ..base import py_str, check_call, _LIB
Expand Down
10 changes: 6 additions & 4 deletions python/tvm/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ ctypedef int64_t tvm_index_t
ctypedef DLTensor* DLTensorHandle
ctypedef void* TVMStreamHandle
ctypedef void* TVMRetValueHandle
ctypedef void* TVMFunctionHandle
ctypedef void* TVMPackedFuncHandle
ctypedef void* ObjectHandle

ctypedef struct TVMObject:
Expand All @@ -96,21 +96,23 @@ ctypedef void (*TVMPackedCFuncFinalizer)(void* resource_handle)
cdef extern from "tvm/runtime/c_runtime_api.h":
void TVMAPISetLastError(const char* msg)
const char *TVMGetLastError()
int TVMFuncCall(TVMFunctionHandle func,
int TVMFuncGetGlobal(const char* name,
TVMPackedFuncHandle* out);
int TVMFuncCall(TVMPackedFuncHandle func,
TVMValue* arg_values,
int* type_codes,
int num_args,
TVMValue* ret_val,
int* ret_type_code)
int TVMFuncFree(TVMFunctionHandle func)
int TVMFuncFree(TVMPackedFuncHandle func)
int TVMCFuncSetReturn(TVMRetValueHandle ret,
TVMValue* value,
int* type_code,
int num_ret)
int TVMFuncCreateFromCFunc(TVMPackedCFunc func,
void* resource_handle,
TVMPackedCFuncFinalizer fin,
TVMFunctionHandle *out)
TVMPackedFuncHandle *out)
int TVMCbArgToReturn(TVMValue* value, int code)
int TVMArrayAlloc(tvm_index_t* shape,
tvm_index_t ndim,
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/_ffi/_cython/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,5 @@

include "./base.pxi"
include "./object.pxi"
# include "./node.pxi"
include "./function.pxi"
include "./packed_func.pxi"
include "./ndarray.pxi"

2 changes: 1 addition & 1 deletion python/tvm/_ffi/_cython/object.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,6 @@ cdef class ObjectBase:
self.chandle = NULL
cdef void* chandle
ConstructorCall(
(<FunctionBase>fconstructor).chandle,
(<PackedFuncBase>fconstructor).chandle,
kTVMObjectHandle, args, &chandle)
self.chandle = chandle
Loading

0 comments on commit f9b46c4

Please sign in to comment.