Skip to content

Commit

Permalink
[Bugfix] Cython CAPI holding GIL causes deadlock when Python callback…
Browse files Browse the repository at this point in the history
… is asynchronous (dmlc#4036)

* cython nogil

* move APIs to internal and add unit test

* fix lint

* disable callback array test
  • Loading branch information
jermainewang authored May 25, 2022
1 parent 230b886 commit 3c129ad
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 13 deletions.
2 changes: 1 addition & 1 deletion python/dgl/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ cdef extern from "dgl/runtime/c_runtime_api.h":
int* type_codes,
int num_args,
DGLValue* ret_val,
int* ret_type_code)
int* ret_type_code) nogil
int DGLFuncFree(DGLFunctionHandle func)
int DGLCFuncSetReturn(DGLRetValueHandle ret,
DGLValue* value,
Expand Down
14 changes: 10 additions & 4 deletions python/dgl/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -208,8 +208,11 @@ cdef inline int FuncCall3(void* chandle,
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
with nogil:
ret = DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode)
if ret != 0:
raise DGLError(py_str(DGLGetLastError()))
return 0

cdef inline int FuncCall(void* chandle,
Expand All @@ -229,8 +232,11 @@ cdef inline int FuncCall(void* chandle,
temp_args = []
for i in range(nargs):
make_arg(args[i], &values[i], &tcodes[i], temp_args)
CALL(DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode))
with nogil:
ret = DGLFuncCall(chandle, &values[0], &tcodes[0],
nargs, ret_val, ret_tcode)
if ret != 0:
raise DGLError(py_str(DGLGetLastError()))
return 0


Expand Down
50 changes: 50 additions & 0 deletions src/api/api_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*!
* Copyright (c) 2022 by Contributors
* \file api/api_test.cc
* \brief C APIs for testing FFI
*/
#include <dgl/runtime/ndarray.h>
#include <dgl/runtime/container.h>
#include <dgl/runtime/registry.h>
#include <dgl/packed_func_ext.h>
#include <thread>

namespace dgl {
namespace runtime {

// Register an internal API for testing python callback.
// It receives two arguments:
// - The python callback function.
// - The argument to pass to the python callback
// It returns what python callback returns
DGL_REGISTER_GLOBAL("_TestPythonCallback")
.set_body([](DGLArgs args, DGLRetValue* rv) {
LOG(INFO) << "Inside C API";
PackedFunc fn = args[0];
DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);
fn.CallPacked(cb_args, rv);
});

// Register an internal API for testing python callback.
// It receives two arguments:
// - The python callback function.
// - The argument to pass to the python callback
// It returns what python callback returns
//
// The API runs the python callback in a separate thread to test
// python GIL is properly released.
DGL_REGISTER_GLOBAL("_TestPythonCallbackThread")
.set_body([](DGLArgs args, DGLRetValue* rv) {
LOG(INFO) << "Inside C API";
PackedFunc fn = args[0];
auto thr = std::make_shared<std::thread>(
[fn, args, rv]() {
LOG(INFO) << "Callback thread " << std::this_thread::get_id();
DGLArgs cb_args(args.values + 1, args.type_codes + 1, 1);
fn.CallPacked(cb_args, rv);
});
thr->join();
});

} // namespace runtime
} // namespace dgl
8 changes: 0 additions & 8 deletions tests/compute/test_cython.py

This file was deleted.

39 changes: 39 additions & 0 deletions tests/compute/test_ffi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import dgl
import numpy as np
import backend as F
import unittest, pytest
import os

@unittest.skipIf(os.name == 'nt', reason='Cython only works on linux')
def test_cython():
import dgl._ffi._cy3.core

@pytest.mark.parametrize('arg', [1, 2.3])
def test_callback(arg):
def cb(x):
return x + 1
ret = dgl._api_internal._TestPythonCallback(cb, arg)
assert ret == arg + 1

@pytest.mark.parametrize('dtype', [F.float32, F.float64, F.int32, F.int64])
def _test_callback_array(dtype):
def cb(x):
return F.to_dgl_nd(F.from_dgl_nd(x) + 1)
arg = F.copy_to(F.tensor([1, 2, 3], dtype=dtype), F.ctx())
ret = F.from_dgl_nd(dgl._api_internal._TestPythonCallback(cb, F.to_dgl_nd(arg)))
assert np.allclose(F.asnumpy(ret), F.asnumpy(arg) + 1)

@pytest.mark.parametrize('arg', [1, 2.3])
def test_callback_thread(arg):
def cb(x):
return x + 1
ret = dgl._api_internal._TestPythonCallbackThread(cb, arg)
assert ret == arg + 1

@pytest.mark.parametrize('dtype', [F.float32, F.float64, F.int32, F.int64])
def _test_callback_array_thread(dtype):
def cb(x):
return F.to_dgl_nd(F.from_dgl_nd(x) + 1)
arg = F.copy_to(F.tensor([1, 2, 3], dtype=dtype), F.ctx())
ret = F.from_dgl_nd(dgl._api_internal._TestPythonCallbackThread(cb, F.to_dgl_nd(arg)))
assert np.allclose(F.asnumpy(ret), F.asnumpy(arg) + 1)

0 comments on commit 3c129ad

Please sign in to comment.