-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy patheval_joint_cat_predict.py
executable file
·142 lines (127 loc) · 5.55 KB
/
eval_joint_cat_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/env python
"""
Evaluation Tool for Predicting Categories
For usage, see help: `python %(program)s -h`
"""
import sys, os
sys.path.append("../oss")
from cat2vec import Category2Vec, REAL
from cat2vec_bind import catsentvec_sim_sum, init_joint_pairtable, joint_catsentvec_sim_sum
from sentences import CatSentence
import logging
import utils
import matutils
from threading import Thread
from Queue import Queue
import numpy as np
import time,re
import cPickle as pickle
from argparse import ArgumentParser
from multiprocessing import cpu_count
logger = logging.getLogger("cat_predict_eval")
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s %(relativeCreated)d : %(threadName)s : %(levelname)s : %(message)s', level=logging.INFO)
logging.info("running %s" % " ".join(sys.argv))
program = os.path.basename(sys.argv[0])
if len(sys.argv) < 2:
print(globals()['__doc__'] % locals())
sys.exit(1)
parser = ArgumentParser(description="Evaluation tool for joint category vector models")
parser.add_argument("--split", dest="split", action="store_true", help="use this option for split training data", default=False)
parser.add_argument("--modelfile1", dest="modelfile1", type=str, help="trained model file 1")
parser.add_argument("--modelfile2", dest="modelfile2", type=str, help="trained model file 2")
parser.add_argument("--test", dest="test", type=str, help="test file")
parser.set_defaults(maxn=sys.maxint)
parser.add_argument("--maxN", dest="maxn", type=int, help="")
parser.set_defaults(thread=cpu_count())
parser.add_argument("-t", "--thread", dest="thread", type=int, help="the number of threads")
parser.set_defaults(knn=1)
parser.add_argument("-k","--knn", dest="knn", type=int, help="use k of the nearest neighbors (default 1)")
args = parser.parse_args()
test_file = args.test
topK = args.knn
maxN = args.maxn
if not args.modelfile1 or not args.modelfile2:
print "Specify modelfile1 and modelfile2"
quit(-1)
logging.info("load trained model file")
modelfile1 = args.modelfile1
model1 = Category2Vec.load(modelfile1)
modelfile2 = args.modelfile2
model2 = Category2Vec.load(modelfile2)
logging.info("initializing pairnorm")
model1.init_pairnorm()
model2.init_pairnorm()
#pairtable = np.empty((model1.pair_len, model1.layer1_size * 2), dtype=REAL)
#init_joint_pairtable(model1, model2, pairtable)
test_sentences = CatSentence(test_file)
confusion_mtx = {}
def prepare_sentences():
count = 0
for sent_tuple in test_sentences:
yield sent_tuple
count += 1
if count > maxN: break
def worker_infer():
while True:
job = jobs.get()
if job is None:
break
diff = 0.
work = matutils.zeros_aligned(model1.layer1_size + 8, dtype=REAL)
neu1 = matutils.zeros_aligned(model1.layer1_size + 8, dtype=REAL)
for sent_tuple in job:
cat_id_gold = sent_tuple[2]
sent_vec1, cat_vec1 = model1.train_single_sent_id([sent_tuple[0]], 20, work, neu1)
sims1 = np.empty(model1.pair_len, dtype=REAL)
catsentvec_sim_sum(model1, sent_vec1, cat_vec1, sims1)
sent_vec2, cat_vec2 = model2.train_single_sent_id([sent_tuple[0]], 20, work, neu1)
sims2 = np.empty(model2.pair_len, dtype=REAL)
catsentvec_sim_sum(model2, sent_vec2, cat_vec2, sims2)
sims1 += sims2
#joint_catsentvec_sim_sum(pairtable, sent_vec1, cat_vec1, sent_vec2, cat_vec2, sims1)
neighbors = np.argsort(sims1)[::-1]
cat_ids = {}
nearest = []
ident_cat = True
for top_cand in neighbors:
(sent_no, cat_no) = model1.sent_cat_pair[top_cand]
cat_id = model1.cat_id_list[cat_no]
if not ident_cat or not cat_ids.has_key(cat_id):
cat_ids[cat_id] = 1
nearest.append(cat_id)
if len(nearest) == topK: break
diff += 1. if cat_id_gold in nearest else 0.
print nearest,cat_id_gold
confusion_mtx.setdefault(cat_id_gold, {})
confusion_mtx[cat_id_gold].setdefault(nearest[0], 0)
confusion_mtx[cat_id_gold][nearest[0]] += 1
qout.put(diff)
jobs = Queue(maxsize=50)
qout = Queue(maxsize=20000)
threads = [Thread(target=worker_infer) for _ in xrange(args.thread)]
sent_num = 0
for t in threads:
t.daemon = True
t.start()
for job_no, job in enumerate(utils.grouper(prepare_sentences(), 100)):
logger.info("putting job #%i in the queue, qsize=%i" % (job_no, jobs.qsize()))
jobs.put(job)
sent_num += len(job)
logger.info("reached the end of input; waiting to finish %i outstanding jobs" % jobs.qsize())
for _ in xrange(args.thread):
jobs.put(None)
for t in threads:
t.join()
avg = 0.0
while not qout.empty():
val = qout.get()
avg += val
avg /= sent_num
print avg
info_file = open("result_joint_cat_eval.txt","a")
info_file.write("infer_cat_k%d\t%f\tm1:%s\tm2:%s\n" % (topK, avg, modelfile1, modelfile2))
info_file.close()
pickle.dump(confusion_mtx, open("joint_"+model1.identifier()+"_"+model2.identifier()+".cmat", "w"))
program = os.path.basename(sys.argv[0])
logging.info("finished running %s" % program)