Skip to content

Commit

Permalink
Force select to use right tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
slundberg committed May 17, 2023
1 parent 686c6f7 commit b1e656f
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 52 deletions.
4 changes: 2 additions & 2 deletions guidance/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.42"
__version__ = "0.0.43"

import types
import sys
Expand All @@ -14,7 +14,7 @@
# allows us to start inner event loops within jupyter notebooks
nest_asyncio.apply()

llm = llms.OpenAI()
llm = None

# This is makes the guidance module callable
class Guidance(types.ModuleType):
Expand Down
9 changes: 6 additions & 3 deletions guidance/_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, text, llm=None, cache_seed=0, logprobs=None, silent='auto', a

# save the given parameters
self._text = text
self.llm = llm or guidance.llm
self.llm = llm or getattr(guidance, "llm", None)
self.cache_seed = cache_seed
self.caching = caching
self.logprobs = logprobs
Expand Down Expand Up @@ -295,8 +295,11 @@ async def execute(self):
await asyncio.sleep(0)

# run the program and capture the output
with self.llm.session(asynchronous=True) as llm_session:
await self._executor.run(llm_session)
if self.llm is None:
await self._executor.run(None)
else:
with self.llm.session(asynchronous=True) as llm_session:
await self._executor.run(llm_session)
self._text = self._executor.prefix

# delete the executor and so mark the program as not executing
Expand Down
65 changes: 27 additions & 38 deletions guidance/_program_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,18 +269,16 @@ def update_return_value(s):
elif isinstance(arg, NamedArgument):
named_args[arg.name] = arg.value
sig = inspect.signature(command_function)
if "parser_prefix" in sig.parameters:
named_args["parser_prefix"] = strip_markers(self.prefix)
if "parser" in sig.parameters:
named_args["parser"] = self
if "partial_output" in sig.parameters:
named_args["partial_output"] = partial_output
if "next_node" in sig.parameters:
named_args["next_node"] = next_node
if "next_next_node" in sig.parameters:
named_args["next_next_node"] = next_next_node
if "prev_node" in sig.parameters:
named_args["prev_node"] = prev_node
if "_parser_context" in sig.parameters:
named_args["_parser_context"] = {
"parser_prefix": strip_markers(self.prefix),
"parser": self,
"partial_output": partial_output,
"next_node": next_node,
"next_next_node": next_next_node,
"prev_node": prev_node,
"block_content": None
}

# call the command
try:
Expand All @@ -294,7 +292,7 @@ def update_return_value(s):
self.caught_stop_iteration = True

# call partial output if the command didn't itself (and we are still executing)
if "partial_output" not in sig.parameters:
if command_output is not None:
partial_output(command_output)
else:
# if the variable does not exist we just pause execution
Expand Down Expand Up @@ -372,40 +370,31 @@ def update_return_value(s):
positional_args.append(arg.value)
elif isinstance(arg, NamedArgument):
named_args[arg.name] = arg.value

# see if the command expects parser context
sig = inspect.signature(command_function)
if "parser_prefix" in sig.parameters:
named_args["parser_prefix"] = strip_markers(self.prefix)
if "parser" in sig.parameters:
named_args["parser"] = self
if "block_content" in sig.parameters:
named_args["block_content"] = self.block_content[-1]
if "partial_output" in sig.parameters:
named_args["partial_output"] = self.extend_prefix
if "parser_node" in sig.parameters:
named_args["parser_node"] = node
if "next_node" in sig.parameters:
named_args["next_node"] = node.children[-1]
if "next_next_node" in sig.parameters:
named_args["next_next_node"] = next_node
if "prev_node" in sig.parameters:
named_args["prev_node"] = node.children[0]
if "_parser_context" in sig.parameters:
named_args["_parser_context"] = {
"parser_prefix": strip_markers(self.prefix),
"parser": self,
"block_content": self.block_content[-1],
"partial_output": self.extend_prefix,
"parser_node": node,
"next_node": node.children[-1],
"next_next_node": next_node,
"prev_node": node.children[0]
}

# call the optionally asyncronous command
if inspect.iscoroutinefunction(command_function):
command_output = await command_function(*positional_args, **named_args)
else:
command_output = command_function(*positional_args, **named_args)

if "partial_output" not in sig.parameters:
# if the command didn't send partial output we do it here
if command_output is not None:
self.extend_prefix(command_output)

# if we stopped execution we need to remove the start marker
# if not self.executing:
# self.prefix = self.prefix[:pos] + self.prefix[pos+len(start_marker):]
# return

else:
command_output = ""

# pop off the block content after the command call
self.block_content.pop()

Expand Down
25 changes: 23 additions & 2 deletions guidance/library/_select.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import itertools

async def select(variable_name="selected", options=None, logprobs=None, _parser_context=None):
''' Select a value from a list of choices.
Expand Down Expand Up @@ -28,12 +30,31 @@ async def select(variable_name="selected", options=None, logprobs=None, _parser_
options.append(block_content[i+1].text)

option_tokens = [parser.program.llm.encode(option) for option in options]
ids_used = set(itertools.chain.from_iterable(option_tokens))

# find the common prefix of all the options BROKEN STILLL
max_tokens = max([len(o) for o in option_tokens])
# for i in range(max_tokens):
# all_match = True
# pos_val = None
# for j in range(len(option_tokens)):
# if len(option_tokens[j]) <= i:
# if pos_val is None:
# pos_val = option_tokens[j][i]
# elif option_tokens[j][i] != pos_val:
# all_match = False
# break
# if not all_match:
# max_tokens = i
# break



# [TODO] we should force the LM to generate a valid specific option
# for openai this means setting logprobs to valid token ids
# call the session to get the logprobs for each option
gen_obj = await parser.llm_session(
parser_prefix,
max_tokens=max([len(o) for o in option_tokens]),
logit_bias={str(id): 50 for id in ids_used},
logprobs=10,
cache_seed=0
)
Expand Down
13 changes: 12 additions & 1 deletion guidance/llms/_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class will return the next item in the list each time it is
output = [f"mock output {i}" for i in range(100)]
self.output = output
self.index = 0
self._tokenizer = MockTokenizer()

def __call__(self, *args, n=1, stream=False, **kwargs):
choices = []
Expand All @@ -40,4 +41,14 @@ def role_start(self, role):
return "<|im_start|>"+role+"\n"

def role_end(self, role=None):
return "<|im_end|>"
return "<|im_end|>"

class MockTokenizer():
def __init__(self):
pass

def encode(self, text):
return [s for s in text.encode("utf-8")]

def decode(self, ids):
return "".join([chr(i) for i in ids])
6 changes: 3 additions & 3 deletions guidance/llms/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ def prompt_to_messages(prompt):

assert prompt.endswith("<|im_start|>assistant\n"), "When calling OpenAI chat models you must generate only directly inside the assistant role! The OpenAI API does not currently support partial assistant prompting."

pattern = r'<\|im_start\|>(\w+)(.*?)(?=<\|im_end\|>)'
pattern = r'<\|im_start\|>(\w+)(.*?)(?=<\|im_end\|>|$)'
matches = re.findall(pattern, prompt, re.DOTALL)

if not matches:
return [{'role': 'user', 'content': prompt.strip()}]

for match in matches:
role, content = match
content = content.strip()
content = content.strip() # should we do this?
messages.append({'role': role, 'content': content})

return messages
Expand Down Expand Up @@ -109,7 +109,7 @@ def __init__(self, model=None, caching=True, max_retries=5, max_calls_per_min=60
endpoint = os.environ.get("OPENAI_ENDPOINT", None)

import tiktoken
self._tokenizer = tiktoken.get_encoding("cl100k_base")
self._tokenizer = tiktoken.get_encoding(tiktoken.encoding_for_model(model).name)
self.chat_mode = chat_mode

self.model_name = model
Expand Down
26 changes: 23 additions & 3 deletions tests/library/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,27 @@ def test_select():
""" Test the behavior of `select`.
"""

llm = guidance.llms.OpenAI("text-curie-001")
prompt = guidance("Is Everest very tall?\nAnswer 'Yes' or 'No': '{{#select 'name'}}Yes{{or}}No{{/select}}", llm=llm)
out = prompt()
llm = guidance.llms.OpenAI("text-curie-001", caching=False)
program = guidance("Is Everest very tall?\nAnswer 'Yes' or 'No': '{{#select 'name'}}Yes{{or}}No{{/select}}", llm=llm)
out = program()
assert out["name"] in ["Yes", "No"]

def test_select_longtext():
""" Test the behavior of `select`.
"""

llm = guidance.llms.OpenAI("text-curie-001", caching=False)
program = guidance("""Is Everest very tall?\nAnswer:
{{#select 'name'}}No because of all the other ones.{{or}}Yes because I saw it.{{/select}}""", llm=llm)
out = program()
assert out["name"] in ["No because of all the other ones.", "Yes because I saw it."]

def test_select_with_list():
""" Test the behavior of `select` in non-block mode.
"""

# llm = guidance.llms.Mock("Yes")
llm = guidance.llms.OpenAI("text-curie-001", caching=False)
program = guidance("Is Everest very tall?\nAnswer 'Yes' or 'No': '{{select 'name' options=options}}", llm=llm)
out = program(options=["Yes", "No"])
assert out["name"] in ["Yes", "No"]

0 comments on commit b1e656f

Please sign in to comment.