Skip to content

Commit

Permalink
fix and test the number of times evaluation are called
Browse files Browse the repository at this point in the history
Reviewed By: alexander-kirillov

Differential Revision: D27988270

fbshipit-source-id: e528f8809f8ace5fd50c7fb63856b0a4dda3db95
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Apr 25, 2021
1 parent 28174e9 commit 1dac147
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 1 deletion.
4 changes: 3 additions & 1 deletion detectron2/engine/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def _do_eval(self):
def after_step(self):
next_iter = self.trainer.iter + 1
if self._period > 0 and next_iter % self._period == 0:
self._do_eval()
# do the last eval in after_train
if next_iter != self.trainer.max_iter:
self._do_eval()

def after_train(self):
# This condition is to prevent the eval from running after a failed training
Expand Down
13 changes: 13 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tempfile
import time
import unittest
from unittest import mock
import torch
from fvcore.common.checkpoint import Checkpointer
from torch import nn
Expand Down Expand Up @@ -128,3 +129,15 @@ def test_checkpoint_resume(self):
checkpointer.resume_or_load("non_exist.pth")
self.assertEqual(trainer.iter, 11) # last finished iter
self.assertEqual(scheduler.last_epoch, 11)

def test_eval_hook(self):
model = _SimpleModel()
dataloader = self._data_loader("cpu")
opt = torch.optim.SGD(model.parameters(), 0.1)

for total_iter, period, eval_count in [(30, 15, 2), (31, 15, 3), (20, 0, 1)]:
test_func = mock.Mock(return_value={"metric": 3.0})
trainer = SimpleTrainer(model, dataloader, opt)
trainer.register_hooks([hooks.EvalHook(period, test_func)])
trainer.train(0, total_iter)
self.assertEqual(test_func.call_count, eval_count)
1 change: 1 addition & 0 deletions tools/lazyconfig_train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def main(args):

if args.eval_only:
model = instantiate(cfg.model)
model.to(cfg.train.device)
model = create_ddp_model(model)
DetectionCheckpointer(model).load(cfg.train.init_checkpoint)
print(do_test(cfg, model))
Expand Down

0 comments on commit 1dac147

Please sign in to comment.