Skip to content

Commit

Permalink
Created LLM chain using langchain
Browse files Browse the repository at this point in the history
  • Loading branch information
mohitbansal964 committed Oct 12, 2024
1 parent f216cc8 commit c9d5cea
Show file tree
Hide file tree
Showing 17 changed files with 265 additions and 16 deletions.
10 changes: 9 additions & 1 deletion app/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os
from dotenv import find_dotenv, load_dotenv
from src.utils import generate_metadata
from src.chains import generate_chain
from src.services import CricbotService

# Load environment variables from a .env file
Expand All @@ -24,5 +26,11 @@
user_input = input("User: ")
if user_input.lower() == "exit":
break
response = cricbot_service.bot_response(user_input)
# Using langchain to sequence LLMs and Data fetching components
metadata = generate_metadata(user_input=user_input)
response = generate_chain(openai_api_key, metadata).invoke(metadata)

# Using custom service
# response = cricbot_service.bot_response(user_input)

print("Cricbot:", response)
1 change: 1 addition & 0 deletions app/src/chains/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cricbot_chain import generate_chain
23 changes: 23 additions & 0 deletions app/src/chains/cricbot_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from src.services import IntentHandlerService, ResponseGeneratorService, LiveMatchService, IntentIdentifierService
from src.utils import get_live_matches_as_string
from src.models import IntentDetails
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser

def generate_chain(openai_api_key: str, metadata: dict):
intent_identifier_service = IntentIdentifierService(openai_api_key)
response_generator_service = ResponseGeneratorService(openai_api_key)
intent_handler_service = IntentHandlerService()
json_parser = JsonOutputParser(pydantic_object=IntentDetails)
str_parser = StrOutputParser()
chain = (lambda x: {**metadata, "live_matches": get_live_matches_as_string(LiveMatchService().fetch_all_matches())}) \
| intent_identifier_service.get_chat_prompt_template(json_parser) \
| intent_identifier_service.llm \
| json_parser \
| intent_handler_service.get_addtional_data \
| (lambda data: {**metadata, **data}) \
| response_generator_service.get_prompt \
| response_generator_service.llm \
| str_parser
return chain


1 change: 1 addition & 0 deletions app/src/constants/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Constants:

# File names for system messages and prompts
INTENT_IDENTIFIER_SYS_MSG_FILE_NAME: str = "intent_identifier_system_message.txt"
INTENT_IDENTIFIER_SYS_MSG_FILE_NAME2: str = "intent_identifier_system_message2.txt"
LIVE_SCORE_RESPONSE_PROMPT: str = "live_score_response_prompt.txt"
ALL_LIVE_MATCHES_RESPONSE_PROMPT: str = "all_live_matches_response_prompt.txt"
FALLBACK_RESPONSE_PROMPT: str = "fallback_response_prompt.txt"
Expand Down
1 change: 1 addition & 0 deletions app/src/enums/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .intents import Intent
7 changes: 7 additions & 0 deletions app/src/enums/intents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum


class Intent(str, Enum):
live_matches = "live_matches"
live_score = "live_score"
fallback = "fallback"
3 changes: 2 additions & 1 deletion app/src/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .match_details import MatchDetails, TeamScoreDetails
from .match_details import MatchDetails, TeamScoreDetails
from .intent_details import IntentDetails
16 changes: 16 additions & 0 deletions app/src/models/intent_details.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field

from src.enums import Intent

class Entities(BaseModel):
series: Optional[str] = Field(None, description="Series of a cricket match")
team1: Optional[str] = Field(None, description="Name of team 1")
team2: Optional[str] = Field(None, description="Name of team 2")
reason: Optional[str] = Field(None, description="Reason why intent identification failed")
date: Optional[datetime] = Field(None, description="Date of the match")

class IntentDetails(BaseModel):
intent: Intent = Field(description="intent of the text message")
entities: Optional[Entities] = Field(default_factory=Entities, description="Entities to find in the text message")
65 changes: 65 additions & 0 deletions app/src/prompts/intent_identifier_system_message2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
Role:
You are an expert in classifying intent and identifying entities from a plain text.

Context:
We are building a chatbot about Cricket where you need to find intent and entities in the message.
Following are the list of live matches:
{live_matches}

Tasks:
- Identify the intent and entities in the given text. Possible intents and their corresponding entities are:
# 'live_matches': User is trying to find the list of all live matches. Try to identify the series name from the text based on above list: Entity to find is:
# 'series' - Series from above list [Optional]
# 'live_score': User is trying to find the live score of a cricket match between 2 teams. Check above list of live matches and identify the teams from the text. If you are not able to identify the teams, then return 'live_matches' intent. If teams are found, then return entities as:
* 'team1' - Cricket team 1 [Mandatory]
* 'team2' - Cricket team 2 [Mandatory]
# 'fallback': If text doesn't fit in any of the above intents, then return this intent. Entity to find is:
* 'reason' - output the reason because of which you are not able to indetify the intent in the given text.
- {format_instructions}
- Consider edge cases like multiple entities in a message or unclear intents and provide reasonable interpretations.
- Ensure all outputs are contextually accurate and specific to Cricket.

Example1:
Input: Get me live scores of cricket match between india and australia.
Output: {{
"intent": "live_score",
"entities": {{
"team1": "india",
"team2": "australia"
}}
}}

Example2:
Input: mumbai indians vs gujarat titans
Output: {{
"intent": "live_score",
"entities": {{
"team1": "mumbai indians",
"team2": "gujarat titans"
}}
}}

Example3:
Input: Show me live score of football match
Output: {{
"intent": "fallback"
"entities": {{
"reason": "Cannot show live score of a football match."
}}
}}

Example4:
Input: xyz vs abcd
Output: {{
"intent": "live_matches",
"entities": {{}}
}}

Example5:
Input: List all the matches of india vs bangladesh series.
Output: {{
"intent": "live_matches",
"entities": {{
"series": "india-vs-bangladesh"
}}
}}
6 changes: 5 additions & 1 deletion app/src/services/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from .cricbot_service import CricbotService
from .cricbot_service import CricbotService
from .intent_identifier_service import IntentIdentifierService
from .live_match_service import LiveMatchService
from .response_generator_service import ResponseGeneratorService
from .intent_handler_service import IntentHandlerService
4 changes: 2 additions & 2 deletions app/src/services/cricbot_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def bot_response(self, user_input: str) -> str:
The response generated by the bot.
"""
try:
live_matches = self.__live_match_service.fetch_all_live_matches()
live_matches = self.__live_match_service.fetch_all_matches()
intent_details = self.__intent_identifier_service.invoke(user_input, live_matches)
response = ""
match intent_details.get('intent'):
Expand Down Expand Up @@ -94,7 +94,7 @@ def __handle_live_matches_intent(self, user_input: str, intent_details: dict) ->
The generated response content for all live matches.
"""
entities = intent_details.get('entities', {})
live_matches = self.__live_match_service.fetch_all_live_matches()
live_matches = self.__live_match_service.fetch_all_matches()
series = entities.get('series', '')
if series:
live_matches_of_series = [match for match in live_matches if match.series_name.lower() == series.lower()]
Expand Down
74 changes: 74 additions & 0 deletions app/src/services/intent_handler_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from .live_match_service import LiveMatchService
from src.constants import Constants
from src.enums import Intent
from src.models import IntentDetails

class IntentHandlerService:

def __init__(self):
self.__live_match_service = LiveMatchService()

def get_addtional_data(self, data: dict) -> dict:
intent_details = IntentDetails(**data)
match intent_details.intent:
case Intent.live_matches:
additional_data = self.__get_current_matches_intent_data(intent_details)
case Intent.live_score:
additional_data = self.__get_live_score_intent_data(intent_details)
case _:
additional_data = self.__get_fallback_intent_data(intent_details)
return {**data, **additional_data}

def __get_current_matches_intent_data(self, intent_details: IntentDetails) -> dict:
entities = intent_details.entities
live_matches = self.__live_match_service.fetch_all_matches(entities.date)
series = entities.series
if series:
live_matches_of_series = [match for match in live_matches if match.series_name.lower() == series.lower()]
additional_data = {
"series": series,
"live_matches": live_matches_of_series
}
else:
additional_data = {
"live_matches": live_matches
}
return additional_data

def __get_live_score_intent_data(self, intent_details: IntentDetails) -> dict:
entities = intent_details.entities
match_score, live_matches = self.__live_match_service.fetch_live_score(
entities.team1,
entities.team2
)
if match_score is None and len(live_matches) > 0:
additional_data = {
"intent": Intent.live_matches,
"entities": {},
"live_matches": live_matches
}
elif match_score is None:
additional_data = {
"intent": Intent.fallback,
"entities": {
"reason": Constants.MATCHES_NOT_PRESENT_REASON
}
}
else:
additional_data = {
"match_score": match_score
}
return additional_data

def __get_fallback_intent_data(self, intent_details: IntentDetails) -> dict:
entities = intent_details.entities
if not entities.reason:
return {
"entities": {
"reason": Constants.REASON_NOT_PRESENT
}
}
return {}



28 changes: 25 additions & 3 deletions app/src/services/intent_identifier_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, List
from langchain_openai import ChatOpenAI
from langchain_core.messages import SystemMessage, BaseMessage, HumanMessage
from langchain_core.prompts import SystemMessagePromptTemplate
from langchain_core.prompts import SystemMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from src.models import MatchDetails
from src.constants import Constants
from src.utils import read_prompt_from_file, get_live_matches_as_string
Expand Down Expand Up @@ -40,7 +41,7 @@ def __init__(self, openai_api_key: str):
openai_api_key : str
The API key for accessing the OpenAI service.
"""
self.__llm_chain = ChatOpenAI(
self.llm = ChatOpenAI(
model=Constants.INTENT_IDENTIFIER_GPT_MODEL,
api_key=openai_api_key
)
Expand All @@ -62,8 +63,16 @@ def invoke(self, user_text: str, live_matches: List[MatchDetails]) -> Any:
The identified intent as a JSON object.
"""
messages = self.__get_llm_messages(user_text, live_matches)
output = self.__llm_chain.invoke(messages)
output = self.llm.invoke(messages)
return json.loads(output.content)

def get_chat_prompt_template(self, parser: JsonOutputParser):
return ChatPromptTemplate.from_messages(
[
self.__get_system_message_prompt_template_for_chain(parser),
self.__get_human_message_prompt_template_for_chain()
]
)

def __get_llm_messages(self, user_text: str, live_matches: List[MatchDetails]) -> List[BaseMessage]:
"""
Expand Down Expand Up @@ -104,6 +113,19 @@ def __get_system_message(self, live_matches: List[MatchDetails]) -> SystemMessag
template=read_prompt_from_file(Constants.INTENT_IDENTIFIER_SYS_MSG_FILE_NAME)
)
return system_msg_template.format(live_matches=get_live_matches_as_string(live_matches))

def __get_system_message_prompt_template_for_chain(self, parser: JsonOutputParser):
return SystemMessagePromptTemplate.from_template(
template=read_prompt_from_file(Constants.INTENT_IDENTIFIER_SYS_MSG_FILE_NAME2),
input_variables=["live_matches"],
partial_variables={"format_instructions": parser.get_format_instructions()},
)

def __get_human_message_prompt_template_for_chain(self):
return HumanMessagePromptTemplate.from_template(
template="{user_input}",
input_variables=["user_input"],
)

def __get_human_message(self, user_text: str) -> HumanMessage:
"""
Expand Down
6 changes: 3 additions & 3 deletions app/src/services/live_match_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,10 @@ def fetch_live_score(self, team1: str, team2: str) -> Tuple[Optional[MatchDetail
A tuple containing the details of the match between the specified teams,
or None if not found, and a list of all live matches.
"""
live_matches = self.fetch_all_live_matches()
live_matches = self.fetch_all_matches()
return (self.__find_match(live_matches, team1, team2), live_matches)

def fetch_all_live_matches(self, date: Optional[str] = None) -> List[MatchDetails]:
def fetch_all_matches(self, date: Optional[datetime] = None) -> List[MatchDetails]:
"""
Retrieves all live matches from the external API for a given date.
Expand All @@ -70,7 +70,7 @@ def fetch_all_live_matches(self, date: Optional[str] = None) -> List[MatchDetail
List[MatchDetails]
A list of MatchDetails objects representing live matches.
"""
cur_date = date if date else datetime.today().strftime("%Y%m%d")
cur_date = (date if date else datetime.today()).strftime("%Y%m%d")
url = f"https://prod-public-api.livescore.com/v1/api/app/date/cricket/{cur_date}/5.30?locale=en&MD=1"
response = requests.get(url)
if response.ok:
Expand Down
Loading

0 comments on commit c9d5cea

Please sign in to comment.