Skip to content

Commit

Permalink
CodeGen V0
Browse files Browse the repository at this point in the history
  • Loading branch information
ashwinaravind committed Jun 28, 2024
1 parent 140de36 commit 7bdc8e8
Show file tree
Hide file tree
Showing 15 changed files with 647 additions and 240 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
long_description = fh.read()
setup(
name='ragbuilder',
version='0.0.2',
version='0.0.3',
author='Ashwin Aravind, Aravind Parameswaran',
author_email='[email protected], [email protected]',
description='RagBuilder is a toolkit designed to help you create optimal Production-ready Retrieval-Augmented Generation (RAG) pipeline for your data',
Expand Down
2 changes: 1 addition & 1 deletion src/ragbuilder/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def write(self, data):
class RagEvaluator:
def __init__(
self,
rag,
rag, #Code for RAG function
test_dataset,
context_fn=None,
llm=None,
Expand Down
50 changes: 39 additions & 11 deletions src/ragbuilder/executor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from ragbuilder.rag_templates.top_n_templates import top_n_templates
from ragbuilder.rag_templates.langchain_templates import nuancedCombos
from ragbuilder.langchain_module.rag import mergerag as rag
# from ragbuilder.langchain_module.rag import mergerag as rag
from ragbuilder.langchain_module.rag import getCode as rag
# from ragbuilder.router import router
from ragbuilder.langchain_module.common import setup_logging
import logging
import json
Expand All @@ -14,13 +16,6 @@
dotenv_path = os.path.join(current_working_directory, '.env')
load_dotenv(dotenv_path)
logger = logging.getLogger("ragbuilder")
# pregenerated_cofigs = [

# #list if permuations of the chunking strategies,llm,...etc
# ]
# def exclude_filter():

# return granular_rag
#Load Sythetic Data
import pandas as pd
from datasets import Dataset
Expand All @@ -37,6 +32,22 @@
verbose=True
)
#####
from langchain_openai import ChatOpenAI
from langchain_community.document_loaders import *
from langchain_openai import OpenAIEmbeddings
from langchain.text_splitter import *
from langchain_chroma import *

import os
from operator import itemgetter
from langchain import hub
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableParallel, RunnableLambda
from langchain.retrievers import *
from langchain.retrievers.document_compressors import DocumentCompressorPipeline

#####


import time
def rag_builder(**kwargs):
Expand All @@ -59,7 +70,7 @@ def rag_builder(**kwargs):
run_config=RunConfig(timeout=RUN_CONFIG_TIMEOUT, max_workers=RUN_CONFIG_MAX_WORKERS, max_wait=RUN_CONFIG_MAX_WAIT, max_retries=RUN_CONFIG_MAX_RETRIES)
logger.info(f"{repr(run_config)}")
rageval=eval.RagEvaluator(
rag_builder,
rag_builder, # code for rag function
test_ds,
llm = chat_model,
embeddings = OpenAIEmbeddings(model="text-embedding-3-large"),
Expand All @@ -81,7 +92,7 @@ def rag_builder(**kwargs):
run_config=RunConfig(timeout=RUN_CONFIG_TIMEOUT, max_workers=RUN_CONFIG_MAX_WORKERS, max_wait=RUN_CONFIG_MAX_WAIT, max_retries=RUN_CONFIG_MAX_RETRIES)
logger.info(f"{repr(run_config)}")
rageval=eval.RagEvaluator(
rag_builder,
rag_builder, # rag function
test_ds,
llm=chat_model,
embeddings=OpenAIEmbeddings(model="text-embedding-3-large"),
Expand Down Expand Up @@ -109,7 +120,15 @@ def __init__(self, val):
self.retriever_kwargs=val['retriever_kwargs']
# self.prompt_text = val['prompt_text']
print(f"retrieval model: {self.retrieval_model}")
self.rag=rag.mergerag(

# self.router(Configs) # Calls appropriate code generator calls codeGen Within returns Code string
# namespace={}
# exec(rag_func_str, namespace) # executes code
# ragchain=namespace['ragchain'] catch the func object
# self.runCode=ragchain()

# output of router is genrated code as string
self.router=rag.codeGen(
framework=self.framework,
# description=self.description,
retrieval_model = self.retrieval_model,
Expand All @@ -120,6 +139,15 @@ def __init__(self, val):
embedding_kwargs=self.embedding_kwargs,
retriever_kwargs=self.retriever_kwargs
)
locals_dict={}
globals_dict = globals()

#execution os string
exec(self.router,globals_dict,locals_dict)
logger.info(f"Generated Code:\n{self.router}")

#old rag func hooked to eval
self.rag = locals_dict['rag_pipeline']()

# def __repr__(self):
# return (
Expand Down
117 changes: 55 additions & 62 deletions src/ragbuilder/langchain_module/chunkingstrategy/langchain_chunking.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import CharacterTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain.text_splitter import MarkdownHeaderTextSplitter
from langchain_text_splitters import HTMLHeaderTextSplitter
from ragbuilder.langchain_module.common import setup_logging
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain.text_splitter import CharacterTextSplitter
# from langchain_experimental.text_splitter import SemanticChunker
# from langchain.text_splitter import MarkdownHeaderTextSplitter
# from langchain_text_splitters import HTMLHeaderTextSplitter
from ragbuilder.langchain_module.common import setup_logging,codeGen
import logging

setup_logging()
logger = logging.getLogger("ragbuilder")

def getChunkingStrategy(**kwargs):
try:
strategy = kwargs.get('chunk_strategy')
strategy = kwargs.get('chunking_kwargs').get('chunk_strategy')
kwargs['chunk_size'] = kwargs.get('chunking_kwargs').get('chunk_size')
kwargs['chunk_overlap'] = kwargs.get('chunking_kwargs').get('chunk_overlap')
if not strategy:
raise ValueError("Missing chunking strategy in kwargs")

Expand All @@ -34,88 +36,79 @@ def getChunkingStrategy(**kwargs):
def getLangchainRecursiveCharacterTextSplitter(**kwargs):
try:
logger.info("RecursiveCharacterTextSplitter Invoked")
splitter = RecursiveCharacterTextSplitter(chunk_size=kwargs['chunk_size'], chunk_overlap=kwargs['chunk_overlap'])
retriever_type = kwargs.get('retriever_type')
if retriever_type in ["parentDocFullDoc", "parentDocLargeChunk" ]:
return splitter
else:
return splitter.split_documents(kwargs['docs'])
splitter_name=kwargs.get('splitter_name','splitter')
code_string = f"""
from langchain.text_splitter import RecursiveCharacterTextSplitter
{splitter_name} = RecursiveCharacterTextSplitter(chunk_size={kwargs['chunk_size']}, chunk_overlap={kwargs['chunk_overlap']})
splits={splitter_name}.split_documents(docs)"""
import_string = f"""from langchain.text_splitter import RecursiveCharacterTextSplitter"""
return {'code_string':code_string,'import_string':import_string}
except KeyError as e:
logger.error(f"Missing key in kwargs for RecursiveCharacterTextSplitter: {e}")
raise
except Exception as e:
logger.error(f"Error in RecursiveCharacterTextSplitter: {e}")
raise

def getLangchainCharacterTextSplitter(**kwargs):
try:
logger.info("CharacterTextSplitter Invoked")
splitter = CharacterTextSplitter(chunk_size=kwargs['chunk_size'], chunk_overlap=kwargs['chunk_overlap'])
retriever_type = kwargs.get('retriever_type')
if retriever_type in ["parentDocFullDoc", "parentDocLargeChunk" ]:
return splitter
else:
return splitter.split_documents(kwargs['docs'])
splitter_name=kwargs.get('splitter_name','splitter')
code_string = f"""
from langchain.text_splitter import CharacterTextSplitter
{splitter_name} = CharacterTextSplitter(chunk_size={kwargs['chunk_size']}, chunk_overlap={kwargs['chunk_overlap']})
splits={splitter_name}.split_documents(docs)"""
import_string = f"""from langchain.text_splitter import CharacterTextSplitter"""
return {'code_string':code_string,'import_string':import_string}
except KeyError as e:
logger.error(f"Missing key in kwargs for CharacterTextSplitter: {e}")
raise
except Exception as e:
logger.error(f"Error in CharacterTextSplitter: {e}")
raise

def getLangchainSemanticChunker(**kwargs):
try:
logger.info("SemanticChunker Invoked")
splitter = SemanticChunker(kwargs['embedding_model'], breakpoint_threshold_type=kwargs['breakpoint_threshold_type'])
retriever_type = kwargs.get('retriever_type')
if retriever_type in ["parentDocFullDoc", "parentDocLargeChunk" ]:
return splitter
else:
return splitter.create_documents(kwargs['docs'][0].page_content)
splitter_name=kwargs.get('splitter_name','splitter')
code_string = f"""
from langchain_experimental.text_splitter import SemanticChunker
{splitter_name} = SemanticChunker(embedding, breakpoint_threshold_type='{kwargs['chunking_kwargs']['breakpoint_threshold_type']}')
splits={splitter_name}.create_documents(docs[0].page_content)
"""
import_string = f"""from langchain_experimental.text_splitter import SemanticChunker"""
return {'code_string':code_string,'import_string':import_string}
except KeyError as e:
logger.error(f"Missing key in kwargs for SemanticChunker: {e}")
raise
except Exception as e:
logger.error(f"Error in SemanticChunker: {e}")
raise

def getMarkdownHeaderTextSplitter(**kwargs):
try:
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3")
]
splitter_name=kwargs.get('splitter_name','splitter')
logger.info("MarkdownHeaderTextSplitter Invoked")
splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
retriever_type = kwargs.get('retriever_type')
if retriever_type in ["parentDocFullDoc", "parentDocLargeChunk" ]:
return splitter
else:
return splitter.split_text(kwargs['docs'][0].page_content)
code_string = f"""
headers_to_split_on = [
("#", "Header 1"),
("##", "Header 2"),
("###", "Header 3")]
from langchain.text_splitter import MarkdownHeaderTextSplitter
{splitter_name} = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on, strip_headers=False)
splits={splitter_name}.split_text(docs[0].page_content)"""
import_string = f"""from langchain.text_splitter import MarkdownHeaderTextSplitter"""
return {'code_string':code_string,'import_string':import_string}
except KeyError as e:
logger.error(f"Missing key in kwargs for MarkdownHeaderTextSplitter: {e}")
raise
except Exception as e:
logger.error(f"Error in MarkdownHeaderTextSplitter: {e}")
raise

def getHTMLHeaderTextSplitter(**kwargs):
try:
logger.info("HTMLHeaderTextSplitter Invoked")
headers_to_split_on = [
("h1", "Header 1"),
("h2", "Header 2"),
("h3", "Header 3"),]
splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
retriever_type = kwargs.get('retriever_type')
if retriever_type in ["parentDocFullDoc", "parentDocLargeChunk" ]:
return splitter
else:
return splitter.split_text(kwargs['docs'][0].page_content)
splitter_name=kwargs.get('splitter_name','splitter')
code_string = f"""
headers_to_split_on = [
("h1", "Header 1"),
("h2", "Header 2"),
("h3", "Header 3"),]
from langchain_text_splitters import HTMLHeaderTextSplitter
splitter = HTMLHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
splits={splitter_name}.split_text(docs[0].page_content)
"""
import_string = f"""from langchain_text_splitters import HTMLHeaderTextSplitter"""
return {'code_string':code_string,'import_string':import_string}
except KeyError as e:
logger.error(f"Missing key in kwargs for HTMLHeaderTextSplitter: {e}")
raise
except Exception as e:
logger.error(f"Error in HTMLHeaderTextSplitter: {e}")
raise

17 changes: 15 additions & 2 deletions src/ragbuilder/langchain_module/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ def setup_logging():
logger.addHandler(console_handler)

# Redirect stdout and stderr to the logger
sys.stdout = LoggerWriter(logger, logging.INFO)
sys.stderr = LoggerWriter(logger, logging.ERROR)
# sys.stdout = LoggerWriter(logger, logging.INFO)
# sys.stderr = LoggerWriter(logger, logging.ERROR)

print(log_filename)
return log_filename
Expand Down Expand Up @@ -99,3 +99,16 @@ def flush(self):
if not re.search(r"GET /get_log_updates|common.py - flush -", self._buffer):
self.logger.log(self.level, self._buffer.rstrip())
self._buffer = ''


def codeGen(code_string,return_code,output_var):
globals_dict = {}
locals_dict = {}
try:
if not return_code:
exec(code_string,globals_dict,locals_dict)
return locals_dict[output_var]
else:
return code_string
except Exception as e:
return e
13 changes: 9 additions & 4 deletions src/ragbuilder/langchain_module/embedding_model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,19 @@ def getEmbedding(**kwargs):

if embedding_model in ["text-embedding-3-small","text-embedding-3-large","text-embedding-ada-002"]:
logger.info(f"OpenAIEmbeddings Invoked: {embedding_model}")
return OpenAIEmbeddings(model=embedding_model)
code_string= f"""embedding=OpenAIEmbeddings(model='{embedding_model}')"""
import_string = f"""from langchain_openai import OpenAIEmbeddings"""
return {'code_string':code_string,'import_string':import_string}
elif embedding_model == "mistral-embed":
logger.info(f"MistralAIEmbeddings Invoked: {embedding_model}")
return MistralAIEmbeddings(api_key=os.environ.get("MISTRAL_API_KEY"))
code_string= f"""embedding=MistralAIEmbeddings(api_key=os.environ.get("MISTRAL_API_KEY"))"""
import_string = f"""from langchain_mistralai import MistralAIEmbeddings"""
return {'code_string':code_string,'import_string':import_string}
elif embedding_model == "all-MiniLM-l6-v2":
logger.info(f"HuggingFaceInferenceAPIEmbeddings Invoked: {embedding_model}")
return HuggingFaceInferenceAPIEmbeddings(
api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), model_name="sentence-transformers/all-MiniLM-l6-v2")
code_string= f"""embedding=HuggingFaceInferenceAPIEmbeddings(api_key=os.environ.get("HUGGINGFACEHUB_API_TOKEN"), model_name="sentence-transformers/all-MiniLM-l6-v2")"""
import_string = f"""from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings"""
return {'code_string':code_string,'import_string':import_string}
else:
raise ValueError(f"Invalid LLM: {embedding_model}")
except KeyError as ke:
Expand Down
39 changes: 12 additions & 27 deletions src/ragbuilder/langchain_module/llms/llmConfig.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,26 @@
from langchain_mistralai.chat_models import ChatMistralAI
from langchain_huggingface import HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from ragbuilder.langchain_module.common import setup_logging
from ragbuilder.langchain_module.common import setup_logging,codeGen
import logging
setup_logging()
import os
logger = logging.getLogger("ragbuilder")
def getLLM(**kwargs):
logger.info("LLM Invoked")
return_code= kwargs.get('return_code', False)
retrieval_model=kwargs['retrieval_model']
if kwargs['retrieval_model'] in ["gpt-3.5-turbo","gpt-4o"]:
return getOpenaiLLM(kwargs['retrieval_model'],return_code)
elif kwargs['retrieval_model'] in ["mistral-7b", "mistral-small-latest","mistral-large-latest"] :
logger.info("LLM Codgen Invoked")
return getMistralLLM(kwargs['retrieval_model'],return_code)
logger.info("LLM Code Gen Invoked")
import_string = f"""from langchain_openai import ChatOpenAI"""
code_string = f"""llm=ChatOpenAI(model='{retrieval_model}')"""
elif kwargs['retrieval_model'] in ["mistral-small-latest","mistral-large-latest"] :
import_string = f"""from langchain_mistralai.chat_models import ChatMistralAI"""
code_string = f"""llm=ChatMistralAI(api_key=os.environ.get('MISTRAL_API_KEY'),model='{retrieval_model}')"""
else:
raise ValueError(f"Invalid LLM: {kwargs['retrieval_model']}")

def getOpenaiLLM(retrieval_model, return_code):
logger.info(f"model={retrieval_model} Invoked")
if not return_code:
llm = ChatOpenAI(model=retrieval_model)
else:
logger.info("getOpenaiLLM Codgen Invoked")
llm = f"""llm=ChatOpenAI(model='{retrieval_model}')""" # Return the code as a string
return llm
return {'code_string':code_string,'import_string':import_string}

def getMistralLLM(retrieval_model,return_code):
logger.info(f"model={retrieval_model} Invoked")
if return_code is None:
llm = ChatMistralAI(
api_key=os.environ.get("MISTRAL_API_KEY"),
model=retrieval_model)
else:
logger.info("getMistralLLM Codgen Invoked")
llm = f"""
llm=ChatMistralAI(api_key=os.environ.get('MISTRAL_API_KEY'),model=retrieval_model)"""
return llm
def getHuggingFaceLLM(retrieval_model,return_code):
def getHuggingFaceLLM(retrieval_model,return_code=False):
logger.info(f"model={retrieval_model} Invoked")
if return_code is None:
llm = HuggingFaceEndpoint(
Expand All @@ -46,3 +29,5 @@ def getHuggingFaceLLM(retrieval_model,return_code):
else:
llm = "HuggingFaceEndpoint(repo_id=retrieval_model,huggingfacehub_api_token=os.environ.get('HUGGINGFACEHUB_API_TOKEN'))"
return llm


Loading

0 comments on commit 7bdc8e8

Please sign in to comment.