Skip to content

Commit

Permalink
Feature/bot-7: Add chat history support to the financial bot (iusztin…
Browse files Browse the repository at this point in the history
…paul#13)

* feat: Add Makefile & GitHub Actions flow

* fix: Typo

* chore: Fix .env.example fileS

* fix: Financial bot Makefile

* feat: Adapt finbot Makefile. Fix finbot context bug.

* fix: Sharing keys between chains

* Add chat history support to the financial bot

* chore: Add TODOs & Fix linting issues
  • Loading branch information
iusztinpaul authored Oct 12, 2023
1 parent 809d652 commit 60afa0d
Show file tree
Hide file tree
Showing 16 changed files with 229 additions and 107 deletions.
38 changes: 38 additions & 0 deletions .github/workflows/pep8_financial_bot.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
name: pep8_financial_bot

on: [push]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
pep8-python:
name: PEP8
runs-on: ubuntu-latest

steps:
- name: Checkout
uses: actions/checkout@v3

- name: Setup Python
uses: actions/setup-python@v3
with:
python-version: '3.10'

- name: Install poetry
uses: abatilo/actions-poetry@v2
with:
poetry-version: 1.5.1

- name: Install packages
working-directory: ./modules/financial_bot
run: make install_only_dev

- name: Run PEP8 linter
working-directory: ./modules/financial_bot
run: make lint_check

- name: Run PEP8 format checker
working-directory: ./modules/financial_bot
run: make format_check
9 changes: 9 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [

{
"name": "Python: Current File",
"type": "python",
Expand Down Expand Up @@ -118,5 +119,13 @@
"tools.run_batch:build_flow(latest_n_days=2, debug=True)"
]
},
{
"name": "Financial Bot [Dev]",
"type": "python",
"request": "launch",
"module": "tools.run_chain",
"justMyCode": false,
"cwd": "${workspaceFolder}/modules/financial_bot",
},
]
}
File renamed without changes.
52 changes: 52 additions & 0 deletions modules/financial_bot/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
### Install ###

install:
@echo "Installing financial bot..."

poetry env use $(shell which python3.10) && \
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring poetry install && \
poetry run pip install torch==2.0.1

install_dev: install
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring poetry install --only dev

install_only_dev:
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring poetry install --only dev

add:
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring poetry add $(package)

add_dev:
PYTHON_KEYRING_BACKEND=keyring.backends.null.Keyring poetry add --group dev $(package)


# === Run ===

run:
@echo "Running financial_bot..."

poetry run python -m tools.run_chain


### PEP 8 ###
# Be sure to install the dev dependencies first #

lint_check:
@echo "Checking for linting issues..."

poetry run ruff check .

lint_fix:
@echo "Fixing linting issues..."

poetry run ruff check --fix .

format_check:
@echo "Checking for formatting issues..."

poetry run black --check .

format_fix:
@echo "Formatting code..."

poetry run black .
48 changes: 22 additions & 26 deletions modules/financial_bot/financial_bot/chains.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Dict, List

import qdrant_client
from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.template import PromptTemplate
from langchain.chains.base import Chain
from langchain.llms import HuggingFacePipeline

from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.template import PromptTemplate


class ContextExtractorChain(Chain):
"""
Expand All @@ -17,68 +18,63 @@ class ContextExtractorChain(Chain):
embedding_model: EmbeddingModelSingleton
vector_store: qdrant_client.QdrantClient
vector_collection: str
output_key: str = "payload"
output_key: str = "context"

@property
def input_keys(self) -> List[str]:
return ["about_me", "question", "context"]
return ["about_me", "question"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# TODO: handle that None, without the need to enter chain
about_key, quest_key, contx_key = self.input_keys
question_str = inputs.get(quest_key, None)
_, quest_key = self.input_keys
question_str = inputs[quest_key]

# TODO: maybe async embed?
embeddings = self.embedding_model(question_str)

# TODO: get rid of hardcoded collection_name, specify 1 top_k or adjust multiple context insertions
# TODO: Using the metadata filter the news from the latest week (or other timeline).
matches = self.vector_store.search(
query_vector=embeddings,
k=self.top_k,
collection_name=self.vector_collection,
)

content = ""
context = ""
for match in matches:
content += match.payload["summary"] + "\n"
context += match.payload["summary"] + "\n"

payload = {
about_key: inputs[about_key],
quest_key: inputs[quest_key],
contx_key: content,
return {
self.output_key: context,
}

return {self.output_key: payload}


class FinancialBotQAChain(Chain):
"""This custom chain handles LLM generation upon given prompt"""

hf_pipeline: HuggingFacePipeline
template: PromptTemplate
output_key: str = "response"
output_key: str = "answer"

@property
def input_keys(self) -> List[str]:
return ["payload"]
return ["context"]

@property
def output_keys(self) -> List[str]:
return [self.output_key]

def _call(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
# TODO: use .get and treat default value?
about_me = inputs["about_me"]
question = inputs["question"]
context = inputs["context"]

prompt = self.template.infer_raw_template.format(
user_context=about_me, news_context=context, question=question
)
prompt = self.template.format_infer(
{
"user_context": inputs["about_me"],
"news_context": inputs["question"],
"chat_history": inputs["chat_history"],
"question": inputs.get("context"),
}
)["prompt"]
response = self.hf_pipeline(prompt)

return {self.output_key: response}
3 changes: 3 additions & 0 deletions modules/financial_bot/financial_bot/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@
# == Prompt Template ==
TEMPLATE_NAME = "falcon"
SYSTEM_MESSAGE = "You are a financial expert. Based on the context I provide, respond in a helpful manner"

# === Misc ===
DEBUG = True
1 change: 1 addition & 0 deletions modules/financial_bot/financial_bot/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(

self._tokenizer = AutoTokenizer.from_pretrained(model_id)
self._model = AutoModel.from_pretrained(model_id).to(self._device)
self._model.eval()

@property
def max_input_length(self) -> int:
Expand Down
50 changes: 33 additions & 17 deletions modules/financial_bot/financial_bot/langchain_bot.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
import logging

from langchain import chains
from langchain.memory import ConversationBufferMemory

from financial_bot import constants
from financial_bot.chains import ContextExtractorChain, FinancialBotQAChain
from financial_bot.embeddings import EmbeddingModelSingleton
from financial_bot.models import build_huggingface_pipeline
from financial_bot.qdrant import build_qdrant_client
from financial_bot.template import get_llm_template
from langchain import chains

logger = logging.getLogger(__name__)


class FinancialBot:
def __init__(self):
def __init__(
self,
llm_model_id: str = constants.LLM_MODEL_ID,
llm_lora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
debug: bool = constants.DEBUG,
):
self._qdrant_client = build_qdrant_client()
self._embd_model = EmbeddingModelSingleton()
self._llm_agent = build_huggingface_pipeline()
self._llm_agent = build_huggingface_pipeline(
llm_model_id=llm_model_id, llm_lora_model_id=llm_lora_model_id, debug=debug
)
self.finbot_chain = self.build_chain()

def build_chain(self) -> chains.SequentialChain:
Expand Down Expand Up @@ -43,7 +52,9 @@ def build_chain(self) -> chains.SequentialChain:
Notes
-----
The actual processing flow within the chain can be visualized as:
[about: str][question: str] > ContextChain > [about: str][question:str] + [context: str] > FinancialChain > LLM Response
[about: str][question: str] > ContextChain >
[about: str][question:str] + [context: str] > FinancialChain >
[answer: str]
"""

logger.info("Building 1/3 - ContextExtractorChain")
Expand All @@ -60,20 +71,28 @@ def build_chain(self) -> chains.SequentialChain:
template=get_llm_template(name=constants.TEMPLATE_NAME),
)

logger.info("Connecting chains into SequentialChain")
logger.info("Building 3/3 - Connecting chains into SequentialChain")
# TODO: Change memory to keep TOP k messages or a summary of the conversation.
seq_chain = chains.SequentialChain(
memory=ConversationBufferMemory(
memory_key="chat_history", input_key="question"
),
chains=[context_retrieval_chain, llm_generator_chain],
input_variables=["about_me", "question"],
output_variables=["response"],
output_variables=["answer"],
verbose=True,
)

logger.info("Done building SequentialChain.")
logger.info("Workflow:")
logger.info(
"> [about: str][question: str])\
>>> ContextChain > [about: str] + [[question :str] -> VectorDB -> TopK -> + [context: str]] > [about: str][question: str][context: str]\
>>> FinancialChain > LLM Response"
"""
[about: str][question: str] > ContextChain >
[about: str][question:str] + [context: str] > FinancialChain >
[answer: str]
"""
)

return seq_chain

def answer(self, about_me: str, question: str) -> str:
Expand All @@ -93,11 +112,8 @@ def answer(self, about_me: str, question: str) -> str:
str
LLM generated response.
"""
try:
inputs = {"about_me": about_me, "question": question}
response = self.finbot_chain.run(inputs)
return response
except KeyError as e:
logger.error(f"Caught key error {e}")
except Exception as e:
logger.error(f"Caught {e}")

inputs = {"about_me": about_me, "question": question}
response = self.finbot_chain.run(inputs)

return response
28 changes: 25 additions & 3 deletions modules/financial_bot/financial_bot/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import os
from pathlib import Path
from typing import Optional, Tuple
from financial_bot.utils import MockedPipeline

import torch
from comet_ml import API
from financial_bot import constants
from langchain.llms import HuggingFacePipeline
from peft import LoraConfig, PeftConfig, PeftModel
from transformers import (
Expand All @@ -15,6 +15,8 @@
pipeline,
)

from financial_bot import constants

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -44,13 +46,33 @@ def download_from_model_registry(model_id: str, cache_dir: Optional[Path] = None
return model_dir


def build_huggingface_pipeline():
def build_huggingface_pipeline(
llm_model_id: str,
llm_lora_model_id: str,
gradient_checkpointing: bool = False,
cache_dir: Optional[Path] = None,
debug: bool = False,
):
"""Using our custom LLM + Finetuned checkpoint we create a HF pipeline"""
model, tokenizer, _ = build_qlora_model()

if debug is True:
return HuggingFacePipeline(
pipeline=MockedPipeline(f=lambda _: "You are doing great!")
)

model, tokenizer, _ = build_qlora_model(
pretrained_model_name_or_path=llm_model_id,
peft_pretrained_model_name_or_path=llm_lora_model_id,
gradient_checkpointing=gradient_checkpointing,
cache_dir=cache_dir,
)
model.eval()

pipe = pipeline(
"text-generation", model=model, tokenizer=tokenizer, max_new_tokens=100
)
hf = HuggingFacePipeline(pipeline=pipe)

return hf


Expand Down
Loading

0 comments on commit 60afa0d

Please sign in to comment.