-
Notifications
You must be signed in to change notification settings - Fork 5.3k
/
Copy pathchroma_utils.py
41 lines (32 loc) · 1.07 KB
/
chroma_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import pathlib
import chromadb
from chromadb.utils import embedding_functions
from more_itertools import batched
def build_chroma_collection(
chroma_path: pathlib.Path,
collection_name: str,
embedding_func_name: str,
ids: list[str],
documents: list[str],
metadatas: list[dict],
distance_func_name: str = "cosine",
):
"""Create a ChromaDB collection"""
chroma_client = chromadb.PersistentClient(chroma_path)
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=embedding_func_name
)
collection = chroma_client.create_collection(
name=collection_name,
embedding_function=embedding_func,
metadata={"hnsw:space": distance_func_name},
)
document_indices = list(range(len(documents)))
for batch in batched(document_indices, 166):
start_idx = batch[0]
end_idx = batch[-1]
collection.add(
ids=ids[start_idx:end_idx],
documents=documents[start_idx:end_idx],
metadatas=metadatas[start_idx:end_idx],
)