Skip to content

Commit

Permalink
Shard test_nn to reduce runtime for each test target (#8678)
Browse files Browse the repository at this point in the history
* Shard test_nn to reduce runtime for each test target

* Use load_tests for selecting tests to enable

* fix lint

* Use arg parser from common.py
  • Loading branch information
yf225 authored Jun 20, 2018
1 parent 9335885 commit d6c873a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 4 deletions.
4 changes: 2 additions & 2 deletions test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
torch.manual_seed(SEED)


def run_tests():
unittest.main(argv=UNITTEST_ARGS)
def run_tests(argv=UNITTEST_ARGS):
unittest.main(argv=argv)

PY3 = sys.version_info > (3, 0)
PY34 = sys.version_info >= (3, 4)
Expand Down
29 changes: 27 additions & 2 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from functools import wraps, reduce
from operator import mul
from collections import OrderedDict
import hashlib
import sys

import torch
import torch.backends.cudnn as cudnn
Expand All @@ -28,7 +30,7 @@
from torch.nn.parallel._functions import Broadcast
from common import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, \
TEST_SCIPY, IS_WINDOWS, download_file, PY3, PY34, to_gpu, \
get_function_arglist, skipCUDAMemoryLeakCheckIf
get_function_arglist, skipCUDAMemoryLeakCheckIf, parser
from common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, \
TEST_CUDNN_VERSION
from common_nn import NNTestCase, ModuleTest, CriterionTest, TestBase, \
Expand Down Expand Up @@ -7609,4 +7611,27 @@ def __call__(self, input):


if __name__ == '__main__':
run_tests()
parser.add_argument(
'--num-shards',
type=int,
required=False,
help='number of shards')
parser.add_argument(
'--shard',
type=int,
required=False,
help='which shard to run')

args, remaining = parser.parse_known_args()
unittest_args = [sys.argv[0]] + remaining

if args.num_shards is not None and args.shard is not None:
def load_tests(loader, tests, pattern):
test_suite = unittest.TestSuite()
for test_group in tests:
for test in test_group:
if int(hashlib.sha256(str(test).encode('utf-8')).hexdigest(), 16) % args.num_shards == args.shard:
test_suite.addTest(test)
return test_suite

run_tests(unittest_args)

0 comments on commit d6c873a

Please sign in to comment.