-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy patheval_sent_predict.py
executable file
·137 lines (121 loc) · 4.7 KB
/
eval_sent_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
"""
Evaluation Tool for Predicting Categories
For usage, see help: `python %(program)s -h`
"""
import sys, os
sys.path.append("../oss")
from sent2vec import Sentence2Vec, REAL
from sentences import CatSentence
import logging
import utils
import matutils
from threading import Thread
from Queue import Queue
import numpy as np
import time,re
import cPickle as pickle
logger = logging.getLogger("sent_predict_eval")
def readSentence(sent):
sent_cat = {}
for tpl in sent:
sent_id = tpl[1]
cat_id = tpl[2]
sent_cat[sent_id] = cat_id
return sent_cat
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s %(relativeCreated)d : %(threadName)s : %(levelname)s : %(message)s', level=logging.INFO)
logging.info("running %s" % " ".join(sys.argv))
program = os.path.basename(sys.argv[0])
if len(sys.argv) < 2:
print(globals()['__doc__'] % locals())
sys.exit(1)
parser = Sentence2Vec.arg_parser()
parser.add_argument("--split", dest="split", action="store_true", help="use this option for split training data", default=False)
parser.add_argument("--modelfile", dest="modelfile", type=str, help="trained model file")
parser.add_argument("--test", dest="test", type=str, help="test file")
parser.set_defaults(maxn=sys.maxint)
parser.add_argument("--maxN", dest="maxn", type=int, help="")
parser.set_defaults(knn=1)
parser.add_argument("-k","--knn", dest="knn", type=int, help="use k of the nearest neighbors (default 1)")
args = parser.parse_args()
test_file = args.test
topK = args.knn
maxN = args.maxn
if not args.train:
print "ERROR: specify training set"
quit()
input_file = args.train[0]
if args.modelfile:
logging.info("load trained model file")
modelfile = args.modelfile
model = Sentence2Vec.load(modelfile)
else:
p_dir = re.compile("^.*/")
basename = p_dir.sub("",input_file)
if args.outdir:
outdir = args.outdir
else:
m = p_dir.search(input_file)
outdir = m.group(0) if m else ""
if args.split:
input_file = args.train
logging.info("train from input file")
model = Sentence2Vec(CatSentence(input_file, cont_col=3, split=args.split), iteration=args.iteration, model=args.model, hs = args.hs, negative = args.neg, workers = args.thread, alpha=args.alpha, size=args.dim, update_mode = args.update)
modelfile = "%s%s_%s.model" % (outdir, basename, model.identifier())
model.save(modelfile)
sent_cat = readSentence(CatSentence(input_file, cont_col=3, split=args.split))
test_sentences = CatSentence(test_file)
confusion_mtx = {}
def prepare_sentences():
count = 0
for sent_tuple in test_sentences:
yield sent_tuple
count += 1
if count > maxN: break
def worker_infer():
while True:
job = jobs.get()
if job is None:
break
diff = 0.
work = matutils.zeros_aligned(model.layer1_size + 8, dtype=REAL)
neu1 = matutils.zeros_aligned(model.layer1_size + 8, dtype=REAL)
for sent_tuple in job:
cat_id = sent_tuple[2]
ret = model.infer_sent([sent_tuple[0]], iteration=20, k=topK, work=work, neu1=neu1)
cats = [sent_cat[sid] for sid in ret[1]]
diff += 1. if cat_id in cats else 0.
print cats,cat_id
confusion_mtx.setdefault(cat_id, {})
confusion_mtx[cat_id].setdefault(cats[0], 0)
confusion_mtx[cat_id][cats[0]] += 1
qout.put(diff)
jobs = Queue(maxsize=50)
qout = Queue(maxsize=20000)
threads = [Thread(target=worker_infer) for _ in xrange(args.thread)]
sent_num = 0
for t in threads:
t.daemon = True
t.start()
for job_no, job in enumerate(utils.grouper(prepare_sentences(), 100)):
logger.info("putting job #%i in the queue, qsize=%i" % (job_no, jobs.qsize()))
jobs.put(job)
sent_num += len(job)
logger.info("reached the end of input; waiting to finish %i outstanding jobs" % jobs.qsize())
for _ in xrange(args.thread):
jobs.put(None)
for t in threads:
t.join()
avg = 0.0
while not qout.empty():
val = qout.get()
avg += val
avg /= sent_num
print avg
info_file = open("result_sent_eval.txt","a")
info_file.write("infer_sent_k%d\t%f\t%s\n" % (topK, avg, modelfile))
info_file.close()
pickle.dump(confusion_mtx, open(modelfile+".cmat", "w"))
program = os.path.basename(sys.argv[0])
logging.info("finished running %s" % program)