Skip to content

Commit

Permalink
Add OP_JIT_MAX_THREADS
Browse files Browse the repository at this point in the history
  • Loading branch information
bozbez committed Jan 17, 2025
1 parent 10e6b68 commit 74c98bc
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion op2/include/op_f2c_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <sstream>
#include <thread>
#include <mutex>
#include <atomic>


#define NVRTC_SAFE_CALL(x) \
Expand Down Expand Up @@ -73,6 +74,9 @@ static bool jit_initialized = false;
static bool jit_enable = true;
static bool jit_seq_compile = false;
static bool jit_debug = false;
static int jit_max_threads = INT32_MAX;

static std::atomic_int jit_active_threads = 0;

static std::string jit_arch = "";

Expand Down Expand Up @@ -113,6 +117,20 @@ static void jit_init() {
jit_seq_compile = true;
}

char *max_threads_str = getenv("OP_JIT_MAX_THREADS");
if (max_threads_str != nullptr) {
int max_threads_int = -1;

try {
max_threads_int = std::stoi(max_threads_str);
} catch (...) {};

if (max_threads_int < 0)
std::printf("warning: OP_JIT_MAX_THREADS set to unsupported value: %s\n", max_threads_str);
else
jit_max_threads = max_threads_int;
}

int device;
CUDA_SAFE_CALL(cudaGetDevice(&device));

Expand Down Expand Up @@ -361,6 +379,8 @@ class KernelInfo {
}

std::thread compile(uint64_t hash) {
++jit_active_threads;

std::string jit_src = std::string("#include <op_f2c_prelude.h>\n") +
std::string("#include <op_f2c_params.h>\n") +
std::string("\nnamespace f2c = op::f2c;\n") +
Expand Down Expand Up @@ -409,6 +429,7 @@ class KernelInfo {
std::forward_as_tuple(hash), std::forward_as_tuple(cubin, m_name));

assert(inserted);
--jit_active_threads;
};

std::thread compilation_thread(do_compile, jit_src, hash);
Expand Down Expand Up @@ -495,7 +516,7 @@ class KernelInfo {

m_jit_kernels_mutex.unlock();

if (hash_elem->second.count > 8 && !hash_elem->second.jit_started) {
if (hash_elem->second.count > 8 && !hash_elem->second.jit_started && jit_active_threads < jit_max_threads) {
if (jit_debug) std::printf("compiling %s for hash %lx\n", m_name.c_str(), hash);

hash_elem->second.jit_started = true;
Expand Down

0 comments on commit 74c98bc

Please sign in to comment.