Skip to content

Commit

Permalink
Add DEC model for unsupervised clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
viswavi committed Mar 16, 2023
1 parent fa109ee commit a6ce10e
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 13 deletions.
25 changes: 20 additions & 5 deletions active_semi_clustering/semi_supervised/labeled_data/dec.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ def __init__(self,
device="cuda:2",
include_contrastive_loss=False,
labels=None,
batch_size = 32,
batch_size = 400,
linear_transformation = False,
canonicalization_side_information=None):
self.n_clusters = n_clusters
self.max_iter = max_iter
self.normalize_vectors = normalize_vectors
self.split_normalization = split_normalization
self.verbose = verbose
self.verbose = verbose
self.batch_size = batch_size
self.num_dataloader_workers = 4
self.init = cluster_init
Expand All @@ -61,7 +61,22 @@ def __init__(self,

def fit(self, X):

cluster_centers = self._init_cluster_centers(X)
if self.init == "k-means":
clusterer = KMeans(n_clusters=self.n_clusters, normalize_vectors=self.normalize_vectors, split_normalization=self.split_normalization, init="k-means++", num_reinit=1, verbose=False)
clusterer.fit(X)
labels = clusterer.labels_
clusters_list = {}
for l, feat in zip(labels, X):
if l not in clusters_list:
clusters_list[l] = []
clusters_list[l].append(feat)

cluster_centers = np.empty((self.n_clusters, X.shape[1]))
for i, l in enumerate(clusters_list.keys()):
avg_vec = np.mean(clusters_list[l], axis=0)
cluster_centers[i] = avg_vec
else:
cluster_centers = self._init_cluster_centers(X)

torch.cuda.set_device(self.device)

Expand All @@ -70,14 +85,14 @@ def fit(self, X):
model = SCCLMatrix(emb_size=X.shape[1], cluster_centers=cluster_centers, include_contrastive_loss=self.include_contrastive_loss, linear_transformation = self.linear_transformation)
model = model.cuda()

resDir = "/projects/ogma1/vijayv/okb-canonicalization/clustering/sccl/opiec_59k_batch_size_32_knn_init_no_linear_adam_larger_lr_fixed/"
resDir = "/projects/ogma1/vijayv/okb-canonicalization/clustering/sccl/opiec_59k_batch_size_400_kpp_init_no_linear_momentum/"
resPath = "SCCL.tensorboard"
resPath = resDir + resPath
tensorboard = SummaryWriter(resPath)

# optimize
Args = namedtuple("Args", "lr lr_scale eta temperature objective print_freq max_iter batch_size tensorboard")
args = Args(1e-05, 1000, 10, 0.5, "SCCL", 300, 200 * X.shape[0] / self.batch_size, self.batch_size, tensorboard)
args = Args(1e-01, 100, 10, 0.5, "SCCL", 300, 2500 * X.shape[0] / self.batch_size, self.batch_size, tensorboard)


optimizer = get_optimizer_linear_transformation(model, args, include_contrastive_loss=self.include_contrastive_loss, linear_transformation=self.linear_transformation)
Expand Down
8 changes: 3 additions & 5 deletions active_semi_clustering/semi_supervised/labeled_data/kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def fit(self, X, y=None, **kwargs):
cluster_centers_shift = np.zeros(cluster_centers.shape)

for iteration in range(self.max_iter):
print(f"iteration {iteration}")
# print(f"iteration {iteration}")
timer_dict = {}
timer = time.perf_counter()
if self.normalize_vectors:
Expand Down Expand Up @@ -76,11 +76,9 @@ def fit(self, X, y=None, **kwargs):
converged = np.allclose(cluster_centers_shift, np.zeros(cluster_centers.shape), atol=1e-6, rtol=0)
timer_dict["Check convergence"] = round(time.perf_counter() - timer, 3)
timer = time.perf_counter()
print(f"K-Means iteration {iteration} took {round(time.perf_counter() - original_start, 3)} seconds.")
print(f"cluster_centers_shift: {cluster_centers_shift}")
print(f"cluster_centers: {cluster_centers}")
# print(f"K-Means iteration {iteration} took {round(time.perf_counter() - original_start, 3)} seconds.")

print(f"Timer dict: {timer_dict}")
# print(f"Timer dict: {timer_dict}")

if converged: break

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import numpy as np

from active_semi_clustering.exceptions import EmptyClustersException
Expand All @@ -15,14 +16,18 @@ def fit(self, X, y=None, ml=[], cl=[]):
# Preprocess constraints
ml_graph, cl_graph, neighborhoods = preprocess_constraints(ml, cl, X.shape[0])

print(f"Num neighborhoods: {neighborhoods}")
print(f"ML constraints:\n{ml}\n")
print(f"CL constraints:\n{cl}\n")

print(f"Num neighborhoods: {neighborhoods}\n\n\n")

# Initialize centroids
# cluster_centers = self._init_cluster_centers(X)
cluster_centers = self._initialize_cluster_centers(X, neighborhoods)

# Repeat until convergence
for iteration in range(self.max_iter):
print(f"iteration: {iteration}")
print(f"\n\n\n\niteration: {iteration}")
# Assign clusters
labels = self._assign_clusters(X, cluster_centers, ml_graph, cl_graph, self.w)

Expand Down Expand Up @@ -66,7 +71,7 @@ def _initialize_cluster_centers(self, X, neighborhoods):
cluster_centers = np.concatenate([cluster_centers, remaining_cluster_centers])
return cluster_centers

def _objective_function(self, X, x_i, centroids, c_i, labels, ml_graph, cl_graph, w):
def _objective_function(self, X, x_i, centroids, c_i, labels, ml_graph, cl_graph, w, print_terms=False):
distance = 1 / 2 * np.sum((X[x_i] - centroids[c_i]) ** 2)

ml_penalty = 0
Expand All @@ -78,6 +83,9 @@ def _objective_function(self, X, x_i, centroids, c_i, labels, ml_graph, cl_graph
for y_i in cl_graph[x_i]:
if labels[y_i] == c_i:
cl_penalty += w
if print_terms:
metric_dict = {"x_i": x_i, "distance": round(distance, 4), "ml_penalty": round(ml_penalty, 4), "cl_penalty": round(ml_penalty, 4)}
print(json.dumps(metric_dict))

return distance + ml_penalty + cl_penalty

Expand All @@ -92,6 +100,9 @@ def _assign_clusters(self, X, cluster_centers, ml_graph, cl_graph, w):
min_cluster_distances.append(min(cluster_distances))
labels[x_i] = np.argmin(cluster_distances)

_ = self._objective_function(X, x_i, cluster_centers, labels[x_i], labels, ml_graph, cl_graph, w, print_terms=True)


# Handle empty clusters
# See https://github.com/scikit-learn/scikit-learn/blob/0.19.1/sklearn/cluster/_k_means.pyx#L309
n_samples_in_cluster = np.bincount(labels, minlength=self.n_clusters)
Expand Down

0 comments on commit a6ce10e

Please sign in to comment.