Skip to content

Commit

Permalink
[Performance] Cacheline-aligned access for UnifiedTensor (dmlc#3254)
Browse files Browse the repository at this point in the history
* Add pytorch-direct version

* remove

* add documentation for UnifiedTensor

* Revert "add documentation for UnifiedTensor"

This reverts commit 63ba426.

* alignment fix for UnifiedTensor access

* fix linting issue

Co-authored-by: shhssdm <[email protected]>
Co-authored-by: xiang song(charlie.song) <[email protected]>
  • Loading branch information
3 people authored Aug 17, 2021
1 parent 76af2a2 commit 2613f7f
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 15 deletions.
10 changes: 8 additions & 2 deletions src/array/cuda/uvm/array_index_select_uvm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
*/
#include <dgl/array.h>
#include "../../../runtime/cuda/cuda_common.h"
#include "../array_index_select.cuh"
#include "./array_index_select_uvm.cuh"
#include "../utils.h"

Expand Down Expand Up @@ -48,8 +49,13 @@ NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
block.y *= 2;
}
const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream,
array_data, num_feat, idx_data, len, ret_data);
if (num_feat * sizeof(DType) < 2 * CACHE_LINE_SIZE) {
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data);
} else {
CUDA_KERNEL_CALL(IndexSelectMultiKernelAligned, grid, block, 0,
thr_entry->stream, array_data, num_feat, idx_data, len, ret_data);
}
}
return ret;
}
Expand Down
21 changes: 8 additions & 13 deletions src/array/cuda/uvm/array_index_select_uvm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,14 @@
#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_

#define CACHE_LINE_SIZE 128

namespace dgl {
namespace aten {
namespace impl {

template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(const DType* array, const IdType* index,
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}

template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
__global__ void IndexSelectMultiKernelAligned(
const DType* const array,
const int64_t num_feat,
const IdType* const index,
Expand All @@ -36,8 +27,12 @@ __global__ void IndexSelectMultiKernel(
while (out_row < length) {
int64_t col = threadIdx.x;
const int64_t in_row = index[out_row];
const int64_t idx_offset =
((uint64_t)(&array[in_row*num_feat]) % CACHE_LINE_SIZE) / sizeof(DType);
col = col - idx_offset;
while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col];
if (col >= 0)
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x;
}
out_row += stride;
Expand Down

0 comments on commit 2613f7f

Please sign in to comment.