Skip to content

Commit

Permalink
simplify index (langchain-ai#126)
Browse files Browse the repository at this point in the history
Simplify indexing
  • Loading branch information
baskaryan authored Sep 6, 2023
1 parent b929e1e commit 232f5aa
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 155 deletions.
Binary file removed agent_all_transformed.pkl
Binary file not shown.
2 changes: 0 additions & 2 deletions constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1 @@
WEAVIATE_REPO_INDEX_NAME = "LangChain_agent_repo"
WEAVIATE_DOCS_INDEX_NAME = "LangChain_agent_docs"
WEAVIATE_SOURCES_INDEX_NAME = "LangChain_agent_sources"
113 changes: 7 additions & 106 deletions ingest.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,26 @@
"""Load html from files, clean up, split, ingest into Weaviate."""
import logging
import os
from git import Repo
import shutil
import pickle
from bs4 import BeautifulSoup
import weaviate
from langchain.document_loaders.recursive_url_loader import RecursiveUrlLoader
from langchain.embeddings import OpenAIEmbeddings
from langchain.indexes import SQLRecordManager, index
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Weaviate
from langchain.document_transformers import Html2TextTransformer
from langchain.document_loaders.generic import GenericLoader
from langchain.document_loaders.parsers import LanguageParser

from constants import (
WEAVIATE_REPO_INDEX_NAME,
WEAVIATE_DOCS_INDEX_NAME,
WEAVIATE_SOURCES_INDEX_NAME,
)

logger = logging.getLogger(__name__)

WEAVIATE_URL = os.environ["WEAVIATE_URL"]
WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"]
RECORD_MANAGER_DB_URL = os.environ["RECORD_MANAGER_DB_URL"]


def ingest_repo():
repo_path = os.path.join(os.getcwd(), "test_repo")
if os.path.exists(repo_path):
shutil.rmtree(repo_path)

Repo.clone_from(
"https://github.com/langchain-ai/langchain", to_path=repo_path
)

loader = GenericLoader.from_filesystem(
repo_path + "/libs/langchain/langchain",
glob="**/*",
suffixes=[".py"],
parser=LanguageParser(language=Language.PYTHON, parser_threshold=500),
)
documents_repo = loader.load()
len(documents_repo)

with open("agent_repo_transformed.pkl", "wb") as f:
pickle.dump(documents_repo, f)

python_splitter = RecursiveCharacterTextSplitter.from_language(
language=Language.PYTHON, chunk_size=2000, chunk_overlap=200
)
texts = python_splitter.split_documents(documents_repo)

client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
)
embedding = OpenAIEmbeddings(chunk_size=200)
vectorstore = Weaviate(
client,
WEAVIATE_REPO_INDEX_NAME,
"text",
embedding=embedding,
by_text=False,
attributes=["source"],
)

record_manager = SQLRecordManager(
f"weaviate/{WEAVIATE_REPO_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL
)
record_manager.create_schema()
index(texts, record_manager, vectorstore, cleanup="full", source_id_key="source")
return texts


def ingest_docs():
urls = [
"https://api.python.langchain.com/en/latest/api_reference.html#module-langchain",
Expand All @@ -100,12 +47,9 @@ def ingest_docs():
documents += temp_docs

html2text = Html2TextTransformer()
docs_transformed = html2text.transform_documents(documents)

text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
with open("agent_docs_transformed.pkl", "wb") as f:
pickle.dump(docs_transformed, f)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4000, chunk_overlap=200)

docs_transformed = html2text.transform_documents(documents)
docs_transformed = text_splitter.split_documents(docs_transformed)

client = weaviate.Client(
Expand Down Expand Up @@ -134,54 +78,11 @@ def ingest_docs():
source_id_key="source",
)

print(
logger.info(
"LangChain now has this many vectors: ",
client.query.aggregate(WEAVIATE_DOCS_INDEX_NAME).with_meta_count().do(),
)


def ingest_sources():
with open("agent_repo_transformed.pkl", "rb") as f:
codes = pickle.load(f)

with open("agent_docs_transformed.pkl", "rb") as f:
documentations = pickle.load(f)

all_texts = codes + documentations
with open("agent_all_transformed.pkl", "wb") as f:
pickle.dump(all_texts, f)
all_sources = [
Document(
page_content=doc.metadata["source"],
metadata={"source": doc.metadata["source"]},
)
for doc in all_texts
]

client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
)
embedding = OpenAIEmbeddings(chunk_size=100) # rate limit
vectorstore = Weaviate(
client,
WEAVIATE_SOURCES_INDEX_NAME,
"text",
embedding=embedding,
by_text=False,
attributes=["source"],
)

record_manager = SQLRecordManager(
f"weaviate/{WEAVIATE_SOURCES_INDEX_NAME}", db_url=RECORD_MANAGER_DB_URL
)
record_manager.create_schema()
index(
all_sources, record_manager, vectorstore, cleanup="full", source_id_key="source"
)


if __name__ == "__main__":
ingest_repo()
ingest_docs()
ingest_sources()
68 changes: 21 additions & 47 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,9 @@
from langchain.agents.openai_functions_agent.agent_token_buffer_memory import (
AgentTokenBufferMemory,
)
import pickle
from langchain.callbacks.base import BaseCallbackHandler

from constants import WEAVIATE_SOURCES_INDEX_NAME, WEAVIATE_DOCS_INDEX_NAME
from constants import WEAVIATE_DOCS_INDEX_NAME

client = Client()

Expand All @@ -51,85 +50,60 @@
WEAVIATE_API_KEY = os.environ["WEAVIATE_API_KEY"]


def search(inp: str, index_name: str, callbacks=None) -> str:
def search(inp: str, callbacks=None) -> list:
client = weaviate.Client(
url=WEAVIATE_URL,
auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_API_KEY),
)
weaviate_client = Weaviate(
client=client,
index_name=index_name,
index_name=WEAVIATE_DOCS_INDEX_NAME,
text_key="text",
embedding=OpenAIEmbeddings(chunk_size=200),
by_text=False,
attributes=["source"]
if not index_name == WEAVIATE_SOURCES_INDEX_NAME
else None,
attributes=["source"],
)
retriever = weaviate_client.as_retriever(
search_kwargs=dict(k=3), callbacks=callbacks
)

return retriever.get_relevant_documents(inp, callbacks=callbacks)


with open("agent_all_transformed.pkl", "rb") as f:
all_texts = pickle.load(f)


def search_everything(inp: str, callbacks: Optional[any] = None) -> str:
global all_texts
docs_references = search(inp, WEAVIATE_DOCS_INDEX_NAME, callbacks=callbacks)
# repo_references = search(inp, "WEAVIATE_REPO_INDEX_NAME", callbacks=callbacks)
all_references = docs_references
all_references_sources = [r for r in all_references if r.metadata["source"]]

sources = search(inp, WEAVIATE_SOURCES_INDEX_NAME, callbacks=callbacks)
sources_docs = [
doc
for doc in all_texts
if doc.metadata["source"] in [source.page_content for source in sources]
]
combined_sources = sources_docs + all_references_sources

return [doc.page_content for doc in combined_sources]
docs = retriever.get_relevant_documents(inp, callbacks=callbacks)
return [doc.page_content for doc in docs]


def get_tools():
langchain_tool = Tool(
name="Documentation",
func=search_everything,
description="useful for when you need to refer to LangChain's documentation, for both API reference and codebase",
func=search,
description="useful for when you need to refer to LangChain's documentation",
)
ALL_TOOLS = [langchain_tool]

return ALL_TOOLS
return [langchain_tool]


def get_agent(llm, chat_history: Optional[list] = None):
def get_agent(llm, *, chat_history: Optional[list] = None):
chat_history = chat_history or []
system_message = SystemMessage(
content=(
"You are an expert developer who is tasked with scouring documentation to answer question about LangChain. "
"Answer the following question as best you can. "
"Be inclined to include CORRECT Python code snippets if relevant to the question. If you can't find the answer, DO NOT hallucinate. Just say you don't know. "
"You have access to a LangChain knowledge bank retriever tool for your answer but know NOTHING about LangChain otherwise. "
"Always provide articulate detail to your action input. "
"You should always first check your search tool for information on the concepts in the question. "
"You are an expert developer tasked answering questions about the LangChain Python package. "
"You have access to a LangChain knowledge bank which you can query but know NOTHING about LangChain otherwise. "
"You should always first query the knowledge bank for information on the concepts in the question. "
"For example, given the following input question:\n"
"-----START OF EXAMPLE INPUT QUESTION-----\n"
"What is the transform() method for runnables? \n"
"-----END OF EXAMPLE INPUT QUESTION-----\n"
"Your research flow should be:\n"
"1. Query your search tool for information on 'Transform() method' to get as much context as you can about it. \n"
"2. Then, query your search tool for information on 'Runnables' to get as much context as you can about it. \n"
"1. Query your search tool for information on 'Runnables.transform() method' to get as much context as you can about it.\n"
"2. Then, query your search tool for information on 'Runnables' to get as much context as you can about it.\n"
"3. Answer the question with the context you have gathered."
"For another example, given the following input question:\n"
"-----START OF EXAMPLE INPUT QUESTION-----\n"
"How can I use vLLM to run my own locally hosted model? \n"
"-----END OF EXAMPLE INPUT QUESTION-----\n"
"Your research flow should be:\n"
"1. Query your search tool for information on 'vLLM' to get as much context as you can about it. \n"
"2. Answer the question as you now have enough context."
"1. Query your search tool for information on 'run vLLM locally' to get as much context as you can about it. \n"
"2. Answer the question as you now have enough context.\n\n"
"Include CORRECT Python code snippets in your answer if relevant to the question. If you can't find the answer, DO NOT make up an answer. Just say you don't know. "
"Answer the following question as best you can:"
)
)

Expand Down Expand Up @@ -206,7 +180,7 @@ def stream() -> Generator:
)

def task():
agent = get_agent(llm, chat_history)
agent = get_agent(llm, chat_history=chat_history)
agent.invoke(
{"input": question, "chat_history": chat_history},
config=runnable_config,
Expand Down

0 comments on commit 232f5aa

Please sign in to comment.