diff --git a/pandas_rag_langgraph/agent.py b/pandas_rag_langgraph/agent.py index f85f099..d96680f 100644 --- a/pandas_rag_langgraph/agent.py +++ b/pandas_rag_langgraph/agent.py @@ -67,7 +67,7 @@ def get_retriever() -> BaseRetriever: # LLM / Retriever / Tools llm = ChatAnthropic(model="claude-3-5-sonnet-20240620", temperature=0) retriever = get_retriever() -tavily_search_tool = TavilySearchResults(max_results=1) +tavily_search_tool = TavilySearchResults(max_results=3) # Prompts / data models @@ -146,6 +146,10 @@ class GraphState(TypedDict): web_fallback: bool +class GraphConfig(TypedDict): + max_retries: int + + def document_search(state: GraphState): """ Retrieve documents @@ -275,7 +279,7 @@ def finalize_response(state: GraphState): # Define graph -workflow = StateGraph(GraphState) +workflow = StateGraph(GraphState, config_schema=GraphConfig) # Define the nodes workflow.add_node("document_search", document_search)