Skip to content

Commit

Permalink
Add input types to cypher templates (langchain-ai#12800)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasonjo authored Nov 2, 2023
1 parent c4fdf78 commit 2a9f40e
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
15 changes: 10 additions & 5 deletions templates/neo4j-cypher-ft/neo4j_cypher_ft/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel, Field
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

try:
from pydantic.v1.main import BaseModel, Field
except ImportError:
from pydantic.main import BaseModel, Field

# Connection to Neo4j
graph = Neo4jGraph()

Expand Down Expand Up @@ -127,3 +123,12 @@ def map_to_database(entities: Entities) -> Optional[str]:
| qa_llm
| StrOutputParser()
)

# Add typing for input


class Question(BaseModel):
question: str


chain = chain.with_types(input_type=Question)
10 changes: 10 additions & 0 deletions templates/neo4j-cypher/neo4j_cypher/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from langchain.chat_models import ChatOpenAI
from langchain.graphs import Neo4jGraph
from langchain.prompts import ChatPromptTemplate
from langchain.pydantic_v1 import BaseModel
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnablePassthrough

Expand Down Expand Up @@ -71,3 +72,12 @@
| qa_llm
| StrOutputParser()
)

# Add typing for input


class Question(BaseModel):
question: str


chain = chain.with_types(input_type=Question)

0 comments on commit 2a9f40e

Please sign in to comment.