Skip to content

Commit

Permalink
Fix type signature of in-place NN functions (#4389)
Browse files Browse the repository at this point in the history
This is a step towards removing the special casing of NN functions in gen_variable_type.py. It fixes the signature of in-place NN functions so that they return Tensor & instead of Tensor.
  • Loading branch information
colesbury authored Dec 28, 2017
1 parent af3bffb commit 98f7191
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 21 deletions.
12 changes: 9 additions & 3 deletions aten/src/ATen/native/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,15 @@ Tensor selu(const Tensor & self) {
}

Tensor & selu_(Tensor & self) {
// TODO: at::elu_ should return `Tensor &`
at::elu_(self, SELU_ALPHA, SELU_SCALE);
return self;
return at::elu_(self, SELU_ALPHA, SELU_SCALE);
}

Tensor rrelu(const Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
return at::rrelu_with_noise(self, self.type().tensor(), lower, upper, training, generator);
}

Tensor & rrelu_(Tensor & self, Scalar lower, Scalar upper, bool training, Generator* generator) {
return at::rrelu_with_noise_(self, self.type().tensor(), lower, upper, training, generator);
}

}} // namespace at::native
6 changes: 6 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@
CPU: RoiPooling2d_backward_cpu
CUDA: RoiPooling2d_backward_cuda

- func: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor
variants: function

- func: rrelu_(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr) -> Tensor
variants: function

- func: select(Tensor self, int64_t dim, int64_t index) -> Tensor

- func: selu(Tensor self) -> Tensor
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/nn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,10 @@
- name: prelu(Tensor self, Tensor weight)
cname: PReLU

- name: rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr)
# NOTE: we treat noise as an input (it's really a buffer) because the codegen
# can't handle in-place functions that have buffers
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=false, Generator* generator=nullptr)
cname: RReLU
buffers: [noise]
has_inplace: True

- name: softmax(Tensor self, int64_t dim)
Expand Down
15 changes: 9 additions & 6 deletions aten/src/ATen/nn_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def arg_expr(prefix, suffix):
name = arg.name
if name == 'state':
continue
if inplace and name == 'output':
name = 'self'
aten_name = camel_to_snake(SUBSTITUTIONS.get(name, name))
parts = aten_name.split('_')
if aten_name in params_by_name:
Expand Down Expand Up @@ -213,7 +215,7 @@ def unique_args(argslist):
return result


def function_info(name, arguments, cimpls, buffers, backends):
def function_info(name, arguments, cimpls, buffers, backends, inplace):
"""
cimpls contains information use to call into THNN:
cname: THNN function name
Expand All @@ -225,7 +227,7 @@ def function_info(name, arguments, cimpls, buffers, backends):
'name': name,
'types': ['Float', 'Double', 'Half'], # Half will be stripped for CPU backend
'arguments': arguments,
'return': get_return(arguments),
'return': 'argument 0' if inplace else get_return(arguments),
'buffers': buffers,
'backends': backends,
'cimpls': cimpls,
Expand All @@ -240,14 +242,15 @@ def base_declaration(func, thnn_function, backends, inplace=False):
name += '_'
params = params.split(', ')
arguments = [argument_to_declaration(a, func) for a in params]
arguments += output_arguments(thnn_function)
if not inplace:
arguments += output_arguments(thnn_function)
buffers = [argument_to_declaration('Tensor ' + buf)
for buf in func.get('buffers', [])]

thnn_args = get_thnn_args(thnn_function, arguments + buffers, inplace)
cimpl = {'cname': thnn_function.name, 'arguments': thnn_args}

return function_info(name, arguments, [cimpl], buffers, backends)
return function_info(name, arguments, [cimpl], buffers, backends, inplace)


def forward_declaration(base, thnn_function, inplace=False):
Expand All @@ -265,7 +268,7 @@ def forward_declaration(base, thnn_function, inplace=False):
arguments = remove_unused_args(arguments, thnn_args)
cimpl = {'cname': thnn_function.name, 'arguments': thnn_args}

return function_info(name, arguments, [cimpl], [], base['backends'])
return function_info(name, arguments, [cimpl], [], base['backends'], inplace)


def backward_declaration(base, thnn_functions):
Expand Down Expand Up @@ -327,7 +330,7 @@ def get_condition(func):
cimpl['condition'] = get_condition(func)
cimpls.append(cimpl)

return function_info(name, arguments, cimpls, [], base['backends'])
return function_info(name, arguments, cimpls, [], base['backends'], False)


def parse_nn_yaml(filename):
Expand Down
12 changes: 6 additions & 6 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -731,11 +731,11 @@
- name: prelu(Tensor self, Tensor weight)
self, weight: prelu_backward(grad, self, weight, grad_input_mask)

- name: rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_backward(grad, output, lower, upper, training, noise)
- name: rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_with_noise_backward(grad, output, noise, lower, upper, training)

- name: rrelu_(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_backward(grad, output, lower, upper, training, noise)
- name: rrelu_with_noise_(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator generator)
self: rrelu_with_noise_backward(grad, output, noise, lower, upper, training)

- name: softmax(Tensor self, int64_t dim)
self: softmax_backward(grad, self, dim, output)
Expand Down Expand Up @@ -958,8 +958,8 @@
- name: prelu_backward(Tensor grad_output, Tensor self, Tensor weight, std::array<bool,2> output_mask)
grad_output, self, weight: prelu_double_backward(grads[0], grads[1], grad_output, self, weight, grad_input_mask)

- name: rrelu_backward(Tensor grad_output, Tensor self, Scalar lower, Scalar upper, bool training, Tensor noise)
grad_output: rrelu_backward(grad, self, lower, upper, training, noise)
- name: rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training)
grad_output: rrelu_with_noise_backward(grad, self, noise, lower, upper, training)
self: zeros_like(grad)

- name: reflection_pad1d_backward(Tensor grad_output, Tensor self, IntList padding)
Expand Down
10 changes: 9 additions & 1 deletion tools/autograd/load_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@ def create_autograd_function(name, derivatives, num_inputs, buffers=None):


def create_derivative(declaration, formula, output_indices, var_names):
returns = [r for r in declaration['returns'] if r.get('name') != 'self']
def transform_return(r):
# In-place functions take in and return self. Call the modified version
# "output" so that it can be referred to in derivative definitions.
if r['name'] == 'self':
r = copy.deepcopy(r)
r['name'] = 'output'
return r

returns = [transform_return(r) for r in declaration['returns']]
arguments = declaration['arguments']
formula, saved_inputs = saved_variables(formula, arguments)
formula, saved_outputs = saved_variables(formula, returns)
Expand Down
6 changes: 3 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,11 +730,11 @@ def rrelu(input, lower=1. / 8, upper=1. / 3, training=False, inplace=False):
Randomized leaky ReLU.
"""
if inplace:
return torch._C._nn.rrelu_(input, lower, upper, training)
return torch._C._nn.rrelu(input, lower, upper, training)
return torch._C._VariableBase.rrelu_(input, lower, upper, training)
return torch._C._VariableBase.rrelu(input, lower, upper, training)


rrelu_ = _add_docstr(torch._C._nn.rrelu_, r"""
rrelu_ = _add_docstr(torch._C._VariableBase.rrelu_, r"""
rrelu_(input, lower=1./8, upper=1./3, training=False) -> Variable
In-place version of :func:`~rrelu`.
Expand Down

0 comments on commit 98f7191

Please sign in to comment.