Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
VoVAllen authored Oct 26, 2021
1 parent 579cd3e commit a9c83bc
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/dgl/backend/pytorch/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ def spmm_cache_Y(binary_op, reduce_op, req_grad_X, req_grad_Y):

def spmm_cache_argX(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argX in SpMM forward stage."""
if req_grad_X:
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False


def spmm_cache_argY(binary_op, reduce_op, req_grad_X, req_grad_Y):
"""Rules to identify whether to cache argY in SpMM forward stage."""
if req_grad_Y:
if req_grad_X or req_grad_Y:
if reduce_op in ['min', 'max']:
return True
return False
Expand Down

0 comments on commit a9c83bc

Please sign in to comment.