forked from FluxML/model-zoo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathiris.jl
88 lines (65 loc) · 2.43 KB
/
iris.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
using Flux
using Flux: logitcrossentropy, normalise, onecold, onehotbatch
using Statistics: mean
using Parameters: @with_kw
@with_kw mutable struct Args
lr::Float64 = 0.5
repeat::Int = 110
end
function get_processed_data(args)
labels = Flux.Data.Iris.labels()
features = Flux.Data.Iris.features()
# Subract mean, divide by std dev for normed mean of 0 and std dev of 1.
normed_features = normalise(features, dims=2)
klasses = sort(unique(labels))
onehot_labels = onehotbatch(labels, klasses)
# Split into training and test sets, 2/3 for training, 1/3 for test.
train_indices = [1:3:150 ; 2:3:150]
X_train = normed_features[:, train_indices]
y_train = onehot_labels[:, train_indices]
X_test = normed_features[:, 3:3:150]
y_test = onehot_labels[:, 3:3:150]
#repeat the data `args.repeat` times
train_data = Iterators.repeated((X_train, y_train), args.repeat)
test_data = (X_test,y_test)
return train_data, test_data
end
# Accuracy Function
accuracy(x, y, model) = mean(onecold(model(x)) .== onecold(y))
# Function to build confusion matrix
function confusion_matrix(X, y, model)
ŷ = onehotbatch(onecold(model(X)), 1:3)
y * transpose(ŷ)
end
function train(; kws...)
# Initialize hyperparameter arguments
args = Args(; kws...)
#Loading processed data
train_data, test_data = get_processed_data(args)
# Declare model taking 4 features as inputs and outputting 3 probabiltiies,
# one for each species of iris.
model = Chain(Dense(4, 3))
# Defining loss function to be used in training
# For numerical stability, we use here logitcrossentropy
loss(x, y) = logitcrossentropy(model(x), y)
# Training
# Gradient descent optimiser with learning rate `args.lr`
optimiser = Descent(args.lr)
println("Starting training.")
Flux.train!(loss, params(model), train_data, optimiser)
return model, test_data
end
function test(model, test)
# Testing model performance on test data
X_test, y_test = test
accuracy_score = accuracy(X_test, y_test, model)
println("\nAccuracy: $accuracy_score")
# Sanity check.
@assert accuracy_score > 0.8
# To avoid confusion, here is the definition of a Confusion Matrix: https://en.wikipedia.org/wiki/Confusion_matrix
println("\nConfusion Matrix:\n")
display(confusion_matrix(X_test, y_test, model))
end
cd(@__DIR__)
model, test_data = train()
test(model, test_data)