Skip to content

Commit

Permalink
Introduce lf.query_reward.
Browse files Browse the repository at this point in the history
This helper method channels reward computation with the user-defined output class.

```python
 class Answer(pg.Object):
   final_answer: int

   def __reward__(self,
                  inputs: lf.Template,
                  expected_output: 'Answer',
                  metadata: dict[str, Any]):
      del inputs
      return (
        1.0 if self.final_answer == expected_output.final_answer else -1.0
      ) * metadata['weight']

    mapping_example_str = pg.to_json_str(lf.MappingExample(
        input=lf.Template('{{x}} + {{y}}', x=1, y=1),
        output=Answer(final_answer=2),
        metadata=dict(weight=0.5)
    ))
    reward = lf.query_reward(
        mapping_example_str
        'Answer(2)'
     ),
     assert reward == 0.5
```

PiperOrigin-RevId: 674059707
  • Loading branch information
daiyip authored and langfun authors committed Sep 12, 2024
1 parent d71ba7a commit e626b0c
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 0 deletions.
1 change: 1 addition & 0 deletions langfun/core/structured/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@
from langfun.core.structured.prompting import query
from langfun.core.structured.prompting import query_prompt
from langfun.core.structured.prompting import query_output
from langfun.core.structured.prompting import query_reward

from langfun.core.structured.description import DescribeStructure
from langfun.core.structured.description import describe
Expand Down
55 changes: 55 additions & 0 deletions langfun/core/structured/prompting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Symbolic query."""

import functools
from typing import Any, Callable, Type, Union

import langfun.core as lf
Expand Down Expand Up @@ -265,3 +266,57 @@ def query_output(
return query(
'Unused prompt', schema, lm=fake.StaticResponse(response), **kwargs
)


def query_reward(
mapping_example: Union[str, mapping.MappingExample],
response: Union[str, lf.Message],
) -> float | None:
"""Returns the reward of an LLM response based on an mapping example."""
if isinstance(mapping_example, str):
mapping_example = pg.from_json_str(mapping_example)
assert isinstance(mapping_example, mapping.MappingExample), mapping_example
schema = mapping_example.schema

if schema and isinstance(schema.spec, pg.typing.Object):
output_cls = schema.spec.cls
elif schema is None and isinstance(mapping_example.output, pg.Object):
output_cls = mapping_example.output.__class__
else:
output_cls = None

reward_fn = _reward_fn(output_cls)
if reward_fn is None:
return None

return reward_fn(
query_output(response, output_cls),
mapping_example.input,
mapping_example.output,
mapping_example.metadata,
)


@functools.cache
def _reward_fn(cls) -> Callable[
[
pg.Object, # Actual output object.
Any, # Input object.
pg.Object, # Expected output object.
pg.Dict # User metadata.
], float] | None:
"""Returns the reward function for a class that is being queried."""
if not callable(getattr(cls, '__reward__', None)):
return None

signature = pg.typing.signature(cls.__reward__)
num_args = len(signature.args)
if num_args < 2 or num_args > 4:
raise TypeError(
f'`{cls.__type_name__}.__reward__` should have signature: '
'`__reward__(self, input, [expected_output], [expected_metadata])`.'
)
def _reward(self, input, expected_output, metadata): # pylint: disable=redefined-builtin
args = [self, input, expected_output, metadata]
return cls.__reward__(*args[:num_args])
return _reward
124 changes: 124 additions & 0 deletions langfun/core/structured/prompting_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"""Tests for structured prompting."""

import inspect
import math
from typing import Any
import unittest

import langfun.core as lf
Expand Down Expand Up @@ -382,6 +384,128 @@ def test_query_output(self):
1,
)

def test_query_reward(self):

class Answer(pg.Object):
final_answer: int

def __reward__(self, inputs: lf.Template) -> None:
diff = abs(self.final_answer - (inputs.x + inputs.y))
# Center screwed sigmoid scaled to [-1.0 and 1.0].
return 4 / (1 + math.exp(diff)) - 1.0

# Case 1: Reward function based on input and output.
self.assertEqual(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
schema=Answer,
output=Answer(final_answer=2),
),
'Answer(2)'
),
1.0
)
self.assertEqual(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=2, y=3),
output=Answer(final_answer=2),
).to_json_str(),
'Answer(5)'
),
1.0
)

# Case 2: Reward function based on input, result and expected output.
class Answer2(pg.Object):
final_answer: int

def __reward__(self, inputs: lf.Template, expected_output: 'Answer2'):
return (
1.0 if self.final_answer == expected_output.final_answer else -1.0
)

self.assertEqual(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
output=Answer2(final_answer=2),
),
'Answer2(3)'
),
-1.0
)

# Case 3: Reward function based on input, result, expected output
# and metadata.
class Answer3(pg.Object):
final_answer: int

def __reward__(self,
inputs: lf.Template,
expected_output: 'Answer3',
metadata: dict[str, Any]):
del inputs
return (
1.0 if self.final_answer == expected_output.final_answer else -1.0
) * metadata['weight']

self.assertEqual(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
output=Answer3(final_answer=2),
metadata=dict(weight=0.5)
),
'Answer3(3)'
),
-0.5
)

# Case 4: No reward function is provided.
class Answer4(pg.Object):
final_answer: int

self.assertIsNone(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
output=Answer4(final_answer=2),
),
'Answer2(2)'
)
)

# Case 5: Not a structured output.
self.assertIsNone(
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
output='2',
),
'2'
)
)

# Case 6: Bad reward function.
class Answer5(pg.Object):
final_answer: int

def __reward__(self):
return 0.0

with self.assertRaisesRegex(
TypeError, '.*Answer5.__reward__` should have signature'
):
prompting.query_reward(
mapping.MappingExample(
input=lf.Template('{{x}} + {{y}}', x=1, y=1),
output=Answer5(final_answer=2),
),
'Answer5(2)'
)


class QueryStructurePythonTest(unittest.TestCase):

Expand Down

0 comments on commit e626b0c

Please sign in to comment.