-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmultiPred2.cpp
executable file
·90 lines (74 loc) · 2.57 KB
/
multiPred2.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#include "multi.h"
#include <omp.h>
#include <cassert>
int main(int argc, char** argv){
if( argc < 1+2 ){
cerr << "multiPred [testfile] [model]" << endl;
cerr << "\tcompute top-1 accuracy" << endl;
exit(0);
}
char* testFile = argv[1];
char* modelFile = argv[2];
char* outFname;
bool is_binary = !isFile(modelFile);
StaticModel* model = readModel(modelFile, is_binary);
Problem* prob = new Problem();
readData( testFile, prob, true );
cerr << "Ntest=" << prob->N << endl;
//compute accuracy
vector<SparseVec*>* data = &(prob->data);
vector<Labels>* labels = &(prob->labels);
Float hit=0.0;
Float* prod = new Float[model->K];
memset(prod, 0.0, sizeof(Float)*model->K);
double start = omp_get_wtime();
vector<int> touched_index;
for(int i=0;i<prob->N;i++){
if( i % (prob->N/100) == 0 )
cerr << ".";
SparseVec* xi = data->at(i);
Labels* yi = &(labels->at(i));
//compute scores
for(SparseVec::iterator it=xi->begin(); it!=xi->end(); it++){
int j= it->first;
Float xij = it->second;
if( j >= model->D )
continue;
SparseVec* wj = &(model->w[j]);
for(SparseVec::iterator it2=wj->begin(); it2!=wj->end(); it2++){
int k = it2->first;
if( prod[k] == 0.0 ){
touched_index.push_back(k);
}
prod[k] += it2->second*xij;
}
}
Float max_val = -1e300;
int max_k = 0;
for(vector<int>::iterator it=touched_index.begin(); it!=touched_index.end(); it++){
if( prod[*it] > max_val ){
max_val = prod[*it];
max_k = *it;
}
}
//compute top-1 precision
bool flag = false;
for (int j = 0; j < yi->size(); j++){
if( yi->at(j) >= prob->label_name_list.size() )
continue;
if (prob->label_name_list[yi->at(j)] == model->label_name_list->at(max_k)){
flag = true;
}
}
if (flag)
hit += 1.0;
//clear earse prod values touched index
for(vector<int>::iterator it=touched_index.begin(); it!=touched_index.end(); it++)
prod[*it] = 0.0;
touched_index.clear();
}
cerr << endl;
double end = omp_get_wtime();
cerr << "Top " << 1 << " Acc=" << ((Float)hit/prob->N) << endl;
cerr << "pred time=" << (end-start) << " s" << endl;
}