Skip to content

Commit

Permalink
[aot] Support push_arg for Kernel in C++ wrapper (taichi-dev#6419)
Browse files Browse the repository at this point in the history
fixes taichi-dev#6413 

### Brief Summary
Note we also changed the way to set vec/matrix args the previous `set`
didn't save users from counting index manually. cc: @YuCrazing
  • Loading branch information
ailzhang authored Oct 24, 2022
1 parent 5bb7a4c commit ecc7e8d
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
22 changes: 16 additions & 6 deletions c_api/include/taichi/cpp/taichi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,17 +507,27 @@ class Kernel {
return at(i);
}

// Temporary workaround for setting vec/matrix arguments in a flattened way.
template <typename T>
void set(uint32_t i, const std::vector<T> &v) {
if (i + v.size() >= args_.size()) {
args_.resize(i + v.size());
}
void push_arg(const std::vector<T> &v) {
int idx = args_.size();
// Temporary workaround for setting vec/matrix arguments in a flattened way.
args_.resize(args_.size() + v.size());
for (int j = 0; j < v.size(); ++j) {
at(i + j) = v[j];
at(idx + j) = v[j];
}
}

template <typename T>
void push_arg(const T &arg) {
int idx = args_.size();
args_.resize(idx + 1);
at(idx) = arg;
}

void clear_args() {
args_.clear();
}

void launch(uint32_t argument_count, const TiArgument *arguments) {
ti_launch_kernel(runtime_, kernel_, argument_count, arguments);
}
Expand Down
10 changes: 7 additions & 3 deletions c_api/tests/c_api_aot_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@ static void kernel_aot_test(TiArch arch) {

std::vector<int> arg2_v = {1, 2, 3};

k_run[0] = arg0_val;
k_run[1] = arg1_array;
k_run.set(2, arg2_v);
// This is just to make sure clear_args() does its work.
k_run.push_arg(arg0_val);
k_run.clear_args();

k_run.push_arg(arg0_val);
k_run.push_arg(arg1_array);
k_run.push_arg(arg2_v);
k_run.launch();
runtime.wait();

Expand Down

0 comments on commit ecc7e8d

Please sign in to comment.