Skip to content

Commit

Permalink
Fixes issue with metric's output device (pytorch#2062)
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 authored Jun 18, 2021
1 parent f240e6e commit 4ead17c
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 5 deletions.
14 changes: 11 additions & 3 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,19 @@ def iteration_completed(self, engine: Engine) -> None:

def completed(self, engine: Engine, name: str) -> None:
"""Helper method to compute metric's value and put into the engine. It is automatically attached to the
`engine` with :meth:`~ignite.metrics.metric.Metric.attach`.
`engine` with :meth:`~ignite.metrics.metric.Metric.attach`. If metrics' value is torch tensor, it is
explicitly sent to CPU device.
Args:
engine: the engine to which the metric must be attached
name: the name of the metric used as key in dict `engine.state.metrics`
.. versionchanged:: 0.4.3
Added dict in metrics results.
.. versionchanged:: 0.4.5
metric's value is put on CPU if torch tensor.
"""
result = self.compute()
if isinstance(result, Mapping):
Expand All @@ -341,8 +346,11 @@ def completed(self, engine: Engine, name: str) -> None:
engine.state.metrics[key] = value
engine.state.metrics[name] = result
else:
if isinstance(result, torch.Tensor) and len(result.size()) == 0:
result = result.item()
if isinstance(result, torch.Tensor):
if len(result.size()) == 0:
result = result.item()
elif "cpu" not in result.device.type:
result = result.cpu()

engine.state.metrics[name] = result

Expand Down
25 changes: 25 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,6 +778,31 @@ def update(self, output):
assert engine.state.metrics == {"metric": "foo"}


@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no GPU")
def test_completed_on_cuda():

# Checks https://github.com/pytorch/ignite/issues/1635#issuecomment-863026919

class DummyMetric(Metric):
def reset(self):
pass

def compute(self):
return torch.tensor([1.0, 2.0, 3.0], device="cuda")

def update(self, output):
pass

m = DummyMetric()

# tensor
engine = MagicMock(state=State(metrics={}))
m.completed(engine, "metric")
assert "metric" in engine.state.metrics
assert isinstance(engine.state.metrics["metric"], torch.Tensor)
assert engine.state.metrics["metric"].device.type == "cpu"


def test_usage_exception():
engine = Engine(lambda e, b: b)
m = DummyMetric2()
Expand Down
3 changes: 2 additions & 1 deletion tests/ignite/metrics/test_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,8 @@ def update(engine, i):
assert pr._updated is True
res = engine.state.metrics["pr"]
if isinstance(res, torch.Tensor):
assert res.device == metric_device
# Fixes https://github.com/pytorch/ignite/issues/1635#issuecomment-863026919
assert res.device.type == "cpu"
res = res.cpu().numpy()

true_res = precision_score(
Expand Down
3 changes: 2 additions & 1 deletion tests/ignite/metrics/test_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,8 @@ def update(engine, i):
assert re._updated is True
res = engine.state.metrics["re"]
if isinstance(res, torch.Tensor):
assert res.device == metric_device
# Fixes https://github.com/pytorch/ignite/issues/1635#issuecomment-863026919
assert res.device.type == "cpu"
res = res.cpu().numpy()

true_res = recall_score(
Expand Down

0 comments on commit 4ead17c

Please sign in to comment.