Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add event + event handlers for custom logging #942

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions torchtnt/utils/event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Dict, Union

EventMetadataValue = Union[str, int, float, bool, None]


@dataclass
class Event:
"""
The class represents the generic event that occurs during a TorchTNT
loop. The event can be any kind of meaningful action.

Args:
name: event name.
metadata: additional data that is associated with the event.
"""

name: str
metadata: Dict[str, EventMetadataValue] = field(default_factory=dict)
85 changes: 85 additions & 0 deletions torchtnt/utils/event_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import random
from contextlib import contextmanager
from functools import lru_cache
from typing import Dict, Generator, List, Optional

import importlib_metadata
from typing_extensions import Protocol, runtime_checkable

from .event import Event

logger: logging.Logger = logging.getLogger(__name__)


@runtime_checkable
class EventHandler(Protocol):
def handle_event(self, event: Event) -> None: ...


_log_handlers: List[EventHandler] = []


@lru_cache(maxsize=None)
def get_event_handlers() -> List[EventHandler]:
global _log_handlers

# Registered event handlers through entry points
eps = importlib_metadata.entry_points(group="tnt_event_handlers")
for entry in eps:
logger.debug(
f"Attempting to register event handler {entry.name}: {entry.value}"
)
factory = entry.load()
handler = factory()

if not isinstance(handler, EventHandler):
raise RuntimeError(
f"The factory function for {({entry.value})} "
"did not return a EventHandler object."
)
_log_handlers.append(handler)
return _log_handlers


def log_event(event: Event) -> None:
"""
Handle an event.

Args:
event: The event to handle.
"""

for handler in get_event_handlers():
handler.handle_event(event)


@contextmanager
def log_interval(
name: str, metadata: Optional[Dict[str, str]] = None
) -> Generator[None, None, None]:
unique_id = _generate_random_int64()
if metadata is None:
metadata = {}
metadata.update({"action": "start", "unique_id": unique_id})
start_event = Event(name=name, metadata=metadata)
log_event(start_event)

yield

metadata["action"] = "end"
end_event = Event(name=name, metadata=metadata)
log_event(end_event)


def _generate_random_int64() -> int:
# avoid being influenced by externally set seed
local_random = random.Random()
return local_random.randint(0, 2**63 - 1)
Loading