Skip to content

Commit

Permalink
Include modalities in LLM cache key.
Browse files Browse the repository at this point in the history
Previously only text prompt is included.

PiperOrigin-RevId: 667717352
  • Loading branch information
daiyip authored and langfun authors committed Aug 26, 2024
1 parent 683ad1b commit 2ef2cae
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 1 deletion.
2 changes: 1 addition & 1 deletion langfun/core/llms/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,4 @@ def _sym_clone(self, deep: bool, memo: Any = None) -> 'LMCacheBase':

def default_key(lm: lf.LanguageModel, prompt: lf.Message, seed: int) -> Any:
"""Default key for LM cache."""
return (prompt.text, lm.sampling_options.cache_key(), seed)
return (prompt.text_with_modality_hash, lm.sampling_options.cache_key(), seed)
28 changes: 28 additions & 0 deletions langfun/core/llms/cache/in_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,34 @@ def cache_entry(response_text, cache_seed=0):
)
self.assertEqual(cache.stats.num_deletes, 1)

def test_cache_with_modalities(self):

class CustomModality(lf.Modality):
content: str

def to_bytes(self):
return self.content.encode()

cache = in_memory.InMemory()
lm = fake.StaticSequence(['1', '2', '3', '4', '5', '6'], cache=cache)
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('foo')))
lm(lf.UserMessage('hi <<[[image]]>>', image=CustomModality('bar')))
self.assertEqual(
list(cache.keys()),
[
(
'hi <<[[image]]>><image>acbd18db</image>',
(None, None, 1, 40, None, None),
0,
),
(
'hi <<[[image]]>><image>37b51d19</image>',
(None, None, 1, 40, None, None),
0,
),
],
)

def test_ttl(self):
cache = in_memory.InMemory(ttl=1)
lm = fake.StaticSequence(['1', '2', '3'], cache=cache)
Expand Down
10 changes: 10 additions & 0 deletions langfun/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,16 @@ def apply_updates(self, updates: dict[pg.KeyPath, pg.FieldUpdate]) -> None:
# API for supporting modalities.
#

@property
def text_with_modality_hash(self) -> str:
"""Returns text with modality object placeheld by their 8-byte MD5 hash."""
parts = [self.text]
for name, modality_obj in self.referred_modalities().items():
parts.append(
f'<{name}>{modality_obj.hash}</{name}>'
)
return ''.join(parts)

def get_modality(
self, var_name: str, default: Any = None, from_message_chain: bool = True
) -> modality.Modality | None:
Expand Down
14 changes: 14 additions & 0 deletions langfun/core/message_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,20 @@ def test_referred_modalities(self):
},
)

def test_text_with_modality_hash(self):
m = message.UserMessage(
'hi, this is a <<[[img1]]>> and <<[[x.img2]]>>',
img1=CustomModality('foo'),
x=dict(img2=CustomModality('bar')),
)
self.assertEqual(
m.text_with_modality_hash,
(
'hi, this is a <<[[img1]]>> and <<[[x.img2]]>>'
'<img1>acbd18db</img1><x.img2>37b51d19</x.img2>'
)
)

def test_chunking(self):
m = message.UserMessage(
inspect.cleandoc("""
Expand Down
12 changes: 12 additions & 0 deletions langfun/core/modality.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Interface for modality (e.g. Image, Video, etc.)."""

import abc
import functools
import hashlib
from typing import Any, ContextManager
from langfun.core import component
import pyglove as pg
Expand All @@ -35,6 +37,11 @@ class Modality(component.Component):
REF_START = '<<[['
REF_END = ']]>>'

def _on_bound(self):
super()._on_bound()
# Invalidate cached hash if modality member is changed.
self.__dict__.pop('hash', None)

def format(self, *args, **kwargs) -> str:
if self.referred_name is None or not pg.object_utils.thread_local_get(
_TLS_MODALITY_AS_REF, False
Expand All @@ -46,6 +53,11 @@ def format(self, *args, **kwargs) -> str:
def to_bytes(self) -> bytes:
"""Returns content in bytes."""

@functools.cached_property
def hash(self) -> str:
"""Returns a 8-byte MD5 hash as the identifier for this modality object."""
return hashlib.md5(self.to_bytes()).hexdigest()[:8]

@classmethod
def text_marker(cls, var_name: str) -> str:
"""Returns a marker in the text for this object."""
Expand Down
1 change: 1 addition & 0 deletions langfun/core/modality_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def test_basic(self):
v = CustomModality('a')
self.assertIsNone(v.referred_name)
self.assertEqual(str(v), "CustomModality(\n content = 'a'\n)")
self.assertEqual(v.hash, '0cc175b9')

_ = pg.Dict(metadata=pg.Dict(x=pg.Dict(metadata=pg.Dict(y=v))))
self.assertEqual(v.referred_name, 'x.metadata.y')
Expand Down

0 comments on commit 2ef2cae

Please sign in to comment.