-
Notifications
You must be signed in to change notification settings - Fork 6.4k
/
Copy pathkmeans_mnist.py
158 lines (130 loc) · 4.4 KB
/
kmeans_mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# https://deeplearningcourses.com/c/cluster-analysis-unsupervised-machine-learning-python
# https://www.udemy.com/cluster-analysis-unsupervised-machine-learning-python
# data is from https://www.kaggle.com/c/digit-recognizer
# each image is a D = 28x28 = 784 dimensional vector
# there are N = 42000 samples
# you can plot an image by reshaping to (28,28) and using plt.imshow()
from __future__ import print_function, division
from future.utils import iteritems
from builtins import range, input
# Note: you may need to update your version of future
# sudo pip install -U future
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from .kmeans import plot_k_means, get_simple_data
from datetime import datetime
def get_data(limit=None):
print("Reading in and transforming data...")
df = pd.read_csv('../large_files/train.csv')
data = df.values
np.random.shuffle(data)
X = data[:, 1:] / 255.0 # data is from 0..255
Y = data[:, 0]
if limit is not None:
X, Y = X[:limit], Y[:limit]
return X, Y
# hard labels
def purity2(Y, R):
# maximum purity is 1, higher is better
C = np.argmax(R, axis=1) # cluster assignments
N = len(Y) # number of data pts
K = len(set(Y)) # number of labels
total = 0.0
for k in range(K):
max_intersection = 0
for j in range(K):
intersection = ((C == k) & (Y == j)).sum()
if intersection > max_intersection:
max_intersection = intersection
total += max_intersection
return total / N
def purity(Y, R):
# maximum purity is 1, higher is better
N, K = R.shape
p = 0
for k in range(K):
best_target = -1 # we don't strictly need to store this
max_intersection = 0
for j in range(K):
intersection = R[Y==j, k].sum()
if intersection > max_intersection:
max_intersection = intersection
best_target = j
p += max_intersection
return p / N
# hard labels
def DBI2(X, R):
N, D = X.shape
_, K = R.shape
# get sigmas, means first
sigma = np.zeros(K)
M = np.zeros((K, D))
assignments = np.argmax(R, axis=1)
for k in range(K):
Xk = X[assignments == k]
M[k] = Xk.mean(axis=0)
# assert(Xk.mean(axis=0).shape == (D,))
n = len(Xk)
diffs = Xk - M[k]
sq_diffs = diffs * diffs
sigma[k] = np.sqrt( sq_diffs.sum() / n )
# calculate Davies-Bouldin Index
dbi = 0
for k in range(K):
max_ratio = 0
for j in range(K):
if k != j:
numerator = sigma[k] + sigma[j]
denominator = np.linalg.norm(M[k] - M[j])
ratio = numerator / denominator
if ratio > max_ratio:
max_ratio = ratio
dbi += max_ratio
return dbi / K
def DBI(X, M, R):
# ratio between sum of std deviations between 2 clusters / distance between cluster means
# lower is better
N, D = X.shape
K, _ = M.shape
# get sigmas first
sigma = np.zeros(K)
for k in range(K):
diffs = X - M[k] # should be NxD
squared_distances = (diffs * diffs).sum(axis=1) # now just N
weighted_squared_distances = R[:,k]*squared_distances
sigma[k] = np.sqrt( weighted_squared_distances.sum() / R[:,k].sum() )
# calculate Davies-Bouldin Index
dbi = 0
for k in range(K):
max_ratio = 0
for j in range(K):
if k != j:
numerator = sigma[k] + sigma[j]
denominator = np.linalg.norm(M[k] - M[j])
ratio = numerator / denominator
if ratio > max_ratio:
max_ratio = ratio
dbi += max_ratio
return dbi / K
def main():
# mnist data
X, Y = get_data(10000)
# simple data
# X = get_simple_data()
# Y = np.array([0]*300 + [1]*300 + [2]*300)
print("Number of data points:", len(Y))
M, R = plot_k_means(X, len(set(Y)))
# Exercise: Try different values of K and compare the evaluation metrics
print("Purity:", purity(Y, R))
print("Purity 2 (hard clusters):", purity2(Y, R))
print("DBI:", DBI(X, M, R))
print("DBI 2 (hard clusters):", DBI2(X, R))
# plot the mean images
# they should look like digits
for k in range(len(M)):
im = M[k].reshape(28, 28)
plt.imshow(im, cmap='gray')
plt.show()
if __name__ == "__main__":
main()