Skip to content

Commit

Permalink
graph: backend: dnnl: kernels: ocl path for select
Browse files Browse the repository at this point in the history
  • Loading branch information
TaoLv committed Apr 27, 2024
1 parent 8f030ca commit 893c2ed
Showing 1 changed file with 80 additions and 0 deletions.
80 changes: 80 additions & 0 deletions src/graph/backend/dnnl/kernels/select.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,86 @@ struct select_t : public kernel_base_t {
}
#endif

#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
status_t ocl_execute_impl(const stream_t *g_stream,
const std::vector<tensor_t> &inputs,
const std::vector<tensor_t> &outputs,
const std::vector<cl_event> &cl_deps,
cl_event *ret_event) override {

auto deps = cl_deps;
cl_event returned_event;
dnnl::stream p_stream = make_dnnl_stream(p_engine_, *g_stream);

// each thread's own local resource
thread_local_cache_t<execution_args_set_t> res_cache;
execution_args_set_t *res = res_cache.get_or_add(
reinterpret_cast<size_t>(this), resource_ctor_);

temporary_scratchpad_t scratchpad(
memory_planner_.total_internal_temporary_size(), p_engine_,
*g_alloc_);
assertm(scratchpad.size()
>= memory_planner_.total_internal_temporary_size(),
"no enough scratchpad memory");
prepare_args_set(res, inputs, outputs, scratchpad);

constant_cache_t::cached_t c_buffer;
if (enabled_constant_cache()) {
std::promise<constant_cache_t::cached_t> c_promise;
constant_cache_t::value_t cached_value
= dnnl_constant_cache_get_or_add(p_engine_, constant_key_,
memory_planner_.total_internal_persistent_size(),
c_promise.get_future());
bool is_from_cache = cached_value.valid();
if (is_from_cache) {
c_buffer = cached_value.get();
grantor_t c_grantor
= memory_planner_.internal_persistent_grantor(
c_buffer->data<char>());
for (auto &mem_offkey :
res->get_mems_use_internal_persistent()) {
mem_offkey.first.set_data_handle(
c_grantor.get(mem_offkey.second));
}
} else {
c_buffer = std::make_shared<dnnl_constant_buffer_t>(
memory_planner_.total_internal_persistent_size(),
p_engine_, g_alloc_);
grantor_t c_grantor
= memory_planner_.internal_persistent_grantor(
c_buffer->data<char>());
for (auto &mem_offkey :
res->get_mems_use_internal_persistent()) {
mem_offkey.first.set_data_handle(
c_grantor.get(mem_offkey.second));
}

for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
if (!subgraph_->is_constant_[i]) continue;
returned_event = subgraph_->execs_[i]->execute_ocl(
p_stream, res->get_exec_args()[i], deps);
deps = {returned_event};
}

c_promise.set_value(c_buffer);
}
}

for (size_t i = 0; i < subgraph_->execs_.size(); i++) {
if (subgraph_->is_constant_[i]) continue;
returned_event = subgraph_->execs_[i]->execute_ocl(
p_stream, res->get_exec_args()[i], deps);
deps = {returned_event};
}

scratchpad.set_deps(returned_event);
if (ret_event) *ret_event = returned_event;

return status::success;
}
#endif

status_t prepare_inplace_pairs_impl() override {
inplace_pairs_ = memory_planner_.get_subgraph_inplace_pairs();
return status::success;
Expand Down

0 comments on commit 893c2ed

Please sign in to comment.