Skip to content

Commit

Permalink
feat: ability to use local embeddings model (sBERT)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tymec committed Apr 14, 2023
1 parent 98efd26 commit 967c927
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 21 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ orjson
Pillow
coverage
flake8
numpy
sentence_transformers
1 change: 1 addition & 0 deletions scripts/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(self):
# Note that indexes must be created on db 0 in redis, this is not configurable.

self.memory_backend = os.getenv("MEMORY_BACKEND", 'local')
self.memory_embeder = os.getenv("MEMORY_EMBEDER", 'ada')
# Initialize the OpenAI API client
openai.api_key = self.openai_api_key

Expand Down
24 changes: 18 additions & 6 deletions scripts/memory/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,28 @@
from config import AbstractSingleton, Config
import openai

cfg = Config()
try:
from sentence_transformers import SentenceTransformer
except ImportError:
SentenceTransformer = None
if cfg.memory_embeder == "sbert":
print("Error: Sentence Transformers is not installed. Please install sentence_transformers"
" to use BERT as an embeder. Defaulting to Ada.")
cfg.memory_embeder = "ada"


cfg = Config()

def get_ada_embedding(text):
def get_embedding(text):
text = text.replace("\n", " ")
if cfg.use_azure:
return openai.Embedding.create(input=[text], engine=cfg.get_azure_deployment_id_for_model("text-embedding-ada-002"))["data"][0]["embedding"]
else:
return openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]

if cfg.memory_embeder == "sbert":
embedding = SentenceTransformer("sentence-transformers/all-mpnet-base-v2", device="cpu").encode(text, show_progress_bar=False)
else:
embedding = openai.Embedding.create(input=[text], model="text-embedding-ada-002")["data"][0]["embedding"]

return embedding


class MemoryProviderSingleton(AbstractSingleton):
@abc.abstractmethod
Expand Down
10 changes: 6 additions & 4 deletions scripts/memory/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from typing import Any, List, Optional
import numpy as np
import os
from memory.base import MemoryProviderSingleton, get_ada_embedding
from memory.base import MemoryProviderSingleton, get_embedding
from config import Config

cfg = Config()

EMBED_DIM = 1536
EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768
SAVE_OPTIONS = orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SERIALIZE_DATACLASS


Expand Down Expand Up @@ -58,7 +60,7 @@ def add(self, text: str):
return ""
self.data.texts.append(text)

embedding = get_ada_embedding(text)
embedding = get_embedding(text)

vector = np.array(embedding).astype(np.float32)
vector = vector[np.newaxis, :]
Expand Down Expand Up @@ -109,7 +111,7 @@ def get_relevant(self, text: str, k: int) -> List[Any]:
Returns: List[str]
"""
embedding = get_ada_embedding(text)
embedding = get_embedding(text)

scores = np.dot(self.data.embeddings, embedding)

Expand Down
9 changes: 4 additions & 5 deletions scripts/memory/pinecone.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@

import pinecone

from memory.base import MemoryProviderSingleton, get_ada_embedding
from memory.base import MemoryProviderSingleton, get_embedding
from logger import logger
from colorama import Fore, Style


class PineconeMemory(MemoryProviderSingleton):
def __init__(self, cfg):
pinecone_api_key = cfg.pinecone_api_key
pinecone_region = cfg.pinecone_region
pinecone.init(api_key=pinecone_api_key, environment=pinecone_region)
dimension = 1536
dimension = 1536 if cfg.memory_embeder == "ada" else 768
metric = "cosine"
pod_type = "p1"
table_name = "auto-gpt"
Expand All @@ -33,7 +32,7 @@ def __init__(self, cfg):
self.index = pinecone.Index(table_name)

def add(self, data):
vector = get_ada_embedding(data)
vector = get_embedding(data)
# no metadata here. We may wish to change that long term.
resp = self.index.upsert([(str(self.vec_num), vector, {"raw_text": data})])
_text = f"Inserting data into memory at index: {self.vec_num}:\n data: {data}"
Expand All @@ -53,7 +52,7 @@ def get_relevant(self, data, num_relevant=5):
:param data: The data to compare to.
:param num_relevant: The number of relevant data to return. Defaults to 5
"""
query_embedding = get_ada_embedding(data)
query_embedding = get_embedding(data)
results = self.index.query(query_embedding, top_k=num_relevant, include_metadata=True)
sorted_results = sorted(results.matches, key=lambda x: x.score)
return [str(item['metadata']["raw_text"]) for item in sorted_results]
Expand Down
13 changes: 8 additions & 5 deletions scripts/memory/redismem.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@
from redis.commands.search.indexDefinition import IndexDefinition, IndexType
import numpy as np

from memory.base import MemoryProviderSingleton, get_ada_embedding
from memory.base import MemoryProviderSingleton, get_embedding
from logger import logger
from colorama import Fore, Style
from config import Config

cfg = Config()

EMBED_DIM = 1536 if cfg.memory_embeder == "ada" else 768

SCHEMA = [
TextField("data"),
Expand All @@ -18,7 +22,7 @@
"HNSW",
{
"TYPE": "FLOAT32",
"DIM": 1536,
"DIM": EMBED_DIM,
"DISTANCE_METRIC": "COSINE"
}
),
Expand All @@ -38,7 +42,6 @@ def __init__(self, cfg):
redis_host = cfg.redis_host
redis_port = cfg.redis_port
redis_password = cfg.redis_password
self.dimension = 1536
self.redis = redis.Redis(
host=redis_host,
port=redis_port,
Expand Down Expand Up @@ -83,7 +86,7 @@ def add(self, data: str) -> str:
"""
if 'Command Error:' in data:
return ""
vector = get_ada_embedding(data)
vector = get_embedding(data)
vector = np.array(vector).astype(np.float32).tobytes()
data_dict = {
b"data": data,
Expand Down Expand Up @@ -131,7 +134,7 @@ def get_relevant(
Returns: A list of the most relevant data.
"""
query_embedding = get_ada_embedding(data)
query_embedding = get_embedding(data)
base_query = f"*=>[KNN {num_relevant} @embedding $vector AS vector_score]"
query = Query(base_query).return_fields(
"data",
Expand Down

0 comments on commit 967c927

Please sign in to comment.