-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathclasificacion_MNIST_redes_neuronales.py
118 lines (95 loc) · 3.72 KB
/
clasificacion_MNIST_redes_neuronales.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Set de datos: MNIST
#
# Contiene 60,000 datos de entrenamiento y 10,000 de validación.
# Cada imagen es de 28x28 pixeles. La clasificación se llevará a cabo
# usando una red neuronal con una capa oculta que contiene 15 neuronas.
#
# codificandobits.com - 2018
from keras.datasets import mnist
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.utils import np_utils
import matplotlib.pyplot as plt
import numpy as np
#
# Lectura, visualización y pre-procesamiento de los datos
#
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Visualizaremos 16 imágenes aleatorias tomadas del set x_train
ids_imgs = np.random.randint(0,x_train.shape[0],16)
for i in range(len(ids_imgs)):
img = x_train[ids_imgs[i],:,:]
plt.subplot(4,4,i+1)
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.title(y_train[ids_imgs[i]])
plt.suptitle('16 imágenes del set MNIST')
plt.show()
# Pre-procesamiento: para introducirlas a la red neuronal debemos
# "aplanar" cada una de las imágenes en un vector de 28x28 = 784 valores
X_train = np.reshape( x_train, (x_train.shape[0],x_train.shape[1]*x_train.shape[2]) )
X_test = np.reshape( x_test, (x_test.shape[0],x_test.shape[1]*x_test.shape[2]) )
# Adicionalmente se normalizarán las intensidades al rango 0-1
X_train = X_train/255.0
X_test = X_test/255.0
# Finalmente, convertimos y_train y y_test a representación "one-hot"
nclasses = 10
Y_train = np_utils.to_categorical(y_train,nclasses)
Y_test = np_utils.to_categorical(y_test,nclasses)
#
# Creación del modelo:
# - Capa de entrada: su dimensión será 784 (el tamaño de cada imagen aplanada)
# - Capa oculta: 15 neuronas con activación ReLU
# - Capa de salida: función de activación 'softmax' (clasificación multiclase) y un
# total de 10 categorías
#
np.random.seed(1) # Para reproducibilidad del entrenamiento
input_dim = X_train.shape[1]
output_dim = Y_train.shape[1]
modelo = Sequential()
modelo.add( Dense(15, input_dim=input_dim, activation='relu'))
modelo.add( Dense(output_dim, activation='softmax'))
print(modelo.summary())
#
# Compilación y entrenamiento: gradiente descendente, learning rate = 0.05, función
# de error: entropía cruzada, métrica de desempeño: precisión
sgd = SGD(lr=0.2)
modelo.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy'])
# Para el entrenamiento se usarán 30 iteraciones y un batch_size de 1024
num_epochs = 50
batch_size = 1024
historia = modelo.fit(X_train, Y_train, epochs=num_epochs, batch_size=batch_size, verbose=2)
#
# Resultados
#
# Error y precisión vs iteraciones
plt.subplot(1,2,1)
plt.plot(historia.history['loss'])
plt.title('Pérdida vs. iteraciones')
plt.ylabel('Pérdida')
plt.xlabel('Iteración')
plt.subplot(1,2,2)
plt.plot(historia.history['acc'])
plt.title('Precisión vs. iteraciones')
plt.ylabel('Precisión')
plt.xlabel('Iteración')
plt.show()
# Calcular la precisión sobre el set de validación
puntaje = modelo.evaluate(X_test,Y_test,verbose=0)
print('Precisión en el set de validación: {:.1f}%'.format(100*puntaje[1]))
# Realizar predicción sobre el set de validación y mostrar algunos ejemplos
# de la clasificación resultante
Y_pred = modelo.predict_classes(X_test)
ids_imgs = np.random.randint(0,X_test.shape[0],9)
for i in range(len(ids_imgs)):
idx = ids_imgs[i]
img = X_test[idx,:].reshape(28,28)
cat_original = np.argmax(Y_test[idx,:])
cat_prediccion = Y_pred[idx]
plt.subplot(3,3,i+1)
plt.imshow(img, cmap='gray')
plt.axis('off')
plt.title('"{}" clasificado como "{}"'.format(cat_original,cat_prediccion))
plt.suptitle('Ejemplos de clasificación en el set de validación')
plt.show()