diff --git a/latentscope/models/providers/openai.py b/latentscope/models/providers/openai.py
index 8b9dfde..44b5797 100644
--- a/latentscope/models/providers/openai.py
+++ b/latentscope/models/providers/openai.py
@@ -52,9 +52,9 @@ def load_model(self):
self.encoder = tiktoken.encoding_for_model(self.name)
else:
self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"), base_url=self.base_url)
- self.encoder = None
- print("BASE URL", self.base_url)
- print("MODEL", self.name)
+ # even if this is some other model, we wont be able to figure out the tokenizer from custom API
+ # so we just use gpt-4o as a fallback, it should be roughly correct for token counts
+ self.encoder = tiktoken.encoding_for_model("gpt-4o")
config = OpenAIConfig(self.name)
self.model = outlines.models.openai(self.client, config)
self.generator = outlines.generate.text(self.model)
diff --git a/latentscope/models/providers/transformers.py b/latentscope/models/providers/transformers.py
index b44f5f3..0eb8264 100644
--- a/latentscope/models/providers/transformers.py
+++ b/latentscope/models/providers/transformers.py
@@ -41,6 +41,7 @@ def load_model(self):
from transformers import AutoTokenizer
self.model = self.outlines.models.transformers(self.name)
self.generator = self.outlines.generate.text(self.model)
+ # TODO: this had an error. but also it would exclude foreign characters
# self.generator = self.outlines.generate.regex(
# self.model
# , r"[a-zA-Z0-9 ]+"
diff --git a/latentscope/scripts/embed.py b/latentscope/scripts/embed.py
index c501715..6d8bc7a 100644
--- a/latentscope/scripts/embed.py
+++ b/latentscope/scripts/embed.py
@@ -161,6 +161,7 @@ def embed(dataset_id, text_column, model_id, prefix, rerun, dimensions, batch_si
"text_column": text_column,
# "dimensions": np_embeds.shape[1],
"dimensions": embeddings.shape[1],
+ "max_seq_length": max_seq_length,
"prefix": prefix,
"min_values": min_values.tolist(),
"max_values": max_values.tolist(),
@@ -235,6 +236,7 @@ def embed_truncate(dataset_id, embedding_id, dimensions):
"model_id": embedding_meta["model_id"],
"dataset_id": dataset_id,
"text_column": embedding_meta["text_column"],
+ "max_seq_length": embedding_meta["max_seq_length"],
"dimensions": matroyshka.shape[1],
"prefix": embedding_meta["prefix"],
"min_values": min_values.tolist(),
diff --git a/latentscope/scripts/label_clusters.py b/latentscope/scripts/label_clusters.py
index 6c0f180..47b9290 100644
--- a/latentscope/scripts/label_clusters.py
+++ b/latentscope/scripts/label_clusters.py
@@ -43,7 +43,7 @@ def main():
parser.add_argument('samples', type=int, help='Number to sample from each cluster (default: 0 for all)', default=0)
parser.add_argument('context', type=str, help='Additional context for labeling model', default="")
parser.add_argument('--rerun', type=str, help='Rerun the given embedding from last completed batch')
- parser.add_argument('--max_tokens', type=int, help='Max tokens for the model', default=-1)
+ parser.add_argument('--max_tokens', type=int, help='Max tokens per sample', default=-1)
# Parse arguments
args = parser.parse_args()
@@ -63,14 +63,24 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
# initialize the labeled property to false when loading default clusters
clusters = clusters.copy()
clusters['labeled'] = False
+
+ cluster_rows = pd.read_parquet(os.path.join(cluster_dir, f"{cluster_id}.parquet"))
+ df["cluster"] = cluster_rows["cluster"]
+ df["raw_cluster"] = cluster_rows["raw_cluster"]
+
+ with open(os.path.join(cluster_dir, f"{cluster_id}.json"), 'r') as f:
+ cluster_meta = json.load(f)
+ umap_id = cluster_meta["umap_id"]
+ umap = pd.read_parquet(os.path.join(DATA_DIR, dataset_id, "umaps", f"{umap_id}.parquet"))
+ df["x"] = umap["x"]
+ df["y"] = umap["y"]
unlabeled_row = 0
if rerun is not None:
label_id = rerun
- clusters = pd.read_parquet(os.path.join(cluster_dir, f"{label_id}.parquet"))
# print(clusters.columns)
# find the first row where labeled isnt True
- unlabeled_row = clusters[~clusters['labeled']].first_valid_index()
+ unlabeled_row = cluster_rows[~cluster_rows['labeled']].first_valid_index()
tqdm.write(f"First unlabeled row: {unlabeled_row}")
@@ -110,23 +120,44 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
extracts = []
for i, row in tqdm(clusters.iterrows(), total=clusters.shape[0], desc="Preparing extracts"):
indices = row['indices']
- items = df.loc[list(indices), text_column]
+ # items = df.loc[list(indices), text_column]
+ items = df.loc[list(indices)]
if samples > 0 and samples < len(items):
- items = items.sample(samples)
+ # first sample the items from cluster_rows where 'raw_cluster' matches the current cluster_id
+ cluster_items = items[items['raw_cluster'] == i]
+
+ if(len(cluster_items) < samples):
+ cluster_items = pd.concat([cluster_items, items[items['cluster'] == i]])
+
+ # Sort cluster items by distance from centroid
+ # Get x,y coordinates for items
+ coords = cluster_items[['x', 'y']].values
+
+ # Calculate centroid
+ centroid = coords.mean(axis=0)
+
+ # Calculate distances from centroid
+ distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1))
+
+ # Add distances as column and sort
+ cluster_items = cluster_items.assign(centroid_dist=distances)
+ cluster_items = cluster_items.sort_values('centroid_dist')
+
+ items = cluster_items[0:samples]
+ # items = cluster_items.sample(samples)
+
items = items.drop_duplicates()
- tokens = 0
+ items = items[text_column]
+
keep_items = []
- if max_tokens > 0 and enc is not None:
- while tokens < max_tokens:
- for item in items:
- if item is None:
- continue
- encoded_item = enc.encode(item)
- if tokens + len(encoded_item) > max_tokens:
- break
- keep_items.append(item)
- tokens += len(encoded_item)
- break
+ if enc is not None:
+ for item in items:
+ if item is None:
+ continue
+ encoded_item = enc.encode(item)
+ if max_tokens > 0 and len(encoded_item) > max_tokens:
+ item = enc.decode(encoded_item[:max_tokens])
+ keep_items.append(item)
else:
keep_items = items
keep_items = [item for item in keep_items if item is not None]
diff --git a/web/src/components/Setup/ClusterLabels.jsx b/web/src/components/Setup/ClusterLabels.jsx
index c07b9d9..aca88e3 100644
--- a/web/src/components/Setup/ClusterLabels.jsx
+++ b/web/src/components/Setup/ClusterLabels.jsx
@@ -342,8 +342,8 @@ function ClusterLabels() {
🤔
- The number of samples to use from each cluster for summarization. Set to 0 to use
- all samples.
+ The number of items to use from each cluster for summarization. Set to 0 to use all
+ items. Items are chosen based on distance from the centroid of the cluster.