Skip to content

Commit

Permalink
Support lf.Message to be used to create lf.Template.
Browse files Browse the repository at this point in the history
- This allows metadata passing seamlessly from `lf.Message` to `lf.Template`, which enables multi-modal object chaining
- Pull up `lf.Template.from_value` to `lf.LangFunc`.
- Eval framework to create prompt bound with input example, allowing multi-modal object passing.

PiperOrigin-RevId: 631431707
  • Loading branch information
daiyip authored and langfun authors committed May 7, 2024
1 parent 47e3d13 commit 394f29d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 27 deletions.
6 changes: 4 additions & 2 deletions langfun/core/eval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1179,7 +1179,7 @@ def call_postprocess(self, lm_response: str) -> str:

def process(self, example: Any, **kwargs) -> lf.Message:
"""Process an example and returns its output."""
prompt = self.prompt.render(example=example).text
prompt = lf.Template.from_value(self.prompt, example=example)
if self.method == 'call':
return lf_structured.call(
prompt,
Expand Down Expand Up @@ -1207,7 +1207,9 @@ def process(self, example: Any, **kwargs) -> lf.Message:
else:
assert self.method == 'complete', self.method
assert isinstance(self.schema.spec, pg.typing.Object), self.schema
input_value = self.schema.spec.cls.partial(prompt)
# TODO(daiyip): Currently multi-modal inputs within the prompt for
# completion is not supported.
input_value = self.schema.spec.cls.partial(prompt.render().text)
return lf_structured.complete(
input_value,
lm=self.lm,
Expand Down
18 changes: 1 addition & 17 deletions langfun/core/langfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
"""LangFunc: Language-based functions."""

import dataclasses
from typing import Annotated, Type, Union
from typing import Annotated, Type

from langfun.core import component
from langfun.core import language_model
Expand Down Expand Up @@ -328,22 +328,6 @@ def transform_output(
"""Transforms the output message before returning from __call__."""
return lm_output

@classmethod
def from_value(
cls, value: Union[str, template_lib.Template], **kwargs
) -> 'LangFunc':
"""Create a LangFunc object from a string or template."""
if isinstance(value, LangFunc):
return value
if isinstance(value, template_lib.Template):
lfun = LangFunc(value.template_str, **kwargs)
# So lfun could acccess all attributes from value.
lfun.sym_setparent(value)
return lfun
if isinstance(value, str):
return LangFunc(template_str=value, **kwargs)
return LangFunc('{{input}}', input=value, **kwargs)


# Register converter from str to LangFunc, therefore we can always
# pass strs to attributes that accept LangFunc.
Expand Down
4 changes: 4 additions & 0 deletions langfun/core/langfunc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def test_from_value(self):
l2 = LangFunc.from_value(l1)
self.assertIs(l2, l1)

l3 = LangFunc.from_value(l1, x=1)
self.assertIsNot(l3, l1)
self.assertTrue(pg.eq(l3, LangFunc('Hello', x=1)))

c = template_lib.Template(
'{{x}} + {{l}}',
x=1,
Expand Down
9 changes: 2 additions & 7 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,13 +214,8 @@ class Flight(pg.Object):
# prompt rendering.
prompt_kwargs.pop('template_str', None)

if isinstance(prompt, str):
prompt = lf.Template(prompt, **prompt_kwargs)
elif isinstance(prompt, lf.Template):
prompt = prompt.rebind(**prompt_kwargs, raise_on_no_change=False)

if isinstance(prompt, lf.Template):
prompt = prompt.render(lm=lm)
if isinstance(prompt, (str, lf.Message, lf.Template)):
prompt = lf.Template.from_value(prompt, **prompt_kwargs).render(lm=lm)
else:
prompt = schema_lib.mark_missing(prompt)

Expand Down
23 changes: 22 additions & 1 deletion langfun/core/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dataclasses
import functools
import inspect
from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type
from typing import Annotated, Any, Callable, Iterator, Set, Tuple, Type, Union

import jinja2
from jinja2 import meta as jinja2_meta
Expand Down Expand Up @@ -495,6 +495,27 @@ class Bar(lf.Template):
t.sym_setparent(self)
return t

@classmethod
def from_value(
cls,
value: Union[str, message_lib.Message, 'Template'],
**kwargs
) -> 'Template':
"""Create a template object from a string or template."""
if isinstance(value, cls):
return value.clone(override=kwargs) if kwargs else value # pylint: disable=no-value-for-parameter
if isinstance(value, str):
return cls(template_str=value, **kwargs)
if isinstance(value, message_lib.Message):
kwargs.update(value.metadata)
return cls(template_str=value.text, **kwargs)
if isinstance(value, Template):
lfun = cls(template_str=value.template_str, **kwargs)
# So lfun could acccess all attributes from value.
lfun.sym_setparent(value)
return lfun
return cls(template_str='{{input}}', input=value, **kwargs)


# Register converter from str to LangFunc, therefore we can always
# pass strs to attributes that accept LangFunc.
Expand Down

0 comments on commit 394f29d

Please sign in to comment.