Skip to content

Commit

Permalink
fix broadcast_to
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanchengFang committed Oct 2, 2022
1 parent 75520fb commit cb43bab
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions hw1/python/needle/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,14 +183,13 @@ def compute(self, a):

def gradient(self, out_grad: Tensor, node: Tensor):
### BEGIN YOUR SOLUTION
shape = node.inputs[0].shape
j = 0
shape = list(node.inputs[0].shape) # (10, ) -> (2, 10)
axes = []
shape = [1] * (len(self.shape) - len(shape)) + shape
for i, s in enumerate(self.shape):
if j >= len(shape) or s != shape[j]:
if i >= len(shape) or s != shape[i]:
axes.append(i)
j += 1
return reshape(summation(out_grad, tuple(axes)), shape)
return reshape(summation(out_grad, tuple(axes)), node.inputs[0].shape)
### END YOUR SOLUTION


Expand Down

0 comments on commit cb43bab

Please sign in to comment.