diff --git a/core/cat/agents/procedures_agent.py b/core/cat/agents/procedures_agent.py index 99726bc5f..588899d18 100644 --- a/core/cat/agents/procedures_agent.py +++ b/core/cat/agents/procedures_agent.py @@ -204,7 +204,7 @@ def generate_examples(self, allowed_procedures): example_json = f""" {{ "action": "{proc.name}", - "action_input": // Input of the action according to its description + "action_input": "...input here..." }}""" list_examples += f"\nQuestion: {random.choice(proc.start_examples)}" list_examples += f"\n```json\n{example_json}\n```" diff --git a/core/cat/experimental/form/cat_form.py b/core/cat/experimental/form/cat_form.py index 8904c1aeb..8c13061ef 100644 --- a/core/cat/experimental/form/cat_form.py +++ b/core/cat/experimental/form/cat_form.py @@ -3,9 +3,6 @@ from typing import List, Dict from pydantic import BaseModel, ValidationError -from langchain.chains import LLMChain -from langchain_core.prompts.prompt import PromptTemplate - from cat.utils import parse_json from cat.log import log @@ -203,14 +200,7 @@ def extract(self): prompt = self.extraction_prompt() log.debug(prompt) - # Invoke LLM chain - extraction_chain = LLMChain( - prompt=PromptTemplate.from_template(prompt), - llm=self._cat._llm, - verbose=True, - output_key="output", - ) - json_str = extraction_chain.invoke({})["output"] # {"stop": ["```"]} + json_str = self.cat.llm(prompt) log.debug(f"Form JSON after parser:\n{json_str}") diff --git a/core/cat/looking_glass/prompts.py b/core/cat/looking_glass/prompts.py index 6e95d2f3d..e8ed510b2 100644 --- a/core/cat/looking_glass/prompts.py +++ b/core/cat/looking_glass/prompts.py @@ -9,7 +9,7 @@ ```json {{ "action": // str - The name of the action to take, should be one of [{tool_names}, "no_action"] - "action_input": // str or null - The input to the action according to it's description + "action_input": // str or null - The input to the action according to its description }} ``` diff --git a/core/cat/looking_glass/stray_cat.py b/core/cat/looking_glass/stray_cat.py index d69013d59..0a21a2514 100644 --- a/core/cat/looking_glass/stray_cat.py +++ b/core/cat/looking_glass/stray_cat.py @@ -8,6 +8,9 @@ from langchain_core.language_models.chat_models import BaseChatModel from langchain_community.llms import BaseLLM from langchain_core.messages import AIMessage, HumanMessage, BaseMessage +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers.string import StrOutputParser from fastapi import WebSocket @@ -291,13 +294,28 @@ def llm(self, prompt: str, stream: bool = False) -> str: # Add a token counter to the callbacks callbacks.append(ModelInteractionHandler(self, self.__class__.__name__)) + + # TODO: add here optional convo history passed to the method, or taken from working memory + messages=[ + HumanMessage(content=prompt) + ] + # Check if self._llm is a completion model and generate a response - if isinstance(self._llm, BaseLLM): - return self._llm(prompt, callbacks=callbacks) + # TODOV2: do not support non-chat models + #if isinstance(self._llm, BaseLLM): + # log.critical("LLM") + # return self._llm.invoke( + # prompt, + # config=RunnableConfig(callbacks=callbacks) + # ) # Check if self._llm is a chat model and call it as a completion model - if isinstance(self._llm, BaseChatModel): - return self._llm.call_as_llm(prompt, callbacks=callbacks) + if True:#isinstance(self._llm, BaseChatModel): + log.critical("CHAT LLM") + return self._llm.invoke( + messages, + config=RunnableConfig(callbacks=callbacks) + ).content # returns AIMessage async def __call__(self, message_dict): """Call the Cat instance. diff --git a/core/cat/utils.py b/core/cat/utils.py index 5b3563641..7edf35d51 100644 --- a/core/cat/utils.py +++ b/core/cat/utils.py @@ -156,14 +156,22 @@ def parse_json(json_string: str, pydantic_model: BaseModel = None) -> dict: # instantiate parser parser = JsonOutputParser(pydantic_object=pydantic_model) - # clean escapes (small LLM error) - json_string_clean = json_string.replace("\_", "_").replace("\-", "-").replace("None", "null") + # clean to help small LLMs + replaces = { + "\_": "_", + "\-": "-", + "None": "null", + "{{": "{", + "}}": "}", + } + for k, v in replaces.items(): + json_string = json_string.replace(k, v) # first "{" occurrence (required by parser) - start_index = json_string_clean.index("{") + start_index = json_string.index("{") # parse - parsed = parser.parse(json_string_clean[start_index:]) + parsed = parser.parse(json_string[start_index:]) if pydantic_model: return pydantic_model(**parsed)