Skip to content

Commit

Permalink
fix dtype checking for softmax (PaddlePaddle#51929)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangting2020 authored Mar 22, 2023
1 parent 2b98993 commit 5984144
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,15 +1110,15 @@ def softmax(x, axis=-1, dtype=None, name=None):
use_cudnn = True
if dtype is None:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'softmax'
x, 'x', ['float16', 'bfloat16', 'float32', 'float64'], 'softmax'
)
else:
check_dtype(
dtype,
'dtype',
['float32', 'float64'],
['float16', 'bfloat16', 'float32', 'float64'],
'softmax',
'If dtype is not None, it only support float32 or float64.',
'If dtype is not None, it only support float16, bfloat16, float32 or float64.',
)

helper = LayerHelper("softmax", **locals())
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/nn/layer/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,7 +1324,7 @@ def __init__(self, axis=-1, name=None):
self._name = name

def forward(self, x):
return F.softmax(x, self._axis, self._dtype, self._name)
return F.softmax(x, self._axis, name=self._name)

def extra_repr(self):
name_str = ', name={}'.format(self._name) if self._name else ''
Expand Down

0 comments on commit 5984144

Please sign in to comment.