Skip to content

Commit

Permalink
Merge branch 'vijay_gpt3_pairwise_constraints' into vijay_refactor_se…
Browse files Browse the repository at this point in the history
…mi_clustering
  • Loading branch information
viswavi committed Sep 26, 2023
2 parents 903e786 + cb032eb commit fb18d58
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 5 deletions.
7 changes: 4 additions & 3 deletions active_semi_clustering/semi_supervised/labeled_data/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@


class KMeans:
def __init__(self, n_clusters=3, max_iter=100, num_reinit=1, normalize_vectors=False, split_normalization=False, init="random", verbose=False):
def __init__(self, n_clusters=3, max_iter=100, num_reinit=1, normalize_vectors=False, split_normalization=False, init="random", split_point=300, verbose=False):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.normalize_vectors = normalize_vectors
Expand All @@ -26,6 +26,7 @@ def __init__(self, n_clusters=3, max_iter=100, num_reinit=1, normalize_vectors=F
self.init = init
self.verbose = verbose
self.num_reinit = num_reinit
self.split_point = split_point

def fit(self, X, y=None, **kwargs):
# Initialize cluster centers
Expand All @@ -51,8 +52,8 @@ def fit(self, X, y=None, **kwargs):
timer = time.perf_counter()
if self.normalize_vectors:
if self.split_normalization:
kg_centers = normalize(cluster_centers[:, :300], axis=1, norm="l2")
bert_centers = normalize(cluster_centers[:, 300:], axis=1, norm="l2")
kg_centers = normalize(cluster_centers[:, :self.split_point], axis=1, norm="l2")
bert_centers = normalize(cluster_centers[:, self.split_point:], axis=1, norm="l2")
cluster_centers = np.hstack([kg_centers, bert_centers])
else:
cluster_centers = normalize(cluster_centers, axis=1, norm="l2")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from sklearn.preprocessing import normalize
from tqdm import tqdm

from active_semi_clustering.active.pairwise_constraints import ExampleOracle
from active_semi_clustering.active.pairwise_constraints.example_oracle import MaximumQueriesExceeded

class KMeansCorrection:
def __init__(self, oracle, cluster_predictions, cluster_centers):
def __init__(self, oracle, cluster_predictions, cluster_centers, labels):
self.oracle = oracle
self.cluster_predictions = cluster_predictions
self.cluster_centers = cluster_centers
self.labels = labels

@staticmethod
def normalize_features(features):
Expand Down Expand Up @@ -47,7 +49,84 @@ def fit(self, X, num_corrections):

corrected_labels = copy.deepcopy(self.cluster_predictions)

example_oracle = ExampleOracle(self.labels, max_queries_cnt=100000)
top_cluster_labels = []
for i, ent_idx in tqdm(enumerate(closest_top_two)):
correct_found = False
top_clust = ambiguous_points_top_five_clust_idxs[i][0]
top_cluster_rep_points = representative_points[top_clust]
top_cluster_bad = False
for top_cluster_rep in top_cluster_rep_points:
top_cluster_class = example_oracle.query(ent_idx, top_cluster_rep)
if top_cluster_class is False:
top_cluster_bad = True
break
if not top_cluster_bad:
top_cluster_labels.append(0)
continue
for i, next_best_clust in enumerate(ambiguous_points_top_five_clust_idxs[i][1:]):
next_best_cluster_rep_points = representative_points[next_best_clust]
make_correction = False
for next_best_cluster_rep in next_best_cluster_rep_points:
next_cluster_class = example_oracle.query(ent_idx, next_best_cluster_rep)
if top_cluster_bad and next_cluster_class is True:
make_correction = True
break
if make_correction:
top_cluster_labels.append(i+1)
correct_found = True
break
if not correct_found:
top_cluster_labels.append(-1)


top_cluster_predictions = []
for i, ent_idx in tqdm(enumerate(closest_top_two)):
correct_found = False
top_clust = ambiguous_points_top_five_clust_idxs[i][0]
top_cluster_rep_points = representative_points[top_clust]
top_cluster_bad = False
for top_cluster_rep in top_cluster_rep_points:
top_cluster_class = self.oracle.query(ent_idx, top_cluster_rep)
if top_cluster_class is False:
top_cluster_bad = True
break
if not top_cluster_bad:
top_cluster_predictions.append(0)
continue
for i, next_best_clust in enumerate(ambiguous_points_top_five_clust_idxs[i][1:]):
next_best_cluster_rep_points = representative_points[next_best_clust]
make_correction = False
for next_best_cluster_rep in next_best_cluster_rep_points:
next_cluster_class = self.oracle.query(ent_idx, next_best_cluster_rep)
if top_cluster_bad and next_cluster_class is True:
make_correction = True
break
if make_correction:
top_cluster_predictions.append(i+1)
correct_found = True
break
if not correct_found:
top_cluster_predictions.append(-1)

num_gpt_corrections_false_positive = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] > 0 and top_cluster_labels[i] == 0 ]); num_gpt_corrections_false_negative = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] == 0 and top_cluster_labels[i] > 0 ]); num_gpt_corrections_true_positive = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] > 0 and top_cluster_labels[i] == top_cluster_predictions[i] ]); num_gpt_corrections_true_negative = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] == 0 and top_cluster_labels[i] == top_cluster_predictions[i] ])


num_gpt_corrections_false_positive = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] > 0 and top_cluster_labels[i] == 0 ])
num_gpt_corrections_false_negative = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] == 0 and top_cluster_labels[i] > 0 ])
num_gpt_corrections_true_positive = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] > 0 and top_cluster_labels[i] == top_cluster_predictions[i] ])
num_gpt_corrections_true_negative = len([x for i, x in enumerate(top_cluster_predictions) if top_cluster_predictions[i] == 0 and top_cluster_labels[i] == top_cluster_predictions[i] ])

print(f"num_gpt_corrections_false_positive: {num_gpt_corrections_false_positive}")
print(f"num_gpt_corrections_false_negative: {num_gpt_corrections_false_negative}")
print(f"num_gpt_corrections_true_positive: {num_gpt_corrections_true_positive}")
print(f"num_gpt_corrections_true_negative: {num_gpt_corrections_true_negative}")


kmeans_corrections = []

num_corrections_made = 0
num_correct_corrections_made = 0
num_queries = 0
try:
for i, ent_idx in tqdm(enumerate(closest_top_two)):
Expand All @@ -69,17 +148,35 @@ def fit(self, X, num_corrections):
if top_cluster_bad and next_cluster_class is True:
make_correction = True
break

gt_make_correction = False
for next_best_cluster_rep in next_best_cluster_rep_points:
next_cluster_class = example_oracle.query(ent_idx, next_best_cluster_rep)
num_queries += 1
if top_cluster_bad and next_cluster_class is True:
gt_make_correction = True
break

if make_correction == True:
num_corrections_made += 1
if gt_make_correction:
num_correct_corrections_made += 1
kmeans_corrections.append({"ent_idx": int(ent_idx),
"previous_label": int(corrected_labels[ent_idx]),
"corrected_label": int(next_best_clust)})
corrected_labels[ent_idx] = next_best_clust

break
except MaximumQueriesExceeded:
pass

self.labels_ = corrected_labels

breakpoint()

print(f"Num oracle queries: {num_queries}")
print(f"Num Corrections: {num_corrections_made}")
print(f"Num Correct Corrections: {num_correct_corrections_made}")

'''
from active_semi_clustering.active.pairwise_constraints import ExampleOracle
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def fit(self, X, y=None, ml=[], cl=[]):
# Initialize centroids
start = time.perf_counter()
cluster_centers = self._init_cluster_centers(X)
#cluster_centers = self._initialize_cluster_centers(X, neighborhoods)
# cluster_centers = self._initialize_cluster_centers(X, neighborhoods)
elapsed = time.perf_counter() - start
print(f"Initializing neighborhoods took {round(elapsed, 4)} seconds")

Expand Down

0 comments on commit fb18d58

Please sign in to comment.