Skip to content

Commit

Permalink
[MPS] Add scalar params to the softplus key. (pytorch#94256)
Browse files Browse the repository at this point in the history
  • Loading branch information
kulinseth authored and pytorchmergebot committed Feb 7, 2023
1 parent 9358726 commit ca74105
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions aten/src/ATen/native/mps/operations/Activation.mm
Original file line number Diff line number Diff line change
Expand Up @@ -1419,7 +1419,8 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSScalar threshold_scalar = getMPSScalar(threshold, ScalarType::Float);

@autoreleasepool {
string key = "softplus_out_mps:" + getTensorsStringKey({self});
string key = "softplus_out_mps:" + getTensorsStringKey({self}) + ":" +
std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down Expand Up @@ -1524,7 +1525,8 @@ Tensor glu_backward_mps (const Tensor& grad_output,
MPSStream* stream = getCurrentMPSStream();

@autoreleasepool {
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self});
string key = "softplus_backward_out_mps:" + getTensorsStringKey({grad_output, self}) + ":" +
std::to_string(beta.to<double>()) + ":" + std::to_string(threshold.to<double>());

CachedGraph* cachedGraph = static_cast<CachedGraph *>(cache_->LookUp(key));
if(!cachedGraph) {
Expand Down
8 changes: 4 additions & 4 deletions test/test_mps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4651,7 +4651,7 @@ def helper(shape, dim=0):

# Test softplus
def test_softplus(self):
def helper(shape, beta=0.5, threshold=0.5):
def helper(shape, beta=1, threshold=20):
cpu_x = torch.randn(shape, device='cpu', dtype=torch.float, requires_grad=True)
x = cpu_x.detach().clone().to('mps').requires_grad_()

Expand All @@ -4669,9 +4669,9 @@ def helper(shape, beta=0.5, threshold=0.5):

# Test empty shape too
for shape in [(), (2, 3), (10, 10), (2, 3, 4, 5)]:
helper(shape)
helper(shape, beta=0.6, threshold=0.6) # relu path
helper(shape, beta=1, threshold=20) # softplus path
for beta in [0.5, 1, 2, 3, 4]:
for threshold in [0.5, 20, 30, 40, 50]:
helper(shape, beta, threshold)

# Test silu

Expand Down

0 comments on commit ca74105

Please sign in to comment.