From a40e8ce7c553d61024fdf4f8f2b7b13ff606e77b Mon Sep 17 00:00:00 2001 From: Will Feng Date: Fri, 1 Feb 2019 12:42:28 -0800 Subject: [PATCH] Add train() / eval() / is_training() to C++ ScriptModule API (#16044) Summary: This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding `train()` / `eval()` / `is_training()` to C++ ScriptModule API. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16044 Differential Revision: D13857724 Pulled By: yf225 fbshipit-source-id: 16d3969fb5840ff7e66c7f72e800e6c75db8d2ff --- .gitignore | 1 + .jenkins/pytorch/test.sh | 2 ++ test/cpp/__init__.py | 0 test/cpp/jit/__init__.py | 0 test/cpp/jit/gtest.cpp | 1 + test/cpp/jit/no-gtest.cpp | 1 + test/cpp/jit/test_base.h | 5 +++++ test/cpp/jit/test_misc.h | 12 ++++++++++ test/cpp/jit/tests_setup.py | 41 ++++++++++++++++++++++++++++++++++ test/test_jit.py | 3 +++ torch/csrc/jit/script/module.h | 21 +++++++++++++++++ 11 files changed, 87 insertions(+) create mode 100644 test/cpp/__init__.py create mode 100644 test/cpp/jit/__init__.py create mode 100644 test/cpp/jit/tests_setup.py diff --git a/.gitignore b/.gitignore index 23838646a6..47ecb5a9f5 100644 --- a/.gitignore +++ b/.gitignore @@ -34,6 +34,7 @@ test/data/gpu_tensors.pt test/data/legacy_modules.t7 test/data/legacy_serialized.pt test/data/linear.pt +dropout_model.pt test/generated_type_hints_smoketest.py test/htmlcov test/cpp_extensions/install/ diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index d9bd2427bf..204e6d1061 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -148,12 +148,14 @@ test_torchvision() { test_libtorch() { if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then echo "Testing libtorch" + python test/cpp/jit/tests_setup.py setup CPP_BUILD="$PWD/../cpp-build" if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then "$CPP_BUILD"/caffe2/bin/test_jit else "$CPP_BUILD"/caffe2/bin/test_jit "[cpu]" fi + python test/cpp/jit/tests_setup.py shutdown python tools/download_mnist.py --quiet -d test/cpp/api/mnist OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$CPP_BUILD"/caffe2/bin/test_api assert_git_not_dirty diff --git a/test/cpp/__init__.py b/test/cpp/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/cpp/jit/__init__.py b/test/cpp/jit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/cpp/jit/gtest.cpp b/test/cpp/jit/gtest.cpp index d21ffc3a71..f177265811 100644 --- a/test/cpp/jit/gtest.cpp +++ b/test/cpp/jit/gtest.cpp @@ -21,6 +21,7 @@ JIT_TEST(CustomOperators) JIT_TEST(Differentiate) JIT_TEST(DifferentiateWithRequiresGrad) JIT_TEST(DynamicDAG) +JIT_TEST(EvalModeForLoadedModule) JIT_TEST(FromQualString) JIT_TEST(InternedStrings) JIT_TEST(IValue) diff --git a/test/cpp/jit/no-gtest.cpp b/test/cpp/jit/no-gtest.cpp index 250451c185..67b01a594f 100644 --- a/test/cpp/jit/no-gtest.cpp +++ b/test/cpp/jit/no-gtest.cpp @@ -20,6 +20,7 @@ std::string runJITCPPTests() { testDifferentiate(out); testDifferentiateWithRequiresGrad(out); testDynamicDAG(); + testEvalModeForLoadedModule(); testFromQualString(); testFusion(); testGraphExecutor(); diff --git a/test/cpp/jit/test_base.h b/test/cpp/jit/test_base.h index 9f6eb92214..b7cc81ea02 100644 --- a/test/cpp/jit/test_base.h +++ b/test/cpp/jit/test_base.h @@ -38,3 +38,8 @@ ASSERT_TRUE(threw); #endif // defined(USE_GTEST) + +bool isSandcastle() { + return ((std::getenv("SANDCASTLE")) || \ + (std::getenv("TW_JOB_USER") && std::string(std::getenv("TW_JOB_USER")) == "sandcastle")); +} diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index 0f8d9a64d6..ed7ba8085d 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -12,6 +12,7 @@ #include "torch/csrc/jit/custom_operator.h" #include "torch/csrc/jit/dynamic_dag.h" #include "torch/csrc/jit/fuser/interface.h" +#include "torch/csrc/jit/import.h" #include "torch/csrc/jit/interpreter.h" #include "torch/csrc/jit/passes/alias_analysis.h" #include "torch/csrc/jit/passes/common_subexpression_elimination.h" @@ -1418,6 +1419,17 @@ void testCustomOperators() { } } +void testEvalModeForLoadedModule() { + if (isSandcastle()) return; // The module file to load is not generated in Sandcastle + std::string module_path = "dropout_model.pt"; + std::shared_ptr module = torch::jit::load(module_path); + AT_ASSERT(module->get_module("dropout")->is_training()); + module->eval(); + AT_ASSERT(!module->get_module("dropout")->is_training()); + module->train(); + AT_ASSERT(module->get_module("dropout")->is_training()); +} + // test a few features that are not directly used in schemas yet void testSchemaParser() { // nested arrays diff --git a/test/cpp/jit/tests_setup.py b/test/cpp/jit/tests_setup.py new file mode 100644 index 0000000000..b80747c9cb --- /dev/null +++ b/test/cpp/jit/tests_setup.py @@ -0,0 +1,41 @@ +import sys +import os +import torch + +testEvalModeForLoadedModule_module_path = 'dropout_model.pt' + + +def testEvalModeForLoadedModule_setup(): + class Model(torch.jit.ScriptModule): + def __init__(self): + super(Model, self).__init__() + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.dropout(x) + return x + + model = Model() + model = model.train() + model.save(testEvalModeForLoadedModule_module_path) + + +def testEvalModeForLoadedModule_shutdown(): + if os.path.exists(testEvalModeForLoadedModule_module_path): + os.remove(testEvalModeForLoadedModule_module_path) + + +def setup(): + testEvalModeForLoadedModule_setup() + + +def shutdown(): + testEvalModeForLoadedModule_shutdown() + + +if __name__ == "__main__": + command = sys.argv[1] + if command == "setup": + setup() + elif command == "shutdown": + shutdown() diff --git a/test/test_jit.py b/test/test_jit.py index 5b00a3590f..953fca61ee 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -1303,9 +1303,12 @@ def doit(x, y): @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") @skipIfRocm def test_cpp_cuda(self): + from cpp.jit import tests_setup + tests_setup.setup() # rather than rebuild assertExpected in cpp, # just glob all the cpp outputs into one file for now self.assertExpected(torch._C._jit_run_cpp_tests()) + tests_setup.shutdown() def test_batchnorm(self): x = torch.ones(2, 2, 2, 2) diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 1e5e9f4360..4561307ca8 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -1,5 +1,6 @@ #pragma once #include +#include #include #include #include @@ -481,6 +482,26 @@ struct Module { } fn(*this); } + /// Enables "training" mode. + void train(bool on = true) { + for (auto& submod : get_modules()) { + submod->module->train(on); + } + register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true); + } + /// Calls train(false) to enable "eval" mode. + /// Do not override this method, override `train()` instead. + void eval() { + train(/*on=*/false); + } + /// True if the module is in training mode. + bool is_training() { + if (auto p = find_parameter("training")) { + return p->slot()->item() == 1; + } + // We are in training mode by default + return true; + } /// Recursively casts all parameters to the given `dtype` and `device`. ///