-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcrossValidate.m
87 lines (70 loc) · 2.55 KB
/
crossValidate.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
function [finalPopulation,finalAccuracy]=crossValidate(agent,classifierType,paramValue,fold)
% function for k-fold crossvaldiation of input population
global train trainLabel
rng('default');
numAgents=size(agent,1);
rows=size(train,1);
selection=createDivision(fold);
finalAccuracy=zeros(1,numAgents);
finalPopulation=agent;
for loop1=1:numAgents
data=train(:,agent(loop1,:)==1);
accuracy=zeros(1,fold);
if (size(data,2)==0)
finalAccuracy(1,loop1) = 0;
return;
end
for loop2=1:fold
trainTestFormat=zeros(rows,1);
for loop3=1:rows
if selection(1,loop3)==loop2
trainTestFormat(loop3,1)=1;
end
end
crossTrain=train(trainTestFormat(:)==0,:);
crossTrainLabel=trainLabel(trainTestFormat(:)==0,:);
crossTest=train(trainTestFormat(:)==1,:);
crossTestLabel=trainLabel(trainTestFormat(:)==1,:);
accuracy(1,loop2) = classify(crossTrain,crossTrainLabel,crossTest,crossTestLabel,agent(loop1,:),classifierType,paramValue);
end
finalAccuracy(1,loop1) = mean(accuracy);
end
[~,index]=sort(finalAccuracy,'descend');
finalAccuracy=finalAccuracy(index);
finalPopulation=finalPopulation(index,:);
end
function [selection] = createDivision(fold)
% function to create k divisions in the training datasets for k-fold crossvalidation
global trainLabel
rows = size(trainLabel,1);
sizeTraining=size(trainLabel,1);
labels=zeros(1,sizeTraining);
for loop=1:sizeTraining
labels(1,loop)=find(trainLabel(loop,:),1);
end
maxLabelNum = max(labels);
selection=zeros(1,rows);
for loop1=1:maxLabelNum
count1=sum(labels(:)==loop1);
samplesPerFold=int16(floor((count1/fold)));
for loop2=1:fold
count=0;
for loop3=1:rows
if(labels(loop3)==loop1 && selection(loop3)==0)
selection(loop3)=loop2;
count=count+1;
end
if(count==samplesPerFold)
break;
end
end
end
loop2=1;
for loop3=1:rows
if(selection(loop3)==0 && labels(loop3)==loop1)
selection(loop3)=loop2;
loop2=loop2+1;
end
end
end
end