Skip to content

Commit

Permalink
add custom prompt for SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
teddylee777 committed Oct 24, 2024
1 parent f74c737 commit 218aba0
Showing 1 changed file with 119 additions and 15 deletions.
134 changes: 119 additions & 15 deletions 14-Chains/02-SQL.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 1,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# API KEY를 환경변수로 관리하기 위한 설정 파일\n",
"from dotenv import load_dotenv\n",
Expand All @@ -24,9 +35,19 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LangSmith 추적을 시작합니다.\n",
"[프로젝트명]\n",
"SQL\n"
]
}
],
"source": [
"# LangSmith 추적을 설정합니다. https://smith.langchain.com\n",
"# !pip install langchain-teddynote\n",
Expand All @@ -45,9 +66,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 47,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sqlite\n",
"['accounts', 'customers', 'transactions']\n"
]
}
],
"source": [
"from langchain_openai import ChatOpenAI\n",
"from langchain.chains import create_sql_query_chain\n",
Expand All @@ -74,7 +104,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -89,14 +119,66 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"chain 을 실행하면 DB 기반으로 쿼리를 생성합니다."
"(옵션) 아래의 방식으로 Prompt 를 직접 지정할 수 있습니다.\n",
"\n",
"직접 작성시 table_info 와 더불어 설명가능한 column description 을 추가할 수 있습니다."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.prompts import PromptTemplate\n",
"\n",
"prompt = PromptTemplate.from_template(\n",
" \"\"\"Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer. Unless the user specifies in his question a specific number of examples he wishes to obtain, always limit your query to at most {top_k} results. You can order the results by a relevant column to return the most interesting examples in the database.\n",
"Use the following format:\n",
"\n",
"Question: \"Question here\"\n",
"SQLQuery: \"SQL Query to run\"\n",
"SQLResult: \"Result of the SQLQuery\"\n",
"Answer: \"Final answer here\"\n",
"\n",
"Only use the following tables:\n",
"{table_info}\n",
"\n",
"Here is the description of the columns in the tables:\n",
"`cust`: customer name\n",
"`prod`: product name\n",
"`trans`: transaction date\n",
"\n",
"Question: {input}\"\"\"\n",
").partial(dialect=db.dialect)\n",
"\n",
"# model 은 gpt-3.5-turbo 를 지정\n",
"llm = ChatOpenAI(model=\"gpt-3.5-turbo\", temperature=0)\n",
"\n",
"# LLM 과 DB 를 매개변수로 입력하여 chain 을 생성합니다.\n",
"chain = create_sql_query_chain(llm, db, prompt)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"chain 을 실행하면 DB 기반으로 쿼리를 생성합니다."
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"'SELECT name\\nFROM customers'\n"
]
}
],
"source": [
"# chain 을 실행하고 결과를 출력합니다.\n",
"generated_sql_query = chain.invoke({\"question\": \"고객의 이름을 나열하세요\"})\n",
Expand All @@ -114,7 +196,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 49,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -126,16 +208,27 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 50,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"\"[('테디',), ('폴',), ('셜리',), ('민수',), ('지영',), ('은정',)]\""
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"execute_query.invoke({\"query\": generated_sql_query})"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -153,9 +246,20 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 52,
"metadata": {},
"outputs": [],
"outputs": [
{
"data": {
"text/plain": [
"\"[('[email protected]',)]\""
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 실행 결과 확인\n",
"chain.invoke({\"question\": \"테디의 이메일을 조회하세요\"})"
Expand Down

0 comments on commit 218aba0

Please sign in to comment.