Skip to content

Commit

Permalink
fix: fix torch's pre_hook
Browse files Browse the repository at this point in the history
  • Loading branch information
uchuhimo committed Sep 24, 2021
1 parent b8a6fee commit e0cb605
Show file tree
Hide file tree
Showing 13 changed files with 217 additions and 307 deletions.
16 changes: 13 additions & 3 deletions cc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
cmake_minimum_required(VERSION 3.17)
cmake_minimum_required(VERSION 3.17...3.21)
project(Amanda)

set(CMAKE_CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard to use")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

set(default_build_type "Release")
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
message(STATUS "Setting build type to '${default_build_type}' as none was specified.")
set(CMAKE_BUILD_TYPE "${default_build_type}" CACHE
STRING "Choose the type of build." FORCE)
# Set the possible values of build type for cmake-gui
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS
"Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/../cmake/modules)
list(APPEND CMAKE_FIND_LIBRARY_SUFFIXES .so.1)

find_package(Python COMPONENTS Interpreter Development NumPy)
find_package(TensorFlow REQUIRED)
# find_package(CUDA 9 REQUIRED)

# set necessary flags
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${SSE_FLAGS} -march=native -fopenmp -D_GLIBCXX_USE_CXX11_ABI=${TensorFlow_ABI}")
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ dependencies:
- pip>=19.3
- cudatoolkit=10.0
- pip:
- poetry==1.1.7
- poetry==1.1.10
243 changes: 90 additions & 153 deletions poetry.lock

Large diffs are not rendered by default.

39 changes: 19 additions & 20 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
absl-py==0.13.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.2.0" and python_version >= "3.7"
absl-py==0.14.0; python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.2.0" and python_version >= "3.7"
aiofiles==0.7.0; python_version >= "3.6" and python_version < "4.0"
aiohttp==3.7.4.post0; python_version >= "3.6"
alabaster==0.7.12; python_version >= "3.6"
Expand All @@ -21,7 +21,7 @@ certifi==2021.5.30; python_version >= "3.6" and python_full_version < "3.0.0" or
cffi==1.14.6; sys_platform == "linux" and python_version >= "3.6"
cfgv==3.3.1; python_full_version >= "3.6.1"
chardet==4.0.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
charset-normalizer==2.0.5; python_full_version >= "3.6.0" and python_version >= "3.6"
charset-normalizer==2.0.6; python_full_version >= "3.6.0" and python_version >= "3.6"
click==8.0.1; python_version >= "3.6"
colorama==0.4.4; platform_system == "Windows" and python_version >= "3.7" and python_full_version >= "3.6.2" and sys_platform == "win32" and (python_version >= "3.6" and python_full_version < "3.0.0" and sys_platform == "win32" or sys_platform == "win32" and python_version >= "3.6" and python_full_version >= "3.5.0") and (python_version >= "3.6" and python_full_version < "3.0.0" and platform_system == "Windows" or python_full_version >= "3.5.0" and python_version >= "3.6" and platform_system == "Windows")
coverage==5.5; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.5.0" and python_version < "4")
Expand All @@ -41,7 +41,7 @@ dephell-specifier==0.2.2; python_version >= "3.6"
dephell-venvs==0.1.18; python_version >= "3.6"
dephell-versioning==0.1.2; python_version >= "3.6"
dephell==0.8.3; python_version >= "3.6"
distlib==0.3.2; python_full_version >= "3.6.1"
distlib==0.3.3; python_full_version >= "3.6.1"
docker==5.0.2; python_version >= "3.6"
dockerpty==0.4.1; python_version >= "3.6"
docutils==0.17.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
Expand All @@ -56,12 +56,11 @@ google-pasta==0.2.0
googledrivedownloader==0.4; python_version >= "3.6"
gprof2dot==2021.2.21
graphviz==0.17; python_version >= "3.6"
grpcio-tools==1.40.0; python_version >= "3.6"
grpcio==1.40.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.2.0" and python_version >= "3.6"
grpcio==1.40.0; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.2.0"
guppy3==3.1.1; python_version >= "3.6"
h5py==2.10.0
html5lib==1.1; python_version >= "3.6" and python_full_version < "3.0.0" or python_version >= "3.6" and python_full_version >= "3.5.0"
identify==2.2.14; python_full_version >= "3.6.1"
identify==2.2.15; python_full_version >= "3.6.1"
idna==3.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6"
imagesize==1.2.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6"
immutables==0.16; python_version >= "3.6"
Expand Down Expand Up @@ -90,18 +89,18 @@ matplotlib-inline==0.1.3; python_version >= "3.7"
mccabe==0.6.1; python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0"
mistune==0.8.4; python_version >= "3.6"
mmdnn==0.3.1
more-itertools==8.9.0; python_version >= "3.5"
more-itertools==8.10.0; python_version >= "3.5"
moreorless==0.4.0; python_version >= "3.6"
multidict==5.1.0; python_version >= "3.6"
mypy-extensions==0.4.3; python_full_version >= "3.6.2" and python_version >= "3.5"
mypy-protobuf==2.9; python_version >= "3.6"
mypy-protobuf==2.10; python_version >= "3.6"
mypy==0.910; python_version >= "3.5"
networkx==2.6.3; python_version >= "3.7"
nodeenv==1.6.0; python_full_version >= "3.6.1"
numba==0.51.2; python_version >= "3.6"
numpy==1.18.5; python_version >= "3.5"
onnx==1.10.1
onnxruntime==1.8.1
onnxruntime==1.9.0
opt-einsum==3.3.0; python_version >= "3.7"
packaging==21.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
pandas==1.1.5; python_full_version >= "3.6.1" and python_version >= "3.6"
Expand All @@ -115,7 +114,7 @@ platformdirs==2.3.0; python_version >= "3.6" and python_full_version >= "3.6.2"
pluggy==1.0.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
pre-commit==2.15.0; python_full_version >= "3.6.1"
prompt-toolkit==3.0.20; python_full_version >= "3.6.2" and python_version >= "3.7"
protobuf==3.17.3
protobuf==3.18.0
ptyprocess==0.7.0; sys_platform != "win32" and python_version >= "3.7"
py==1.10.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
pybind11==2.7.1; (python_version >= "2.7" and python_version < "3.0") or (python_version > "3.0" and python_version < "3.1") or (python_version > "3.1" and python_version < "3.2") or (python_version > "3.2" and python_version < "3.3") or (python_version > "3.3" and python_version < "3.4") or (python_version > "3.4")
Expand All @@ -128,15 +127,15 @@ pyparsing==2.4.7; python_version >= "3.7" and python_full_version < "3.0.0" or p
pytest-dependency==0.5.1
pytest-forked==1.3.0; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
pytest-profiling==1.7.0
pytest-xdist==2.3.0; python_version >= "3.6"
pytest-xdist==2.4.0; python_version >= "3.6"
pytest==6.2.5; python_version >= "3.6"
python-dateutil==2.8.2; python_full_version >= "3.6.1" and python_version >= "3.6"
python-gnupg==0.4.7; python_version >= "3.6"
pytz==2021.1; python_full_version >= "3.6.1" and python_version >= "3.6"
pywin32-ctypes==0.2.0; sys_platform == "win32" and python_version >= "3.6"
pywin32==227; python_version >= "3.6" and sys_platform == "win32"
pyyaml==5.4.1; python_full_version >= "3.6.1" and python_version >= "3.6" and (python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.6.0") and (python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version >= "3.6")
rdflib==6.0.0; python_version >= "3.7"
rdflib==6.0.1; python_version >= "3.7"
readme-renderer==29.0; python_version >= "3.6"
regex==2021.8.28; python_full_version >= "3.6.2"
requests-toolbelt==0.9.1; python_version >= "3.6"
Expand All @@ -148,7 +147,7 @@ scikit-learn==0.24.2; python_version >= "3.6"
scipy==1.6.1; python_version >= "3.7"
secretstorage==3.3.1; sys_platform == "linux" and python_version >= "3.6"
shellingham==1.4.0; python_version >= "3.6"
six==1.16.0; python_full_version >= "3.6.1" and python_version >= "3.6" and (python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.6") and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.7")
six==1.16.0; python_version >= "3.6" and python_full_version >= "3.6.1" and (python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.6") and (python_version >= "3.7" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.7")
snowballstemmer==2.1.0; python_version >= "3.6"
sphinx==4.2.0; python_version >= "3.6"
sphinxcontrib-applehelp==1.0.2; python_version >= "3.6"
Expand All @@ -166,25 +165,25 @@ threadpoolctl==2.2.0; python_version >= "3.6"
toml==0.10.2; python_full_version >= "3.6.1" and python_version >= "3.5" and (python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.3.0" and python_version >= "3.6") and (python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0") and (python_version >= "2.7" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version < "4")
tomli==1.2.1; python_version >= "3.6" and python_full_version >= "3.6.2"
tomlkit==0.7.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.5.0" and python_version >= "3.6"
torch-geometric==2.0.0; python_version >= "3.6"
torch-geometric==2.0.1; python_version >= "3.6"
torch-scatter @ https://pytorch-geometric.com/whl/torch-1.8.0+cu102/torch_scatter-2.0.8-cp37-cp37m-linux_x86_64.whl ; python_version >= "3.6"
torch-sparse @ https://pytorch-geometric.com/whl/torch-1.8.0+cu102/torch_sparse-0.6.11-cp37-cp37m-linux_x86_64.whl ; python_version >= "3.6"
torch==1.8.1; python_full_version >= "3.6.2"
torchvision==0.9.1
tox==3.24.3; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.5.0")
tqdm==4.62.2; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6"
tox==3.24.4; (python_version >= "2.7" and python_full_version < "3.0.0") or (python_full_version >= "3.5.0")
tqdm==4.62.3; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.4.0" and python_version >= "3.6"
traitlets==5.1.0; python_version >= "3.7"
twine==3.4.2; python_version >= "3.6"
typed-ast==1.4.3; python_version < "3.8" and python_full_version >= "3.6.2" and python_version >= "3.5"
types-filelock==0.1.5
types-futures==3.3.0; python_version >= "3.6"
types-protobuf==3.17.4
types-setuptools==57.0.2
types-setuptools==57.4.0
types-six==1.16.1
types-toml==0.1.5
types-toml==0.10.0
typing-extensions==3.10.0.2
urllib3==1.26.6; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version < "4" and python_version >= "3.6"
virtualenv==20.7.2; python_full_version >= "3.6.1"
urllib3==1.26.7; python_version >= "3.6" and python_full_version < "3.0.0" or python_full_version >= "3.6.0" and python_version < "4" and python_version >= "3.6"
virtualenv==20.8.0; python_full_version >= "3.6.1"
volatile==2.1.0; python_version >= "3.6"
watchdog==2.1.5; python_version >= "3.6"
wcwidth==0.2.5; python_full_version >= "3.6.2" and python_version >= "3.7"
Expand Down
35 changes: 32 additions & 3 deletions src/amanda/conversion/listener/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,17 +1,46 @@
cmake_minimum_required(VERSION 3.17)
cmake_minimum_required(VERSION 3.17...3.21)
project(amanda_pybind)

set(CMAKE_CXX_STANDARD 14)
include(CMakePrintHelpers)

set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ standard to use")
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)

find_program(CCACHE_PROGRAM ccache)
if(CCACHE_PROGRAM)
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}") # CMake 3.9+
cmake_print_variables(CMAKE_CXX_COMPILER_LAUNCHER)
endif()

set(default_build_type "Release")
if(NOT CMAKE_BUILD_TYPE AND NOT CMAKE_CONFIGURATION_TYPES)
message(STATUS "Setting build type to '${default_build_type}' as none was specified.")
set(CMAKE_BUILD_TYPE "${default_build_type}" CACHE
STRING "Choose the type of build." FORCE)
# Set the possible values of build type for cmake-gui
set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS
"Debug" "Release" "MinSizeRel" "RelWithDebInfo")
endif()

find_package(Python COMPONENTS Interpreter Development)
find_package(Python REQUIRED COMPONENTS Interpreter Development)
find_package(Torch REQUIRED)
find_package(pybind11 CONFIG)

find_library(TORCH_PYTHON_LIBRARY torch_python PATHS "${TORCH_INSTALL_PREFIX}/lib")
cmake_print_variables(TORCH_PYTHON_LIBRARY)

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

include_directories(SYSTEM ${Python_INCLUDE_DIRS})

pybind11_add_module(amanda_pybind amanda_pybind.cpp)
target_link_libraries(amanda_pybind PRIVATE ${PYTHON_LIBRARIES})
target_link_libraries(amanda_pybind PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(amanda_pybind PRIVATE ${TORCH_PYTHON_LIBRARY})

add_executable(listener_test listener.cpp)
target_link_libraries(listener_test PRIVATE ${PYTHON_LIBRARIES})
target_link_libraries(listener_test PRIVATE ${TORCH_LIBRARIES})
target_link_libraries(listener_test PRIVATE ${TORCH_PYTHON_LIBRARY})
6 changes: 2 additions & 4 deletions src/amanda/conversion/listener/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
- `listener\` contains src for pybind listener.
- build it first
```
mkdir build
cd build
cmake -DCMAKE_PREFIX_PATH=`python -c "import torch;print(torch.utils.cmake_prefix_path)"`\;`python -m pybind11 --cmakedir` ..
make amanda_pybind
cmake -DCMAKE_PREFIX_PATH=`python -c "import torch;print(torch.utils.cmake_prefix_path)"`\;`python -m pybind11 --cmakedir` -S src/amanda/conversion/listener -B src/amanda/conversion/listener/build
cmake --build src/amanda/conversion/listener/build
```
107 changes: 20 additions & 87 deletions src/amanda/conversion/listener/function_pre_hook.cpp
Original file line number Diff line number Diff line change
@@ -1,111 +1,39 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/autograd/function_hook.h>
#include <torch/csrc/autograd/python_cpp_function.h>
#include <torch/csrc/autograd/python_hook.h>
#include <torch/csrc/autograd/python_variable.h>
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/function.h>

#include <torch/torch.h>
#include <torch/extension.h>
#include <iostream>

#include <pybind11/pybind11.h>
#include <pybind11/functional.h>

using torch::autograd::variable_list;

template<class T>
class amanda_THPPointer {
public:
amanda_THPPointer(): ptr(nullptr) {};
explicit amanda_THPPointer(T *ptr) noexcept : ptr(ptr) {};
amanda_THPPointer(amanda_THPPointer &&p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; };

~amanda_THPPointer() { free(); };
T * get() { return ptr; }
const T * get() const { return ptr; }
T * release() { T *tmp = ptr; ptr = nullptr; return tmp; }
operator T*() { return ptr; }
amanda_THPPointer& operator =(T *new_ptr) noexcept { free(); ptr = new_ptr; return *this; }
amanda_THPPointer& operator =(amanda_THPPointer &&p) noexcept { free(); ptr = p.ptr; p.ptr = nullptr; return *this; }
T * operator ->() { return ptr; }
explicit operator bool() const { return ptr != nullptr; }

private:
void free() {if (ptr) Py_DECREF(ptr);}
T *ptr = nullptr;
};

PyObject* THPVariableClass = nullptr;
bool amanda_pybind_init = false;

// auto tensor_module = amanda_THPPointer<PyObject>(PyImport_ImportModule("torch.tensor"));
// PyObject* THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
// bool amanda_pybind_init = true;

void init_THPVariableClass()
{
if (amanda_pybind_init) {
return ;
}
else {
auto tensor_module = amanda_THPPointer<PyObject>(PyImport_ImportModule("torch.tensor"));
THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
amanda_pybind_init = true;
return ;
}
}

PyObject* amanda_THPVariable_NewWithVar(PyTypeObject* type, torch::autograd::Variable var)
{
PyObject* obj = type->tp_alloc(type, 0);
if (obj) {
auto v = (THPVariable*) obj;
new (&v->cdata) torch::autograd::Variable(std::move(var));
torch::autograd::impl::set_pyobj(v->cdata, obj);
}
return obj;
auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch.tensor"));
THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
}

PyObject * amanda_THPVariable_Wrap(torch::autograd::Variable var)
{
if (!var.defined()) {
Py_RETURN_NONE;
}

if (auto obj = torch::autograd::impl::pyobj(var)) {
Py_INCREF(obj);
return obj;
}
return amanda_THPVariable_NewWithVar((PyTypeObject *)THPVariableClass, std::move(var));
}

PyObject* amanda_wrap_variables(const variable_list& c_variables)
static PyObject* amanda_wrap_variables(const variable_list& c_variables)
{
size_t num_vars = c_variables.size();
amanda_THPPointer<PyObject> tuple(PyTuple_New(num_vars));
THPObjectPtr tuple(PyTuple_New(num_vars));
if (!tuple) throw python_error();
for (size_t i = 0; i < num_vars; ++i) {
amanda_THPPointer<PyObject> var(amanda_THPVariable_Wrap(c_variables[i]));
THPObjectPtr var(THPVariable_Wrap(c_variables[i]));
if (!var) throw python_error();
PyTuple_SET_ITEM(tuple.get(), i, var.release());
}
return tuple.release();

}

inline bool amanda_THPVariable_Check(PyObject *obj)
static variable_list amanda_unwrap_variables(PyObject *py_variables)
{
return THPVariableClass && PyObject_IsInstance(obj, THPVariableClass);
}

variable_list amanda_unwrap_variables(PyObject* py_variables) {
variable_list results(PyTuple_GET_SIZE(py_variables));
for (size_t i = 0; i < results.size(); i++) {
PyObject* item = PyTuple_GET_ITEM(py_variables, i);
if (item == Py_None) {
continue;
} else if (amanda_THPVariable_Check(item)) {
} else if (THPVariable_Check(item)) {
results[i] = ((THPVariable*)item)->cdata;
} else {
// this should never happen, but just in case...
Expand All @@ -120,18 +48,25 @@ variable_list amanda_unwrap_variables(PyObject* py_variables) {
class AmandaPreHook : public torch::autograd::FunctionPreHook
{
public:
AmandaPreHook(PyObject* fn): fn_(fn) {}
AmandaPreHook(PyObject* fn): fn_(fn) {
Py_INCREF(fn_);
}

~AmandaPreHook()
{
pybind11::gil_scoped_acquire gil;
Py_DECREF(fn_);
}

variable_list operator()(const variable_list& _inputs) override
{
pybind11::gil_scoped_acquire gil;
// wrap cpp vector<torch::autograd::Variable> _inputs -> PyObject inputs
amanda_THPPointer<PyObject> inputs(amanda_wrap_variables(_inputs));
THPObjectPtr inputs(amanda_wrap_variables(_inputs));
// call python function from cpp as PyObject
amanda_THPPointer<PyObject> res(PyObject_CallFunctionObjArgs(fn_, inputs.get(), nullptr));
THPObjectPtr res(PyObject_CallFunctionObjArgs(fn_, inputs.get(), nullptr));
// unwarp PyObject into vector<torch::autograd::Variable>
amanda_THPPointer<PyObject> outputs = std::move(res);
return amanda_unwrap_variables(outputs.get());
return amanda_unwrap_variables(res.get());
}

protected:
Expand All @@ -140,8 +75,6 @@ class AmandaPreHook : public torch::autograd::FunctionPreHook

int amanda_add_pre_hook(const pybind11::object &grad_fn, const pybind11::object & hook)
{
assert(amanda_pybind_init==true);

PyObject *raw_grad_fn = grad_fn.ptr();
PyObject *raw_hook = hook.ptr();
torch::autograd::THPCppFunction *cast_grad_fn = (torch::autograd::THPCppFunction *)raw_grad_fn;
Expand Down
Loading

0 comments on commit e0cb605

Please sign in to comment.