-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor memories, introduce MemoryManager
- Loading branch information
1 parent
35d8ff1
commit 2ef341a
Showing
7 changed files
with
1,738 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,3 +30,4 @@ venv.bak/ | |
*.swp | ||
*.swo | ||
.DS_Store | ||
secret/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
[[source]] | ||
url = "https://pypi.org/simple" | ||
verify_ssl = true | ||
name = "pypi" | ||
|
||
[packages] | ||
chromadb = "*" | ||
openai = "*" | ||
beautifulsoup4 = "*" | ||
requests = "*" | ||
|
||
[dev-packages] | ||
|
||
[requires] | ||
python_version = "3.11" |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
# RAFT / RATF | ||
|
||
RAFT, or Retrieval-Augmented Fine-Tuning, is a method comprising of a fine-tuning and a RAG-based retrieval phase. It is particularly suited for the creation of agents that realistically emulate a specific human target. | ||
|
||
RATF, or Replica Agent Testing Framework, is a framework for evaluating the performance of retrieval-based dialogue agents. | ||
|
||
## Abstract | ||
|
||
The emulation of specific humans in conversational agents presents unique challenges and opportunities for contextual understanding, theory of mind and personalization. In this paper, we introduce the Retrieval-Augmented Fine-Tuning (RAFT) methodology, designed explicitly for simulating individual humans. | ||
|
||
RAFT employs a dual-phase process: | ||
|
||
In the Retrieval-Augmented Fine-Tuning phase proper, combines interview transcripts featuring the human target with appropriately selected, rephrased and evaluated "memories" from the author's past output to give the model a sense of the way the target human combines past writings with the current context to generate responses. | ||
|
||
In the generation phase, these memories augment the language model's responses to create a nuanced and personalized dialogue. | ||
|
||
We demonstrate the efficacy of RAFT through a unique evaluation metric, RATF (Replica Agent Testing Framework) that compares model-generated responses with original human responses in an interview setting. Our findings highlight RAFT's potential to significantly advance the field of personalized, context-sensitive conversational agents. | ||
|
||
## Process | ||
|
||
### Retrieval-Augmented Fine-Tuning | ||
|
||
Two datasets are required for the fine-tuning phase: | ||
|
||
- A dataset of interview transcripts featuring the target human | ||
- A dataset of the author's past written output (tweets, essays, etc.) | ||
|
||
The interview transcripts used within a RAG process retreiving "memories" from the autor's written output for each of the interviewer's questions. These memories are then rephrased and evaluated in the context of the target user's answer and, if found useful, they are interpolated between question and answer for the fine-tuning phase. | ||
|
||
The steps to reproduce this process are as follows: | ||
|
||
1. Create a dataset of interview transcripts featuring the target human. Each interview should be a separate file, with the interviewer's questions and the target human's answers separated by a newline. | ||
2. Create a dataset of the author's past written output. Each piece of writing should be a separate file. | ||
3. Split the past output dataset in chuncks of a size suitable for the chosen embedding model (8192 tokens for Openai's text-embedding-ada-002), and collect the embeddings for each chunk. | ||
|
||
Then, in order to generate a fine-tuning dataset: | ||
|
||
1. For each interview, run the RAG process to retrieve memories from the author's past output for each of the interviewer's questions. | ||
2. Ask the model to rephrase each memory in the context of the interviewer's question. The same model and prompt will be used in the generation phase. | ||
3. Evaluate the resulting memory by the question only first, and discard it if it is not considered useful by the model. We apply this first pass separately because, at inference time, we will not have access to the target human's answer. | ||
4. Finally, evaluate the memory by the question and the target human's answer, and discard it if it is not considered useful by the model. | ||
5. Save the resulting context including question, memory and as many of the previous [question, memory and answers] tuples as possible, up to the maximum context size the finetune allows, as a new finetune sample. | ||
|
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
from openai import ChatCompletion | ||
|
||
def summarize_memory(memory, question:str): | ||
messages = [ | ||
{"role": "system", "content": "You are helping Paul Graham answer the questions from a young founders team. Decide whether the quote presented is helpful in answerting the question. If it is, rephrase it from his perspective, in a way that would be helpful for answeing, in one or two sentences - type it directly, without intro. Id it is not helpful, simply type 'skip'"}, | ||
{"role": "user", "content": f"Question: {question}\nMemory: {memory[0]}"} | ||
] | ||
response = ChatCompletion.create(model="gpt-3.5-turbo-16k", messages=messages) | ||
return response['choices'][0]['message']['content'].strip() | ||
|
||
def contextualise_memories_for_prompt(memories): | ||
memories_string = "\n".join([f"[memory]{memory[0]}" for memory in memories]) | ||
if (len(memories_string) > 0): | ||
[{ "role": "function", "name": "memory", "content": f"I wrote something relevant to this question in the past. To wit:\n{memories_string}" }] | ||
else: | ||
return [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import List, Tuple, Dict | ||
from concurrent.futures import ThreadPoolExecutor | ||
from openai import ChatCompletion, Embedding, GPT3Encoder | ||
import logging | ||
import chromadb | ||
from ..prompts import summarize_memory | ||
|
||
MAX_EMBEDDING_LENGTH = 2048 | ||
|
||
class MemoryManager: | ||
def __init__(self, name: str): | ||
self.chroma_client = chromadb.Client() | ||
self.collection = self.chroma_client.get_collection(name) | ||
self.encoder = GPT3Encoder() | ||
|
||
def get_embedding(self, text:str): | ||
return Embedding.create(text=text) | ||
|
||
def get_similar_and_summarize(self, text:str): | ||
similar_extract = self.get_similar_extracts(text) | ||
summary = self.summarize_helpful_memories(text, similar_extract) | ||
return summary | ||
|
||
def get_similar_extracts(self, text: str) -> List[str]: | ||
results = self.collection.query( | ||
query_embeddings=[self.get_embedding(text)], | ||
min_score=0.5, | ||
max_results=3 | ||
) | ||
similar_texts = [result.document for result in results] | ||
return similar_texts | ||
|
||
def summarize_helpful_memories(self, question:str, similar_extracts:List[Tuple[str, float]]): | ||
with ThreadPoolExecutor() as executor: | ||
summaries = list(executor.map(summarize_memory, similar_extracts, [question]*len(similar_extracts))) | ||
summaries = [summary for summary in summaries if summary != "skip"] | ||
return summaries | ||
|
||
def store_grounding_embeddings(self, blog_posts: List[Dict]): | ||
# Iterate over the blog posts | ||
for post in blog_posts: | ||
# Reserve some tokens for the title and date | ||
reserved_tokens = self.encoder.encode(f"Title: {post['title']}\nDate: {post['date']}\nPart: 1\n") | ||
chunk_size = MAX_EMBEDDING_LENGTH - len(reserved_tokens) | ||
|
||
# Split the post content into chunks | ||
content_tokens = self.encoder.encode(post['content']) | ||
chunks = [content_tokens[i:i+chunk_size] for i in range(0, len(content_tokens), chunk_size)] | ||
|
||
logging.info(f"Storing {len(chunks)} chunks for post {post['title']}") | ||
|
||
# Add title, date and part to each chunk and store its embedding in the Chroma DB | ||
for j, chunk in enumerate(chunks): | ||
document = f"Title: {post['title']}\nDate: {post['date']}\nPart: {j+1}\n{self.encoder.decode(chunk)}" | ||
self.collection.add(documents=[document]) |