Skip to content

Commit

Permalink
feat: add a python generator example
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch committed Apr 28, 2023
1 parent bab719a commit 6ab54da
Show file tree
Hide file tree
Showing 6 changed files with 309 additions and 2 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: check-docstring-first
- id: check-added-large-files
exclude: \.(geojson)$
- id: check-yaml
Expand Down
291 changes: 291 additions & 0 deletions examples/python_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,291 @@
"""An example of how to test Python code generating prompts"""
import re

# Brining some "prompt generator" classes
from promptimize.prompt_cases import LangchainPromptCase

# Bringing some useful eval function that help evaluating and scoring responses
# eval functions have a handle on the prompt object and are expected
# to return a score between 0 and 1
from langchain import PromptTemplate
from langchain.output_parsers import StructuredOutputParser, ResponseSchema
import demjson

from RestrictedPython import compile_restricted, safe_globals, safe_builtins
from RestrictedPython.Guards import guarded_unpack_sequence
from RestrictedPython.Eval import default_guarded_getiter


response_schemas = [
ResponseSchema(
name="python_function",
description="the python function itself",
),
ResponseSchema(
name="functon_name",
description="the name of the function",
),
ResponseSchema(name="test_cases", description="test cases"),
ResponseSchema(
name="hints",
description="if any, any recommendations to the users about clarifying their prompt",
),
]

output_parser = StructuredOutputParser.from_response_schemas(response_schemas)
format_instructions = output_parser.get_format_instructions().replace("\t", " ")
"""
* you include great useful docstrings and doctests that follow the Google conventions
"""

template = """\
System: you are an AI that writes python function that accomplish specific tasks
Python guidelines:
* you follow the PEP8 conventions
* use 4 spaces indent, no tabs!
* use snake case (using underscores)
The output should be a VALID JSON blob with the following keys:
* "python_function" as a string with the python function code
* "function_name" as the name of the function
* "hints": as some hints about how to use the function
User: write a function that multipllies a number by 2 and returns the result
System:
{
"python_function": "def multiply_by_two(number):\\n return number * 2\\n"
"function_name": "multiply_by_two",
"hints": "This function is not that helpful as you can simply mulitply by two\\ninstead of calling this function"
}
User: {{ user_input }}
System:
""" # noqa

lc_template = PromptTemplate(
input_variables=["user_input"],
partial_variables={"format_instructions": format_instructions},
template=template,
template_format="jinja2",
)


def function_from_string(function_as_string, function_name):
restricted_code = compile_restricted(function_as_string, "<inline code>", "exec")

# Define a separate environment for the code to run in
execution_globals = safe_globals.copy()
execution_globals.update(
{
"__builtins__": safe_builtins,
"_unpack_sequence_": guarded_unpack_sequence,
"_getiter_": default_guarded_getiter,
}
)

# Execute the code in the restricted environment
exec(restricted_code, execution_globals)

# Access the function from the restricted environment
return execution_globals[function_name]


def test(func, args, expected_result):
if func:
if not isinstance(args, (list, tuple)):
args = [args]
try:
result = func(*args)
if expected_result == result:
return 1
except Exception:
return 0
return 0


def decode_shitty_json(s):
json_match = re.search(r"\{[\s\S]*\}", s)

if json_match:
json_string = json_match.group()

# Parse the JSON string using demjson
json_data = demjson.decode(json_string)

return json_data
return None


def test_is_prime(prompt_case, val, exp):
return test(prompt_case.python_function, val, exp)


class PythonGeneratorPrompt(LangchainPromptCase):
def post_run(self):
success = False
self.python_function = None
self.f = None
try:
self.response = decode_shitty_json(self.response)
success = True
except Exception as e:
self.error = str(e)

if success:
# try:
f = function_from_string(
self.response.get("python_function"), self.response.get("function_name")
)
self.python_function = f
self.f = f
# except Exception as e:
# self.error = str(e)


sql_prompts = [
PythonGeneratorPrompt(
lc_template,
key="is_prime",
user_input="write a function that tests if an number is a prime number, returns a boolean",
evaluators=[
lambda x: test(x.f, 2, True),
lambda x: test(x.f, 4, False),
lambda x: test(x.f, 7, True),
lambda x: test(x.f, 10, False),
lambda x: test(x.f, 11, True),
lambda x: test(x.f, 113, True),
],
),
PythonGeneratorPrompt(
lc_template,
key="gcd",
user_input="write a function that finds the greatest common divisor (GCD) of two numbers?",
evaluators=[
lambda x: test(x.f, [14, 28], 14),
lambda x: test(x.f, [56, 98], 14),
lambda x: test(x.f, [81, 153], 9),
],
),
PythonGeneratorPrompt(
lc_template,
key="factorial",
user_input="write a function that calculates the factorial of a given number",
evaluators=[
lambda x: test(x.f, 0, 1),
lambda x: test(x.f, 1, 1),
lambda x: test(x.f, 5, 120),
lambda x: test(x.f, 7, 5040),
lambda x: test(x.f, 10, 3628800),
],
),
PythonGeneratorPrompt(
lc_template,
key="is_palindrome",
user_input="write a function that determines if a given string is a palindrome",
evaluators=[
lambda x: test(x.f, "racecar", True),
lambda x: test(x.f, "hello", False),
lambda x: test(x.f, "madam", True),
lambda x: test(x.f, "python", False),
lambda x: test(x.f, "Aibohphobia", True),
],
),
PythonGeneratorPrompt(
lc_template,
key="fibonacci",
user_input=(
"write a function that generates the Fibonacci sequence ",
"up to a specified number of terms",
),
evaluators=[
lambda x: test(x.f, 1, [0]),
lambda x: test(x.f, 2, [0, 1]),
lambda x: test(x.f, 5, [0, 1, 1, 2, 3]),
lambda x: test(x.f, 10, [0, 1, 1, 2, 3, 5, 8, 13, 21, 34]),
lambda x: test(x.f, 7, [0, 1, 1, 2, 3, 5, 8]),
],
),
PythonGeneratorPrompt(
lc_template,
key="sum_of_multiples",
user_input=(
"write a function that calculates the sum of all multiples ",
"of 3 and 5 below a given number",
),
evaluators=[
lambda x: test(x.f, 10, 23),
lambda x: test(x.f, 20, 78),
lambda x: test(x.f, 30, 195),
lambda x: test(x.f, 50, 543),
lambda x: test(x.f, 100, 2418),
],
),
PythonGeneratorPrompt(
lc_template,
key="is_leap_year",
user_input="write a function that checks whether a given year is a leap year",
evaluators=[
lambda x: test(x.f, 2000, True),
lambda x: test(x.f, 1900, False),
lambda x: test(x.f, 2020, True),
lambda x: test(x.f, 2021, False),
lambda x: test(x.f, 2400, True),
],
),
PythonGeneratorPrompt(
lc_template,
key="longest_substring_without_repeating_chars",
user_input=(
"write a function that finds the longest substring of a ",
"given string without repeating characters",
),
evaluators=[
lambda x: test(x.f, "abcabcbb", "abc"),
lambda x: test(x.f, "bbbbbb", "b"),
lambda x: test(x.f, "pwwkew", "wke"),
lambda x: test(x.f, "abcdefgh", "abcdefgh"),
lambda x: test(x.f, "abcbdacf", "bdacf"),
],
),
PythonGeneratorPrompt(
lc_template,
key="longest_common_prefix",
user_input="write a function that finds the longest common prefix of a list of strings",
evaluators=[
lambda x: test(x.f, ["flower", "flow", "flight"], "fl"),
lambda x: test(x.f, ["dog", "racecar", "car"], ""),
lambda x: test(x.f, ["interspecies", "interstellar", "interstate"], "inter"),
lambda x: test(x.f, ["prefix", "suffix", "infix"], ""),
lambda x: test(x.f, ["geeksforgeeks", "geeks", "geek"], "geek"),
],
),
PythonGeneratorPrompt(
lc_template,
key="sum_of_digits",
user_input="write a function that calculates the sum of the digits of a given number",
evaluators=[
lambda x: test(x.f, 123, 6),
lambda x: test(x.f, 456, 15),
lambda x: test(x.f, 789, 24),
lambda x: test(x.f, 1001, 2),
lambda x: test(x.f, 54321, 15),
],
),
PythonGeneratorPrompt(
lc_template,
key="decimal_to_binary",
user_input=(
"write a function that converts a given decimal number to " "its binary representation"
),
evaluators=[
lambda x: test(x.f, 2, "10"),
lambda x: test(x.f, 7, "111"),
lambda x: test(x.f, 10, "1010"),
lambda x: test(x.f, 16, "10000"),
lambda x: test(x.f, 31, "11111"),
],
),
]
2 changes: 1 addition & 1 deletion promptimize/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def cli():
"-t",
type=click.FLOAT,
default=0.5,
help="max_tokens passed to the model",
help="the temperature passed to the model",
)
@click.option(
"--engine",
Expand Down
12 changes: 12 additions & 0 deletions promptimize/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,15 @@ def all_words(response: str, words: List[str], case_sensitive: bool = False) ->
0
"""
return _common_word_search(response, words, case_sensitive, match_type="all")


base_all = all
base_any = any


def all(iteratable):
return 1 if base_all([i == 1 for i in iteratable]) else 0


def any(iteratable):
return 1 if base_any([i == 1 for i in iteratable]) else 0
3 changes: 3 additions & 0 deletions promptimize/prompt_cases.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@ def to_dict(self, verbose=False):
"weight": self.weight,
"execution": self.execution.to_dict(),
}
if hasattr(self, "error"):
d["error"] = self.error
return d

def print(self, verbose=False, style="yaml"):
Expand All @@ -133,6 +135,7 @@ def test(self):

if len(test_results):
self.execution.score = sum(test_results) / len(test_results)
self.execution.results = test_results
self.was_tested = True

@property
Expand Down
2 changes: 2 additions & 0 deletions requirements-examples.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
RestrictedPython
demjson

0 comments on commit 6ab54da

Please sign in to comment.