Skip to content

Commit

Permalink
Add train() / eval() / is_training() to C++ ScriptModule API (#16044)
Browse files Browse the repository at this point in the history
Summary:
This PR aims to fix https://discuss.pytorch.org/t/how-to-change-a-loaded-model-to-evaluation-mode-in-c/32330, by adding `train()` / `eval()` / `is_training()` to C++ ScriptModule API.
Pull Request resolved: pytorch/pytorch#16044

Differential Revision: D13857724

Pulled By: yf225

fbshipit-source-id: 16d3969fb5840ff7e66c7f72e800e6c75db8d2ff
  • Loading branch information
Will Feng authored and facebook-github-bot committed Feb 1, 2019
1 parent 6d373c0 commit a40e8ce
Show file tree
Hide file tree
Showing 11 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ test/data/gpu_tensors.pt
test/data/legacy_modules.t7
test/data/legacy_serialized.pt
test/data/linear.pt
dropout_model.pt
test/generated_type_hints_smoketest.py
test/htmlcov
test/cpp_extensions/install/
Expand Down
2 changes: 2 additions & 0 deletions .jenkins/pytorch/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,14 @@ test_torchvision() {
test_libtorch() {
if [[ "$BUILD_TEST_LIBTORCH" == "1" ]]; then
echo "Testing libtorch"
python test/cpp/jit/tests_setup.py setup
CPP_BUILD="$PWD/../cpp-build"
if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then
"$CPP_BUILD"/caffe2/bin/test_jit
else
"$CPP_BUILD"/caffe2/bin/test_jit "[cpu]"
fi
python test/cpp/jit/tests_setup.py shutdown
python tools/download_mnist.py --quiet -d test/cpp/api/mnist
OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$CPP_BUILD"/caffe2/bin/test_api
assert_git_not_dirty
Expand Down
Empty file added test/cpp/__init__.py
Empty file.
Empty file added test/cpp/jit/__init__.py
Empty file.
1 change: 1 addition & 0 deletions test/cpp/jit/gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ JIT_TEST(CustomOperators)
JIT_TEST(Differentiate)
JIT_TEST(DifferentiateWithRequiresGrad)
JIT_TEST(DynamicDAG)
JIT_TEST(EvalModeForLoadedModule)
JIT_TEST(FromQualString)
JIT_TEST(InternedStrings)
JIT_TEST(IValue)
Expand Down
1 change: 1 addition & 0 deletions test/cpp/jit/no-gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ std::string runJITCPPTests() {
testDifferentiate(out);
testDifferentiateWithRequiresGrad(out);
testDynamicDAG();
testEvalModeForLoadedModule();
testFromQualString();
testFusion();
testGraphExecutor();
Expand Down
5 changes: 5 additions & 0 deletions test/cpp/jit/test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,8 @@
ASSERT_TRUE(threw);

#endif // defined(USE_GTEST)

bool isSandcastle() {
return ((std::getenv("SANDCASTLE")) || \
(std::getenv("TW_JOB_USER") && std::string(std::getenv("TW_JOB_USER")) == "sandcastle"));
}
12 changes: 12 additions & 0 deletions test/cpp/jit/test_misc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "torch/csrc/jit/custom_operator.h"
#include "torch/csrc/jit/dynamic_dag.h"
#include "torch/csrc/jit/fuser/interface.h"
#include "torch/csrc/jit/import.h"
#include "torch/csrc/jit/interpreter.h"
#include "torch/csrc/jit/passes/alias_analysis.h"
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
Expand Down Expand Up @@ -1418,6 +1419,17 @@ void testCustomOperators() {
}
}

void testEvalModeForLoadedModule() {
if (isSandcastle()) return; // The module file to load is not generated in Sandcastle
std::string module_path = "dropout_model.pt";
std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(module_path);
AT_ASSERT(module->get_module("dropout")->is_training());
module->eval();
AT_ASSERT(!module->get_module("dropout")->is_training());
module->train();
AT_ASSERT(module->get_module("dropout")->is_training());
}

// test a few features that are not directly used in schemas yet
void testSchemaParser() {
// nested arrays
Expand Down
41 changes: 41 additions & 0 deletions test/cpp/jit/tests_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import sys
import os
import torch

testEvalModeForLoadedModule_module_path = 'dropout_model.pt'


def testEvalModeForLoadedModule_setup():
class Model(torch.jit.ScriptModule):
def __init__(self):
super(Model, self).__init__()
self.dropout = torch.nn.Dropout(0.1)

def forward(self, x):
x = self.dropout(x)
return x

model = Model()
model = model.train()
model.save(testEvalModeForLoadedModule_module_path)


def testEvalModeForLoadedModule_shutdown():
if os.path.exists(testEvalModeForLoadedModule_module_path):
os.remove(testEvalModeForLoadedModule_module_path)


def setup():
testEvalModeForLoadedModule_setup()


def shutdown():
testEvalModeForLoadedModule_shutdown()


if __name__ == "__main__":
command = sys.argv[1]
if command == "setup":
setup()
elif command == "shutdown":
shutdown()
3 changes: 3 additions & 0 deletions test/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,9 +1303,12 @@ def doit(x, y):
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
@skipIfRocm
def test_cpp_cuda(self):
from cpp.jit import tests_setup
tests_setup.setup()
# rather than rebuild assertExpected in cpp,
# just glob all the cpp outputs into one file for now
self.assertExpected(torch._C._jit_run_cpp_tests())
tests_setup.shutdown()

def test_batchnorm(self):
x = torch.ones(2, 2, 2, 2)
Expand Down
21 changes: 21 additions & 0 deletions torch/csrc/jit/script/module.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <torch/csrc/autograd/variable.h>
#include <torch/csrc/autograd/generated/variable_factories.h>
#include <torch/csrc/jit/argument_spec.h>
#include <c10/util/Exception.h>
#include <torch/csrc/jit/graph_executor.h>
Expand Down Expand Up @@ -481,6 +482,26 @@ struct Module {
}
fn(*this);
}
/// Enables "training" mode.
void train(bool on = true) {
for (auto& submod : get_modules()) {
submod->module->train(on);
}
register_parameter("training", torch::tensor(on ? 1 : 0, at::kLong), /*is_buffer=*/true);
}
/// Calls train(false) to enable "eval" mode.
/// Do not override this method, override `train()` instead.
void eval() {
train(/*on=*/false);
}
/// True if the module is in training mode.
bool is_training() {
if (auto p = find_parameter("training")) {
return p->slot()->item<int64_t>() == 1;
}
// We are in training mode by default
return true;
}

/// Recursively casts all parameters to the given `dtype` and `device`.
///
Expand Down

0 comments on commit a40e8ce

Please sign in to comment.