Skip to content

Commit

Permalink
issue pytorch#710 : add event list and or operator (pytorch#868)
Browse files Browse the repository at this point in the history
* issue pytorch#710 : add event list and or operator

* fix fstring

* improve test of event list

* move EventList from on to add_event_handler

* improve docs - add_eventlist_handler

* fix naming RemovableEventHandler

* RemovableEventHandler handles EventsList

* move back RemovableEventHandler to RemovableEventHandle

* add test_events_list_removable_handle

* docstring

* minimal update of README

* Update README.md

* Added a test to make coverage happy

* Added a test with custom event and events list
- minor: fixed typo

Co-authored-by: Desroziers <[email protected]>
Co-authored-by: vfdev <[email protected]>
  • Loading branch information
3 people authored Apr 1, 2020
1 parent 10135a9 commit 6a2460e
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 11 deletions.
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,22 @@ def log_gradients(_):

</details>

### Stack events to share some actions

<details>
<summary>
Examples
</summary>

Events can be stacked together to enable multiple calls:
```python
@trainer.on(Events.COMPLETED | Events.EPOCH_COMPLETED(every=10))
def do_some_validation(engine):
# ...
```

</details>

### Custom events to go beyond standard events

<details>
Expand Down
2 changes: 1 addition & 1 deletion docs/source/concepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ Attaching an event handler is simple using method :meth:`~ignite.engine.Engine.a
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
Event handlers can be detached via :meth:`~ignite.engine.Engine.remove_event_handler` or via the :class:`~ignite.engine.RemovableEventHandler`
Event handlers can be detached via :meth:`~ignite.engine.Engine.remove_event_handler` or via the :class:`~ignite.engine.RemovableEventHandle`
reference returned by :meth:`~ignite.engine.Engine.add_event_handler`. This can be used to reuse a configured engine for multiple loops:

.. code-block:: python
Expand Down
21 changes: 16 additions & 5 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch

from ignite.engine.events import Events, State, CallableEventWithFilter, RemovableEventHandle
from ignite.engine.events import Events, State, CallableEventWithFilter, RemovableEventHandle, EventsList
from ignite.engine.utils import ReproducibleBatchSampler, _update_dataloader, _check_signature
from ignite._utils import _to_hours_mins_secs

Expand Down Expand Up @@ -157,7 +157,7 @@ def register_events(self, *event_names: Union[str, int, Any], event_to_attr: Opt
.. code-block:: python
from ignite.engine import Engine, EvenEnum
from ignite.engine import Engine, EventEnum
class CustomEvents(EventEnum):
FOO_EVENT = "foo_event"
Expand Down Expand Up @@ -211,8 +211,9 @@ def add_event_handler(self, event_name: str, handler: Callable, *args, **kwargs)
"""Add an event handler to be executed when the specified event is fired.
Args:
event_name: An event to attach the handler to. Valid events are from :class:`~ignite.engine.Events`
or any `event_name` added by :meth:`~ignite.engine.Engine.register_events`.
event_name: An event or a list of events to attach the handler. Valid events are
from :class:`~ignite.engine.Events` or any `event_name` added by
:meth:`~ignite.engine.Engine.register_events`.
handler (callable): the callable event handler that should be invoked
*args: optional args to be passed to `handler`.
**kwargs: optional keyword args to be passed to `handler`.
Expand All @@ -225,7 +226,7 @@ def add_event_handler(self, event_name: str, handler: Callable, *args, **kwargs)
passed here, for example during :attr:`~ignite.engine.Events.EXCEPTION_RAISED`.
Returns:
:class:`~ignite.engine.RemovableEventHandler`, which can be used to remove the handler.
:class:`~ignite.engine.RemovableEventHandle`, which can be used to remove the handler.
Example usage:
Expand All @@ -238,12 +239,22 @@ def print_epoch(engine):
engine.add_event_handler(Events.EPOCH_COMPLETED, print_epoch)
events_list = Events.EPOCH_COMPLETED | Events.COMPLETED
def execute_validation(engine):
# do some validations
engine.add_event_handler(events_list, execute_validation)
Note:
Since v0.3.0, Events become more flexible and allow to pass an event filter to the Engine.
See :class:`~ignite.engine.Events` for more details.
"""
if isinstance(event_name, EventsList):
for e in event_name:
self.add_event_handler(e, handler, *args, **kwargs)
return RemovableEventHandle(event_name, handler, self)
if (
isinstance(event_name, CallableEventWithFilter)
and event_name.filter != CallableEventWithFilter.default_event_filter
Expand Down
63 changes: 60 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ def __eq__(self, other):
def __hash__(self):
return hash(self._name_)

def __or__(self, other):
return EventsList() | self | other


class EventEnum(CallableEventWithFilter, Enum):
pass
Expand Down Expand Up @@ -179,6 +182,55 @@ def call_once(engine):
GET_BATCH_STARTED = "get_batch_started"
GET_BATCH_COMPLETED = "get_batch_completed"

def __or__(self, other):
return EventsList() | self | other


class EventsList:
"""Collection of events stacked by operator `__or__`.
.. code-block:: python
events = Events.STARTED | Events.COMPLETED
events |= Events.ITERATION_STARTED(every=3)
engine = ...
@engine.on(events)
def call_on_events(engine):
# do something
or
.. code-block:: python
@engine.on(Events.STARTED | Events.COMPLETED | Events.ITERATION_STARTED(every=3))
def call_on_events(engine):
# do something
"""

def __init__(self):
self._events = []

def _append(self, event: Union[Events, CallableEventWithFilter]):
if not isinstance(event, (Events, CallableEventWithFilter)):
raise ValueError("Argument event should be Events or CallableEventWithFilter, got: {}".format(type(event)))
self._events.append(event)

def __getitem__(self, item):
return self._events[item]

def __iter__(self):
return iter(self._events)

def __len__(self):
return len(self._events)

def __or__(self, other: Union[Events, CallableEventWithFilter]):
self._append(event=other)
return self


class State:
"""An object that is used to pass internal and user-defined state between event handlers. By default, state
Expand Down Expand Up @@ -270,7 +322,7 @@ def print_epoch(engine):
# print_epoch handler is now unregistered
"""

def __init__(self, event_name: Union[CallableEventWithFilter, Enum], handler: Callable, engine):
def __init__(self, event_name: Union[CallableEventWithFilter, Enum, EventsList], handler: Callable, engine):
self.event_name = event_name
self.handler = weakref.ref(handler)
self.engine = weakref.ref(engine)
Expand All @@ -283,8 +335,13 @@ def remove(self) -> None:
if handler is None or engine is None:
return

if engine.has_event_handler(handler, self.event_name):
engine.remove_event_handler(handler, self.event_name)
if isinstance(self.event_name, EventsList):
for e in self.event_name:
if engine.has_event_handler(handler, e):
engine.remove_event_handler(handler, e)
else:
if engine.has_event_handler(handler, self.event_name):
engine.remove_event_handler(handler, self.event_name)

def __enter__(self):
return self
Expand Down
58 changes: 57 additions & 1 deletion tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from ignite.engine import Engine, Events
from ignite.engine.events import CallableEventWithFilter, EventEnum
from ignite.engine.events import CallableEventWithFilter, EventEnum, EventsList

import pytest

Expand Down Expand Up @@ -75,6 +75,23 @@ def handle(engine):
engine.register_events(*CustomEvents, event_to_attr=custom_event_to_attr)


def test_custom_events_with_events_list():
class CustomEvents(EventEnum):
TEST_EVENT = "test_event"

def process_func(engine, batch):
engine.fire_event(CustomEvents.TEST_EVENT)

engine = Engine(process_func)
engine.register_events(*CustomEvents)

# Handle should be called
handle = MagicMock()
engine.add_event_handler(CustomEvents.TEST_EVENT | Events.STARTED, handle)
engine.run(range(1))
assert handle.called


def test_callable_events_with_wrong_inputs():

with pytest.raises(ValueError, match=r"Only one of the input arguments should be specified"):
Expand Down Expand Up @@ -433,3 +450,42 @@ def test_distrib_gpu(distributed_context_single_node_nccl):
device = "cuda:{}".format(distributed_context_single_node_nccl["local_rank"])
_test_every_event_filter_with_engine(device)
_test_every_event_filter_with_engine_with_dataloader(device)


def test_event_list():

e1 = Events.ITERATION_STARTED(once=1)
e2 = Events.ITERATION_STARTED(every=3)
e3 = Events.COMPLETED

event_list = e1 | e2 | e3

assert type(event_list) == EventsList
assert len(event_list) == 3
assert event_list[0] == e1
assert event_list[1] == e2
assert event_list[2] == e3


def test_list_of_events():
def _test(event_list, true_iterations):

engine = Engine(lambda e, b: b)

iterations = []

num_calls = [0]

@engine.on(event_list)
def execute_some_handler(e):
iterations.append(e.state.iteration)
num_calls[0] += 1

engine.run(range(3), max_epochs=5)

assert iterations == true_iterations
assert num_calls[0] == len(true_iterations)

_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=1), [1, 1])
_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(once=10), [1, 10])
_test(Events.ITERATION_STARTED(once=1) | Events.ITERATION_STARTED(every=3), [1, 3, 6, 9, 12, 15])
111 changes: 110 additions & 1 deletion tests/ignite/engine/test_event_handlers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import gc

from unittest.mock import MagicMock
from unittest.mock import call, MagicMock

import pytest
from pytest import raises

from ignite.engine import Engine, Events, State
from ignite.engine.events import EventsList


class DummyEngine(Engine):
Expand Down Expand Up @@ -181,6 +182,114 @@ def _handler(_):
assert removable_handle.handler() is None


def test_events_list_removable_handle():

# Removable handle removes event from engine.
engine = DummyEngine()
handler = MagicMock(spec_set=True)
assert not hasattr(handler, "_parent")

events_list = Events.STARTED | Events.COMPLETED

removable_handle = engine.add_event_handler(events_list, handler)
for e in events_list:
assert engine.has_event_handler(handler, e)

engine.run(1)
calls = [call(engine), call(engine)]
handler.assert_has_calls(calls)
assert handler.call_count == 2

removable_handle.remove()
for e in events_list:
assert not engine.has_event_handler(handler, e)

# Second engine pass does not fire handle again.
engine.run(1)
handler.assert_has_calls(calls)
assert handler.call_count == 2

# Removable handle can be used as a context manager
handler = MagicMock(spec_set=True)

with engine.add_event_handler(events_list, handler):
for e in events_list:
assert engine.has_event_handler(handler, e)
engine.run(1)

for e in events_list:
assert not engine.has_event_handler(handler, e)
handler.assert_has_calls(calls)
assert handler.call_count == 2

engine.run(1)
handler.assert_has_calls(calls)
assert handler.call_count == 2

# Removeable handle only effects a single event registration
handler = MagicMock(spec_set=True)

other_events_list = Events.EPOCH_STARTED | Events.EPOCH_COMPLETED

with engine.add_event_handler(events_list, handler):
with engine.add_event_handler(other_events_list, handler):
for e in events_list:
assert engine.has_event_handler(handler, e)
for e in other_events_list:
assert engine.has_event_handler(handler, e)
for e in events_list:
assert engine.has_event_handler(handler, e)
for e in other_events_list:
assert not engine.has_event_handler(handler, e)
for e in events_list:
assert not engine.has_event_handler(handler, e)
for e in other_events_list:
assert not engine.has_event_handler(handler, e)

# Removeable handle is re-enter and re-exitable

handler = MagicMock(spec_set=True)

remove = engine.add_event_handler(events_list, handler)

with remove:
with remove:
for e in events_list:
assert engine.has_event_handler(handler, e)
for e in events_list:
assert not engine.has_event_handler(handler, e)
for e in events_list:
assert not engine.has_event_handler(handler, e)

# Removeable handle is a weakref, does not keep engine or event alive
def _add_in_closure():
_engine = DummyEngine()

def _handler(_):
pass

_handle = _engine.add_event_handler(events_list, _handler)
assert _handle.engine() is _engine
assert _handle.handler() is _handler

return _handle

removable_handle = _add_in_closure()

# gc.collect, resolving reference cycles in engine/state
# required to ensure object deletion in python2
gc.collect()

assert removable_handle.engine() is None
assert removable_handle.handler() is None


def test_eventslist__append_raises():
ev_list = EventsList()
with pytest.raises(ValueError, match=r"Argument event should be Events or CallableEventWithFilter"):
ev_list._append("abc")


def test_has_event_handler():
engine = DummyEngine()
handlers = [MagicMock(spec_set=True), MagicMock(spec_set=True)]
Expand Down

0 comments on commit 6a2460e

Please sign in to comment.