forked from rushter/MLAlgorithms
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathgaussian_mixture.py
46 lines (35 loc) · 1.09 KB
/
gaussian_mixture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from mla.kmeans import KMeans
from mla.gaussian_mixture import GaussianMixture
random.seed(1)
np.random.seed(6)
def make_clusters(skew=True, *arg, **kwargs):
X, y = datasets.make_blobs(*arg, **kwargs)
if skew:
nrow = X.shape[1]
for i in np.unique(y):
X[y == i] = X[y == i].dot(np.random.random((nrow, nrow)) - 0.5)
return X, y
def KMeans_and_GMM(K):
COLOR = "bgrcmyk"
X, y = make_clusters(skew=True, n_samples=1500, centers=K)
_, axes = plt.subplots(1, 3)
# Ground Truth
axes[0].scatter(X[:, 0], X[:, 1], c=[COLOR[int(assignment)] for assignment in y])
axes[0].set_title("Ground Truth")
# KMeans
kmeans = KMeans(K=K, init="++")
kmeans.fit(X)
kmeans.predict()
axes[1].set_title("KMeans")
kmeans.plot(ax=axes[1], holdon=True)
# Gaussian Mixture
gmm = GaussianMixture(K=K, init="kmeans")
gmm.fit(X)
axes[2].set_title("Gaussian Mixture")
gmm.plot(ax=axes[2])
if __name__ == "__main__":
KMeans_and_GMM(4)