-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
Copy pathfigure4_5_no_sklearn.py
79 lines (67 loc) · 2.2 KB
/
figure4_5_no_sklearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
# This code is supporting material for the book
# Building Machine Learning Systems with Python
# by Willi Richert and Luis Pedro Coelho
# published by PACKT Publishing
#
# It is made available under the MIT License
COLOUR_FIGURE = False
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
from load import load_dataset
import numpy as np
from knn import fit_model, predict
feature_names = [
'area',
'perimeter',
'compactness',
'length of kernel',
'width of kernel',
'asymmetry coefficien',
'length of kernel groove',
]
def plot_decision(features, labels):
'''Plots decision boundary for KNN
Parameters
----------
features : ndarray
labels : sequence
Returns
-------
fig : Matplotlib Figure
ax : Matplotlib Axes
'''
y0, y1 = features[:, 2].min() * .9, features[:, 2].max() * 1.1
x0, x1 = features[:, 0].min() * .9, features[:, 0].max() * 1.1
X = np.linspace(x0, x1, 100)
Y = np.linspace(y0, y1, 100)
X, Y = np.meshgrid(X, Y)
model = fit_model(1, features[:, (0, 2)], np.array(labels))
C = predict(
model, np.vstack([X.ravel(), Y.ravel()]).T).reshape(X.shape)
if COLOUR_FIGURE:
cmap = ListedColormap([(1., .6, .6), (.6, 1., .6), (.6, .6, 1.)])
else:
cmap = ListedColormap([(1., 1., 1.), (.2, .2, .2), (.6, .6, .6)])
fig,ax = plt.subplots()
ax.set_xlim(x0, x1)
ax.set_ylim(y0, y1)
ax.set_xlabel(feature_names[0])
ax.set_ylabel(feature_names[2])
ax.pcolormesh(X, Y, C, cmap=cmap)
if COLOUR_FIGURE:
cmap = ListedColormap([(1., .0, .0), (.0, 1., .0), (.0, .0, 1.)])
ax.scatter(features[:, 0], features[:, 2], c=labels, cmap=cmap)
else:
for lab, ma in zip(range(3), "Do^"):
ax.plot(features[labels == lab, 0], features[
labels == lab, 2], ma, c=(1., 1., 1.))
return fig,ax
features, labels = load_dataset('seeds')
names = sorted(set(labels))
labels = np.array([names.index(ell) for ell in labels])
fig,ax = plot_decision(features, labels)
fig.savefig('figure4.png')
features -= features.mean(0)
features /= features.std(0)
fig,ax = plot_decision(features, labels)
fig.savefig('figure5.png')