Skip to content

Commit

Permalink
Some fixes to callback (stanfordnlp#1696)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmoneygithub authored Oct 26, 2024
1 parent c9a8cd4 commit 16ba98a
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 87 deletions.
17 changes: 8 additions & 9 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
import functools
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional

import litellm
import ujson
import uuid
from litellm.caching import Cache

from dspy.utils.logging import logger
from dspy.clients.finetune import FinetuneJob, TrainingMethod
from dspy.clients.lm_finetune_utils import (
get_provider_finetune_job_class,
execute_finetune_job,
get_provider_finetune_job_class,
)

from dspy.utils.callback import with_callbacks
import litellm
from litellm.caching import Cache

from dspy.utils.logging import logger

DISK_CACHE_DIR = os.environ.get("DSPY_CACHEDIR") or os.path.join(Path.home(), ".dspy_cache")
litellm.cache = Cache(disk_cache_dir=DISK_CACHE_DIR, type="disk")
Expand Down
6 changes: 3 additions & 3 deletions dspy/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .callback import *
from .dummies import *
from .logging import *
from dspy.utils.callback import BaseCallback, with_callbacks
from dspy.utils.dummies import *
from dspy.utils.logging import *
112 changes: 53 additions & 59 deletions dspy/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@


class BaseCallback:
"""
A base class for defining callback handlers for DSPy components.
"""A base class for defining callback handlers for DSPy components.
To use a callback, subclass this class and implement the desired handlers. Each handler
will be called at the appropriate time before/after the execution of the corresponding component.
will be called at the appropriate time before/after the execution of the corresponding component. For example, if
you want to print a message before and after an LM is called, implement `the on_llm_start` and `on_lm_end` handler.
Users can set the callback globally using `dspy.settings.configure` or locally by passing it to the component
constructor.
For example, if you want to print a message before and after an LM is called, implement
the on_llm_start and on_lm_end handler and set the callback to the global settings using `dspy.settings.configure`.
Example 1: Set a global callback using `dspy.settings.configure`.
```
import dspy
Expand All @@ -45,19 +47,18 @@ def on_lm_end(self, call_id, outputs, exception):
# > LM is finished with outputs: {'answer': '42'}
```
Another way to set the callback is to pass it directly to the component constructor.
In this case, the callback will only be triggered for that specific instance.
Example 2: Set a local callback by passing it to the component constructor.
```
lm = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
lm(question="What is the meaning of life?")
lm_1 = dspy.LM("gpt-3.5-turbo", callbacks=[LoggingCallback()])
lm_1(question="What is the meaning of life?")
# > LM is called with inputs: {'question': 'What is the meaning of life?'}
# > LM is finished with outputs: {'answer': '42'}
lm_2 = dspy.LM("gpt-3.5-turbo")
lm_2(question="What is the meaning of life?")
# No logging here
# No logging here because only `lm_1` has the callback set.
```
"""

Expand All @@ -67,8 +68,7 @@ def on_module_start(
instance: Any,
inputs: Dict[str, Any],
):
"""
A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
"""A handler triggered when forward() method of a module (subclass of dspy.Module) is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -84,8 +84,7 @@ def on_module_end(
outputs: Optional[Any],
exception: Optional[Exception] = None,
):
"""
A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
"""A handler triggered after forward() method of a module (subclass of dspy.Module) is executed.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -101,8 +100,7 @@ def on_lm_start(
instance: Any,
inputs: Dict[str, Any],
):
"""
A handler triggered when __call__ method of dspy.LM instance is called.
"""A handler triggered when __call__ method of dspy.LM instance is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -118,8 +116,7 @@ def on_lm_end(
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""
A handler triggered after __call__ method of dspy.LM instance is executed.
"""A handler triggered after __call__ method of dspy.LM instance is executed.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -129,14 +126,13 @@ def on_lm_end(
"""
pass

def on_format_start(
def on_adapter_format_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
"""
A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
"""A handler triggered when format() method of an adapter (subclass of dspy.Adapter) is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -146,14 +142,13 @@ def on_format_start(
"""
pass

def on_format_end(
def on_adapter_format_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""
A handler triggered after format() method of dspy.LM instance is executed.
"""A handler triggered after format() method of an adapter (subclass of dspy.Adapter) is called..
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -163,14 +158,13 @@ def on_format_end(
"""
pass

def on_parse_start(
def on_adapter_parse_start(
self,
call_id: str,
instance: Any,
inputs: Dict[str, Any],
):
"""
A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
"""A handler triggered when parse() method of an adapter (subclass of dspy.Adapter) is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -180,14 +174,13 @@ def on_parse_start(
"""
pass

def on_parse_end(
def on_adapter_parse_end(
self,
call_id: str,
outputs: Optional[Dict[str, Any]],
exception: Optional[Exception] = None,
):
"""
A handler triggered after parse() method of dspy.LM instance is executed.
"""A handler triggered after parse() method of an adapter (subclass of dspy.Adapter) is called.
Args:
call_id: A unique identifier for the call. Can be used to connect start/end handlers.
Expand All @@ -200,23 +193,23 @@ def on_parse_end(

def with_callbacks(fn):
@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
# Combine global and local (per-instance) callbacks
callbacks = dspy.settings.get("callbacks", []) + getattr(self, "callbacks", [])
def wrapper(instance, *args, **kwargs):
# Combine global and local (per-instance) callbacks.
callbacks = dspy.settings.get("callbacks", []) + getattr(instance, "callbacks", [])

# if no callbacks are provided, just call the function
# If no callbacks are provided, just call the function
if not callbacks:
return fn(self, *args, **kwargs)
return fn(instance, *args, **kwargs)

# Generate call ID to connect start/end handlers if needed
# Generate call ID as the unique identifier for the call, this is useful for instrumentation.
call_id = uuid.uuid4().hex

inputs = inspect.getcallargs(fn, self, *args, **kwargs)
inputs = inspect.getcallargs(fn, instance, *args, **kwargs)
inputs.pop("self") # Not logging self as input

for callback in callbacks:
try:
_get_on_start_handler(callback, self, fn)(call_id=call_id, instance=self, inputs=inputs)
_get_on_start_handler(callback, instance, fn)(call_id=call_id, instance=instance, inputs=inputs)

except Exception as e:
logger.warning(f"Error when calling callback {callback}: {e}")
Expand All @@ -225,58 +218,59 @@ def wrapper(self, *args, **kwargs):
exception = None
try:
parent_call_id = ACTIVE_CALL_ID.get()
# Active ID must be set right before the function is called,
# not before calling the callbacks.
# Active ID must be set right before the function is called, not before calling the callbacks.
ACTIVE_CALL_ID.set(call_id)
results = fn(self, *args, **kwargs)
results = fn(instance, *args, **kwargs)
return results
except Exception as e:
exception = e
raise exception
finally:
# Execute the end handlers even if the function call raises an exception.
ACTIVE_CALL_ID.set(parent_call_id)
for callback in callbacks:
try:
_get_on_end_handler(callback, self, fn)(
_get_on_end_handler(callback, instance, fn)(
call_id=call_id,
outputs=results,
exception=exception,
)
except Exception as e:
logger.warning(f"Error when calling callback {callback}: {e}")
logger.warning(
f"Error when applying callback {callback}'s end handler on function {fn.__name__}: {e}."
)

return wrapper


def _get_on_start_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable:
"""
Selects the appropriate on_start handler of the callback
based on the instance and function name.
"""
if isinstance(instance, (dspy.LM)):
"""Selects the appropriate on_start handler of the callback based on the instance and function name."""
if isinstance(instance, dspy.LM):
return callback.on_lm_start
elif isinstance(instance, (dspy.Adapter)):

if isinstance(instance, dspy.Adapter):
if fn.__name__ == "format":
return callback.on_format_start
return callback.on_adapter_format_start
elif fn.__name__ == "parse":
return callback.on_parse_start
return callback.on_adapter_parse_start
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")

# We treat everything else as a module.
return callback.on_module_start


def _get_on_end_handler(callback: BaseCallback, instance: Any, fn: Callable) -> Callable:
"""
Selects the appropriate on_end handler of the callback
based on the instance and function name.
"""
"""Selects the appropriate on_end handler of the callback based on the instance and function name."""
if isinstance(instance, (dspy.LM)):
return callback.on_lm_end
elif isinstance(instance, (dspy.Adapter)):

if isinstance(instance, (dspy.Adapter)):
if fn.__name__ == "format":
return callback.on_format_end
return callback.on_adapter_format_end
elif fn.__name__ == "parse":
return callback.on_parse_end

return callback.on_adapter_parse_end
else:
raise ValueError(f"Unsupported adapter method for using callback: {fn.__name__}.")
# We treat everything else as a module.
return callback.on_module_end
34 changes: 18 additions & 16 deletions tests/callback/test_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def reset_settings():


class MyCallback(BaseCallback):
"""A simple callback that records the calls."""

def __init__(self):
self.calls = []

Expand All @@ -33,17 +35,17 @@ def on_lm_start(self, call_id, instance, inputs):
def on_lm_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_lm_end", "outputs": outputs, "exception": exception})

def on_format_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_format_start", "instance": instance, "inputs": inputs})
def on_adapter_format_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_adapter_format_start", "instance": instance, "inputs": inputs})

def on_format_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_format_end", "outputs": outputs, "exception": exception})
def on_adapter_format_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_adapter_format_end", "outputs": outputs, "exception": exception})

def on_parse_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_parse_start", "instance": instance, "inputs": inputs})
def on_adapter_parse_start(self, call_id, instance, inputs):
self.calls.append({"handler": "on_adapter_parse_start", "instance": instance, "inputs": inputs})

def on_parse_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_parse_end", "outputs": outputs, "exception": exception})
def on_adapter_parse_end(self, call_id, outputs, exception):
self.calls.append({"handler": "on_adapter_parse_end", "outputs": outputs, "exception": exception})


@pytest.mark.parametrize(
Expand Down Expand Up @@ -163,17 +165,17 @@ def test_callback_complex_module():
assert [call["handler"] for call in callback.calls] == [
"on_module_start",
"on_module_start",
"on_format_start",
"on_format_end",
"on_adapter_format_start",
"on_adapter_format_end",
"on_lm_start",
"on_lm_end",
# Parsing will run per output (n=3)
"on_parse_start",
"on_parse_end",
"on_parse_start",
"on_parse_end",
"on_parse_start",
"on_parse_end",
"on_adapter_parse_start",
"on_adapter_parse_end",
"on_adapter_parse_start",
"on_adapter_parse_end",
"on_adapter_parse_start",
"on_adapter_parse_end",
"on_module_end",
"on_module_end",
]
Expand Down

0 comments on commit 16ba98a

Please sign in to comment.