Skip to content

Commit

Permalink
refine cluster label sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
enjalot committed Nov 30, 2024
1 parent 8f50735 commit fc03f6a
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 24 deletions.
6 changes: 3 additions & 3 deletions latentscope/models/providers/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def load_model(self):
self.encoder = tiktoken.encoding_for_model(self.name)
else:
self.client = AsyncOpenAI(api_key=get_key("OPENAI_API_KEY"), base_url=self.base_url)
self.encoder = None
print("BASE URL", self.base_url)
print("MODEL", self.name)
# even if this is some other model, we wont be able to figure out the tokenizer from custom API
# so we just use gpt-4o as a fallback, it should be roughly correct for token counts
self.encoder = tiktoken.encoding_for_model("gpt-4o")
config = OpenAIConfig(self.name)
self.model = outlines.models.openai(self.client, config)
self.generator = outlines.generate.text(self.model)
Expand Down
1 change: 1 addition & 0 deletions latentscope/models/providers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def load_model(self):
from transformers import AutoTokenizer
self.model = self.outlines.models.transformers(self.name)
self.generator = self.outlines.generate.text(self.model)
# TODO: this had an error. but also it would exclude foreign characters
# self.generator = self.outlines.generate.regex(
# self.model
# , r"[a-zA-Z0-9 ]+"
Expand Down
2 changes: 2 additions & 0 deletions latentscope/scripts/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def embed(dataset_id, text_column, model_id, prefix, rerun, dimensions, batch_si
"text_column": text_column,
# "dimensions": np_embeds.shape[1],
"dimensions": embeddings.shape[1],
"max_seq_length": max_seq_length,
"prefix": prefix,
"min_values": min_values.tolist(),
"max_values": max_values.tolist(),
Expand Down Expand Up @@ -235,6 +236,7 @@ def embed_truncate(dataset_id, embedding_id, dimensions):
"model_id": embedding_meta["model_id"],
"dataset_id": dataset_id,
"text_column": embedding_meta["text_column"],
"max_seq_length": embedding_meta["max_seq_length"],
"dimensions": matroyshka.shape[1],
"prefix": embedding_meta["prefix"],
"min_values": min_values.tolist(),
Expand Down
65 changes: 48 additions & 17 deletions latentscope/scripts/label_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def main():
parser.add_argument('samples', type=int, help='Number to sample from each cluster (default: 0 for all)', default=0)
parser.add_argument('context', type=str, help='Additional context for labeling model', default="")
parser.add_argument('--rerun', type=str, help='Rerun the given embedding from last completed batch')
parser.add_argument('--max_tokens', type=int, help='Max tokens for the model', default=-1)
parser.add_argument('--max_tokens', type=int, help='Max tokens per sample', default=-1)

# Parse arguments
args = parser.parse_args()
Expand All @@ -63,14 +63,24 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
# initialize the labeled property to false when loading default clusters
clusters = clusters.copy()
clusters['labeled'] = False

cluster_rows = pd.read_parquet(os.path.join(cluster_dir, f"{cluster_id}.parquet"))
df["cluster"] = cluster_rows["cluster"]
df["raw_cluster"] = cluster_rows["raw_cluster"]

with open(os.path.join(cluster_dir, f"{cluster_id}.json"), 'r') as f:
cluster_meta = json.load(f)
umap_id = cluster_meta["umap_id"]
umap = pd.read_parquet(os.path.join(DATA_DIR, dataset_id, "umaps", f"{umap_id}.parquet"))
df["x"] = umap["x"]
df["y"] = umap["y"]

unlabeled_row = 0
if rerun is not None:
label_id = rerun
clusters = pd.read_parquet(os.path.join(cluster_dir, f"{label_id}.parquet"))
# print(clusters.columns)
# find the first row where labeled isnt True
unlabeled_row = clusters[~clusters['labeled']].first_valid_index()
unlabeled_row = cluster_rows[~cluster_rows['labeled']].first_valid_index()
tqdm.write(f"First unlabeled row: {unlabeled_row}")


Expand Down Expand Up @@ -110,23 +120,44 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
extracts = []
for i, row in tqdm(clusters.iterrows(), total=clusters.shape[0], desc="Preparing extracts"):
indices = row['indices']
items = df.loc[list(indices), text_column]
# items = df.loc[list(indices), text_column]
items = df.loc[list(indices)]
if samples > 0 and samples < len(items):
items = items.sample(samples)
# first sample the items from cluster_rows where 'raw_cluster' matches the current cluster_id
cluster_items = items[items['raw_cluster'] == i]

if(len(cluster_items) < samples):
cluster_items = pd.concat([cluster_items, items[items['cluster'] == i]])

# Sort cluster items by distance from centroid
# Get x,y coordinates for items
coords = cluster_items[['x', 'y']].values

# Calculate centroid
centroid = coords.mean(axis=0)

# Calculate distances from centroid
distances = np.sqrt(np.sum((coords - centroid) ** 2, axis=1))

# Add distances as column and sort
cluster_items = cluster_items.assign(centroid_dist=distances)
cluster_items = cluster_items.sort_values('centroid_dist')

items = cluster_items[0:samples]
# items = cluster_items.sample(samples)

items = items.drop_duplicates()
tokens = 0
items = items[text_column]

keep_items = []
if max_tokens > 0 and enc is not None:
while tokens < max_tokens:
for item in items:
if item is None:
continue
encoded_item = enc.encode(item)
if tokens + len(encoded_item) > max_tokens:
break
keep_items.append(item)
tokens += len(encoded_item)
break
if enc is not None:
for item in items:
if item is None:
continue
encoded_item = enc.encode(item)
if max_tokens > 0 and len(encoded_item) > max_tokens:
item = enc.decode(encoded_item[:max_tokens])
keep_items.append(item)
else:
keep_items = items
keep_items = [item for item in keep_items if item is not None]
Expand Down
9 changes: 5 additions & 4 deletions web/src/components/Setup/ClusterLabels.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,25 @@ function ClusterLabels() {
🤔
</span>
<Tooltip id="samples" place="top" effect="solid" className="tooltip-area">
The number of samples to use from each cluster for summarization. Set to 0 to use
all samples.
The number of items to use from each cluster for summarization. Set to 0 to use all
items. Items are chosen based on distance from the centroid of the cluster.
</Tooltip>
</label>
<label>
<span className={styles['cluster-labels-form-label']}>Max Tokens:</span>
<input
type="number"
name="max_tokens"
defaultValue={chatModel?.params?.max_tokens || 8192}
defaultValue={scope?.embedding?.max_seq_length || 512}
min={-1}
disabled={!!clusterLabelsJob || !cluster}
/>
<span className="tooltip" data-tooltip-id="max_tokens">
🤔
</span>
<Tooltip id="max_tokens" place="top" effect="solid" className="tooltip-area">
The maximum number of tokens to use for the model. Set to -1 to ignore limits.
The maximum number of tokens per sample to use, truncates long samples to max
tokens. Set to -1 to ignore limits.
</Tooltip>
</label>
<textarea
Expand Down

0 comments on commit fc03f6a

Please sign in to comment.