Skip to content

Commit

Permalink
Use at most one shared_ptr block at a time to manage THPFunctions (py…
Browse files Browse the repository at this point in the history
…torch#1454)

* Fix failing ln in build_all.sh

* Use at most one shared_ptr block at a time to manage THPFunctions
  • Loading branch information
apaszke authored and soumith committed May 3, 2017
1 parent e1278d4 commit 72e8190
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 4 deletions.
22 changes: 21 additions & 1 deletion torch/csrc/autograd/python_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ static void THPFunction_dealloc(THPFunction* self)
{
PyObject_GC_UnTrack(self);
THPFunction_clear(self);
self->cdata_ptr.~weak_ptr();
self->cdata.~PyFunction();
Py_TYPE(self)->tp_free((PyObject*)self);
}
Expand All @@ -282,6 +283,7 @@ PyObject *THPFunction_new(PyTypeObject *type, PyObject *args, PyObject *kwargs)
// most fields
THPFunction* self = (THPFunction*)obj;
new (&self->cdata) torch::autograd::PyFunction(obj);
new (&self->cdata_ptr) std::weak_ptr<torch::autograd::PyFunction>();
self->cdata.num_inputs = -1;
self->cdata.is_stochastic = PyObject_IsInstance(obj, THPStochasticFunctionClass);
return obj;
Expand Down Expand Up @@ -998,11 +1000,29 @@ struct Decref {
}
};

// Similar to shared_from_this. There's a problem that the Python object
// and its cdata depend on each other being alive, so we can't keep
// shared_ptrs as members, but we'd like to be able to manage the lifetime of
// the objects using shared_ptrs in the C++ graph. The only way to get a new
// shared_ptr that references them is through THPFunction_asFunction. When
// called for the first time it will allocate a new shared_ptr and save a
// weak_ptr in cdata_ptr attr. Later, when we try to take another reference,
// we'll try to lock cdata_ptr and return its value if successful. Otherwise it
// means that all shared_ptrs returned previously have been freed, so we can
// create a new one. This ensures that this object is managed by at most one
// shared_ptr control block at any time - a guarantee we depend on in other places
// (e.g. we use weak_ptrs in SavedVariable because we know it won't go out of scope).
std::shared_ptr<PyFunction> THPFunction_asFunction(THPFunction* self)
{
if (!self) {
return std::shared_ptr<PyFunction>();
}
Py_INCREF((PyObject*)self);
return std::shared_ptr<PyFunction>(&self->cdata, Decref());

auto ptr = self->cdata_ptr.lock();
if (ptr) return ptr;

ptr = std::shared_ptr<PyFunction>(&self->cdata, Decref());
self->cdata_ptr = ptr;
return ptr;
}
2 changes: 2 additions & 0 deletions torch/csrc/autograd/python_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ struct THPFunction {
std::vector<bool> *is_variable_input;
char has_freed_buffers;

// See a comment in THPFucntion_asFunction for details about this field.
std::weak_ptr<torch::autograd::PyFunction> cdata_ptr;
torch::autograd::PyFunction cdata;
};

Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/autograd/variable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ auto SavedVariable::unpack() -> std::shared_ptr<Variable> {
// should have saved the grad accumulator. Even if the Variable no longer
// alive, the accumulator should be kept alive by the references in the graph).
if (requires_grad && !grad_fn && weak_grad_fn.expired() && grad_accumulator.expired())
throw std::logic_error("No grad accumulator for a saved leaf!");
throw std::logic_error("No grad accumulator for a saved leaf!");
new_var->grad_accumulator = grad_accumulator;

return new_var;
Expand Down
6 changes: 4 additions & 2 deletions torch/lib/build_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ BASIC_C_FLAGS=" -DTH_INDEX_BASE=0 -I$INSTALL_DIR/include \
LDFLAGS="-L$INSTALL_DIR/lib "
LD_POSTFIX=".so.1"
LD_POSTFIX_UNVERSIONED=".so"
if [[ $(uname) == 'Darwin' ]]; then
if [[ $(uname) == 'Darwin' ]]; then
LDFLAGS="$LDFLAGS -Qunused-arguments -Wl,-rpath,@loader_path"
LD_POSTFIX=".1.dylib"
LD_POSTFIX_UNVERSIONED=".dylib"
Expand Down Expand Up @@ -93,7 +93,9 @@ function build_nccl() {
-DCMAKE_CXX_FLAGS="$C_FLAGS $CPP_FLAGS"
make install
cp "lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so.1"
ln -s "${INSTALL_DIR}/lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so"
if [ ! -f "${INSTALL_DIR}/lib/libnccl.so" ]; then
ln -s "${INSTALL_DIR}/lib/libnccl.so.1" "${INSTALL_DIR}/lib/libnccl.so"
fi
cd ../..
}

Expand Down

0 comments on commit 72e8190

Please sign in to comment.