Skip to content

Commit

Permalink
Merge pull request FluxML#64 from sambitdash/master
Browse files Browse the repository at this point in the history
Changes to CNN to improve model accuracy.
  • Loading branch information
DhairyaLGandhi authored Jan 17, 2019
2 parents ea994cd + 9dff27b commit c9427b7
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions vision/mnist/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ imgs = MNIST.images()

labels = onehotbatch(MNIST.labels(), 0:9)

# Partition into batches of size 1,000
# Partition into batches of size 32
train = [(cat(float.(imgs[i])..., dims = 4), labels[:,i])
for i in partition(1:60_000, 1000)]
for i in partition(1:60_000, 32)]

train = gpu.(train)

Expand All @@ -19,12 +19,14 @@ tX = cat(float.(MNIST.images(:test)[1:1000])..., dims = 4) |> gpu
tY = onehotbatch(MNIST.labels(:test)[1:1000], 0:9) |> gpu

m = Chain(
Conv((2,2), 1=>16, relu),
x -> maxpool(x, (2,2)),
Conv((2,2), 16=>8, relu),
x -> maxpool(x, (2,2)),
x -> reshape(x, :, size(x, 4)),
Dense(288, 10), softmax) |> gpu
Conv((3, 3), 1=>32, relu),
Conv((3, 3), 32=>32, relu),
x -> maxpool(x, (2,2)),
Conv((3, 3), 32=>16, relu),
x -> maxpool(x, (2,2)),
Conv((3, 3), 16=>10, relu),
x -> reshape(x, :, size(x, 4)),
Dense(90, 10), softmax) |> gpu

m(train[1][1])

Expand Down

0 comments on commit c9427b7

Please sign in to comment.