Skip to content

Commit

Permalink
examples: graph: sdpa: generalize the sdpa example
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Jun 22, 2024
1 parent 33f862b commit 236ac41
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 206 deletions.
2 changes: 1 addition & 1 deletion examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ if(NOT ONEDNN_BUILD_GRAPH)
${CMAKE_CURRENT_SOURCE_DIR}/graph/cpu_single_op_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph/sycl_single_op_partition.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph/gpu_opencl_getting_started.cpp
${CMAKE_CURRENT_SOURCE_DIR}/graph/gpu_opencl_sdpa.cpp)
${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa.cpp)
endif()

if(DNNL_SYCL_HIP)
Expand Down
205 changes: 0 additions & 205 deletions examples/graph/gpu_opencl_sdpa.cpp

This file was deleted.

146 changes: 146 additions & 0 deletions examples/graph/graph_example_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
#ifndef GRAPH_EXAMPLE_UTILS_HPP
#define GRAPH_EXAMPLE_UTILS_HPP

#include <unordered_set>

#include "oneapi/dnnl/dnnl_graph.hpp"

#include "example_utils.hpp"

#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
#include "dnnl_ocl.hpp"
#elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
Expand Down Expand Up @@ -316,6 +320,148 @@ void allocate_ocl_graph_mem(std::vector<dnnl::graph::tensor> &tensors,
if (!is_input) global_outputs_ts_map[lt_id] = tensors.back();
}
}

void ocl_memcpy(dnnl::engine &eng, void *dst, const void *src, size_t size) {
using F = cl_int (*)(cl_command_queue, cl_bool, void *, const void *,
size_t, cl_uint, const cl_event *, cl_event *);
if (!src || !dst) return;
cl_platform_id platform;
cl_context ctx = dnnl::ocl_interop::get_context(eng);
cl_device_id dev = dnnl::ocl_interop::get_device(eng);
cl_int err = 0;

// clCreateCommandQueue is deprecated in OpenCL.
#ifdef CL_VERSION_2_0
cl_command_queue queue
= clCreateCommandQueueWithProperties(ctx, dev, nullptr, &err);
#else
cl_command_queue queue = clCreateCommandQueue(ctx, dev, {}, &err);
#endif
if (err != CL_SUCCESS)
throw std::runtime_error("cannot create a cl_command_queue");

err = clGetDeviceInfo(
dev, CL_DEVICE_PLATFORM, sizeof(platform), &platform, nullptr);
if (err != CL_SUCCESS) throw std::runtime_error("clGetDeviceInfo failed");

const char *f_name = "clEnqueueMemcpyINTEL";
auto f = reinterpret_cast<F>(
clGetExtensionFunctionAddressForPlatform(platform, f_name));
err = f(queue, CL_FALSE, dst, src, size, 0, nullptr, nullptr);
if (err != CL_SUCCESS)
throw std::runtime_error("clEnqueueMemcpyINTEL failed");

return;
}
#endif

inline dnnl::memory::desc make_md(const dnnl::graph::logical_tensor &lt,
dnnl::memory::data_type dt = dnnl::memory::data_type::undef) {
using layout_type = dnnl::graph::logical_tensor::layout_type;
using dims = dnnl::memory::dims;

// if not specified, use the tensor data type.
if (dt == dnnl::memory::data_type::undef)
dt = static_cast<dnnl::memory::data_type>(lt.get_data_type());

if (lt.get_layout_type() != layout_type::strided) {
throw std::runtime_error("make_md: bad layout type");
} else {
const auto sz = lt.get_dims();
const auto st = lt.get_strides();
const auto nd = sz.size();
if (nd > 0) {
return dnnl::memory::desc(sz, dt, st);
} else {
// nd == 0
return dnnl::memory::desc(dims {1}, dt, dims {1});
}
}
}

inline void write_dt(void *handle, dnnl::graph::tensor &ts) {
dnnl::engine eng = ts.get_engine();
size_t size = ts.get_logical_tensor().get_mem_size();

if (!handle) throw std::runtime_error("handle is nullptr.");

#ifdef DNNL_WITH_SYCL
bool is_cpu_sycl = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL
&& eng.get_kind() == dnnl::engine::kind::cpu);
bool is_gpu_sycl = (DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL
&& eng.get_kind() == dnnl::engine::kind::gpu);
if (is_cpu_sycl || is_gpu_sycl) {
// only usm is supported in graph API.
uint8_t *dst_ptr = (uint8_t *)ts.get_data_handle();
if (!dst_ptr)
throw std::runtime_error("get_data_handle returned nullptr.");
if (is_cpu_sycl) {
for (size_t i = 0; i < size; ++i)
dst_ptr[i] = ((uint8_t *)handle)[i];
} else {
auto sycl_queue = dnnl::sycl_interop::get_queue(dnnl::stream(eng));
sycl_queue.memcpy(dst_ptr, handle, size).wait();
}
return;
}
#endif
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
if (eng.get_kind() == dnnl::engine::kind::gpu) {
// only usm is supported in graph API.
uint8_t *dst_ptr = (uint8_t *)ts.get_data_handle();
if (!dst_ptr)
throw std::runtime_error("get_data_handle returned nullptr.");
ocl_memcpy(eng, dst_ptr, handle, size);
return;
}
#endif

if (eng.get_kind() == dnnl::engine::kind::cpu) {
uint8_t *dst = static_cast<uint8_t *>(ts.get_data_handle());
if (!dst) throw std::runtime_error("get_data_handle returned nullptr.");
for (size_t i = 0; i < size; ++i)
dst[i] = ((uint8_t *)handle)[i];
return;
}

assert(!"not expected");
}

// Read from handle, write to tensor. Assume handle contains f32 data.
inline void write_to_dnnl_tensor(void *handle, dnnl::graph::tensor &ts) {
if (!handle) throw std::runtime_error("handle is nullptr.");

dnnl::engine eng = ts.get_engine();
const dnnl::graph::logical_tensor lt = ts.get_logical_tensor();
const dnnl::graph::logical_tensor::data_type dt = lt.get_data_type();

if (dt != dnnl::graph::logical_tensor::data_type::f32) {
// if non-f32 data type, use reorder to convert.
const auto f32_md = make_md(lt, dnnl::memory::data_type::f32);
auto f32_mem = dnnl::memory(f32_md, eng);
write_to_dnnl_memory(handle, f32_mem);

const auto dt_md = make_md(lt);
if (dt_md.get_size() != lt.get_mem_size()) {
throw std::runtime_error("incorrect memory size.");
}

dnnl::memory dt_mem;
#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
if (eng.get_kind() == dnnl::engine::kind::gpu) {
dt_mem = dnnl::ocl_interop::make_memory(dt_md, eng,
dnnl::ocl_interop::memory_kind::usm, ts.get_data_handle());
} else
#endif
dt_mem = dnnl::memory(dt_md, eng, ts.get_data_handle());

dnnl::stream strm(eng);
dnnl::reorder(f32_mem, dt_mem).execute(strm, f32_mem, dt_mem);
strm.wait();
} else {
// directly write to ts.
write_dt(handle, ts);
}
}

#endif
Loading

0 comments on commit 236ac41

Please sign in to comment.