diff --git a/latentscope/models/providers/openai.py b/latentscope/models/providers/openai.py index 8b9dfde..44b5797 100644 --- a/latentscope/models/providers/openai.py +++ b/latentscope/models/providers/openai.py @@ -52,9 +52,9 @@ def load_model(self): self.encoder = tiktoken.encoding_for_model(self.name) else: self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"), base_url=self.base_url) - self.encoder = None - print("BASE URL", self.base_url) - print("MODEL", self.name) + # even if this is some other model, we wont be able to figure out the tokenizer from custom API + # so we just use gpt-4o as a fallback, it should be roughly correct for token counts + self.encoder = tiktoken.encoding_for_model("gpt-4o") config = OpenAIConfig(self.name) self.model = outlines.models.openai(self.client, config) self.generator = outlines.generate.text(self.model) diff --git a/latentscope/models/providers/transformers.py b/latentscope/models/providers/transformers.py index b44f5f3..0eb8264 100644 --- a/latentscope/models/providers/transformers.py +++ b/latentscope/models/providers/transformers.py @@ -41,6 +41,7 @@ def load_model(self): from transformers import AutoTokenizer self.model = self.outlines.models.transformers(self.name) self.generator = self.outlines.generate.text(self.model) + # TODO: this had an error. but also it would exclude foreign characters # self.generator = self.outlines.generate.regex( # self.model # , r"[a-zA-Z0-9 ]+" diff --git a/latentscope/scripts/embed.py b/latentscope/scripts/embed.py index c501715..6d8bc7a 100644 --- a/latentscope/scripts/embed.py +++ b/latentscope/scripts/embed.py @@ -161,6 +161,7 @@ def embed(dataset_id, text_column, model_id, prefix, rerun, dimensions, batch_si "text_column": text_column, # "dimensions": np_embeds.shape[1], "dimensions": embeddings.shape[1], + "max_seq_length": max_seq_length, "prefix": prefix, "min_values": min_values.tolist(), "max_values": max_values.tolist(), @@ -235,6 +236,7 @@ def embed_truncate(dataset_id, embedding_id, dimensions): "model_id": embedding_meta["model_id"], "dataset_id": dataset_id, "text_column": embedding_meta["text_column"], + "max_seq_length": embedding_meta["max_seq_length"], "dimensions": matroyshka.shape[1], "prefix": embedding_meta["prefix"], "min_values": min_values.tolist(), diff --git a/latentscope/scripts/label_clusters.py b/latentscope/scripts/label_clusters.py index 6c0f180..47b9290 100644 --- a/latentscope/scripts/label_clusters.py +++ b/latentscope/scripts/label_clusters.py @@ -43,7 +43,7 @@ def main(): parser.add_argument('samples', type=int, help='Number to sample from each cluster (default: 0 for all)', default=0) parser.add_argument('context', type=str, help='Additional context for labeling model', default="") parser.add_argument('--rerun', type=str, help='Rerun the given embedding from last completed batch') - parser.add_argument('--max_tokens', type=int, help='Max tokens for the model', default=-1) + parser.add_argument('--max_tokens', type=int, help='Max tokens per sample', default=-1) # Parse arguments args = parser.parse_args() @@ -63,14 +63,24 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" # initialize the labeled property to false when loading default clusters clusters = clusters.copy() clusters['labeled'] = False + + cluster_rows = pd.read_parquet(os.path.join(cluster_dir, f"{cluster_id}.parquet")) + df["cluster"] = cluster_rows["cluster"] + df["raw_cluster"] = cluster_rows["raw_cluster"] + + with open(os.path.join(cluster_dir, f"{cluster_id}.json"), 'r') as f: + cluster_meta = json.load(f) + umap_id = cluster_meta["umap_id"] + umap = pd.read_parquet(os.path.join(DATA_DIR, dataset_id, "umaps", f"{umap_id}.parquet")) + df["x"] = umap["x"] + df["y"] = umap["y"] unlabeled_row = 0 if rerun is not None: label_id = rerun - clusters = pd.read_parquet(os.path.join(cluster_dir, f"{label_id}.parquet")) # print(clusters.columns) # find the first row where labeled isnt True - unlabeled_row = clusters[~clusters['labeled']].first_valid_index() + unlabeled_row = cluster_rows[~cluster_rows['labeled']].first_valid_index() tqdm.write(f"First unlabeled row: {unlabeled_row}") @@ -110,23 +120,44 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id=" extracts = [] for i, row in tqdm(clusters.iterrows(), total=clusters.shape[0], desc="Preparing extracts"): indices = row['indices'] - items = df.loc[list(indices), text_column] + # items = df.loc[list(indices), text_column] + items = df.loc[list(indices)] if samples > 0 and samples < len(items): - items = items.sample(samples) + # first sample the items from cluster_rows where 'raw_cluster' matches the current cluster_id + cluster_items = items[items['raw_cluster'] == i] + + if(len(cluster_items) < samples): + cluster_items = pd.concat([cluster_items, items[items['cluster'] == i]]) + + # Sort cluster items by distance from centroid + # Get x,y coordinates for items + coords = cluster_items[['x', 'y']].values + + # Calculate centroid + centroid = coords.mean(axis=0) + + # Calculate distances from centroid + distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1)) + + # Add distances as column and sort + cluster_items = cluster_items.assign(centroid_dist=distances) + cluster_items = cluster_items.sort_values('centroid_dist') + + items = cluster_items[0:samples] + # items = cluster_items.sample(samples) + items = items.drop_duplicates() - tokens = 0 + items = items[text_column] + keep_items = [] - if max_tokens > 0 and enc is not None: - while tokens < max_tokens: - for item in items: - if item is None: - continue - encoded_item = enc.encode(item) - if tokens + len(encoded_item) > max_tokens: - break - keep_items.append(item) - tokens += len(encoded_item) - break + if enc is not None: + for item in items: + if item is None: + continue + encoded_item = enc.encode(item) + if max_tokens > 0 and len(encoded_item) > max_tokens: + item = enc.decode(encoded_item[:max_tokens]) + keep_items.append(item) else: keep_items = items keep_items = [item for item in keep_items if item is not None] diff --git a/web/src/components/Setup/ClusterLabels.jsx b/web/src/components/Setup/ClusterLabels.jsx index c07b9d9..aca88e3 100644 --- a/web/src/components/Setup/ClusterLabels.jsx +++ b/web/src/components/Setup/ClusterLabels.jsx @@ -342,8 +342,8 @@ function ClusterLabels() { 🤔 - The number of samples to use from each cluster for summarization. Set to 0 to use - all samples. + The number of items to use from each cluster for summarization. Set to 0 to use all + items. Items are chosen based on distance from the centroid of the cluster.