Skip to content

Commit

Permalink
LM Requests Wrapper (langchain-ai#3457)
Browse files Browse the repository at this point in the history
Co-authored-by: jnmarti <[email protected]>
  • Loading branch information
vowelparrot and jnmarti authored Apr 24, 2023
1 parent b64c86a commit d06d47b
Showing 1 changed file with 30 additions and 9 deletions.
39 changes: 30 additions & 9 deletions langchain/agents/agent_toolkits/openapi/planner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Agent that interacts with OpenAPI APIs via a hierarchical planning approach."""
import json
import re
from typing import List, Optional
from functools import partial
from typing import Callable, List, Optional

import yaml
from pydantic import Field

from langchain.agents.agent import AgentExecutor
from langchain.agents.agent_toolkits.openapi.planner_prompt import (
Expand All @@ -30,6 +32,7 @@
from langchain.llms.openai import OpenAI
from langchain.memory import ReadOnlySharedMemory
from langchain.prompts import PromptTemplate
from langchain.prompts.base import BasePromptTemplate
from langchain.requests import RequestsWrapper
from langchain.schema import BaseLanguageModel
from langchain.tools.base import BaseTool
Expand All @@ -44,13 +47,26 @@
MAX_RESPONSE_LENGTH = 5000


def _get_default_llm_chain(prompt: BasePromptTemplate) -> LLMChain:
return LLMChain(
llm=OpenAI(),
prompt=prompt,
)


def _get_default_llm_chain_factory(
prompt: BasePromptTemplate,
) -> Callable[[], LLMChain]:
"""Returns a default LLMChain factory."""
return partial(_get_default_llm_chain, prompt)


class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
name = "requests_get"
description = REQUESTS_GET_TOOL_DESCRIPTION
response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_GET_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_GET_PROMPT)
)

def _run(self, text: str) -> str:
Expand All @@ -74,9 +90,8 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
description = REQUESTS_POST_TOOL_DESCRIPTION

response_length: Optional[int] = MAX_RESPONSE_LENGTH
llm_chain = LLMChain(
llm=OpenAI(),
prompt=PARSING_POST_PROMPT,
llm_chain: LLMChain = Field(
default_factory=_get_default_llm_chain_factory(PARSING_POST_PROMPT)
)

def _run(self, text: str) -> str:
Expand Down Expand Up @@ -173,9 +188,15 @@ def _create_api_controller_agent(
requests_wrapper: RequestsWrapper,
llm: BaseLanguageModel,
) -> AgentExecutor:
get_llm_chain = LLMChain(llm=llm, prompt=PARSING_GET_PROMPT)
post_llm_chain = LLMChain(llm=llm, prompt=PARSING_POST_PROMPT)
tools: List[BaseTool] = [
RequestsGetToolWithParsing(requests_wrapper=requests_wrapper),
RequestsPostToolWithParsing(requests_wrapper=requests_wrapper),
RequestsGetToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=get_llm_chain
),
RequestsPostToolWithParsing(
requests_wrapper=requests_wrapper, llm_chain=post_llm_chain
),
]
prompt = PromptTemplate(
template=API_CONTROLLER_PROMPT,
Expand Down

0 comments on commit d06d47b

Please sign in to comment.