Skip to content

Commit

Permalink
Add init_eval to evaluation hook (open-mmlab#3550)
Browse files Browse the repository at this point in the history
* Add init_eval to evaluation hook

* Add start to eval hook

* fix docstring

* fix docstring

* Support tmpdir in DistEvalHook.

* change according to comments

* Resolve comments

* Resolve comments and add unittest

* Simplify the code

* Deal with negative start number.

* small change
  • Loading branch information
Johnson-Wang authored Aug 23, 2020
1 parent 31fb4fb commit 4b6ff75
Show file tree
Hide file tree
Showing 2 changed files with 184 additions and 13 deletions.
78 changes: 66 additions & 12 deletions mmdet/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os.path as osp
import warnings

from mmcv.runner import Hook
from torch.utils.data import DataLoader
Expand All @@ -7,25 +8,67 @@
class EvalHook(Hook):
"""Evaluation hook.
Note that if new arguments are added for EvalHook, tools/test.py may be
Notes:
If new arguments are added for EvalHook, tools/test.py may be
effected.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
start (int, optional): Evaluation starting epoch. It enables evaluation
before the training starts if ``start`` <= the resuming epoch.
If None, whether to evaluate is merely decided by ``interval``.
Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
eval_kwargs (dict): Evaluation arguments of :func:`dataset.evaluate()`.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""

def __init__(self, dataloader, interval=1, **eval_kwargs):
def __init__(self, dataloader, start=None, interval=1, **eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got'
f' {type(dataloader)}')
if not interval > 0:
raise ValueError(f'interval must be positive, but got {interval}')
if start is not None and start < 0:
warnings.warn(
f'The evaluation start epoch {start} is smaller than 0, '
f'use 0 instead', UserWarning)
start = 0
self.dataloader = dataloader
self.interval = interval
self.start = start
self.eval_kwargs = eval_kwargs
self.initial_epoch_flag = True

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training."""
if not self.initial_epoch_flag:
return
if self.start is not None and runner.epoch >= self.start:
self.after_train_epoch(runner)
self.initial_epoch_flag = False

def evaluation_flag(self, runner):
"""Judge whether to perform_evaluation after this epoch.
Returns:
bool: The flag indicating whether to perform evaluation.
"""
if self.start is None:
if not self.every_n_epochs(runner, self.interval):
# No evaluation during the interval epochs.
return False
elif (runner.epoch + 1) < self.start:
# No evaluation if start is larger than the current epoch.
return False
else:
# Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
if (runner.epoch + 1 - self.start) % self.interval:
return False
return True

def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
if not self.evaluation_flag(runner):
return
from mmdet.apis import single_gpu_test
results = single_gpu_test(runner.model, self.dataloader, show=False)
Expand All @@ -42,36 +85,47 @@ def evaluate(self, runner, results):
class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
Notes:
If new arguments are added, tools/test.py may be effected.
Attributes:
dataloader (DataLoader): A PyTorch dataloader.
start (int, optional): Evaluation starting epoch. It enables evaluation
before the training starts if ``start`` <= the resuming epoch.
If None, whether to evaluate is merely decided by ``interval``.
Default: None.
interval (int): Evaluation interval (by epochs). Default: 1.
tmpdir (str | None): Temporary directory to save the results of all
processes. Default: None.
gpu_collect (bool): Whether to use gpu or cpu to collect results.
Default: False.
**eval_kwargs: Evaluation arguments fed into the evaluate function of
the dataset.
"""

def __init__(self,
dataloader,
start=None,
interval=1,
tmpdir=None,
gpu_collect=False,
**eval_kwargs):
if not isinstance(dataloader, DataLoader):
raise TypeError('dataloader must be a pytorch DataLoader, but got '
f'{type(dataloader)}')
self.dataloader = dataloader
self.interval = interval
super().__init__(
dataloader, start=start, interval=interval, **eval_kwargs)
self.tmpdir = tmpdir
self.gpu_collect = gpu_collect
self.eval_kwargs = eval_kwargs

def after_train_epoch(self, runner):
if not self.every_n_epochs(runner, self.interval):
if not self.evaluation_flag(runner):
return
from mmdet.apis import multi_gpu_test
tmpdir = self.tmpdir
if tmpdir is None:
tmpdir = osp.join(runner.work_dir, '.eval_hook')
results = multi_gpu_test(
runner.model,
self.dataloader,
tmpdir=osp.join(runner.work_dir, '.eval_hook'),
tmpdir=tmpdir,
gpu_collect=self.gpu_collect)
if runner.rank == 0:
print('\n')
Expand Down
119 changes: 118 additions & 1 deletion tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import bisect
import logging
import math
import os.path as osp
import tempfile
from collections import defaultdict
from unittest.mock import MagicMock
from unittest.mock import MagicMock, patch

import mmcv
import numpy as np
import pytest
import torch
import torch.nn as nn
from mmcv.runner import EpochBasedRunner
from torch.utils.data import DataLoader

from mmdet.core.evaluation import DistEvalHook, EvalHook
from mmdet.datasets import (DATASETS, ClassBalancedDataset, CocoDataset,
ConcatDataset, CustomDataset, RepeatDataset,
build_dataset)
Expand Down Expand Up @@ -343,3 +349,114 @@ def test_dataset_wrapper():
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)


@patch('mmdet.apis.single_gpu_test', MagicMock)
@patch('mmdet.apis.multi_gpu_test', MagicMock)
@pytest.mark.parametrize('EvalHookParam', (EvalHook, DistEvalHook))
def test_evaluation_hook(EvalHookParam):
# create dummy data
dataloader = DataLoader(torch.ones((5, 2)))

# 0.1. dataloader is not a DataLoader object
with pytest.raises(TypeError):
EvalHookParam(dataloader=MagicMock(), interval=-1)

# 0.2. negative interval
with pytest.raises(ValueError):
EvalHookParam(dataloader, interval=-1)

# 1. start=None, interval=1: perform evaluation after each epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, interval=1)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2

# 2. start=1, interval=1: perform evaluation after each epoch.
runner = _build_demo_runner()

evalhook = EvalHookParam(dataloader, start=1, interval=1)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2

# 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, interval=2)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 1 # after epoch 2

# 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=1, interval=2)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 3)
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 3

# 5. start=0/negative, interval=1: perform evaluation after each epoch and
# before epoch 1.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=0)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2

runner = _build_demo_runner()
with pytest.warns(UserWarning):
evalhook = EvalHookParam(dataloader, start=-2)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2

# 6. resuming from epoch i, start = x (x<=i), interval =1: perform
# evaluation after each epoch and before the first epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=1)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner._epoch = 2
runner.run([dataloader], [('train', 1)], 3)
assert evalhook.evaluate.call_count == 2 # before & after epoch 3

# 7. resuming from epoch i, start = i+1/None, interval =1: perform
# evaluation after each epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=2)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner._epoch = 1
runner.run([dataloader], [('train', 1)], 3)
assert evalhook.evaluate.call_count == 2 # after epoch 2 & 3


def _build_demo_runner():

class Model(nn.Module):

def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)

def forward(self, x):
return self.linear(x)

def train_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))

def val_step(self, x, optimizer, **kwargs):
return dict(loss=self(x))

model = Model()
tmp_dir = tempfile.mkdtemp()

runner = EpochBasedRunner(
model=model, work_dir=tmp_dir, logger=logging.getLogger())
return runner

0 comments on commit 4b6ff75

Please sign in to comment.