Skip to content

Commit

Permalink
update langchain calls in stray.llm
Browse files Browse the repository at this point in the history
  • Loading branch information
pieroit committed Jul 25, 2024
1 parent ccf88a1 commit 1d0f057
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 21 deletions.
2 changes: 1 addition & 1 deletion core/cat/agents/procedures_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def generate_examples(self, allowed_procedures):
example_json = f"""
{{
"action": "{proc.name}",
"action_input": // Input of the action according to its description
"action_input": "...input here..."
}}"""
list_examples += f"\nQuestion: {random.choice(proc.start_examples)}"
list_examples += f"\n```json\n{example_json}\n```"
Expand Down
12 changes: 1 addition & 11 deletions core/cat/experimental/form/cat_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
from typing import List, Dict
from pydantic import BaseModel, ValidationError

from langchain.chains import LLMChain
from langchain_core.prompts.prompt import PromptTemplate

from cat.utils import parse_json
from cat.log import log

Expand Down Expand Up @@ -203,14 +200,7 @@ def extract(self):
prompt = self.extraction_prompt()
log.debug(prompt)

# Invoke LLM chain
extraction_chain = LLMChain(
prompt=PromptTemplate.from_template(prompt),
llm=self._cat._llm,
verbose=True,
output_key="output",
)
json_str = extraction_chain.invoke({})["output"] # {"stop": ["```"]}
json_str = self.cat.llm(prompt)

log.debug(f"Form JSON after parser:\n{json_str}")

Expand Down
2 changes: 1 addition & 1 deletion core/cat/looking_glass/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
```json
{{
"action": // str - The name of the action to take, should be one of [{tool_names}, "no_action"]
"action_input": // str or null - The input to the action according to it's description
"action_input": // str or null - The input to the action according to its description
}}
```
Expand Down
26 changes: 22 additions & 4 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage
from langchain_core.runnables import RunnableConfig, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers.string import StrOutputParser

from fastapi import WebSocket

Expand Down Expand Up @@ -291,13 +294,28 @@ def llm(self, prompt: str, stream: bool = False) -> str:
# Add a token counter to the callbacks
callbacks.append(ModelInteractionHandler(self, self.__class__.__name__))


# TODO: add here optional convo history passed to the method, or taken from working memory
messages=[
HumanMessage(content=prompt)
]

# Check if self._llm is a completion model and generate a response
if isinstance(self._llm, BaseLLM):
return self._llm(prompt, callbacks=callbacks)
# TODOV2: do not support non-chat models
#if isinstance(self._llm, BaseLLM):
# log.critical("LLM")
# return self._llm.invoke(
# prompt,
# config=RunnableConfig(callbacks=callbacks)
# )

# Check if self._llm is a chat model and call it as a completion model
if isinstance(self._llm, BaseChatModel):
return self._llm.call_as_llm(prompt, callbacks=callbacks)
if True:#isinstance(self._llm, BaseChatModel):
log.critical("CHAT LLM")
return self._llm.invoke(
messages,
config=RunnableConfig(callbacks=callbacks)
).content # returns AIMessage

async def __call__(self, message_dict):
"""Call the Cat instance.
Expand Down
16 changes: 12 additions & 4 deletions core/cat/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,22 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict:
# instantiate parser
parser = JsonOutputParser(pydantic_object=pydantic_model)

# clean escapes (small LLM error)
json_string_clean = json_string.replace("\_", "_").replace("\-", "-").replace("None", "null")
# clean to help small LLMs
replaces = {
"\_": "_",
"\-": "-",
"None": "null",
"{{": "{",
"}}": "}",
}
for k, v in replaces.items():
json_string = json_string.replace(k, v)

# first "{" occurrence (required by parser)
start_index = json_string_clean.index("{")
start_index = json_string.index("{")

# parse
parsed = parser.parse(json_string_clean[start_index:])
parsed = parser.parse(json_string[start_index:])

if pydantic_model:
return pydantic_model(**parsed)
Expand Down

0 comments on commit 1d0f057

Please sign in to comment.