forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlp.jl
37 lines (26 loc) · 879 Bytes
/
mlp.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
# using CuArrays
# Classify MNIST digits with a simple multi-layer-perceptron
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...) |> gpu
labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9) |> gpu
m = Chain(
Dense(28^2, 32, relu),
Dense(32, 10),
softmax) |> gpu
loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()
Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))
accuracy(X, Y)
# Test set accuracy
tX = hcat(float.(reshape.(MNIST.images(:test), :))...) |> gpu
tY = onehotbatch(MNIST.labels(:test), 0:9) |> gpu
accuracy(tX, tY)