Skip to content

Commit

Permalink
TST: Implemented an unused test for np.random.randint
Browse files Browse the repository at this point in the history
In numpy/random/tests/test_random.py, a class called TestSingleEltArrayInput had a method called test_randint that was commented out, with the instructions to uncomment it once np.random.randint was able to broadcast arguments. Since np.random.randint has been able to broadcast arguments for a while now, I uncommented the test. The only modification I made to the code was fixing a small error, where the author incorrectly tried to call "assert_equal" as a method of the TestSingleEltArrayInput instead of a function that was imported from numpy.testing. I ran runtests.py, and the new test passed.
  • Loading branch information
MatteoRaso committed Aug 26, 2022
1 parent 50a74fb commit 1647f46
Showing 1 changed file with 16 additions and 17 deletions.
33 changes: 16 additions & 17 deletions numpy/random/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,23 +1712,22 @@ def test_two_arg_funcs(self):
out = func(self.argOne, argTwo[0])
assert_equal(out.shape, self.tgtShape)

# TODO: Uncomment once randint can broadcast arguments
# def test_randint(self):
# itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
# np.int32, np.uint32, np.int64, np.uint64]
# func = np.random.randint
# high = np.array([1])
# low = np.array([0])
#
# for dt in itype:
# out = func(low, high, dtype=dt)
# self.assert_equal(out.shape, self.tgtShape)
#
# out = func(low[0], high, dtype=dt)
# self.assert_equal(out.shape, self.tgtShape)
#
# out = func(low, high[0], dtype=dt)
# self.assert_equal(out.shape, self.tgtShape)
def test_randint(self):
itype = [bool, np.int8, np.uint8, np.int16, np.uint16,
np.int32, np.uint32, np.int64, np.uint64]
func = np.random.randint
high = np.array([1])
low = np.array([0])

for dt in itype:
out = func(low, high, dtype=dt)
assert_equal(out.shape, self.tgtShape)

out = func(low[0], high, dtype=dt)
assert_equal(out.shape, self.tgtShape)

out = func(low, high[0], dtype=dt)
assert_equal(out.shape, self.tgtShape)

def test_three_arg_funcs(self):
funcs = [np.random.noncentral_f, np.random.triangular,
Expand Down

0 comments on commit 1647f46

Please sign in to comment.