Skip to content

Commit

Permalink
Use local RandomState instead of seeding the global RNG (keras-team#1…
Browse files Browse the repository at this point in the history
…2259)

* Use local RandomState instead of seeding the global RNG

* Create a unit test module for datasets and move tests there

* Move initializer test to the proper file
  • Loading branch information
YuriyGuts authored and fchollet committed Feb 24, 2019
1 parent 7dd1c00 commit 91ccb28
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 9 deletions.
4 changes: 2 additions & 2 deletions keras/datasets/boston_housing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ def load_data(path='boston_housing.npz', test_split=0.2, seed=113):
x = f['x']
y = f['y']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(x))
np.random.shuffle(indices)
rng.shuffle(indices)
x = x[indices]
y = y[indices]

Expand Down
6 changes: 3 additions & 3 deletions keras/datasets/imdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def load_data(path='imdb.npz', num_words=None, skip_top=0,
x_train, labels_train = f['x_train'], f['y_train']
x_test, labels_test = f['x_test'], f['y_test']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(x_train))
np.random.shuffle(indices)
rng.shuffle(indices)
x_train = x_train[indices]
labels_train = labels_train[indices]

indices = np.arange(len(x_test))
np.random.shuffle(indices)
rng.shuffle(indices)
x_test = x_test[indices]
labels_test = labels_test[indices]

Expand Down
4 changes: 2 additions & 2 deletions keras/datasets/reuters.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def load_data(path='reuters.npz', num_words=None, skip_top=0,
with np.load(path) as f:
xs, labels = f['x'], f['y']

np.random.seed(seed)
rng = np.random.RandomState(seed)
indices = np.arange(len(xs))
np.random.shuffle(indices)
rng.shuffle(indices)
xs = xs[indices]
labels = labels[indices]

Expand Down
5 changes: 3 additions & 2 deletions keras/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,9 +248,10 @@ def __call__(self, shape, dtype=None):
num_rows *= dim
num_cols = shape[-1]
flat_shape = (num_rows, num_cols)
rng = np.random
if self.seed is not None:
np.random.seed(self.seed)
a = np.random.normal(0.0, 1.0, flat_shape)
rng = np.random.RandomState(self.seed)
a = rng.normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
# Pick the one with the correct shape.
q = u if u.shape == flat_shape else v
Expand Down
90 changes: 90 additions & 0 deletions tests/keras/datasets/datasets_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import tempfile

import numpy as np
import pytest

from keras.datasets import boston_housing
from keras.datasets import imdb
from keras.datasets import reuters


@pytest.fixture
def fake_downloaded_boston_path(monkeypatch):
num_rows = 100
num_cols = 10
rng = np.random.RandomState(123)

x = rng.randint(1, 100, size=(num_rows, num_cols))
y = rng.normal(loc=100, scale=15, size=num_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x=x, y=y)
monkeypatch.setattr(boston_housing, 'get_file',
lambda *args, **kwargs: f.name)
yield f.name


@pytest.fixture
def fake_downloaded_imdb_path(monkeypatch):
train_rows = 100
test_rows = 20
seq_length = 10
rng = np.random.RandomState(123)

x_train = rng.randint(1, 100, size=(train_rows, seq_length))
y_train = rng.binomial(n=1, p=0.5, size=train_rows)
x_test = rng.randint(1, 100, size=(test_rows, seq_length))
y_test = rng.binomial(n=1, p=0.5, size=test_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
monkeypatch.setattr(imdb, 'get_file', lambda *args, **kwargs: f.name)
yield f.name


@pytest.fixture
def fake_downloaded_reuters_path(monkeypatch):
num_rows = 100
seq_length = 10
rng = np.random.RandomState(123)

x = rng.randint(1, 100, size=(num_rows, seq_length))
y = rng.binomial(n=1, p=0.5, size=num_rows)

with tempfile.NamedTemporaryFile('wb', delete=True) as f:
np.savez(f, x=x, y=y)
monkeypatch.setattr(reuters, 'get_file', lambda *args, **kwargs: f.name)
yield f.name


def test_boston_load_does_not_affect_global_rng(fake_downloaded_boston_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
boston_housing.load_data(path=fake_downloaded_boston_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


def test_imdb_load_does_not_affect_global_rng(fake_downloaded_imdb_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
imdb.load_data(path=fake_downloaded_imdb_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


def test_reuters_load_does_not_affect_global_rng(fake_downloaded_reuters_path):
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
reuters.load_data(path=fake_downloaded_reuters_path, seed=9876)
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)
12 changes: 12 additions & 0 deletions tests/keras/initializers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ def test_orthogonal(tensor_shape):
target_mean=0.)


def test_orthogonal_init_does_not_affect_global_rng():
np.random.seed(1337)
before = np.random.randint(0, 100, size=10)

np.random.seed(1337)
init = initializers.orthogonal(seed=9876)
init(shape=(10, 5))
after = np.random.randint(0, 100, size=10)

assert np.array_equal(before, after)


@pytest.mark.parametrize('tensor_shape',
[(100, 100), (10, 20), (30, 80), (1, 2, 3, 4)],
ids=['FC', 'RNN', 'RNN_INVALID', 'CONV'])
Expand Down

0 comments on commit 91ccb28

Please sign in to comment.