Skip to content

Commit

Permalink
Merge pull request tkipf#6 from rusty1s/master
Browse files Browse the repository at this point in the history
PyTorch 0.3 Fix
  • Loading branch information
tkipf authored Mar 7, 2018
2 parents d3e4f39 + 49c9400 commit 7474897
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 14 deletions.
23 changes: 10 additions & 13 deletions pygcn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,18 @@ class SparseMM(torch.autograd.Function):
does-pytorch-support-autograd-on-sparse-matrix/6156/7
"""

def forward(self, matrix1, matrix2):
self.save_for_backward(matrix1, matrix2)
return torch.mm(matrix1, matrix2)
def __init__(self, sparse):
super(SparseMM, self).__init__()
self.sparse = sparse

def backward(self, grad_output):
matrix1, matrix2 = self.saved_tensors
grad_matrix1 = grad_matrix2 = None
def forward(self, dense):
return torch.mm(self.sparse, dense)

def backward(self, grad_output):
grad_input = None
if self.needs_input_grad[0]:
grad_matrix1 = torch.mm(grad_output, matrix2.t())

if self.needs_input_grad[1]:
grad_matrix2 = torch.mm(matrix1.t(), grad_output)

return grad_matrix1, grad_matrix2
grad_input = torch.mm(self.sparse.t(), grad_output)
return grad_input


class GraphConvolution(Module):
Expand All @@ -56,7 +53,7 @@ def reset_parameters(self):

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = SparseMM()(adj, support)
output = SparseMM(adj)(support)
if self.bias is not None:
return output + self.bias
else:
Expand Down
2 changes: 1 addition & 1 deletion pygcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
idx_val = idx_val.cuda()
idx_test = idx_test.cuda()

features, adj, labels = Variable(features), Variable(adj), Variable(labels)
features, labels = Variable(features), Variable(labels)


def train(epoch):
Expand Down

0 comments on commit 7474897

Please sign in to comment.