Skip to content

Commit

Permalink
Ensure all test in Sonnet are run when test.sh is called and fixes to…
Browse files Browse the repository at this point in the history
… enable this

Also remove the dataset fetching test from within examples as we shouldn't be reading from/downloading files within tests

PiperOrigin-RevId: 274140680
Change-Id: I872e4251a1a8815afbe566c3499986bb3be31429
  • Loading branch information
tamaranorman authored and sonnet-copybara committed Oct 11, 2019
1 parent adae0b2 commit 275af72
Show file tree
Hide file tree
Showing 10 changed files with 13 additions and 20 deletions.
2 changes: 1 addition & 1 deletion docs/ext/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package(default_visibility = ["//sonnet:__subpackages__"])
package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"])

licenses(["notice"]) # Apache 2.0 License

Expand Down
2 changes: 1 addition & 1 deletion docs/ext/link_tf_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import print_function

from absl.testing import absltest
from sonnet.docs.ext import link_tf_api
from docs.ext import link_tf_api
import tensorflow as tf

DOC_BASE_URL = "https://www.tensorflow.org/versions/r2.0/api_docs/python/tf"
Expand Down
15 changes: 3 additions & 12 deletions examples/simple_mnist_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from __future__ import print_function

import sonnet as snt
from sonnet.examples import simple_mnist
from examples import simple_mnist
from sonnet.src import test_utils
import tensorflow as tf

Expand All @@ -42,20 +42,11 @@ def test_train_epoch(self):
(tf.random.normal([2, 8, 8, 1]),
tf.ones([2], dtype=tf.int64))).batch(2).repeat(4)

loss = simple_mnist.train_epoch(model, optimizer, dataset)
for _ in range(3):
loss = simple_mnist.train_epoch(model, optimizer, dataset)
self.assertEqual(loss.shape, [])
self.assertEqual(loss.dtype, tf.float32)

def test_get_dataset(self):
dataset = simple_mnist.mnist("train", 4)

batch = next(iter(dataset))

self.assertEqual(batch[0].shape, [4, 28, 28, 1])
self.assertEqual(batch[1].shape, [4])
self.assertEqual(batch[0].dtype, tf.float32)
self.assertEqual(batch[1].dtype, tf.int64)

def test_test_accuracy(self):
model = snt.Sequential([
snt.Flatten(),
Expand Down
2 changes: 2 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
mock>=3.0.5
tensorflow-datasets>1
docutils
2 changes: 1 addition & 1 deletion sonnet/src/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package(default_visibility = ["//sonnet:__subpackages__"])
package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"])

licenses(["notice"]) # Apache 2.0 License

Expand Down
2 changes: 1 addition & 1 deletion sonnet/src/conformance/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package(
default_testonly = True,
default_visibility = ["//sonnet:__subpackages__"],
default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"],
)

licenses(["notice"]) # Apache 2.0 License
Expand Down
2 changes: 1 addition & 1 deletion sonnet/src/conformance/checkpoints/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package(
default_testonly = True,
default_visibility = ["//sonnet:__subpackages__"],
default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"],
)

licenses(["notice"]) # Apache 2.0 License
Expand Down
2 changes: 1 addition & 1 deletion sonnet/src/nets/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package(default_visibility = ["//sonnet:__subpackages__"])
package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"])

licenses(["notice"]) # Apache 2.0 License

Expand Down
2 changes: 1 addition & 1 deletion sonnet/src/nets/dnc/BUILD
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Description:
# Differentiable Neural Computer

package(default_visibility = ["//sonnet:__subpackages__"])
package(default_visibility = ["//sonnet:__subpackages__", "//docs/ext:__subpackages__", "//examples:__subpackages__"])

load("//sonnet/src:build_defs.bzl", "snt_py_library", "snt_py_test")

Expand Down
2 changes: 1 addition & 1 deletion test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ python3 -c 'import tensorflow as tf; print(tf.__version__)'
bazel test --jobs=${N_JOBS} --test_timeout 300,450,1200,3600 \
--build_tests_only --test_output=errors \
--cache_test_results=no \
-- //sonnet/...
-- //...

# Test docs still build.
cd docs/
Expand Down

0 comments on commit 275af72

Please sign in to comment.