Skip to content

Commit 4f3300a

Browse files
authored
Create 04_PerceotronAlgorithm.py
1.Perceotron Algorithm
1 parent b18c546 commit 4f3300a

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

04_PerceotronAlgorithm.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
4+
#dataset
5+
X1 = np.array([-0.62231486, -0.96251306, 0.42269922, -1.452746 , -0.66915783,
6+
-0.35716016, 0.49505163, -1.8117848 , 0.53376487, -1.86923838,
7+
0.71434306, -0.4055084 , 0.82887254, 0.81221287, 1.44280951,
8+
-0.45599278, -1.16715888, 1.08913131, -1.61470741, 1.61113001,
9+
-1.4532688 , 1.04872588, -1.52312195, -1.62831727, -0.25191539])
10+
11+
X2 = np.array([-1.67427011, -1.81046748, 1.20384694, -0.41572751, 0.66851908,
12+
-1.75435288, -1.57532207, -1.22329618, -0.84375819, 0.52873296,
13+
-1.10837773, 0.04612922, 0.67696196, 0.84618152, -0.77362548,
14+
0.99153072, 1.7896494 , -0.38343121, -0.21337742, 0.64754817,
15+
0.36719101, 0.23132427, 1.07029963, 1.62919909, -1.53920827])
16+
17+
Y = np.array([ 1., -1., -1., -1., -1.,
18+
1., 1., -1., 1., -1.,
19+
1., -1., 1., 1., 1.,
20+
-1., -1., 1., -1., 1.,
21+
-1., 1., -1., -1., 1.])
22+
23+
xl1 = yl1 = -2.0
24+
xl2 = yl2 = 2.0
25+
26+
#Plotting distribution of a dataset
27+
def plot_data(filename = 'data0'):
28+
fig = plt.figure()
29+
ax = fig.add_subplot(111)
30+
plt.scatter(X1[Y >= 0], X2[Y >= 0], s = 80, c = 'b', marker = "o")
31+
plt.scatter(X1[Y < 0], X2[Y < 0], s = 80, c = 'r', marker = "^")
32+
#ax.set_xlim([-2.0, 2.0])
33+
#ax.set_ylim([-2.0, 2.0])
34+
ax.set_xlim(xl1, xl2)
35+
ax.set_ylim(yl1, yl2)
36+
fig.set_size_inches(6, 6)
37+
plt.show()
38+
39+
#Perceotron algorithm
40+
def plot_data_and_line(w1,w2):
41+
w1,w2 = float(w1),float(w2)
42+
if w2 != 0 :
43+
y1,y2 = (-w1*(xl1))/w2, (-w1*(xl2))/w2
44+
vx1,vy1 = [xl1,xl2,xl2,xl1,xl1], [y1,y2,yl2,yl2,y1]
45+
vx2,vy2 = [xl1,xl2,xl2,xl1,xl1], [y1,y2,yl1,yl1,y1]
46+
elif w1 != 0:
47+
vx1,vy1 = [xl2,0,0,xl2,xl2], [yl1,yl1,yl2,yl2,yl1]
48+
vx2,vy2 = [xl1,0,0,xl1,xl1], [yl1,yl1,yl2,yl2,yl1]
49+
else:
50+
print "ERROR, Invalid w1 and w2."
51+
return;
52+
if w2 > 0 or ( w2 == 0 and w1 > 0):
53+
c1,c2 = 'b','r'
54+
else:
55+
c1,c2 = 'r','b'
56+
fig = plt.figure()
57+
ax = fig.add_subplot(111)
58+
plt.scatter(X1[Y > 0], X2[Y > 0], s = 80, c = 'b', marker = "o")
59+
plt.scatter(X1[Y<= 0], X2[Y<= 0], s = 80, c = 'r', marker = "^")
60+
plt.fill(vx1, vy1, c1, alpha = 0.25)
61+
plt.fill(vx2, vy2, c2, alpha = 0.25)
62+
ax.set_title(("w1 = %s, w2 = %s")%( w1, w2))
63+
ax.set_xlim(xl1, xl2)
64+
ax.set_ylim(yl1, yl2)
65+
fig.set_size_inches(6, 6)
66+
plt.show()
67+
68+
def learn_perceptron(times=1000):
69+
w1,w2 = 1,1
70+
for i in range(times):
71+
ERR = (w1*X1+w2*X2) * Y < 0
72+
if len(filter(bool,ERR)) > 0:
73+
err_x1,err_x2,err_y = X1[ERR][0],X2[ERR][0],Y[ERR][0]
74+
w1,w2 = (w1+err_y*err_x1),(w2+err_y*err_x2)
75+
else:
76+
print "Complete!"
77+
break;
78+
plot_data_and_line(w1,w2)
79+
80+
81+
if __name__ == '__main__':
82+
#plot_data();
83+
#plot_data_and_line(1,1);
84+
learn_perceptron();
85+
86+

0 commit comments

Comments
 (0)