Skip to content

Commit

Permalink
[Feature] Tvm integration (dmlc#2367)
Browse files Browse the repository at this point in the history
Co-authored-by: Zihao Ye <[email protected]>
  • Loading branch information
kira-lin and yzh119 authored Dec 31, 2020
1 parent 035f1ae commit 4208ce2
Show file tree
Hide file tree
Showing 15 changed files with 435 additions and 17 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@
[submodule "third_party/xbyak"]
path = third_party/xbyak
url = https://github.com/herumi/xbyak
[submodule "third_party/tvm"]
path = third_party/tvm
url = https://github.com/apache/incubator-tvm
40 changes: 29 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ endif()
dgl_option(USE_CUDA "Build with CUDA" OFF)
dgl_option(USE_OPENMP "Build with OpenMP" ON)
dgl_option(USE_AVX "Build with AVX optimization" ON)
dgl_option(USE_TVM "Build with TVM kernels" OFF)
dgl_option(BUILD_CPP_TEST "Build cpp unittest executables" OFF)
dgl_option(LIBCXX_ENABLE_PARALLEL_ALGORITHMS "Enable the parallel algorithms library. This requires the PSTL to be available." OFF)
dgl_option(USE_S3 "Build with S3 support" OFF)
Expand Down Expand Up @@ -52,17 +53,6 @@ if(USE_CUDA)
endif()
endif(USE_CUDA)

# include directories
include_directories("include")
include_directories("third_party/dlpack/include")
include_directories("third_party/METIS/include/")
include_directories("third_party/dmlc-core/include")
include_directories("third_party/minigun/minigun")
include_directories("third_party/minigun/third_party/moderngpu/src")
include_directories("third_party/phmap/")
include_directories("third_party/xbyak/")
include_directories("tensoradapter/include")

# initial variables
if(NOT MSVC)
set(DGL_LINKER_LIBS "dl")
Expand Down Expand Up @@ -165,6 +155,17 @@ else(USE_CUDA)
add_library(dgl SHARED ${DGL_SRC})
endif(USE_CUDA)

# include directories
target_include_directories(dgl PRIVATE "include")
target_include_directories(dgl PRIVATE "third_party/dlpack/include")
target_include_directories(dgl PRIVATE "third_party/dmlc-core/include")
target_include_directories(dgl PRIVATE "third_party/minigun/minigun")
target_include_directories(dgl PRIVATE "third_party/minigun/third_party/moderngpu/src")
target_include_directories(dgl PRIVATE "third_party/phmap/")
target_include_directories(dgl PRIVATE "third_party/xbyak/")
target_include_directories(dgl PRIVATE "third_party/METIS/include/")
target_include_directories(dgl PRIVATE "tensoradapter/include")

# For serialization
if (USE_HDFS)
option(DMLC_HDFS_SHARED "dgl has to build with dynamic hdfs library" ON)
Expand All @@ -178,10 +179,23 @@ if(NOT MSVC)
set(GKLIB_PATH "${CMAKE_SOURCE_DIR}/third_party/METIS/GKlib")
include(${GKLIB_PATH}/GKlibSystem.cmake)
include_directories(${GKLIB_PATH})
include_directories("third_party/METIS/include/")
add_subdirectory("third_party/METIS/libmetis/")
list(APPEND DGL_LINKER_LIBS metis)
endif(NOT MSVC)

# Compile TVM Runtime and Featgraph
# (NOTE) We compile a dynamic library called featgraph_runtime, which the DGL library links to.
# Kernels are packed in a separate dynamic library called featgraph_kernels, which DGL
# will load during runtime.
if(USE_TVM)
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DUSE_TVM")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DUSE_TVM")
target_include_directories(dgl PRIVATE "featgraph/include")
add_subdirectory("featgraph/")
list(APPEND DGL_LINKER_LIBS featgraph_runtime)
endif(USE_TVM)

# support PARALLEL_ALGORITHMS
if (LIBCXX_ENABLE_PARALLEL_ALGORITHMS)
add_definitions(-DPARALLEL_ALGORITHMS)
Expand Down Expand Up @@ -238,6 +252,10 @@ if(BUILD_CPP_TEST)
add_subdirectory(./third_party/googletest)
enable_testing()
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
include_directories("include")
include_directories("third_party/dlpack/include")
include_directories("third_party/xbyak")
include_directories("third_party/dmlc-core/include")
file(GLOB_RECURSE TEST_SRC_FILES ${PROJECT_SOURCE_DIR}/tests/cpp/*.cc)
add_executable(runUnitTests ${TEST_SRC_FILES})
target_link_libraries(runUnitTests gtest gtest_main)
Expand Down
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Contributors
* [Gongze Cao](https://github.com/Zardinality): Cluster GCN
* [Yicheng Wu](https://github.com/MilkshakeForReal): RotatE in PyTorch
* [Hao Xiong](https://github.com/ShawXh): DeepWalk in PyTorch
* [Zhi Lin](https://github.com/kira-lin): Integrate FeatGraph into DGL

Other improvement
* [Brett Koonce](https://github.com/brettkoonce)
Expand Down
12 changes: 8 additions & 4 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,18 @@ set(USE_CUDA OFF)
#---------------------------------------------
# Misc.
#---------------------------------------------
# Whether to build cpp unittest executables
# Whether to build cpp unittest executables.
set(BUILD_CPP_TEST OFF)

# Whether to enable OpenMP
# Whether to enable OpenMP.
set(USE_OPENMP ON)

# Whether to enable Intel's avx optimized kernel
# Whether to enable Intel's avx optimized kernel.
set(USE_AVX ON)

# Whether to build PyTorch plugins
# Whether to build PyTorch plugins.
set(BUILD_TORCH ON)

# Whether to enable CUDA kernels compiled with TVM.
set(USE_TVM OFF)

34 changes: 34 additions & 0 deletions featgraph/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
cmake_minimum_required(VERSION 3.5)

project(featgraph C CXX)
message(STATUS "Start configuring project ${PROJECT_NAME}")

# Find CUDA
include(../cmake/util/FindCUDA.cmake)
find_cuda(ON)
message(STATUS "${CUDA_INCLUDE_DIRS}")

add_custom_target(
featgraph_kernel
COMMAND python ${CMAKE_CURRENT_SOURCE_DIR}/pack_featgraph.py
COMMENT "Creating featgraph kernels..."
)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++14 -O2 -fPIC")
file(GLOB FEATGRAPH_SRC
src/featgraph.cc
src/tvm_runtime_pack.cc
)
add_library(featgraph_runtime SHARED ${FEATGRAPH_SRC})
target_include_directories(featgraph_runtime PRIVATE ${CUDA_INCLUDE_DIRS})
target_include_directories(featgraph_runtime PRIVATE "./include")
target_include_directories(featgraph_runtime PRIVATE "../third_party/tvm/include")
target_include_directories(featgraph_runtime PRIVATE "../third_party/tvm/3rdparty/dmlc-core/include")
target_include_directories(featgraph_runtime PRIVATE "../third_party/tvm/3rdparty/dlpack/include")
target_link_libraries(featgraph_runtime "dl" # dynamic linking
${CUDA_CUDART_LIBRARY}
${CUDA_CUDA_LIBRARY}
${CUDA_NVRTC_LIBRARY})
add_dependencies(featgraph_runtime featgraph_kernel)

install(TARGETS featgraph_runtime LIBRARY DESTINATION lib)
21 changes: 21 additions & 0 deletions featgraph/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# FeatGraph-DGL

FeatGraph is an efficient backend for Graph Neural Networks based on TVM.

- Original repo: https://github.com/amazon-research/FeatGraph
- SC2020 Paper: https://www.csl.cornell.edu/~zhiruz/pdfs/featgraph-sc2020.pdf

This folder contains the code for integrating featgraph kernels to DGL.

## Usage

After building DGL with `USE_TVM=ON`, you should be able to run:
```bash
python test.py
```
to verify correctness.

## Reference

- [TVM Tutorial on Deploy TVM Module using C++ API](https://tvm.apache.org/docs/deploy/cpp_deploy.html).

25 changes: 25 additions & 0 deletions featgraph/include/featgraph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*!
* Copyright (c) 2020 by Contributors
* \file featgraph/include/featgraph.h
* \brief FeatGraph kernel headers.
*/
#ifndef FEATGRAPH_H_
#define FEATGRAPH_H_

#include <dlpack/dlpack.h>

namespace dgl {
namespace featgraph {

/* \brief Load Featgraph module from given path. */
void LoadFeatGraphModule(const std::string& path);

/* \brief Call Featgraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out);

} // namespace featgraph
} // namespace dgl

#endif // FEATGRAPH_H_
40 changes: 40 additions & 0 deletions featgraph/pack_featgraph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
""" Export featgraph kernels to a shared library. """
import tvm
from sddmm import sddmm_tree_reduction_gpu


def get_sddmm_kernels_gpu(idtypes, dtypes):
"""
Parameters
----------
idtypes: List[str]
Possible index types.
dtypes: List[str]
Possible data types.
Returns
-------
List[IRModule]:
The list of IRModules.
"""
ret = []
# SDDMM Tree Reduction
for dtype in dtypes:
for idtype in idtypes:
ret.append(sddmm_tree_reduction_gpu(idtype, dtype))

return ret


if __name__ == '__main__':
binary_path = 'libfeatgraph_kernels.so'
kernels = []
idtypes = ['int32', 'int64']
dtypes = ['float16', 'float64', 'float32', 'int32', 'int64']

kernels += get_sddmm_kernels_gpu(idtypes, dtypes)

# build kernels and export the module to libfeatgraph_kernels.so
module = tvm.build(kernels, target='cuda', target_host='llvm')
module.export_library(binary_path)

53 changes: 53 additions & 0 deletions featgraph/sddmm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
""" The compute function and schedules for SDDMM kernels written in TVM. """
import tvm
from tvm import te


def sddmm_tree_reduction_gpu(idx_type, feat_type):
""" SDDMM kernels on GPU optimized with Tree Reduction.
Parameters
----------
idx_type : str
The data type for indexing tensors.
feat_type : str
The data type of feature tensor.
Returns
-------
IRModule
The result IRModule.
"""
# define vars and placeholders
nnz = te.var('nnz', idx_type)
num_rows = te.var('num_rows', idx_type)
num_cols = te.var('num_cols', idx_type)
H = te.var('num_heads', idx_type)
D = te.var('feat_len', idx_type)
row = te.placeholder((nnz,), idx_type, 'row')
col = te.placeholder((nnz,), idx_type, 'col')
ufeat = te.placeholder((num_rows, H, D), feat_type, 'ufeat')
vfeat = te.placeholder((num_cols, H, D), feat_type, 'vfeat')
# define edge computation function
def edge_func(eid, h, i):
k = te.reduce_axis((0, D), name='k')
return te.sum(ufeat[row[eid], h, k] * vfeat[col[eid], h, k], axis=k)
out = te.compute((nnz, H, tvm.tir.IntImm(idx_type, 1)), edge_func, name='out')
# define schedules
sched = te.create_schedule(out.op)
edge_axis, head_axis, _ = out.op.axis
reduce_axis = out.op.reduce_axis[0]
_, red_inner = sched[out].split(reduce_axis, factor=32)
edge_outer, edge_inner = sched[out].split(edge_axis, factor=32)
sched[out].bind(red_inner, te.thread_axis('threadIdx.x'))
sched[out].bind(edge_inner, te.thread_axis('threadIdx.y'))
sched[out].bind(edge_outer, te.thread_axis('blockIdx.x'))
sched[out].bind(head_axis, te.thread_axis('blockIdx.y'))
return tvm.lower(sched, [row, col, ufeat, vfeat, out],
name='SDDMMTreeReduction_{}_{}'.format(idx_type, feat_type))


if __name__ == '__main__':
kernel0 = sddmm_tree_reduction_gpu('int32', 'float32')
print(kernel0)

78 changes: 78 additions & 0 deletions featgraph/src/featgraph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*!
* Copyright (c) 2020 by Contributors
* \file featgraph/src/featgraph.cc
* \brief FeatGraph kernels.
*/
#include <tvm/runtime/module.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/registry.h>
#include <dmlc/logging.h>
#include <featgraph.h>

namespace dgl {
namespace featgraph {

/* \brief Singleton that loads the featgraph module. */
class FeatGraphModule {
public:
static FeatGraphModule* Global() {
static FeatGraphModule inst;
return &inst;
}

void Load(const std::string& path) {
mod = tvm::runtime::Module::LoadFromFile(path);
}

inline tvm::runtime::ModuleNode* Get() {
auto ret = mod.operator->();
if (!ret) {
LOG(FATAL) << "FeatGraph module have not been loaded. "
<< "Please set path of featgraph shared library.";
}
return ret;
}
private:
tvm::runtime::Module mod;
FeatGraphModule() {}
};

/* \brief Load Featgraph module from given path. */
void LoadFeatGraphModule(const std::string& path) {
FeatGraphModule::Global()->Load(path);
}

/* \brief Convert DLDataType to string. */
inline std::string DTypeAsStr(const DLDataType& t) {
switch(t.code) {
case 0U: return "int" + std::to_string(t.bits);
case 1U: return "uint" + std::to_string(t.bits);
case 2U: return "float" + std::to_string(t.bits);
case 3U: return "bfloat" + std::to_string(t.bits);
default: LOG(FATAL) << "Type code " << t.code << " not recognized";
}
}

/* \brief Get operator filename. */
inline std::string GetOperatorName(
const std::string& base_name,
const DLDataType& dtype,
const DLDataType& idtype) {
return base_name + "_" + DTypeAsStr(dtype) + "_" + DTypeAsStr(idtype);
}

/* \brief Call FeatGraph's SDDMM kernel. */
void SDDMMTreeReduction(DLManagedTensor* row, DLManagedTensor* col,
DLManagedTensor* lhs, DLManagedTensor* rhs,
DLManagedTensor* out) {
tvm::runtime::ModuleNode* mod = FeatGraphModule::Global()->Get();
std::string f_name = GetOperatorName("SDDMMTreeReduction",
(row->dl_tensor).dtype,
(lhs->dl_tensor).dtype);
tvm::runtime::PackedFunc f = mod->GetFunction(f_name);
if (f != nullptr)
f(row, col, lhs, rhs, out);
}

} // namespace featgraph
} // namespace dgl
Loading

0 comments on commit 4208ce2

Please sign in to comment.