Skip to content

Commit

Permalink
support using pointer with an original offset (apache#826)
Browse files Browse the repository at this point in the history
* when there is no intrin func, using body for initialization. For issue 714.

* Refine code per review comments, and add a test case.

* Fix lint issues.

* Re-organize the tensorize test cases, and add a new case for none-reset
mode.

* Fix a typo.

* Delete the unit case because merged it into test_schedule_tensorize.py already.

* always use new tensor in its stage when rewrite for cache read

* revert previous changes to sync up with master

* support using the ptr with an original offset

* update test case and fix CI error
  • Loading branch information
kun-zh authored and tqchen committed Jan 27, 2018
1 parent 0b54952 commit 293dac3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 6 deletions.
3 changes: 2 additions & 1 deletion include/tvm/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ class Buffer : public NodeRef {
* \param access_mask The access mask
* \param ptr_type The type of the pointer.
* \param content_lanes The number of lanes for the (data) type.
* \param offset The offset of ptr.
*/
TVM_DLL Expr access_ptr(int access_mask, Type ptr_type = Handle(),
int content_lanes = 1) const;
int content_lanes = 1, int offset = 0) const;
/*!
* \brief Create an Expr that does a vector load at begin index.
* \param begin The beginning index
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class Buffer(NodeBase):
READ = 1
WRITE = 2

def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1, offset=0):
"""Get an access pointer to the head of buffer.
This is the recommended method to get buffer data
Expand All @@ -45,6 +45,10 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
The number of lanes for the data type. This value
is greater than one for vector types.
offset: int, optional
The offset of pointer. We can use it to offset by
the number of elements from the address of ptr.
Examples
--------
.. code-block:: python
Expand All @@ -68,7 +72,7 @@ def access_ptr(self, access_mask, ptr_type="handle", content_lanes=1):
raise ValueError("Unknown access_mask %s" % access_mask)
access_mask = mask
return _api_internal._BufferAccessPtr(self, access_mask, ptr_type,
content_lanes)
content_lanes, offset)

def vload(self, begin, dtype=None):
"""Generate an Expr that loads dtype from begin index.
Expand Down
2 changes: 1 addition & 1 deletion src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ TVM_REGISTER_API("_Buffer")
TVM_REGISTER_API("_BufferAccessPtr")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = args[0].operator Buffer()
.access_ptr(args[1], args[2], args[3]);
.access_ptr(args[1], args[2], args[3], args[4]);
});

TVM_REGISTER_API("_BufferVLoad")
Expand Down
4 changes: 2 additions & 2 deletions src/lang/buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ Buffer Buffer::MakeSlice(Array<Expr> begins, Array<Expr> extents) const {
0);
}

Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const {
Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes, int offset) const {
const BufferNode* self = operator->();
Expr e_dtype;
Expr extent;
Expand All @@ -348,7 +348,7 @@ Expr Buffer::access_ptr(int access_mask, Type ptr_type, int content_lanes) const
} else {
extent = arith::ComputeReduce<ir::Mul>(self->shape, Expr());
}
Expr elem_offset = self->elem_offset;
Expr elem_offset = self->elem_offset + offset;
if (content_lanes > 1) {
e_dtype = make_zero(self->dtype.with_lanes(content_lanes));
extent = extent / make_const(self->elem_offset.type(), content_lanes);
Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_lang_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@ def test_buffer_access_ptr():
aptr = Ab.access_ptr("w")
assert aptr.args[4].value == Buffer.WRITE

def test_buffer_access_ptr_offset():
m = tvm.var('m')
n = tvm.var('n')
Ab = tvm.decl_buffer((m, n), tvm.float32)
aptr = Ab.access_ptr("rw", offset=100)
offset = tvm.ir_pass.Simplify(aptr.args[2])
assert tvm.ir_pass.Equal(offset, 100)
assert aptr.args[4].value == Buffer.READ | Buffer.WRITE

def test_buffer_index_merge_mult_mod():
m = tvm.var('m')
n = tvm.var('n')
Expand Down Expand Up @@ -57,4 +66,5 @@ def assert_simplified_equal(index_simplified, index_direct):
if __name__ == "__main__":
test_buffer()
test_buffer_access_ptr()
test_buffer_access_ptr_offset()
test_buffer_index_merge_mult_mod()

0 comments on commit 293dac3

Please sign in to comment.