forked from lazyprogrammer/machine_learning_examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbayes_classifier_gaussian.py
70 lines (56 loc) · 1.77 KB
/
bayes_classifier_gaussian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# https://deeplearningcourses.com/c/deep-learning-gans-and-variational-autoencoders
# https://www.udemy.com/deep-learning-gans-and-variational-autoencoders
from __future__ import print_function, division
from builtins import range, input
# Note: you may need to update your version of future
# sudo pip install -U future
import util
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal as mvn
def clamp_sample(x):
x = np.minimum(x, 1)
x = np.maximum(x, 0)
return x
class BayesClassifier:
def fit(self, X, Y):
# assume classes are numbered 0...K-1
self.K = len(set(Y))
self.gaussians = []
self.p_y = np.zeros(self.K)
for k in range(self.K):
Xk = X[Y == k]
self.p_y[k] = len(Xk)
mean = Xk.mean(axis=0)
cov = np.cov(Xk.T)
g = {'m': mean, 'c': cov}
self.gaussians.append(g)
# normalize p(y)
self.p_y /= self.p_y.sum()
def sample_given_y(self, y):
g = self.gaussians[y]
return clamp_sample( mvn.rvs(mean=g['m'], cov=g['c']) )
def sample(self):
y = np.random.choice(self.K, p=self.p_y)
return clamp_sample( self.sample_given_y(y) )
if __name__ == '__main__':
X, Y = util.get_mnist()
clf = BayesClassifier()
clf.fit(X, Y)
for k in range(clf.K):
# show one sample for each class
# also show the mean image learned
sample = clf.sample_given_y(k).reshape(28, 28)
mean = clf.gaussians[k]['m'].reshape(28, 28)
plt.subplot(1,2,1)
plt.imshow(sample, cmap='gray')
plt.title("Sample")
plt.subplot(1,2,2)
plt.imshow(mean, cmap='gray')
plt.title("Mean")
plt.show()
# generate a random sample
sample = clf.sample().reshape(28, 28)
plt.imshow(sample, cmap='gray')
plt.title("Random Sample from Random Class")
plt.show()