diff --git a/python/paddle/nn/functional/input.py b/python/paddle/nn/functional/input.py index ab5a000a2bfc62..40b9441c2dc001 100644 --- a/python/paddle/nn/functional/input.py +++ b/python/paddle/nn/functional/input.py @@ -192,6 +192,13 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): x=label, weight=weight, sparse=True, name="embedding") """ + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + weight.shape[0] + padding_idx) + + if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]: + raise ValueError("padding_idx must be within [-{}, {})".format( + weight.shape[0], weight.shape[0])) + if in_dygraph_mode(): return core.ops.lookup_table_v2( weight, x, 'is_sparse', sparse, 'is_distributed', False, @@ -206,12 +213,6 @@ def embedding(x, weight, padding_idx=None, sparse=False, name=None): remote_prefetch = sparse and (not is_distributed) tmp = helper.create_variable_for_type_inference(dtype) - padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( - weight.shape[0] + padding_idx) - - if padding_idx >= weight.shape[0] or padding_idx < -weight.shape[0]: - raise ValueError("padding_idx must be within [-{}, {})".format( - weight.shape[0], weight.shape[0])) helper.append_op( type='lookup_table_v2', diff --git a/python/paddle/nn/layer/common.py b/python/paddle/nn/layer/common.py index 6e3910745e1579..cf8aa7a66e3a76 100644 --- a/python/paddle/nn/layer/common.py +++ b/python/paddle/nn/layer/common.py @@ -1103,8 +1103,7 @@ def __init__(self, self._embedding_dim = embedding_dim self._sparse = sparse self._is_distributed = False - self._padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( - num_embeddings + padding_idx) + self._padding_idx = padding_idx if self._num_embeddings <= 0: raise ValueError("num_embeddings must be gather than 0") @@ -1112,7 +1111,10 @@ def __init__(self, if self._embedding_dim <= 0: raise ValueError("embedding_dim must be gather than 0") - if self._padding_idx >= num_embeddings or self._padding_idx < -num_embeddings: + padding_idx = -1 if padding_idx is None else padding_idx if padding_idx >= 0 else ( + num_embeddings + padding_idx) + + if padding_idx >= num_embeddings or padding_idx < -num_embeddings: raise ValueError("padding_idx must be within [-{}, {})".format( num_embeddings, num_embeddings))