Skip to content

Commit

Permalink
Replace async with non_blocking for Python 3.7 (pytorch#4999)
Browse files Browse the repository at this point in the history
* Replace async with non_blocking for Python 3.7 upgrade

* Remove trailing whitespace

* Give _cuda and _type kwargs and accept async for compatibility

* Rename async to non_blocking in all C++ code

* Add entries for async in python_variable_methods

* Friendlier backward compatibility for cuda and type
  • Loading branch information
goldsborough authored and soumith committed Feb 2, 2018
1 parent 8e22f84 commit 86fd5fd
Show file tree
Hide file tree
Showing 16 changed files with 91 additions and 71 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/UndefinedType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ const char * UndefinedType::typeString() {
return "UndefinedType";
}

Tensor & UndefinedType::s_copy_(Tensor & self, const Tensor & src, bool async) const {
Tensor & UndefinedType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
runtime_error("s_copy not defined for UndefinedType");
}

Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/UndefinedType.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct UndefinedType final : public Type {
virtual std::unique_ptr<Storage> unsafeStorageFromTH(void * th_pointer, bool retain) const override;
virtual Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const override;

virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool async) const override;
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const override;
};

} // namespace at
6 changes: 3 additions & 3 deletions aten/src/ATen/copy_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@
""")

COPY_ASYNC_CPU = CodeTemplate("""\
if (async) {
if (non_blocking) {
${THTensor}_copyAsyncCPU(${state,}self_->tensor, static_cast<${src_tensor}*>(src.pImpl)->tensor);
break;
}
""")

COPY_ASYNC_CUDA = CodeTemplate("""\
if (async) {
if (non_blocking) {
${THTensor}_copyAsyncCuda(${state,}self_->tensor, static_cast<${src_tensor}*>(src.pImpl)->tensor);
break;
}
Expand All @@ -44,7 +44,7 @@
""")

FUNCTION = CodeTemplate("""\
Tensor & ${Type}::s_copy_(Tensor & self, const Tensor & src, bool async) const {
Tensor & ${Type}::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
// code generated by function_wrapper
auto self_ = checked_cast_tensor<${Tensor}>(self.pImpl, "self", 0,false);
(void) self_; //silence unused warning
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/Tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ struct Tensor : public detail::TensorBase {
return pImpl->storage();
}
inline Tensor toType(const Type & t) const;
inline Tensor & copy_(const Tensor & src, bool async=false);
inline Tensor & copy_(const Tensor & src, bool non_blocking=false);
inline Tensor toType(ScalarType t) const;
inline Tensor toBackend(Backend b) const;

Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/templates/TensorMethods.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ inline Tensor Tensor::toType(const Type & t) const {
return t.copy(*this);
}

inline Tensor & Tensor::copy_(const Tensor & src, bool async) {
return type().copy_(*this, src, async);
inline Tensor & Tensor::copy_(const Tensor & src, bool non_blocking) {
return type().copy_(*this, src, non_blocking);
}

inline Tensor Tensor::toType(ScalarType t) const {
Expand Down
12 changes: 6 additions & 6 deletions aten/src/ATen/templates/Type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,25 @@ void Type::registerAll(Context * context) {
context->type_registry[static_cast<int>(Backend::Undefined)][static_cast<int>(ScalarType::Undefined)].reset(new UndefinedType(context));
}

Tensor & Type::copy_(Tensor & self, const Tensor & src, bool async) const {
Tensor & Type::copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
Tensor b_src;
std::tie(b_src) = expand_inplace(self, src, "copy");
return s_copy_(self, b_src, async);
return s_copy_(self, b_src, non_blocking);
}

Tensor Type::copy(const Tensor & src, bool async) const {
Tensor Type::copy(const Tensor & src, bool non_blocking) const {
AT_ASSERT(src.defined(), "attempt to copy an undefined tensor");
if (is_sparse()) {
auto indices = src._indices();
auto values = src._values();
auto & this_dense = toBackend(is_cuda() ? Backend::CUDA : Backend::CPU);
auto & this_dense_idx = this_dense.toScalarType(ScalarType::Long);
auto indices_copy = this_dense_idx.copy(indices, async);
auto values_copy = this_dense.copy(values, async);
auto indices_copy = this_dense_idx.copy(indices, non_blocking);
auto values_copy = this_dense.copy(values, non_blocking);
return sparse_coo_tensor(indices_copy, values_copy, src.sizes());
} else {
Tensor r = this->tensor(src.sizes());
r.copy_(src, async);
r.copy_(src, non_blocking);
return r;
}
}
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/templates/Type.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ struct AT_API Type {
// for external dispatch
virtual TypeID ID() const = 0;

Tensor copy(const Tensor & src, bool async=false) const;
Tensor & copy_(Tensor & self, const Tensor & src, bool async=false) const;
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool async) const = 0;
Tensor copy(const Tensor & src, bool non_blocking=false) const;
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const;
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;

Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const;
Tensor tensorFromBlob(void * data, IntList sizes, IntList strides, const std::function<void(void*)> & deleter=noop_deleter) const;
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/templates/TypeDerived.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ struct ${Type} final : public Type {
// example
// virtual Tensor * add(Tensor & a, Tensor & b) override;

virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool async) const override;
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const override;
${type_derived_method_declarations}
};

Expand Down
10 changes: 5 additions & 5 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1058,7 +1058,7 @@ def test_streams(self):
self.assertTrue(user_stream.query())
# copy 10 MB tensor from CPU-GPU which should take some time
tensor1 = torch.ByteTensor(10000000).pin_memory()
tensor2 = tensor1.cuda(async=True)
tensor2 = tensor1.cuda(non_blocking=True)
self.assertFalse(default_stream.query())
default_stream.synchronize()
self.assertTrue(default_stream.query())
Expand Down Expand Up @@ -1106,7 +1106,7 @@ def test_record_stream(self):
# Performs the CPU->GPU copy in a background stream
def perform_copy():
with torch.cuda.stream(stream):
tmp = t.cuda(async=True)
tmp = t.cuda(non_blocking=True)
ptr[0] = tmp.data_ptr()
torch.cuda.current_stream().wait_stream(stream)
tmp.record_stream(torch.cuda.current_stream())
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def test_caching_pinned_memory(self):
# check that the allocation is not re-used if it's in-use by a copy
gpu_tensor = torch.cuda.FloatTensor([0])
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the copy
gpu_tensor.copy_(t, async=True)
gpu_tensor.copy_(t, non_blocking=True)
del t
t = torch.FloatTensor([1]).pin_memory()
self.assertNotEqual(t.data_ptr(), ptr, 'allocation re-used too soon')
Expand All @@ -1164,14 +1164,14 @@ def test_caching_pinned_memory_multi_gpu(self):

with torch.cuda.device(1):
torch.cuda._sleep(int(50 * cycles_per_ms)) # delay the copy
gpu_tensor1.copy_(t, async=True)
gpu_tensor1.copy_(t, non_blocking=True)

del t
t = torch.FloatTensor([2]).pin_memory()
self.assertNotEqual(t.data_ptr(), ptr, 'allocation re-used too soon')

with torch.cuda.device(0):
gpu_tensor0.copy_(t, async=True)
gpu_tensor0.copy_(t, non_blocking=True)

self.assertEqual(gpu_tensor1[0], 1)
self.assertEqual(gpu_tensor0[0], 2)
Expand Down
4 changes: 2 additions & 2 deletions tools/autograd/templates/VariableType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ static bool isFloatingPoint(ScalarType s) {
return s == kFloat || s == kDouble || s == kHalf;
}

Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool async) const {
Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const {
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
// it automatically
auto& self_ = unpack(self, "self", 0);
Expand All @@ -392,7 +392,7 @@ Tensor & VariableType::s_copy_(Tensor & self, const Tensor & src, bool async) co
grad_fn->src_type = &src.type();
grad_fn->src_device = src.is_cuda() ? src.get_device() : -1;
}
baseType->s_copy_(self_, src_, async);
baseType->s_copy_(self_, src_, non_blocking);
increment_version(self);
rebase_history(self, std::move(grad_fn));
return self;
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/templates/VariableType.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct VariableType final : public at::Type {
static bool isVariableType(const at::Type& type);
static std::vector<at::Type*> allTypes();

virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool async) const override;
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const override;
${type_derived_method_declarations}

private:
Expand Down
16 changes: 9 additions & 7 deletions tools/autograd/templates/python_variable_methods.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,17 +191,18 @@ static PyObject * THPVariable_contiguous(PyObject* self, PyObject* args)
END_HANDLE_TH_ERRORS
}

static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool async) {
static Tensor dispatch_copy_(Tensor & self, const Tensor & other, bool non_blocking) {
AutoNoGIL no_gil;
AutoGPU auto_gpu(self);
return self.copy_(other, async);
return self.copy_(other, non_blocking);
}

static PyObject * THPVariable_copy_(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"copy_(Tensor other, bool async=False)"
"copy_(Tensor other, bool non_blocking=False)",
"copy_(Tensor other, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
PyObject* parsed_args[2];
Expand Down Expand Up @@ -299,7 +300,7 @@ static void lazy_init_cuda() {
});
}

static Tensor dispatch_type(const Tensor & self, const at::Type & type, int device, bool async) {
static Tensor dispatch_type(const Tensor & self, const at::Type & type, int device, bool non_blocking) {
if (type.is_cuda()) {
lazy_init_cuda();
}
Expand All @@ -308,7 +309,7 @@ static Tensor dispatch_type(const Tensor & self, const at::Type & type, int devi
int64_t tensor_device = self.is_cuda() ? self.get_device() : -1;
if (tensor_device != at::current_device()) {
// copy if the devices are different even if the types are the same
return type.copy(self, async);
return type.copy(self, non_blocking);
}
return self.toType(type);
}
Expand All @@ -332,7 +333,7 @@ static PyObject * THPVariable_cuda(PyObject* self, PyObject* args, PyObject* kwa
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"cuda(int64_t device=-1, bool async=False)"
"cuda(int64_t device=-1, bool non_blocking=False)"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
PyObject* parsed_args[2];
Expand Down Expand Up @@ -501,7 +502,8 @@ static PyObject * THPVariable_type(PyObject* self, PyObject* args, PyObject* kwa
{
HANDLE_TH_ERRORS
static PythonArgParser parser({
"type(PyObject* new_type=None, bool async=False)"
"type(PyObject* new_type=None, bool non_blocking=False)",
"type(PyObject* new_type=None, bool async=False)|deprecated"
});
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
PyObject* parsed_args[2];
Expand Down
8 changes: 4 additions & 4 deletions torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def add_docstr_all(method, docstr):

add_docstr_all('copy_',
r"""
copy_(src, async=False, broadcast=True) -> Tensor
copy_(src, non_blocking=False, broadcast=True) -> Tensor
Copies the elements from :attr:`src` into :attr:`self` tensor and returns
:attr:`self`.
Expand All @@ -339,9 +339,9 @@ def add_docstr_all(method, docstr):
Args:
src (Tensor): the source tensor to copy from
async (bool): if ``True`` and this copy is between CPU and GPU, the copy may
occur asynchronously with respect to the host. For other cases, this
argument has no effect.
non_blocking (bool): if ``True`` and this copy is between CPU and GPU,
the copy may occur asynchronously with respect to the host. For other
cases, this argument has no effect.
broadcast (bool): if ``True``, :attr:`src` will be broadcast to the shape of
the underlying tensor.
""")
Expand Down
48 changes: 33 additions & 15 deletions torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import torch
import importlib
import warnings
from collections import defaultdict


def _type(self, new_type=None, async=False):
def _type(self, new_type=None, non_blocking=False, **kwargs):
"""Returns the type if `new_type` is not provided, else casts this object to
the specified type.
Expand All @@ -12,11 +13,14 @@ def _type(self, new_type=None, async=False):
Args:
new_type (type or string): The desired type
async (bool): If ``True``, and the source is in pinned memory and
destination is on the GPU or vice versa, the copy is
performed asynchronously with respect to the host.
Otherwise, the argument has no effect.
non_blocking (bool): If ``True``, and the source is in pinned memory
and destination is on the GPU or vice versa, the copy is performed
asynchronously with respect to the host. Otherwise, the argument
has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking('type', non_blocking, kwargs)
if new_type is None:
return self.__module__ + '.' + self.__class__.__name__

Expand All @@ -29,27 +33,30 @@ def _type(self, new_type=None, async=False):
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
new_module_name = new_type.__module__.replace('.sparse', '')
new_values_type_name = new_module_name + '.' + new_type.__name__
new_values = self._values().type(new_values_type_name, async)
new_values = self._values().type(new_values_type_name, non_blocking)
new_indices_type_name = new_module_name + '.LongTensor'
new_indices = self._indices().type(new_indices_type_name, async)
new_indices = self._indices().type(new_indices_type_name, non_blocking)
return new_type(new_indices, new_values, self.size())
if new_type.is_sparse:
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
return new_type(self.size()).copy_(self, async)
return new_type(self.size()).copy_(self, non_blocking)


def _cuda(self, device=None, async=False):
def _cuda(self, device=None, non_blocking=False, **kwargs):
"""Returns a copy of this object in CUDA memory.
If this object is already in CUDA memory and on the correct device, then
no copy is performed and the original object is returned.
Args:
device (int): The destination GPU id. Defaults to the current device.
async (bool): If ``True`` and the source is in pinned memory, the copy will
be asynchronous with respect to the host. Otherwise, the
argument has no effect.
non_blocking (bool): If ``True`` and the source is in pinned memory,
the copy will be asynchronous with respect to the host. Otherwise,
the argument has no effect.
**kwargs: For compatibility, may contain the key ``async`` in place of
the ``non_blocking`` argument.
"""
non_blocking = _get_async_or_non_blocking('cuda', non_blocking, kwargs)
if self.is_cuda:
if device is None:
device = torch.cuda.current_device()
Expand All @@ -61,12 +68,23 @@ def _cuda(self, device=None, async=False):
with torch.cuda.device(device):
if self.is_sparse:
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
indices = self._indices().cuda(device, async)
values = self._values().cuda(device, async)
indices = self._indices().cuda(device, non_blocking)
values = self._values().cuda(device, non_blocking)
return new_type(indices, values, self.size())
else:
new_type = getattr(torch.cuda, self.__class__.__name__)
return new_type(self.size()).copy_(self, async)
return new_type(self.size()).copy_(self, non_blocking)


def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
if not kwargs:
return non_blocking
if len(kwargs) != 1 or 'async' not in kwargs:
message = "{}() got an unexpected keyword argument '{}'"
argument = list(kwargs.keys()).pop()
raise TypeError(message.format(function_name, argument))
warnings.warn("'async' is deprecated; use 'non_blocking'")
return kwargs['async']


def _rebuild_tensor(storage, storage_offset, size, stride):
Expand Down
Loading

0 comments on commit 86fd5fd

Please sign in to comment.