Skip to content

Commit

Permalink
cluster labels can sample from clusters. handle saved scopes properly
Browse files Browse the repository at this point in the history
  • Loading branch information
enjalot committed Nov 6, 2024
1 parent dfd1b44 commit 9cfc247
Show file tree
Hide file tree
Showing 17 changed files with 222 additions and 98 deletions.
11 changes: 9 additions & 2 deletions latentscope/scripts/label_clusters.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,17 @@ def main():
parser.add_argument('text_column', type=str, help='Output file', default='text')
parser.add_argument('cluster_id', type=str, help='ID of cluster set', default='cluster-001')
parser.add_argument('model_id', type=str, help='ID of model to use', default="openai-gpt-3.5-turbo")
parser.add_argument('samples', type=int, help='Number to sample from each cluster (default: 0 for all)', default=0)
parser.add_argument('context', type=str, help='Additional context for labeling model', default="")
parser.add_argument('--rerun', type=str, help='Rerun the given embedding from last completed batch')

# Parse arguments
args = parser.parse_args()

labeler(args.dataset_id, args.text_column, args.cluster_id, args.model_id, args.context, args.rerun)
labeler(args.dataset_id, args.text_column, args.cluster_id, args.model_id, args.samples, args.context, args.rerun)


def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="openai-gpt-3.5-turbo", context="", rerun=""):
def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="openai-gpt-3.5-turbo", samples=0, context="", rerun=""):
import numpy as np
import pandas as pd
DATA_DIR = get_data_dir()
Expand Down Expand Up @@ -87,6 +88,9 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
model.load_model()
enc = model.encoder

# unescape the context
context = context.replace('\\"', '"')

system_prompt = {"role":"system", "content": f"""You're job is to summarize lists of items with a short label of no more than 4 words. The items are part of a cluster and the label will be used to distinguish this cluster from others, so pay attention to what makes this group of similar items distinct.
{context}
The user will submit a list of items in the format:
Expand All @@ -110,6 +114,8 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
for i, row in tqdm(clusters.iterrows(), total=clusters.shape[0], desc="Preparing extracts"):
indices = row['indices']
items = df.loc[list(indices), text_column]
if samples > 0 and samples < len(items):
items = items.sample(samples)
items = items.drop_duplicates()
# text = '\n'.join([f"{i+1}. {t}" for i, t in enumerate(items) if not too_many_duplicates(t)])
text = '\n'.join([f"<ListItem>{t}</ListItem>" for i, t in enumerate(items) if not too_many_duplicates(t)])
Expand Down Expand Up @@ -180,6 +186,7 @@ def labeler(dataset_id, text_column="text", cluster_id="cluster-001", model_id="
"cluster_id": cluster_id,
"model_id": model_id,
"text_column": text_column,
"samples": samples,
"context": context,
"system_prompt": system_prompt,
"max_tokens": max_tokens,
Expand Down
7 changes: 5 additions & 2 deletions latentscope/server/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,14 @@ def run_cluster_label():
text_column = request.args.get('text_column')
cluster_id = request.args.get('cluster_id')
context = request.args.get('context')
print("run cluster label", dataset, chat_id, text_column, cluster_id)
samples = request.args.get('samples')
print("run cluster label", dataset, chat_id, text_column, cluster_id, samples)
if context:
context = context.replace('"', '\\"')
print("context", context)

job_id = str(uuid.uuid4())
command = f'ls-label "{dataset}" "{text_column}" "{cluster_id}" "{chat_id}" "{context}"'
command = f'ls-label "{dataset}" "{text_column}" "{cluster_id}" "{chat_id}" {samples} "{context}"'
threading.Thread(target=run_job, args=(dataset, job_id, command)).start()
return jsonify({"job_id": job_id})

Expand Down
24 changes: 17 additions & 7 deletions web/src/components/Setup/Cluster.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import styles from './Cluster.module.scss';
// This component is responsible for the embeddings state
// New embeddings update the list
function Cluster() {
const { dataset, scope, updateScope, goToNextStep } = useSetup();
const { dataset, scope, savedScope, updateScope, goToNextStep, setPreviewLabel } = useSetup();

const [clusterJob, setClusterJob] = useState(null);
const { startJob: startClusterJob } = useStartJobPolling(dataset, setClusterJob, `${apiUrl}/jobs/cluster`);
Expand All @@ -29,6 +29,10 @@ function Cluster() {
const [umap, setUmap] = useState(null);
const [cluster, setCluster] = useState(null);

useEffect(() => {
setPreviewLabel(cluster?.id)
}, [cluster, setPreviewLabel])

// Update local state when scope changes
useEffect(() => {
if(scope?.embedding_id) {
Expand Down Expand Up @@ -87,6 +91,15 @@ function Cluster() {
})
}, [startClusterJob, umap])

const handleNextStep = useCallback(() => {
if(savedScope?.cluster_id == cluster?.id) {
updateScope({...savedScope})
} else {
updateScope({cluster_id: cluster?.id, cluster_labels_id: null, id: null})
}
goToNextStep()
}, [updateScope, goToNextStep, cluster, savedScope])

return (
<div className={styles["cluster"]}>
<div className={styles["cluster-setup"]}>
Expand Down Expand Up @@ -126,15 +139,15 @@ function Cluster() {

<div className={styles["cluster-list"]}>
{umap && clusters.filter(d => d.umap_id == umap.id).map((cl, index) => (
<div className={styles["item"]} key={index}>
<div className={styles["item"] + (cl.id === cluster?.id ? " " + styles["selected"] : "")} key={index}>
<label htmlFor={`cluster${index}`}>
<input type="radio"
id={`cluster${index}`}
name="cluster"
value={cl}
checked={cl.id === cluster?.id}
onChange={() => setCluster(cl)} />
<span>{cl.id}</span>
<span>{cl.id} {savedScope?.cluster_id == cl.id ? <span className="tooltip" data-tooltip-id="saved">💾</span> : null}</span>
<div className={styles["item-info"]}>
<span>Samples: {cl.samples}</span>
<span>Min Samples: {cl.min_samples}</span>
Expand Down Expand Up @@ -162,10 +175,7 @@ function Cluster() {
<div className={styles["navigate"]}>
<Button
disabled={!cluster}
onClick={() => {
updateScope({cluster_id: cluster?.id})
goToNextStep()
}}
onClick={handleNextStep}
text={cluster ? `Proceed with ${cluster?.id}` : "Select a Cluster"}
/>
</div>
Expand Down
3 changes: 3 additions & 0 deletions web/src/components/Setup/Cluster.module.scss
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
.item {
@include setup-item;
}
.selected {
@include selected;
}
.item-info {
@include setup-item-info;
}
Expand Down
85 changes: 60 additions & 25 deletions web/src/components/Setup/ClusterLabels.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@ import { useState, useEffect, useCallback} from 'react';
import { useStartJobPolling } from '../Job/Run';
import { apiService, apiUrl } from '../../lib/apiService';
import { useSetup } from '../../contexts/SetupContext';
import { Button } from 'react-element-forge';
import { Button, Select } from 'react-element-forge';
import { Tooltip } from 'react-tooltip';

import JobProgress from '../Job/Progress';
import DataTable from '../DataTable';

import styles from './ClusterLabels.module.scss';

function labelName(labelId) {
if(!labelId) return ""
return labelId == "default" ? "label-default" : labelId.split("-").slice(2).join("-")
}

// This component is responsible for the embeddings state
// New embeddings update the list
function ClusterLabels() {
const { datasetId, dataset, scope, updateScope, goToNextStep } = useSetup();
const { datasetId, dataset, scope, savedScope, updateScope, goToNextStep, setPreviewLabel } = useSetup();
const [clusterLabelsJob, setClusterLabelsJob] = useState(null);
const { startJob: startClusterLabelsJob } = useStartJobPolling(dataset, setClusterLabelsJob, `${apiUrl}/jobs/cluster_label`);
const { startJob: rerunClusterLabelsJob } = useStartJobPolling(dataset, setClusterLabelsJob, `${apiUrl}/jobs/rerun`);
Expand All @@ -29,6 +31,16 @@ function ClusterLabels() {
const [embeddings, setEmbeddings] = useState([]);
const [clusters, setClusters] = useState([]);

const [chatModel, setChatModel] = useState(null);
// the models used to label a particular cluster (the ones the user has run)
const [clusterLabelSets, setClusterLabelSets] = useState([]);
// the actual labels for the given cluster
const [clusterLabelData, setClusterLabelData] = useState([]);

useEffect(() => {
setPreviewLabel(labelName(selected))
}, [selected, setPreviewLabel])

// Update local state when scope changes
useEffect(() => {
if(scope?.embedding_id) {
Expand All @@ -39,6 +51,10 @@ function ClusterLabels() {
const cl = clusters?.find(c => c.id == scope.cluster_id)
setCluster(cl)
}
if(scope?.cluster_labels_id) {
// const cl = clusterLabelSets?.find(c => c.id == scope.cluster_labels_id)
setSelected(scope.cluster_labels_id)
}
}, [scope, clusters, embeddings])

// Fetch initial data
Expand All @@ -54,16 +70,13 @@ function ClusterLabels() {
apiService.fetchChatModels()
.then(data => {
setChatModels(data)
setChatModel(data[0]?.id)
}).catch(err => {
console.log(err)
setChatModels([])
})
}, []);

// the models used to label a particular cluster (the ones the user has run)
const [clusterLabelSets, setClusterLabelSets] = useState([]);
// the actual labels for the given cluster
const [clusterLabelData, setClusterLabelData] = useState([]);
useEffect(() => {
if(datasetId && cluster && selected) {
const id = selected.split("-")[3] || selected
Expand All @@ -89,14 +102,16 @@ function ClusterLabels() {
let lbl;
const defaultLabel = { id: "default", model_id: "N/A", cluster_id: cluster.id }
if(selected){
console.log("selected", selected)
lbl = labelsAvailable.find(d => d.id == selected) || defaultLabel
console.log("found?", lbl, labelsAvailable)
} else if(labelsAvailable[0]) {
lbl = labelsAvailable[0]
} else {
lbl = defaultLabel
}
setClusterLabelSets([...labelsAvailable, defaultLabel])
setSelected(lbl?.id)
// setSelected(lbl?.id)
}).catch(err => {
console.log(err)
setClusterLabelSets([])
Expand All @@ -118,12 +133,13 @@ function ClusterLabels() {
e.preventDefault()
const form = e.target
const data = new FormData(form)
const model = data.get('chatModel')
const model = chatModel
const text_column = embedding.text_column
const cluster_id = cluster.id
const context = data.get('context')
startClusterLabelsJob({chat_id: model, cluster_id: cluster_id, text_column, context})
}, [cluster, embedding, startClusterLabelsJob])
const samples = data.get('samples')
startClusterLabelsJob({chat_id: model, cluster_id: cluster_id, text_column, context, samples})
}, [cluster, embedding, chatModel, startClusterLabelsJob])

function handleRerun(job) {
rerunClusterLabelsJob({job_id: job?.id});
Expand All @@ -138,20 +154,41 @@ function ClusterLabels() {
.catch(console.error);
}, [datasetId])

const handleNextStep = useCallback(() => {
if(savedScope?.cluster_labels_id == selected) {
updateScope({...savedScope})
} else {
updateScope({cluster_labels_id: selected, id: null})
}
goToNextStep()
}, [updateScope, goToNextStep, selected, savedScope])

return (
<div className={styles["cluster-labels"]}>
<div className={styles["cluster-labels-setup"]}>
<div className={styles["cluster-form"]}>
<p>Automatically create labels for each cluster
{cluster ? ` in ${cluster.id}` : ''} using a chat model. Default labels are created from the top 3 words in each cluster using nltk.</p>
{cluster ? ` in ${cluster.id}` : ''} using a chat model. For quickest CPU based results use nltk top-words.</p>
<form onSubmit={handleNewLabels}>
<label>
<span className={styles["cluster-labels-form-label"]}>Chat Model:</span>
<select id="chatModel" name="chatModel" disabled={!!clusterLabelsJob}>
{chatModels.filter(d => clusterLabelSets?.indexOf(d.id) < 0).map((model, index) => (
<option key={index} value={model.id}>{model.provider} - {model.name}</option>
))}
</select>
<Select id="chatModel"
disabled={!!clusterLabelsJob}
options={chatModels.filter(d => clusterLabelSets?.indexOf(d.id) < 0).map(model => ({
label: `${model.provider} - ${model.name}`,
value: model.id
}))}
value={chatModel}
onChange={(e) => setChatModel(e.target.value)}
/>
</label>
<label>
<span className={styles["cluster-labels-form-label"]}>Samples:</span>
<input type="number" name="samples" value={10} min={0} disabled={!!clusterLabelsJob || !cluster} />
<span className="tooltip" data-tooltip-id="samples">🤔</span>
<Tooltip id="samples" place="top" effect="solid">
The number of samples to use from each cluster for summarization. Set to 0 to use all samples.
</Tooltip>
</label>
<textarea
name="context"
Expand All @@ -170,18 +207,19 @@ function ClusterLabels() {
</div>
<div className={styles["cluster-labels-list"]}>
{cluster && clusterLabelSets.filter(d => d.cluster_id == cluster.id).map((cl, index) => (
<div className={styles["item"]} key={index}>
<div className={styles["item"] + (cl.id == selected ? " " + styles["selected"] : "")} key={index}>
<label htmlFor={`cluster${index}`}>
<input type="radio"
id={`cluster${index}`}
name="cluster"
value={cl.id}
checked={cl.id === selected}
onChange={() => setSelected(cl.id)} />
<span>{labelName(cl.id)}</span>
<span>{labelName(cl.id)} {cl.id == savedScope?.cluster_labels_id && <span className="tooltip" data-tooltip-id="saved">💾</span> }</span>
<div className={styles["item-info"]}>
<span>Model: {cl.model_id}</span>
<span>Context: <code style={{width: "100%"}}>{cl.context}</code></span>
{cl.context && <span>Context: <code style={{width: "100%"}}>{cl.context}</code></span>}
{cl.samples && <span>Samples: {cl.samples}</span>}
</div>
</label>
{/* <Button className={styles["delete"]} color="secondary" onClick={() => handleKill(cl)} text="🗑️" /> */}
Expand All @@ -193,9 +231,9 @@ function ClusterLabels() {
{cluster && (
<div className={styles["cluster-labels-preview"]}>
<div className={styles["preview"]}>
<div className={styles["preview-header"]}>
{/* <div className={styles["preview-header"]}>
<h3>Preview: {labelName(selected)}</h3>
</div>
</div> */}
<div className={styles["cluster-labels-table"]}>
<DataTable
data={clusterLabelData.map((d,i) => ({
Expand All @@ -209,10 +247,7 @@ function ClusterLabels() {
<div className={styles["navigate"]}>
<Button
disabled={!selected}
onClick={() => {
updateScope({cluster_labels_id: selected})
goToNextStep()
}}
onClick={handleNextStep}
text={selected ? `Proceed with ${labelName(selected)}` : "Select a Label"}
/>
</div>
Expand Down
5 changes: 4 additions & 1 deletion web/src/components/Setup/ClusterLabels.module.scss
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
.item {
@include setup-item;
}
.selected {
@include selected;
}
.item-info {
@include setup-item-info;
}
Expand All @@ -45,7 +48,7 @@
@include setup-delete;
}
.cluster-labels-form-label {
width: 125px;
width: 100px;
display: inline-block;
}
.cluster-labels-form input {
Expand Down
Loading

0 comments on commit 9cfc247

Please sign in to comment.