Skip to content

Commit

Permalink
Fix broken inspect_history and broken prompt cache (stanfordnlp#1744)
Browse files Browse the repository at this point in the history
* Fix broken inspect_history and broken prompt cache

* Remove errant print statement

* Move global inspect history into base_lm

* Remove skip parameter

* Delete examples/temp.py

* Minor adjustment to make adapters go back to original behavior

---------

Co-authored-by: Omar Khattab <[email protected]>
  • Loading branch information
isaacbmiller and okhat authored Nov 3, 2024
1 parent 2d3ed8d commit 9170658
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 19 deletions.
9 changes: 1 addition & 8 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dspy.clients import * # isort: skip
from dspy.adapters import * # isort: skip
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging

settings = dsp.settings

configure_dspy_loggers(__name__)
Expand Down Expand Up @@ -70,10 +69,4 @@
BootstrapRS = dspy.teleprompt.BootstrapFewShotWithRandomSearch
COPRO = dspy.teleprompt.COPRO
MIPROv2 = dspy.teleprompt.MIPROv2
Ensemble = dspy.teleprompt.Ensemble


# TODO: Consider if this should access settings.lm *or* a list that's shared across all LMs in the program.
def inspect_history(*args, **kwargs):
from dspy.clients.lm import GLOBAL_HISTORY, _inspect_history
return _inspect_history(GLOBAL_HISTORY, *args, **kwargs)
Ensemble = dspy.teleprompt.Ensemble
3 changes: 1 addition & 2 deletions dspy/adapters/chat_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def format_fields(fields_with_values: Dict[FieldInfoWithName, Any], assume_text=
else:
output[-1]["text"] += formatted_field_value["text"]
if assume_text:
return "\n\n".join(output)
return "\n\n".join(output).strip()
else:
return output

Expand Down Expand Up @@ -396,7 +396,6 @@ def format_signature_fields_for_instructions(fields: Dict[str, FieldInfo]):
parts.append(format_signature_fields_for_instructions(signature.input_fields))
parts.append(format_signature_fields_for_instructions(signature.output_fields))
parts.append(format_fields({BuiltInCompletedOutputFieldInfo: ""}, assume_text=True))

instructions = textwrap.dedent(signature.instructions)
objective = ("\n" + " " * 8).join([""] + instructions.splitlines())
parts.append(f"In adhering to this structure, your objective is: {objective}")
Expand Down
2 changes: 1 addition & 1 deletion dspy/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .lm import LM
from .base_lm import BaseLM
from .base_lm import BaseLM, inspect_history
26 changes: 21 additions & 5 deletions dspy/clients/base_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod

GLOBAL_HISTORY = []

class BaseLM(ABC):
def __init__(self, model, model_type='chat', temperature=0.0, max_tokens=1000, cache=True, **kwargs):
Expand All @@ -14,7 +15,10 @@ def __call__(self, prompt=None, messages=None, **kwargs):
pass

def inspect_history(self, n: int = 1):
_inspect_history(self, n)
_inspect_history(self.history, n)

def update_global_history(self, entry):
GLOBAL_HISTORY.append(entry)


def _green(text: str, end: str = "\n"):
Expand All @@ -24,15 +28,21 @@ def _green(text: str, end: str = "\n"):
def _red(text: str, end: str = "\n"):
return "\x1b[31m" + str(text) + "\x1b[0m" + end

def _blue(text: str, end: str = "\n"):
return "\x1b[34m" + str(text) + "\x1b[0m" + end


def _inspect_history(lm, n: int = 1):
def _inspect_history(history, n: int = 1):
"""Prints the last n prompts and their completions."""

for item in reversed(lm.history[-n:]):
for item in history[-n:]:
messages = item["messages"] or [{"role": "user", "content": item["prompt"]}]
outputs = item["outputs"]
timestamp = item.get("timestamp", "Unknown time")

print("\n\n\n")
print("\x1b[34m" + f"[{timestamp}]" + "\x1b[0m" + "\n")

for msg in messages:
print(_red(f"{msg['role'].capitalize()} message:"))
if isinstance(msg["content"], str):
Expand All @@ -43,11 +53,13 @@ def _inspect_history(lm, n: int = 1):
if c["type"] == "text":
print(c["text"].strip())
elif c["type"] == "image_url":
image_str = ""
if "base64" in c["image_url"].get("url", ""):
len_base64 = len(c["image_url"]["url"].split("base64,")[1])
print(f"<{c['image_url']['url'].split('base64,')[0]}base64,<IMAGE BASE 64 ENCODED({str(len_base64)})>")
image_str = f"<{c['image_url']['url'].split('base64,')[0]}base64,<IMAGE BASE 64 ENCODED({str(len_base64)})>"
else:
print(f"<image_url: {c['image_url']['url']}>")
image_str = f"<image_url: {c['image_url']['url']}>"
print(_blue(image_str.strip()))
print("\n")

print(_red("Response:"))
Expand All @@ -58,3 +70,7 @@ def _inspect_history(lm, n: int = 1):
print(_red(choices_text, end=""))

print("\n\n\n")

def inspect_history(n: int = 1):
"""The global history shared across all LMs."""
return _inspect_history(GLOBAL_HISTORY, n)
4 changes: 1 addition & 3 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
if "LITELLM_LOCAL_MODEL_COST_MAP" not in os.environ:
os.environ["LITELLM_LOCAL_MODEL_COST_MAP"] = "True"

GLOBAL_HISTORY = []

logger = logging.getLogger(__name__)

class LM(BaseLM):
Expand Down Expand Up @@ -109,7 +107,7 @@ def __call__(self, prompt=None, messages=None, **kwargs):
model_type=self.model_type,
)
self.history.append(entry)
GLOBAL_HISTORY.append(entry)
self.update_global_history(entry)

return outputs

Expand Down
1 change: 1 addition & 0 deletions dspy/utils/dummies.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ def format_answer_fields(field_names_and_values: Dict[str, Any]):
entry = dict(**entry, outputs=outputs, usage=0)
entry = dict(**entry, cost=0)
self.history.append(entry)
self.update_global_history(entry)

return outputs

Expand Down
68 changes: 68 additions & 0 deletions tests/clients/test_inspect_global_history.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import pytest
from dspy.utils.dummies import DummyLM
from dspy.clients.base_lm import GLOBAL_HISTORY
import dspy

@pytest.fixture(autouse=True)
def clear_history():
GLOBAL_HISTORY.clear()
yield

def test_inspect_history_basic(capsys):
# Configure a DummyLM with some predefined responses
lm = DummyLM([{"response": "Hello"}, {"response": "How are you?"}])
dspy.settings.configure(lm=lm)

# Make some calls to generate history
predictor = dspy.Predict("query: str -> response: str")
predictor(query="Hi")
predictor(query="What's up?")

# Test inspecting all history
history = GLOBAL_HISTORY
print(capsys)
assert len(history) > 0
assert isinstance(history, list)
assert all(isinstance(entry, dict) for entry in history)
assert all("messages" in entry for entry in history)

def test_inspect_history_with_n(capsys):
lm = DummyLM([{"response": "One"}, {"response": "Two"}, {"response": "Three"}])
dspy.settings.configure(lm=lm)

# Generate some history
predictor = dspy.Predict("query: str -> response: str")
predictor(query="First")
predictor(query="Second")
predictor(query="Third")

dspy.inspect_history(n=2)
# Test getting last 2 entries
out, err = capsys.readouterr()
assert not "First" in out
assert "Second" in out
assert "Third" in out

def test_inspect_empty_history(capsys):
# Configure fresh DummyLM
lm = DummyLM([])
dspy.settings.configure(lm=lm)

# Test inspecting empty history
dspy.inspect_history()
history = GLOBAL_HISTORY
assert len(history) == 0
assert isinstance(history, list)

def test_inspect_history_n_larger_than_history(capsys):
lm = DummyLM([{"response": "First"}, {"response": "Second"}])
dspy.settings.configure(lm=lm)

predictor = dspy.Predict("query: str -> response: str")
predictor(query="Query 1")
predictor(query="Query 2")

# Request more entries than exist
dspy.inspect_history(n=5)
history = GLOBAL_HISTORY
assert len(history) == 2 # Should return all available entries

0 comments on commit 9170658

Please sign in to comment.