diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3bc12cf073b..5a25d585a2b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -75,7 +75,7 @@ if(NOT ONEDNN_BUILD_GRAPH) ${CMAKE_CURRENT_SOURCE_DIR}/graph/cpu_single_op_partition.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/sycl_single_op_partition.cpp ${CMAKE_CURRENT_SOURCE_DIR}/graph/gpu_opencl_getting_started.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/graph/gpu_opencl_sdpa.cpp) + ${CMAKE_CURRENT_SOURCE_DIR}/graph/sdpa.cpp) endif() if(DNNL_SYCL_HIP) diff --git a/examples/graph/gpu_opencl_sdpa.cpp b/examples/graph/gpu_opencl_sdpa.cpp deleted file mode 100644 index c9f7498b8c5..00000000000 --- a/examples/graph/gpu_opencl_sdpa.cpp +++ /dev/null @@ -1,205 +0,0 @@ -/******************************************************************************* -* Copyright 2024 Intel Corporation -* -* Licensed under the Apache License, Version 2.0 (the "License"); -* you may not use this file except in compliance with the License. -* You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*******************************************************************************/ - -#include -#include -#include -#include -#include - -#include - -#include "oneapi/dnnl/dnnl_graph.hpp" -#include "oneapi/dnnl/dnnl_graph_ocl.hpp" - -#include "example_utils.hpp" -#include "graph_example_utils.hpp" - -using namespace dnnl::graph; -using data_type = logical_tensor::data_type; -using layout_type = logical_tensor::layout_type; -using dim = logical_tensor::dim; -using dims = logical_tensor::dims; - -void gpu_float_sdpa(data_type dtype, int batch_size, int seq_len, int num_head, - int head_dim) { - const engine::kind ekind = engine::kind::gpu; - allocator alloc = ocl_interop::make_allocator(ocl_malloc_shared, ocl_free); - - cl_uint num_platforms = 0; - OCL_CHECK(clGetPlatformIDs(0, NULL, &num_platforms)); - std::vector platforms(num_platforms); - if (num_platforms > 0) { - OCL_CHECK(clGetPlatformIDs(num_platforms, platforms.data(), NULL)); - } else { - throw "Cannot find any openCL platform!"; - } - - std::vector gpu_device_ids; - for (cl_platform_id &platform_id : platforms) { - cl_uint num_devices; - if (!clGetDeviceIDs( - platform_id, CL_DEVICE_TYPE_GPU, 0, NULL, &num_devices)) { - std::vector device_ids(num_devices); - OCL_CHECK(clGetDeviceIDs(platform_id, CL_DEVICE_TYPE_GPU, - num_devices, device_ids.data(), NULL)); - gpu_device_ids.insert( - gpu_device_ids.end(), device_ids.begin(), device_ids.end()); - } - } - if (gpu_device_ids.empty()) { throw "Cannot find any OpenCL device!"; } - - cl_device_id device_id = gpu_device_ids[0]; - cl_int err = 0; - auto ctx = clCreateContext(NULL, 1, &device_id, NULL, NULL, &err); - OCL_CHECK(err); - -// clCreateCommandQueue is deprecated in OpenCL. -#ifdef CL_VERSION_2_0 - cl_command_queue q - = clCreateCommandQueueWithProperties(ctx, device_id, nullptr, &err); -#else - cl_command_queue q = clCreateCommandQueue(ctx, device_id, {}, &err); -#endif - OCL_CHECK(err); - dnnl::engine eng = dnnl::graph::ocl_interop::make_engine_with_allocator( - device_id, ctx, alloc); - stream strm = dnnl::ocl_interop::make_stream(eng, q); - - int size_per_head = head_dim / num_head; - dims qkv_input_shape = {batch_size, num_head, seq_len, size_per_head}; - dims qk_output_shape = {batch_size, num_head, seq_len, seq_len}; - dims scale_shape = {1}; - dims attention_mask_shape = {batch_size, 1, 1, seq_len}; - - size_t lt_id = 0; - - logical_tensor query_input { - lt_id++, dtype, qkv_input_shape, layout_type::strided}; - logical_tensor key_input { - lt_id++, dtype, qkv_input_shape, layout_type::strided}; - logical_tensor matmul_qk_out { - lt_id++, dtype, qk_output_shape, layout_type::strided}; - op matmul_qk {0, op::kind::MatMul, {query_input, key_input}, - {matmul_qk_out}, "matmul_qk"}; - matmul_qk.set_attr(op::attr::transpose_b, true); - - logical_tensor scale_factor {lt_id++, dtype, scale_shape, - layout_type::strided, logical_tensor::property_type::constant}; - logical_tensor scaled_qk_out { - lt_id++, dtype, qk_output_shape, layout_type::strided}; - op scale_div {1, op::kind::Divide, {matmul_qk_out, scale_factor}, - {scaled_qk_out}, "scale_div"}; - - logical_tensor attention_mask { - lt_id++, dtype, attention_mask_shape, layout_type::strided}; - logical_tensor masked_qk_out { - lt_id++, dtype, qk_output_shape, layout_type::strided}; - op mask_add {2, op::kind::Add, {scaled_qk_out, attention_mask}, - {masked_qk_out}, "mask_add"}; - - logical_tensor softmax_out { - lt_id++, dtype, qk_output_shape, layout_type::strided}; - op softmax { - 3, op::kind::SoftMax, {masked_qk_out}, {softmax_out}, "softmax"}; - softmax.set_attr(op::attr::axis, -1); - - logical_tensor value_input { - lt_id++, dtype, qkv_input_shape, layout_type::strided}; - logical_tensor matmul_v_out { - lt_id++, dtype, qkv_input_shape, layout_type::strided}; - op matmul_v {4, op::kind::MatMul, {softmax_out, value_input}, - {matmul_v_out}, "matmul_v"}; - - graph g(ekind); - g.add_op(matmul_qk); - g.add_op(scale_div); - g.add_op(mask_add); - g.add_op(softmax); - g.add_op(matmul_v); - g.finalize(); - - std::vector partitions = g.get_partitions(); - // just for testing purpose. User code should not make assertion for it. - assert(partitions.size() == 1); - - std::vector inputs = partitions[0].get_input_ports(); - std::vector outputs = partitions[0].get_output_ports(); - compiled_partition sdp_cpartition - = partitions[0].compile(inputs, outputs, eng); - - std::vector inputs_ts, outputs_ts; - std::vector> data_buffer; - std::unordered_map global_outputs_ts_map; - // Input/output memory should be prepared by users. This helper function is - // for testing purpose and not part of API. - allocate_ocl_graph_mem( - inputs_ts, inputs, data_buffer, global_outputs_ts_map, eng, true); - allocate_ocl_graph_mem(outputs_ts, outputs, data_buffer, - global_outputs_ts_map, eng, false); - - sdp_cpartition.execute(strm, inputs_ts, outputs_ts); - strm.wait(); -} - -data_type str2data_type(const std::string &v) { - if (v == "f32") - return data_type::f32; - else if (v == "bf16") - return data_type::bf16; - else if (v == "f16") - return data_type::f16; - else - return data_type::undef; -} - -int main(int argc, char **argv) { - if (argc > 2) { - std::cout << "One parameter (dtype) is needed: f32 / bf16 / f16 \n"; - return 0; - } - // if dtype is not provide, use f32 as default. - const std::string dtype_str = argc == 2 ? argv[1] : "f32"; - data_type dtype = str2data_type(dtype_str); - - int batch_size = 1; - int seq_len = 384; - int num_head = 16; - int head_dim = 1024; - - std::cout << "Running SDPA with data_type: " << dtype_str - << ", batch_size: " << batch_size << ", seq_len: " << seq_len - << ", num_head: " << num_head << ", head_dim: " << head_dim - << std::endl; - - int exit_code = 0; - try { - gpu_float_sdpa(dtype, batch_size, seq_len, num_head, head_dim); - } catch (dnnl::error &e) { - std::cout << "oneDNN error caught: " << std::endl - << "\tStatus: " << dnnl_status2str(e.status) << std::endl - << "\tMessage: " << e.what() << std::endl; - exit_code = 1; - } catch (std::exception &e) { - std::cout << "Error in the example: " << e.what() << "." << std::endl; - exit_code = 2; - } - - std::cout << "Example " << (exit_code ? "failed" : "passed") << " with " - << dtype_str << "." << std::endl; - - return exit_code; -} diff --git a/examples/graph/graph_example_utils.hpp b/examples/graph/graph_example_utils.hpp index 8276fe3f9ed..c5748ebf5ab 100644 --- a/examples/graph/graph_example_utils.hpp +++ b/examples/graph/graph_example_utils.hpp @@ -17,8 +17,12 @@ #ifndef GRAPH_EXAMPLE_UTILS_HPP #define GRAPH_EXAMPLE_UTILS_HPP +#include + #include "oneapi/dnnl/dnnl_graph.hpp" +#include "example_utils.hpp" + #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL #include "dnnl_ocl.hpp" #elif DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL @@ -316,6 +320,148 @@ void allocate_ocl_graph_mem(std::vector &tensors, if (!is_input) global_outputs_ts_map[lt_id] = tensors.back(); } } + +void ocl_memcpy(dnnl::engine &eng, void *dst, const void *src, size_t size) { + using F = cl_int (*)(cl_command_queue, cl_bool, void *, const void *, + size_t, cl_uint, const cl_event *, cl_event *); + if (!src || !dst) return; + cl_platform_id platform; + cl_context ctx = dnnl::ocl_interop::get_context(eng); + cl_device_id dev = dnnl::ocl_interop::get_device(eng); + cl_int err = 0; + +// clCreateCommandQueue is deprecated in OpenCL. +#ifdef CL_VERSION_2_0 + cl_command_queue queue + = clCreateCommandQueueWithProperties(ctx, dev, nullptr, &err); +#else + cl_command_queue queue = clCreateCommandQueue(ctx, dev, {}, &err); #endif + if (err != CL_SUCCESS) + throw std::runtime_error("cannot create a cl_command_queue"); + + err = clGetDeviceInfo( + dev, CL_DEVICE_PLATFORM, sizeof(platform), &platform, nullptr); + if (err != CL_SUCCESS) throw std::runtime_error("clGetDeviceInfo failed"); + + const char *f_name = "clEnqueueMemcpyINTEL"; + auto f = reinterpret_cast( + clGetExtensionFunctionAddressForPlatform(platform, f_name)); + err = f(queue, CL_FALSE, dst, src, size, 0, nullptr, nullptr); + if (err != CL_SUCCESS) + throw std::runtime_error("clEnqueueMemcpyINTEL failed"); + + return; +} +#endif + +inline dnnl::memory::desc make_md(const dnnl::graph::logical_tensor <, + dnnl::memory::data_type dt = dnnl::memory::data_type::undef) { + using layout_type = dnnl::graph::logical_tensor::layout_type; + using dims = dnnl::memory::dims; + + // if not specified, use the tensor data type. + if (dt == dnnl::memory::data_type::undef) + dt = static_cast(lt.get_data_type()); + + if (lt.get_layout_type() != layout_type::strided) { + throw std::runtime_error("make_md: bad layout type"); + } else { + const auto sz = lt.get_dims(); + const auto st = lt.get_strides(); + const auto nd = sz.size(); + if (nd > 0) { + return dnnl::memory::desc(sz, dt, st); + } else { + // nd == 0 + return dnnl::memory::desc(dims {1}, dt, dims {1}); + } + } +} + +inline void write_dt(void *handle, dnnl::graph::tensor &ts) { + dnnl::engine eng = ts.get_engine(); + size_t size = ts.get_logical_tensor().get_mem_size(); + + if (!handle) throw std::runtime_error("handle is nullptr."); + +#ifdef DNNL_WITH_SYCL + bool is_cpu_sycl = (DNNL_CPU_RUNTIME == DNNL_RUNTIME_SYCL + && eng.get_kind() == dnnl::engine::kind::cpu); + bool is_gpu_sycl = (DNNL_GPU_RUNTIME == DNNL_RUNTIME_SYCL + && eng.get_kind() == dnnl::engine::kind::gpu); + if (is_cpu_sycl || is_gpu_sycl) { + // only usm is supported in graph API. + uint8_t *dst_ptr = (uint8_t *)ts.get_data_handle(); + if (!dst_ptr) + throw std::runtime_error("get_data_handle returned nullptr."); + if (is_cpu_sycl) { + for (size_t i = 0; i < size; ++i) + dst_ptr[i] = ((uint8_t *)handle)[i]; + } else { + auto sycl_queue = dnnl::sycl_interop::get_queue(dnnl::stream(eng)); + sycl_queue.memcpy(dst_ptr, handle, size).wait(); + } + return; + } +#endif +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + if (eng.get_kind() == dnnl::engine::kind::gpu) { + // only usm is supported in graph API. + uint8_t *dst_ptr = (uint8_t *)ts.get_data_handle(); + if (!dst_ptr) + throw std::runtime_error("get_data_handle returned nullptr."); + ocl_memcpy(eng, dst_ptr, handle, size); + return; + } +#endif + + if (eng.get_kind() == dnnl::engine::kind::cpu) { + uint8_t *dst = static_cast(ts.get_data_handle()); + if (!dst) throw std::runtime_error("get_data_handle returned nullptr."); + for (size_t i = 0; i < size; ++i) + dst[i] = ((uint8_t *)handle)[i]; + return; + } + + assert(!"not expected"); +} + +// Read from handle, write to tensor. Assume handle contains f32 data. +inline void write_to_dnnl_tensor(void *handle, dnnl::graph::tensor &ts) { + if (!handle) throw std::runtime_error("handle is nullptr."); + + dnnl::engine eng = ts.get_engine(); + const dnnl::graph::logical_tensor lt = ts.get_logical_tensor(); + const dnnl::graph::logical_tensor::data_type dt = lt.get_data_type(); + + if (dt != dnnl::graph::logical_tensor::data_type::f32) { + // if non-f32 data type, use reorder to convert. + const auto f32_md = make_md(lt, dnnl::memory::data_type::f32); + auto f32_mem = dnnl::memory(f32_md, eng); + write_to_dnnl_memory(handle, f32_mem); + + const auto dt_md = make_md(lt); + if (dt_md.get_size() != lt.get_mem_size()) { + throw std::runtime_error("incorrect memory size."); + } + + dnnl::memory dt_mem; +#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL + if (eng.get_kind() == dnnl::engine::kind::gpu) { + dt_mem = dnnl::ocl_interop::make_memory(dt_md, eng, + dnnl::ocl_interop::memory_kind::usm, ts.get_data_handle()); + } else +#endif + dt_mem = dnnl::memory(dt_md, eng, ts.get_data_handle()); + + dnnl::stream strm(eng); + dnnl::reorder(f32_mem, dt_mem).execute(strm, f32_mem, dt_mem); + strm.wait(); + } else { + // directly write to ts. + write_dt(handle, ts); + } +} #endif diff --git a/examples/graph/sdpa.cpp b/examples/graph/sdpa.cpp new file mode 100644 index 00000000000..20e9ef8bc13 --- /dev/null +++ b/examples/graph/sdpa.cpp @@ -0,0 +1,277 @@ +/******************************************************************************* +* Copyright 2024 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "oneapi/dnnl/dnnl_graph.hpp" + +#include "graph_example_utils.hpp" + +using namespace dnnl::graph; +using data_type = logical_tensor::data_type; +using layout_type = logical_tensor::layout_type; +using dim = logical_tensor::dim; +using dims = logical_tensor::dims; + +struct sdpa_dims_t { + dim mb; + dim seq_len; + dim head_num; + dim head_size; +}; + +static const int min_runs = 4; + +// this is changed from the fill_random() function in matmul_perf.cpp. +void fill_random(std::vector &out) { + static std::vector random_data_f; + constexpr size_t nrand = 1037; + + if (random_data_f.empty()) { + std::mt19937 generator; + std::uniform_real_distribution dist_f(-1.0f, 1.0f); + + random_data_f.resize(nrand); + for (auto &d : random_data_f) + d = dist_f(generator); + } + + for (size_t i = 0; i < out.size(); i += nrand) { + size_t chunk = std::min(nrand, out.size() - i); + std::memcpy(&out[i], random_data_f.data(), chunk * sizeof(float)); + } +} + +// initialize the mask with first 3/4 elements with 0s and the last 1/4 elements +// with -inf. +void fill_mask(std::vector &mask, size_t seq_len) { + const size_t pos = seq_len * 3 / 4; + for (size_t i = 0; i < mask.size(); ++i) { + if (i % seq_len < pos) + mask[i] = 0.f; + else + mask[i] = -1 * std::numeric_limits::infinity(); + } +} + +const char *get_type_string(data_type dt) { + const char *type_string = "unknown"; + +#define TYPE_CASE(T) \ + if (dt == data_type::T) type_string = #T; + TYPE_CASE(f16); + TYPE_CASE(f32); + TYPE_CASE(bf16); +#undef TYPE_CASE + + return type_string; +} + +void print_test_case(data_type dt, const sdpa_dims_t &p) { + std::cout << '[' << std::setw(4) << get_type_string(dt); + std::cout << " mb = " << p.mb << ", seq_len = " << p.seq_len + << ", head_num = " << p.head_num + << ", head_size = " << p.head_size; + std::cout << "] " << std::flush; +} + +void bench_sdpa(engine::kind ekind, data_type dt, const sdpa_dims_t &p, + double time_limit = 0.) { + const bool quick_test = (time_limit == 0.); + if (!quick_test) print_test_case(dt, p); + + // Create execution dnnl::engine. + dnnl::engine eng(ekind, 0); + // Create dnnl::stream. + dnnl::stream strm(eng); + + // Prepare input and output shapes to construct the sdpa graph. + const dims qkv_sz = {p.mb, p.head_num, p.seq_len, p.head_size}; + const dims score_sz = {p.mb, p.head_num, p.seq_len, p.seq_len}; + const dims scale_sz = {1}; + const dims mask_sz = {p.mb, 1, 1, p.seq_len}; + + // Incremental IDs used to create logical tensors and operations. + size_t id = 0; + + // score = query x key.T + auto query = logical_tensor(id++, dt, qkv_sz, layout_type::strided); + auto key = logical_tensor(id++, dt, qkv_sz, layout_type::strided); + auto score = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto bmm1 = op(id++, op::kind::MatMul, "bmm1"); + bmm1.set_attr(op::attr::transpose_b, true); + bmm1.add_inputs({query, key}); + bmm1.add_outputs({score}); + + // scaled_score = score / scale + auto scale = logical_tensor(id++, dt, scale_sz, layout_type::strided); + auto scaled_score + = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto scale_div = op(id++, op::kind::Divide, "scale_div"); + scale_div.add_inputs({score, scale}); + scale_div.add_outputs({scaled_score}); + + // masked_score = scaled_score + mask + auto mask = logical_tensor(id++, dt, mask_sz, layout_type::strided); + auto masked_score + = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto mask_add = op(id++, op::kind::Add, "mask_add"); + mask_add.add_inputs({scaled_score, mask}); + mask_add.add_outputs({masked_score}); + + // attention_probs = softmax(masked_score) + auto probs = logical_tensor(id++, dt, score_sz, layout_type::strided); + auto softmax = op(id++, op::kind::SoftMax, "softmax"); + softmax.set_attr(op::attr::axis, -1); + softmax.add_inputs({masked_score}); + softmax.add_outputs({probs}); + + // attention_output = attention_probs x value + auto value = logical_tensor(id++, dt, qkv_sz, layout_type::strided); + auto output = logical_tensor(id++, dt, qkv_sz, layout_type::strided); + auto bmm2 = op(id++, op::kind::MatMul, "bmm2"); + bmm2.add_inputs({probs, value}); + bmm2.add_outputs({output}); + + // Construct a sdpa graph with engine kind and operations. + graph sdpa(ekind); + sdpa.add_op(bmm1); + sdpa.add_op(scale_div); + sdpa.add_op(mask_add); + sdpa.add_op(softmax); + sdpa.add_op(bmm2); + sdpa.finalize(); + + // Get partitions from the sdpa graph. + std::vector partitions = sdpa.get_partitions(); + // This is just for oneDNN testing purpose. + if (partitions.size() != 1) { + std::cout << "unsupported sdpa" << std::endl; + return; + } + + // Compile the partition with inputs, outputs, and an engine. + compiled_partition cp = partitions[0].compile( + {query, key, scale, mask, value}, {output}, eng); + + // Create tensor objects + auto ts_query = tensor(query, eng); + auto ts_key = tensor(key, eng); + auto ts_scale = tensor(scale, eng); + auto ts_mask = tensor(mask, eng); + auto ts_value = tensor(value, eng); + auto ts_output = tensor(output, eng); + + // Allocate user data. + std::vector query_data(product(qkv_sz)); + std::vector key_data(product(qkv_sz)); + std::vector scale_data(product(scale_sz), std::sqrt(p.head_size)); + std::vector mask_data(product(mask_sz)); + std::vector value_data(product(qkv_sz)); + std::vector output_data(product(qkv_sz)); + + fill_random(query_data); + fill_random(key_data); + fill_random(value_data); + fill_mask(mask_data, static_cast(p.seq_len)); + + // Write data to tensor object's handle. + write_to_dnnl_tensor(query_data.data(), ts_query); + write_to_dnnl_tensor(key_data.data(), ts_key); + write_to_dnnl_tensor(scale_data.data(), ts_scale); + write_to_dnnl_tensor(mask_data.data(), ts_mask); + write_to_dnnl_tensor(value_data.data(), ts_value); + + // Warmup run. + // Execute the compiled partition of sdpa. + cp.execute( + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); + + // Wait for the computation to finish. + strm.wait(); + + // First run. + auto start_first = std::chrono::steady_clock::now(); + cp.execute( + strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, {ts_output}); + strm.wait(); + auto end_first = std::chrono::steady_clock::now(); + std::chrono::duration dur_first + = end_first - start_first; + + if (quick_test) return; + + // Timing runs. + const int runs = std::max(min_runs, int(time_limit / dur_first.count())); + auto start = std::chrono::steady_clock::now(); + for (int i = 0; i <= runs; i++) + cp.execute(strm, {ts_query, ts_key, ts_scale, ts_mask, ts_value}, + {ts_output}); + strm.wait(); + auto end = std::chrono::steady_clock::now(); + std::chrono::duration duration = end - start; + + // Display the results. + double avg_time = (duration.count() - dur_first.count()) / runs; + std::cout << "runs: " << runs + 1 << " "; + std::cout << "avg_time: " << avg_time << " ms" << std::endl; + + return; +} + +void bad_args() { + std::cerr << "Usage: graph-sdpa-cpp [cpu|gpu]\n" + " graph-sdpa-cpp [cpu|gpu] " + "\n"; + throw std::invalid_argument("Incorrect input arguments."); +} + +void sdpa_perf(engine::kind ekind, int argc, char **argv) { + // default testing parameters + sdpa_dims_t params = {1, 128, 16, 64}; + + if (argc > 2) { + if (argc == 6) { + params.mb = std::atoi(argv[2]); + params.seq_len = std::atoi(argv[3]); + params.head_num = std::atoi(argv[4]); + params.head_size = std::atoi(argv[5]); + } else { + bad_args(); + } + + if (params.mb <= 0 || params.seq_len <= 0 || params.head_num <= 0 + || params.head_size <= 0) { + bad_args(); + } + } + + bench_sdpa(ekind, data_type::f32, params, 4000.0 /*ms*/); + bench_sdpa(ekind, data_type::bf16, params, 4000.0 /*ms*/); + bench_sdpa(ekind, data_type::f16, params, 4000.0 /*ms*/); +} + +int main(int argc, char **argv) { + return handle_example_errors( + sdpa_perf, parse_engine_kind(argc, argv, 4), argc, argv); +}