forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Spark SQL support (langchain-ai#4602) (langchain-ai#4956)
# Add Spark SQL support * Add Spark SQL support. It can connect to Spark via building a local/remote SparkSession. * Include a notebook example I tried some complicated queries (window function, table joins), and the tool works well. Compared to the [Spark Dataframe agent](https://python.langchain.com/en/latest/modules/agents/toolkits/examples/spark.html), this tool is able to generate queries across multiple tables. --------- # Your PR Title (What it does) <!-- Thank you for contributing to LangChain! Your PR will appear in our next release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting <!-- If you're adding a new integration, include an integration test and an example notebook showing its use! --> ## Who can review? Community members can review the PR once tests pass. Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 --> --------- Co-authored-by: Gengliang Wang <[email protected]> Co-authored-by: Mike W <[email protected]> Co-authored-by: Eugene Yurtsev <[email protected]> Co-authored-by: UmerHA <[email protected]> Co-authored-by: 张城铭 <[email protected]> Co-authored-by: assert <[email protected]> Co-authored-by: blob42 <spike@w530> Co-authored-by: Yuekai Zhang <[email protected]> Co-authored-by: Richard He <[email protected]> Co-authored-by: Dev 2049 <[email protected]> Co-authored-by: Leonid Ganeline <[email protected]> Co-authored-by: Alexey Nominas <[email protected]> Co-authored-by: elBarkey <[email protected]> Co-authored-by: Davis Chase <[email protected]> Co-authored-by: Jeffrey D <[email protected]> Co-authored-by: so2liu <[email protected]> Co-authored-by: Viswanadh Rayavarapu <[email protected]> Co-authored-by: Chakib Ben Ziane <[email protected]> Co-authored-by: Daniel Chalef <[email protected]> Co-authored-by: Daniel Chalef <[email protected]> Co-authored-by: Jari Bakken <[email protected]> Co-authored-by: escafati <[email protected]>
- Loading branch information
1 parent
5feb60f
commit 88a3a56
Showing
13 changed files
with
812 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Spark SQL agent.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
"""Spark SQL agent.""" | ||
from typing import Any, Dict, List, Optional | ||
|
||
from langchain.agents.agent import AgentExecutor | ||
from langchain.agents.agent_toolkits.spark_sql.prompt import SQL_PREFIX, SQL_SUFFIX | ||
from langchain.agents.agent_toolkits.spark_sql.toolkit import SparkSQLToolkit | ||
from langchain.agents.mrkl.base import ZeroShotAgent | ||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS | ||
from langchain.base_language import BaseLanguageModel | ||
from langchain.callbacks.base import BaseCallbackManager | ||
from langchain.chains.llm import LLMChain | ||
|
||
|
||
def create_spark_sql_agent( | ||
llm: BaseLanguageModel, | ||
toolkit: SparkSQLToolkit, | ||
callback_manager: Optional[BaseCallbackManager] = None, | ||
prefix: str = SQL_PREFIX, | ||
suffix: str = SQL_SUFFIX, | ||
format_instructions: str = FORMAT_INSTRUCTIONS, | ||
input_variables: Optional[List[str]] = None, | ||
top_k: int = 10, | ||
max_iterations: Optional[int] = 15, | ||
max_execution_time: Optional[float] = None, | ||
early_stopping_method: str = "force", | ||
verbose: bool = False, | ||
agent_executor_kwargs: Optional[Dict[str, Any]] = None, | ||
**kwargs: Dict[str, Any], | ||
) -> AgentExecutor: | ||
"""Construct a sql agent from an LLM and tools.""" | ||
tools = toolkit.get_tools() | ||
prefix = prefix.format(top_k=top_k) | ||
prompt = ZeroShotAgent.create_prompt( | ||
tools, | ||
prefix=prefix, | ||
suffix=suffix, | ||
format_instructions=format_instructions, | ||
input_variables=input_variables, | ||
) | ||
llm_chain = LLMChain( | ||
llm=llm, | ||
prompt=prompt, | ||
callback_manager=callback_manager, | ||
) | ||
tool_names = [tool.name for tool in tools] | ||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) | ||
return AgentExecutor.from_agent_and_tools( | ||
agent=agent, | ||
tools=tools, | ||
callback_manager=callback_manager, | ||
verbose=verbose, | ||
max_iterations=max_iterations, | ||
max_execution_time=max_execution_time, | ||
early_stopping_method=early_stopping_method, | ||
**(agent_executor_kwargs or {}), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# flake8: noqa | ||
|
||
SQL_PREFIX = """You are an agent designed to interact with Spark SQL. | ||
Given an input question, create a syntactically correct Spark SQL query to run, then look at the results of the query and return the answer. | ||
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results. | ||
You can order the results by a relevant column to return the most interesting examples in the database. | ||
Never query for all the columns from a specific table, only ask for the relevant columns given the question. | ||
You have access to tools for interacting with the database. | ||
Only use the below tools. Only use the information returned by the below tools to construct your final answer. | ||
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again. | ||
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. | ||
If the question does not seem related to the database, just return "I don't know" as the answer. | ||
""" | ||
|
||
SQL_SUFFIX = """Begin! | ||
Question: {input} | ||
Thought: I should look at the tables in the database to see what I can query. | ||
{agent_scratchpad}""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
"""Toolkit for interacting with Spark SQL.""" | ||
from typing import List | ||
|
||
from pydantic import Field | ||
|
||
from langchain.agents.agent_toolkits.base import BaseToolkit | ||
from langchain.base_language import BaseLanguageModel | ||
from langchain.tools import BaseTool | ||
from langchain.tools.spark_sql.tool import ( | ||
InfoSparkSQLTool, | ||
ListSparkSQLTool, | ||
QueryCheckerTool, | ||
QuerySparkSQLTool, | ||
) | ||
from langchain.utilities.spark_sql import SparkSQL | ||
|
||
|
||
class SparkSQLToolkit(BaseToolkit): | ||
"""Toolkit for interacting with Spark SQL.""" | ||
|
||
db: SparkSQL = Field(exclude=True) | ||
llm: BaseLanguageModel = Field(exclude=True) | ||
|
||
class Config: | ||
"""Configuration for this pydantic object.""" | ||
|
||
arbitrary_types_allowed = True | ||
|
||
def get_tools(self) -> List[BaseTool]: | ||
"""Get the tools in the toolkit.""" | ||
return [ | ||
QuerySparkSQLTool(db=self.db), | ||
InfoSparkSQLTool(db=self.db), | ||
ListSparkSQLTool(db=self.db), | ||
QueryCheckerTool(db=self.db, llm=self.llm), | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tools for interacting with Spark SQL.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# flake8: noqa | ||
QUERY_CHECKER = """ | ||
{query} | ||
Double check the Spark SQL query above for common mistakes, including: | ||
- Using NOT IN with NULL values | ||
- Using UNION when UNION ALL should have been used | ||
- Using BETWEEN for exclusive ranges | ||
- Data type mismatch in predicates | ||
- Properly quoting identifiers | ||
- Using the correct number of arguments for functions | ||
- Casting to the correct data type | ||
- Using the proper columns for joins | ||
If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
# flake8: noqa | ||
"""Tools for interacting with Spark SQL.""" | ||
from typing import Any, Dict, Optional | ||
|
||
from pydantic import BaseModel, Extra, Field, root_validator | ||
|
||
from langchain.base_language import BaseLanguageModel | ||
from langchain.callbacks.manager import ( | ||
AsyncCallbackManagerForToolRun, | ||
CallbackManagerForToolRun, | ||
) | ||
from langchain.chains.llm import LLMChain | ||
from langchain.prompts import PromptTemplate | ||
from langchain.utilities.spark_sql import SparkSQL | ||
from langchain.tools.base import BaseTool | ||
from langchain.tools.spark_sql.prompt import QUERY_CHECKER | ||
|
||
|
||
class BaseSparkSQLTool(BaseModel): | ||
"""Base tool for interacting with Spark SQL.""" | ||
|
||
db: SparkSQL = Field(exclude=True) | ||
|
||
# Override BaseTool.Config to appease mypy | ||
# See https://github.com/pydantic/pydantic/issues/4173 | ||
class Config(BaseTool.Config): | ||
"""Configuration for this pydantic object.""" | ||
|
||
arbitrary_types_allowed = True | ||
extra = Extra.forbid | ||
|
||
|
||
class QuerySparkSQLTool(BaseSparkSQLTool, BaseTool): | ||
"""Tool for querying a Spark SQL.""" | ||
|
||
name = "query_sql_db" | ||
description = """ | ||
Input to this tool is a detailed and correct SQL query, output is a result from the Spark SQL. | ||
If the query is not correct, an error message will be returned. | ||
If an error is returned, rewrite the query, check the query, and try again. | ||
""" | ||
|
||
def _run( | ||
self, | ||
query: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Execute the query, return the results or an error message.""" | ||
return self.db.run_no_throw(query) | ||
|
||
async def _arun( | ||
self, | ||
query: str, | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> str: | ||
raise NotImplementedError("QuerySqlDbTool does not support async") | ||
|
||
|
||
class InfoSparkSQLTool(BaseSparkSQLTool, BaseTool): | ||
"""Tool for getting metadata about a Spark SQL.""" | ||
|
||
name = "schema_sql_db" | ||
description = """ | ||
Input to this tool is a comma-separated list of tables, output is the schema and sample rows for those tables. | ||
Be sure that the tables actually exist by calling list_tables_sql_db first! | ||
Example Input: "table1, table2, table3" | ||
""" | ||
|
||
def _run( | ||
self, | ||
table_names: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Get the schema for tables in a comma-separated list.""" | ||
return self.db.get_table_info_no_throw(table_names.split(", ")) | ||
|
||
async def _arun( | ||
self, | ||
table_name: str, | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> str: | ||
raise NotImplementedError("SchemaSqlDbTool does not support async") | ||
|
||
|
||
class ListSparkSQLTool(BaseSparkSQLTool, BaseTool): | ||
"""Tool for getting tables names.""" | ||
|
||
name = "list_tables_sql_db" | ||
description = "Input is an empty string, output is a comma separated list of tables in the Spark SQL." | ||
|
||
def _run( | ||
self, | ||
tool_input: str = "", | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Get the schema for a specific table.""" | ||
return ", ".join(self.db.get_usable_table_names()) | ||
|
||
async def _arun( | ||
self, | ||
tool_input: str = "", | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> str: | ||
raise NotImplementedError("ListTablesSqlDbTool does not support async") | ||
|
||
|
||
class QueryCheckerTool(BaseSparkSQLTool, BaseTool): | ||
"""Use an LLM to check if a query is correct. | ||
Adapted from https://www.patterns.app/blog/2023/01/18/crunchbot-sql-analyst-gpt/""" | ||
|
||
template: str = QUERY_CHECKER | ||
llm: BaseLanguageModel | ||
llm_chain: LLMChain = Field(init=False) | ||
name = "query_checker_sql_db" | ||
description = """ | ||
Use this tool to double check if your query is correct before executing it. | ||
Always use this tool before executing a query with query_sql_db! | ||
""" | ||
|
||
@root_validator(pre=True) | ||
def initialize_llm_chain(cls, values: Dict[str, Any]) -> Dict[str, Any]: | ||
if "llm_chain" not in values: | ||
values["llm_chain"] = LLMChain( | ||
llm=values.get("llm"), | ||
prompt=PromptTemplate( | ||
template=QUERY_CHECKER, input_variables=["query"] | ||
), | ||
) | ||
|
||
if values["llm_chain"].prompt.input_variables != ["query"]: | ||
raise ValueError( | ||
"LLM chain for QueryCheckerTool need to use ['query'] as input_variables " | ||
"for the embedded prompt" | ||
) | ||
|
||
return values | ||
|
||
def _run( | ||
self, | ||
query: str, | ||
run_manager: Optional[CallbackManagerForToolRun] = None, | ||
) -> str: | ||
"""Use the LLM to check the query.""" | ||
return self.llm_chain.predict(query=query) | ||
|
||
async def _arun( | ||
self, | ||
query: str, | ||
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | ||
) -> str: | ||
return await self.llm_chain.apredict(query=query) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.