Skip to content

Commit

Permalink
fix Uniform sample method (PaddlePaddle#37823)
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki authored Dec 9, 2021
1 parent 34a06cf commit 491d4f0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/paddle/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def sample(self, shape, seed=0):
else:
output_shape = shape + batch_shape
output = nn.uniform_random(
output_shape, seed=seed, dtype=self.dtype) * (tensor.zeros(
output_shape, dtype=self.dtype, min=0., max=1.,
seed=seed) * (tensor.zeros(
output_shape, dtype=self.dtype) + (self.high - self.low))
output = elementwise_add(output, self.low, name=name)
if self.all_arg_is_float:
Expand Down
23 changes: 23 additions & 0 deletions python/paddle/fluid/tests/unittests/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,29 @@ def init_static_data(self, batch_size, dims):
name='values', shape=[dims], dtype='float32')


class UniformTestSample(unittest.TestCase):
def setUp(self):
self.init_param()

def init_param(self):
self.low = 3.0
self.high = 4.0

def test_uniform_sample(self):
paddle.disable_static()
uniform = Uniform(low=self.low, high=self.high)
s = uniform.sample([100])
self.assertTrue((s >= self.low).all())
self.assertTrue((s < self.high).all())
paddle.enable_static()


class UniformTestSample2(UniformTestSample):
def init_param(self):
self.low = -5.0
self.high = 2.0


class NormalNumpy(DistributionNumpy):
def __init__(self, loc, scale):
self.loc = np.array(loc)
Expand Down

0 comments on commit 491d4f0

Please sign in to comment.