Skip to content

Commit

Permalink
[RUNTIME][PASS] Allow declare vector type array (apache#302)
Browse files Browse the repository at this point in the history
* [RUNTIME][PASS] Allow declare vector type array

* fix bcast

* [BUFFER] Enable vload/store function in buffer

* ok
  • Loading branch information
tqchen authored Aug 8, 2017
1 parent 1e48b02 commit 1146495
Show file tree
Hide file tree
Showing 19 changed files with 347 additions and 73 deletions.
25 changes: 12 additions & 13 deletions include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,6 @@ class Buffer : public NodeRef {
public:
Buffer() {}
explicit Buffer(std::shared_ptr<Node> n) : NodeRef(n) {}
/*!
* \brief Generate a load expression loading the index location of buffer.
* \param index The index to the buffer.
* \return The load expression.
*/
Expr MakeLoad(Array<Expr> index) const;
/*!
* \brief Generate a store statement.
* \param index The index to the buffer.
* \param value The value to be stored.
* \return The load expression.
*/
Stmt MakeStore(Array<Expr> index, Expr value) const;
/*!
* \brief Return a new buffer that is equivalent with current one
* but always add stride field.
Expand All @@ -66,6 +53,18 @@ class Buffer : public NodeRef {
* \param ptr_type The type of the pointer.
*/
Expr access_ptr(int access_mask, Type ptr_type = Handle()) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
* \param dtype The data type to be loaded.
*/
Expr vload(Array<Expr> begin, Type dtype) const;
/*!
* \brief Create a Stmt that does a vector store at begin index.
* \param begin The beginning index
* \param value The value to be stored.
*/
Stmt vstore(Array<Expr> begin, Expr value) const;
/*!
* \brief access the internal node container
* \return the pointer to the internal node container
Expand Down
33 changes: 26 additions & 7 deletions python/tvm/_ffi/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,31 +142,44 @@ def __setitem__(self, in_slice, value):
if value.handle is not self.handle:
value.copyto(self)
elif isinstance(value, (np.ndarray, np.generic)):
self._sync_copyfrom(value)
self.copyfrom(value)
else:
raise TypeError('type %s not supported' % str(type(value)))

def _sync_copyfrom(self, source_array):
def copyfrom(self, source_array):
"""Peform an synchronize copy from the array.
Parameters
----------
source_array : array_like
The data source we should like to copy from.
Returns
-------
arr : NDArray
Reference to self.
"""
if not isinstance(source_array, np.ndarray):
try:
source_array = np.array(source_array, dtype=self.dtype)
except:
raise TypeError('array must be an array_like data,' +
'type %s is not supported' % str(type(source_array)))
source_array = np.ascontiguousarray(source_array, dtype=self.dtype)
if source_array.shape != self.shape:
raise ValueError('array shape do not match the shape of NDArray')
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
source_array = np.ascontiguousarray(source_array, dtype=dtype)
if source_array.shape != shape:
raise ValueError("array shape do not match the shape of NDArray {0} vs {1}".format(
source_array.shape, shape))
assert source_array.flags['C_CONTIGUOUS']
data = source_array.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(source_array.shape) * source_array.dtype.itemsize)
check_call(_LIB.TVMArrayCopyFromBytes(self.handle, data, nbytes))
return self

def asnumpy(self):
"""Convert this array to numpy array
Expand All @@ -176,7 +189,13 @@ def asnumpy(self):
np_arr : numpy.ndarray
The corresponding numpy array.
"""
np_arr = np.empty(self.shape, dtype=self.dtype)
t = TVMType(self.dtype)
shape, dtype = self.shape, self.dtype
if t.lanes > 1:
shape = shape + (t.lanes,)
t.lanes = 1
dtype = str(t)
np_arr = np.empty(shape, dtype=dtype)
assert np_arr.flags['C_CONTIGUOUS']
data = np_arr.ctypes.data_as(ctypes.c_void_p)
nbytes = ctypes.c_size_t(np.prod(np_arr.shape) * np_arr.dtype.itemsize)
Expand All @@ -188,7 +207,7 @@ def copyto(self, target):
Parameters
----------
target : tvm.NDArray
target : NDArray
The target array to be copied, must have same shape as this array.
"""
if isinstance(target, TVMContext):
Expand Down
26 changes: 16 additions & 10 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,36 @@ class TVMType(ctypes.Structure):
2 : 'float',
4 : 'handle'
}
def __init__(self, type_str, lanes=1):
def __init__(self, type_str):
super(TVMType, self).__init__()
if isinstance(type_str, np.dtype):
type_str = str(type_str)
if type_str.startswith("int"):
arr = type_str.split("x")
head = arr[0]
self.lanes = int(arr[1]) if len(arr) > 1 else 1
bits = 32

if head.startswith("int"):
self.type_code = 0
bits = int(type_str[3:])
elif type_str.startswith("uint"):
head = head[3:]
elif head.startswith("uint"):
self.type_code = 1
bits = int(type_str[4:])
elif type_str.startswith("float"):
head = head[4:]
elif head.startswith("float"):
self.type_code = 2
bits = int(type_str[5:])
elif type_str.startswith("handle"):
head = head[5:]
elif head.startswith("handle"):
self.type_code = 4
bits = 64
head = ""
else:
raise ValueError("Donot know how to handle type %s" % type_str)

bits = 32 if bits == 0 else bits
bits = int(head) if head else bits
if (bits & (bits - 1)) != 0 or bits < 8:
raise ValueError("Donot know how to handle type %s" % type_str)
self.bits = bits
self.lanes = lanes


def __repr__(self):
x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits)
Expand Down
17 changes: 1 addition & 16 deletions python/tvm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ._ffi.function import Function
from ._ffi.function import _init_api, register_func, get_global_func
from ._ffi.function import convert_to_tvm_func as _convert_tvm_func
from ._ffi.runtime_ctypes import TVMType
from . import _api_internal
from . import make as _make
from . import expr as _expr
Expand Down Expand Up @@ -546,22 +547,6 @@ def reduce_axis(dom, name="rv"):
"""
return _IterVar(dom, name, 2)

def cast(dtype, expr):
"""Cast an expression to other type
Parameters
----------
dtype : str, optional
The type of new expression
expr : Expr
The expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, expr)


def select(cond, t, f):
"""Construct a select branch
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,18 +97,19 @@ def equal(self, other):
return _make.EQ(self, other)

def astype(self, dtype):
"""Cast the expression to other type
"""Cast the expression to other type.
Parameters
----------
dtype : str, optional
dtype : str
The type of new expression
Returns
-------
expr : Expr
Expression with new type
"""
return _make.Cast(dtype, self)
return _make.static_cast(dtype, self)


class Expr(NodeBase, ExprOp):
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/ir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import container as _container
from ._ffi.base import string_types
from ._ffi.node import NodeGeneric
from ._ffi.runtime_ctypes import TVMType
from .expr import Call as _Call

class WithScope(object):
Expand Down Expand Up @@ -56,7 +57,14 @@ def __init__(self, builder, buffer_var, content_type):
def asnode(self):
return self._buffer_var

@property
def dtype(self):
return self._content_type

def __getitem__(self, index):
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
return _make.Load(self._content_type, self._buffer_var, index)

def __setitem__(self, index, value):
Expand All @@ -65,6 +73,9 @@ def __setitem__(self, index, value):
raise ValueError(
"data type does not match content type %s vs %s" % (
value.dtype, self._content_type))
t = TVMType(self._content_type)
if t.lanes > 1:
index = _make.Ramp(index * t.lanes, 1, t.lanes)
self._builder.emit(_make.Store(self._buffer_var, value, index))


Expand Down
29 changes: 29 additions & 0 deletions python/tvm/make.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
You can use make function to build the IR node.
"""
from ._ffi.function import _init_api
from ._ffi.runtime_ctypes import TVMType
from . import stmt as _stmt

def range_by_min_extent(min_value, extent):
Expand All @@ -30,6 +31,34 @@ def range_by_min_extent(min_value, extent):
return _range_by_min_extent(min_value, extent)


def static_cast(dtype, expr):
"""Cast expr to dtype.
If expr is scalar and dtype is a corresponding vector
type, a Broadcast is generated. Otherwise it is a Cast.
Parameters
----------
dtype : str
The target data type.
expr : Expr
The expression to be casted.
Returns
-------
casted : Expr
The casted expression.
"""
target_type = TVMType(dtype)
src_type = TVMType(expr.dtype)
if target_type.type_code == src_type.type_code\
and src_type.lanes == 1\
and target_type.lanes > 1:
return Broadcast(expr, target_type.lanes)
return Cast(dtype, expr)


def node(type_key, **kwargs):
"""Make a new DSL node by its type key and fields
Expand Down
4 changes: 1 addition & 3 deletions python/tvm/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ def array(arr, ctx=cpu(0)):
"""
if not isinstance(arr, _np.ndarray):
arr = _np.array(arr)
ret = empty(arr.shape, arr.dtype, ctx)
ret[:] = arr
return ret
return empty(arr.shape, arr.dtype, ctx).copyfrom(arr)

_set_class_ndarray(NDArray)
40 changes: 40 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,46 @@ def access_ptr(self, access_mask, ptr_type="handle"):
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type)

def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
dtype : str
The data type to be loaded,
can be vector type which have lanes that is multiple of Buffer.dtype
Returns
-------
load : Expr
The corresponding load expression.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
dtype = dtype if dtype else self.dtype
return _api_internal._BufferVLoad(self, begin, dtype)

def vstore(self, begin, value):
"""Generate a Stmt that store value into begin index.
Parameters
----------
begin : Array of Expr
The beginning index in unit of Buffer.dtype
value : Expr
The value to be stored.
Returns
-------
store : Stmt
The corresponding store stmt.
"""
begin = (begin,) if isinstance(begin, (int, _expr.Expr)) else begin
return _api_internal._BufferVStore(self, begin, value)


@register_node
class Split(NodeBase):
Expand Down
12 changes: 12 additions & 0 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,18 @@ TVM_REGISTER_API("_BufferAccessPtr")
.access_ptr(args[1], args[2]);
});

TVM_REGISTER_API("_BufferVLoad")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vload(args[1], args[2]);
});

TVM_REGISTER_API("_BufferVStore")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.vstore(args[1], args[2]);
});

TVM_REGISTER_API("_Tensor")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = TensorNode::make(args[0],
Expand Down
Loading

0 comments on commit 1146495

Please sign in to comment.