Skip to content

Commit

Permalink
Delete unused autograd functions (#3856)
Browse files Browse the repository at this point in the history
  • Loading branch information
colesbury authored Nov 24, 2017
1 parent 9bbf4ee commit ed64001
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 1,403 deletions.
1 change: 0 additions & 1 deletion test/run_test.bat
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ echo Running torch tests

echo Running autograd tests
%PYCMD% test_autograd.py
%PYCMD% test_potrf.py

echo Running sparse tests
%PYCMD% test_sparse.py
Expand Down
1 change: 0 additions & 1 deletion test/run_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ $PYCMD test_torch.py $@

echo "Running autograd tests"
$PYCMD test_autograd.py $@
$PYCMD test_potrf.py $@

echo "Running torch.distributions tests"
$PYCMD test_distributions.py $@
Expand Down
41 changes: 38 additions & 3 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
import math
import torch
import unittest
import warnings
import random
from copy import deepcopy
from collections import OrderedDict
from itertools import product
from operator import mul
from functools import reduce
import torch.nn.functional as F
from torch.autograd.gradcheck import gradgradcheck, gradcheck
from torch.autograd.function import once_differentiable
from torch.autograd.profiler import profile

from common import TestCase, run_tests, skipIfNoLapack
from torch.autograd._functions import *
from torch.autograd import Variable, Function
from torch.autograd.function import InplaceFunction

if sys.version_info[0] == 2:
import cPickle as pickle
Expand Down Expand Up @@ -1571,6 +1569,43 @@ def test_dir(self):
for key in keys:
self.assertTrue(hasattr(x, key))

@skipIfNoLapack
def test_potrf_gradient(self):
def _calc_deriv_numeric(A, L, upper):
# numerical forward derivative
dA = Variable(_make_cov(5))
eps = 1e-6
outb = torch.potrf(A + (eps / 2) * dA, upper)
outa = torch.potrf(A - (eps / 2) * dA, upper)
dL = (outb - outa) / eps

return dA, dL

def _calc_deriv_sym(A, L, upper):
# reverse mode
Lbar = Variable(torch.rand(5, 5).tril())
if upper:
Lbar = Lbar.t()
L.backward(Lbar)
Abar = A.grad

return Abar, Lbar

def _check_total_variation(A, L, upper):
dA, dL = _calc_deriv_numeric(A, L, upper)
Abar, Lbar = _calc_deriv_sym(A, L, upper)

# compare df = Tr(dA^T Abar) = Tr(dL^T Lbar)
df1 = (dL * Lbar).sum()
df2 = (dA * Abar).sum()

self.assertEqual(df1, df2, prec=1e-3)

for upper in [True, False]:
A = Variable(_make_cov(5), requires_grad=True)
L = torch.potrf(A, upper)
_check_total_variation(A, L, upper)

def test_as_strided(self):
x = Variable(torch.arange(0, 25).view(5, 5), requires_grad=True)

Expand Down
52 changes: 0 additions & 52 deletions test/test_potrf.py

This file was deleted.

6 changes: 0 additions & 6 deletions torch/autograd/_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,2 @@
from .basic_ops import *
from .tensor import *
from .pointwise import *
from .reduce import *
from .linalg import *
from .blas import *
from .compare import *
from .initializers import *
35 changes: 1 addition & 34 deletions torch/autograd/_functions/basic_ops.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
import torch
from ..function import Function, InplaceFunction, traceable
from .utils import maybe_unexpand, maybe_unexpand_or_view
from ..function import Function, traceable
import math


def sort_args(a, b, key=torch.is_tensor):
return (a, b, True) if key(a) else (b, a, False)


def gen_inputs(g, a, b):
tensor, constant, tensor_first = sort_args(a, b, key=is_node)
assert tensor.hasType()
type = str(tensor.type().scalarType())
broadcast = False
if len(tensor.type().sizes()) > 1:
broadcast = True
constant = g.constant(constant, [0], type).setTypeAs(tensor)
return tensor, constant, broadcast, tensor_first


@traceable
class PowConstant(Function):

Expand All @@ -41,24 +29,3 @@ def backward(ctx, grad_output):
else:
var_result, = ctx.saved_variables
return None, grad_output.mul(var_result).mul_(math.log(ctx.constant))


@traceable
class Negate(InplaceFunction):

@staticmethod
def symbolic(g, i, inplace=False):
# See Note [Export inplace]
return g.op("Scale", i, scale_f=-1)

@staticmethod
def forward(ctx, i, inplace=False):
if inplace:
ctx.mark_dirty(i)
return i.neg_()
else:
return i.neg()

@staticmethod
def backward(ctx, grad_output):
return grad_output.neg(), None
181 changes: 0 additions & 181 deletions torch/autograd/_functions/blas.py

This file was deleted.

Loading

0 comments on commit ed64001

Please sign in to comment.