forked from VowpalWabbit/vowpal_wabbit
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpylibvw.cc
712 lines (618 loc) · 33.1 KB
/
pylibvw.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
#include "../vowpalwabbit/vw.h"
#include "../vowpalwabbit/multiclass.h"
#include "../vowpalwabbit/cost_sensitive.h"
#include "../vowpalwabbit/cb.h"
#include "../vowpalwabbit/search.h"
#include "../vowpalwabbit/search_hooktask.h"
#include "../vowpalwabbit/parse_example.h"
#include "../vowpalwabbit/gd.h"
#include <boost/make_shared.hpp>
#include <boost/python.hpp>
#include <boost/python/suite/indexing/vector_indexing_suite.hpp>
using namespace std;
namespace py=boost::python;
typedef boost::shared_ptr<vw> vw_ptr;
typedef boost::shared_ptr<example> example_ptr;
typedef boost::shared_ptr<Search::search> search_ptr;
typedef boost::shared_ptr<Search::predictor> predictor_ptr;
const size_t lDEFAULT = 0;
const size_t lBINARY = 1;
const size_t lMULTICLASS = 2;
const size_t lCOST_SENSITIVE = 3;
const size_t lCONTEXTUAL_BANDIT = 4;
const size_t lMAX = 5;
void dont_delete_me(void*arg) { }
vw_ptr my_initialize(string args)
{ vw*foo = VW::initialize(args);
return boost::shared_ptr<vw>(foo, dont_delete_me);
}
void my_finish(vw_ptr all)
{ VW::finish(*all, false); // don't delete all because python will do that for us!
}
search_ptr get_search_ptr(vw_ptr all)
{ return boost::shared_ptr<Search::search>((Search::search*)(all->searchstr), dont_delete_me);
}
void my_audit_example(vw_ptr all, example_ptr ec) { GD::print_audit_features(*all, *ec); }
predictor_ptr get_predictor(search_ptr sch, ptag my_tag)
{ Search::predictor* P = new Search::predictor(*sch, my_tag);
return boost::shared_ptr<Search::predictor>(P);
}
label_parser* get_label_parser(vw*all, size_t labelType)
{ switch (labelType)
{ case lDEFAULT: return all ? &all->p->lp : NULL;
case lBINARY: return &simple_label;
case lMULTICLASS: return &MULTICLASS::mc_label;
case lCOST_SENSITIVE: return &COST_SENSITIVE::cs_label;
case lCONTEXTUAL_BANDIT: return &CB::cb_label;
default: cerr << "get_label_parser called on invalid label type" << endl; throw exception();
}
}
void my_delete_example(void*voidec)
{ example* ec = (example*) voidec;
size_t labelType = (ec->tag.size() == 0) ? lDEFAULT : ec->tag[0];
label_parser* lp = get_label_parser(NULL, labelType);
VW::dealloc_example(lp ? lp->delete_label : NULL, *ec);
free(ec);
}
example* my_empty_example0(vw_ptr vw, size_t labelType)
{ label_parser* lp = get_label_parser(&*vw, labelType);
example* ec = VW::alloc_examples(lp->label_size, 1);
lp->default_label(&ec->l);
if (labelType == lCOST_SENSITIVE)
{ COST_SENSITIVE::wclass zero = { 0., 1, 0., 0. };
ec->l.cs.costs.push_back(zero);
}
ec->tag.erase();
if (labelType != lDEFAULT)
ec->tag.push_back((char)labelType); // hide the label type in the tag
return ec;
}
example_ptr my_empty_example(vw_ptr vw, size_t labelType)
{ example* ec = my_empty_example0(vw, labelType);
return boost::shared_ptr<example>(ec, my_delete_example);
}
example_ptr my_read_example(vw_ptr all, size_t labelType, char*str)
{ example*ec = my_empty_example0(all, labelType);
VW::read_line(*all, ec, str);
VW::parse_atomic_example(*all, ec, false);
VW::setup_example(*all, ec);
ec->example_counter = labelType;
ec->tag.erase();
if (labelType != lDEFAULT)
ec->tag.push_back((char)labelType); // hide the label type in the tag
return boost::shared_ptr<example>(ec, my_delete_example);
}
void my_finish_example(vw_ptr all, example_ptr ec)
{ // TODO
}
void my_learn(vw_ptr all, example_ptr ec)
{ if (ec->test_only)
{ all->l->predict(*ec);
}
else
{ all->learn(ec.get());
}
}
float my_learn_string(vw_ptr all, char*str)
{ example*ec = VW::read_example(*all, str);
all->learn(ec);
float pp = ec->partial_prediction;
VW::finish_example(*all, ec);
return pp;
}
float my_predict(vw_ptr all, example_ptr ec)
{ bool old_test_only = ec->test_only;
ec->test_only = true;
all->l->learn(*ec);
ec->test_only = old_test_only;
return ec->partial_prediction;
}
float my_predict_string(vw_ptr all, char*str)
{ example*ec = VW::read_example(*all, str);
ec->test_only = true;
all->learn(ec);
float pp = ec->partial_prediction;
VW::finish_example(*all, ec);
return pp;
}
string varray_char_to_string(v_array<char> &a)
{ string ret = "";
for (auto c : a)
ret += c;
return ret;
}
string my_get_tag(example_ptr ec)
{ return varray_char_to_string(ec->tag);
}
uint32_t ex_num_namespaces(example_ptr ec)
{ return ec->indices.size();
}
unsigned char ex_namespace(example_ptr ec, uint32_t ns)
{ return ec->indices[ns];
}
uint32_t ex_num_features(example_ptr ec, unsigned char ns)
{ return ec->feature_space[ns].size();
}
uint32_t ex_feature(example_ptr ec, unsigned char ns, uint32_t i)
{ return ec->feature_space[ns].indicies[i];
}
float ex_feature_weight(example_ptr ec, unsigned char ns, uint32_t i)
{ return ec->feature_space[ns].values[i];
}
float ex_sum_feat_sq(example_ptr ec, unsigned char ns)
{ return ec->feature_space[ns].sum_feat_sq;
}
void ex_push_feature(example_ptr ec, unsigned char ns, uint32_t fid, float v)
{ // warning: assumes namespace exists!
ec->feature_space[ns].push_back(v,fid);
ec->num_features++;
ec->total_sum_feat_sq += v * v;
}
void ex_push_feature_list(example_ptr ec, vw_ptr vw, unsigned char ns, py::list& a)
{ // warning: assumes namespace exists!
char ns_str[2] = { (char)ns, 0 };
uint32_t ns_hash = VW::hash_space(*vw, ns_str);
size_t count = 0; float sum_sq = 0.;
for (size_t i=0; i<len(a); i++)
{ feature f = { 1., 0 };
py::object ai = a[i];
py::extract<py::tuple> get_tup(ai);
if (get_tup.check())
{ py::tuple fv = get_tup();
if (len(fv) != 2) { cerr << "warning: malformed feature in list" << endl; continue; } // TODO str(ai)
py::extract<float> get_val(fv[1]);
if (get_val.check())
f.x = get_val();
else { cerr << "warning: malformed feature in list" << endl; continue; }
ai = fv[0];
}
if (f.x != 0.)
{ bool got = false;
py::extract<string> get_str(ai);
if (get_str.check())
{ f.weight_index = VW::hash_feature(*vw, get_str(), ns_hash);
got = true;
}
else
{ py::extract<uint32_t> get_int(ai);
if (get_int.check()) { f.weight_index = get_int(); got = true; }
else { cerr << "warning: malformed feature in list" << endl; continue; }
}
if (got)
{ ec->feature_space[ns].push_back(f.x, f.weight_index);
count++;
sum_sq += f.x*f.x;
}
}
}
ec->num_features += count;
ec->total_sum_feat_sq += sum_sq;
}
void ex_push_namespace(example_ptr ec, unsigned char ns)
{ ec->indices.push_back(ns);
}
void ex_ensure_namespace_exists(example_ptr ec, unsigned char ns)
{ for (auto nss : ec->indices)
if (ns == nss) return;
ex_push_namespace(ec, ns);
}
void ex_push_dictionary(example_ptr ec, vw_ptr vw, py::dict& dict)
{ py::object objectKey, objectVal;
const py::object objectKeys = dict.iterkeys();
const py::object objectVals = dict.itervalues();
unsigned long ulCount = boost::python::extract<unsigned long>(dict.attr("__len__")());
for (size_t u=0; u<ulCount; u++)
{ objectKey = objectKeys.attr( "next" )();
objectVal = objectVals.attr( "next" )();
char chCheckKey = objectKey.ptr()->ob_type->tp_name[0];
if (chCheckKey != 's') continue;
chCheckKey = objectVal.ptr()->ob_type->tp_name[0];
if (chCheckKey != 'l') continue;
py::extract<string> ns_e(objectKey);
if (ns_e().length() < 1) continue;
py::extract<py::list> list_e(objectVal);
py::list list = list_e();
char ns = ns_e()[0];
ex_ensure_namespace_exists(ec, ns);
ex_push_feature_list(ec, vw, ns, list);
}
}
bool ex_pop_feature(example_ptr ec, unsigned char ns)
{ if (ec->feature_space[ns].size() == 0) return false;
float val = ec->feature_space[ns].values.pop();
if (ec->feature_space[ns].indicies.size()> 0)
ec->feature_space[ns].indicies.pop();
if (ec->feature_space[ns].space_names.size()> 0)
ec->feature_space[ns].space_names.pop();
ec->num_features--;
ec->feature_space[ns].sum_feat_sq -= val * val;
ec->total_sum_feat_sq -= val * val;
return true;
}
void ex_erase_namespace(example_ptr ec, unsigned char ns)
{ ec->num_features -= ec->feature_space[ns].size();
ec->total_sum_feat_sq -= ec->feature_space[ns].sum_feat_sq;
ec->feature_space[ns].sum_feat_sq = 0.;
ec->feature_space[ns].erase();
}
bool ex_pop_namespace(example_ptr ec)
{ if (ec->indices.size() == 0) return false;
unsigned char ns = ec->indices.pop();
ex_erase_namespace(ec, ns);
return true;
}
void my_setup_example(vw_ptr vw, example_ptr ec)
{ VW::setup_example(*vw, ec.get());
}
void unsetup_example(vw_ptr vwP, example_ptr ae)
{ vw&all = *vwP;
ae->partial_prediction = 0.;
ae->num_features = 0;
ae->total_sum_feat_sq = 0;
ae->loss = 0.;
if (all.ignore_some)
{ cerr << "error: cannot unsetup example when some namespaces are ignored!" << endl;
throw exception();
}
if(all.ngram_strings.size() > 0)
{ cerr << "error: cannot unsetup example when ngrams are in use!" << endl;
throw exception();
}
if (all.add_constant)
{ ae->feature_space[constant_namespace].erase();
int hit_constant = -1;
size_t N = ae->indices.size();
for (size_t i=0; i<N; i++)
{ size_t j = N - 1 - i;
if (ae->indices[j] == constant_namespace)
{ if (hit_constant >= 0) { cerr << "error: hit constant namespace twice!" << endl; throw exception(); }
hit_constant = j;
break;
}
}
if (hit_constant >= 0)
{ for (size_t i=hit_constant; i<N-1; i++)
ae->indices[i] = ae->indices[i+1];
ae->indices.pop();
}
}
uint32_t multiplier = all.wpp << all.reg.stride_shift;
if(multiplier != 1) //make room for per-feature information.
for (auto ns : ae->indices)
for (auto& idx : ae->feature_space[ns].indicies)
idx /= multiplier;
}
void ex_set_label_string(example_ptr ec, vw_ptr vw, string label, size_t labelType)
{ // SPEEDUP: if it's already set properly, don't modify
label_parser& old_lp = vw->p->lp;
vw->p->lp = *get_label_parser(&*vw, labelType);
VW::parse_example_label(*vw, *ec, label);
vw->p->lp = old_lp;
}
float ex_get_simplelabel_label(example_ptr ec) { return ec->l.simple.label; }
float ex_get_simplelabel_weight(example_ptr ec) { return ec->l.simple.weight; }
float ex_get_simplelabel_initial(example_ptr ec) { return ec->l.simple.initial; }
float ex_get_simplelabel_prediction(example_ptr ec) { return ec->pred.scalar; }
uint32_t ex_get_multiclass_label(example_ptr ec) { return ec->l.multi.label; }
float ex_get_multiclass_weight(example_ptr ec) { return ec->l.multi.weight; }
uint32_t ex_get_multiclass_prediction(example_ptr ec) { return ec->pred.multiclass; }
uint32_t ex_get_costsensitive_prediction(example_ptr ec) { return ec->pred.multiclass; }
uint32_t ex_get_costsensitive_num_costs(example_ptr ec) { return ec->l.cs.costs.size(); }
float ex_get_costsensitive_cost(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].x; }
uint32_t ex_get_costsensitive_class(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].class_index; }
float ex_get_costsensitive_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].partial_prediction; }
float ex_get_costsensitive_wap_value(example_ptr ec, uint32_t i) { return ec->l.cs.costs[i].wap_value; }
uint32_t ex_get_cbandits_prediction(example_ptr ec) { return ec->pred.multiclass; }
uint32_t ex_get_cbandits_num_costs(example_ptr ec) { return ec->l.cb.costs.size(); }
float ex_get_cbandits_cost(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].cost; }
uint32_t ex_get_cbandits_class(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].action; }
float ex_get_cbandits_probability(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].probability; }
float ex_get_cbandits_partial_prediction(example_ptr ec, uint32_t i) { return ec->l.cb.costs[i].partial_prediction; }
size_t get_example_counter(example_ptr ec) { return ec->example_counter; }
uint32_t get_ft_offset(example_ptr ec) { return ec->ft_offset; }
size_t get_num_features(example_ptr ec) { return ec->num_features; }
float get_partial_prediction(example_ptr ec) { return ec->partial_prediction; }
float get_updated_prediction(example_ptr ec) { return ec->updated_prediction; }
float get_loss(example_ptr ec) { return ec->loss; }
float get_example_t(example_ptr ec) { return ec->example_t; }
float get_total_sum_feat_sq(example_ptr ec) { return ec->total_sum_feat_sq; }
double get_sum_loss(vw_ptr vw) { return vw->sd->sum_loss; }
double get_weighted_examples(vw_ptr vw) { return vw->sd->weighted_examples; }
bool search_should_output(search_ptr sch) { return sch->output().good(); }
void search_output(search_ptr sch, string s) { sch->output() << s; }
/*
uint32_t search_predict_one_all(search_ptr sch, example_ptr ec, uint32_t one_ystar) {
return sch->predict(ec.get(), one_ystar, NULL);
}
uint32_t search_predict_one_some(search_ptr sch, example_ptr ec, uint32_t one_ystar, vector<uint32_t>& yallowed) {
v_array<uint32_t> yallowed_va;
yallowed_va.begin = yallowed.data();
yallowed_va.end = yallowed_va.begin + yallowed.size();
yallowed_va.end_array = yallowed_va.end;
yallowed_va.erase_count = 0;
return sch->predict(ec.get(), one_ystar, &yallowed_va);
}
uint32_t search_predict_many_all(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar) {
v_array<uint32_t> ystar_va;
ystar_va.begin = ystar.data();
ystar_va.end = ystar_va.begin + ystar.size();
ystar_va.end_array = ystar_va.end;
ystar_va.erase_count = 0;
return sch->predict(ec.get(), &ystar_va, NULL);
}
uint32_t search_predict_many_some(search_ptr sch, example_ptr ec, vector<uint32_t>& ystar, vector<uint32_t>& yallowed) {
v_array<uint32_t> ystar_va;
ystar_va.begin = ystar.data();
ystar_va.end = ystar_va.begin + ystar.size();
ystar_va.end_array = ystar_va.end;
ystar_va.erase_count = 0;
v_array<uint32_t> yallowed_va;
yallowed_va.begin = yallowed.data();
yallowed_va.end = yallowed_va.begin + yallowed.size();
yallowed_va.end_array = yallowed_va.end;
yallowed_va.erase_count = 0;
return sch->predict(ec.get(), &ystar_va, &yallowed_va);
}
*/
void verify_search_set_properly(search_ptr sch)
{ if (sch->task_name == NULL)
{ cerr << "set_structured_predict_hook: search task not initialized properly" << endl;
throw exception();
}
if (strcmp(sch->task_name, "hook") != 0)
{ cerr << "set_structured_predict_hook: trying to set hook when search task is not 'hook'!" << endl;
throw exception();
}
}
uint32_t search_get_num_actions(search_ptr sch)
{ verify_search_set_properly(sch);
HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
return d->num_actions;
}
void search_run_fn(Search::search&sch)
{ try
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
py::object run = *(py::object*)d->run_object;
run.attr("__call__")();
}
catch(...)
{ PyErr_Print();
PyErr_Clear();
throw exception();
}
}
void search_setup_fn(Search::search&sch)
{ try
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
py::object run = *(py::object*)d->setup_object;
run.attr("__call__")();
}
catch(...)
{ PyErr_Print();
PyErr_Clear();
throw exception();
}
}
void search_takedown_fn(Search::search&sch)
{ try
{ HookTask::task_data* d = sch.get_task_data<HookTask::task_data>();
py::object run = *(py::object*)d->takedown_object;
run.attr("__call__")();
}
catch(...)
{ PyErr_Print();
PyErr_Clear();
throw exception();
}
}
void py_delete_run_object(void* pyobj)
{ py::object* o = (py::object*)pyobj;
delete o;
}
void set_force_oracle(search_ptr sch, bool useOracle)
{ verify_search_set_properly(sch);
sch->set_force_oracle(useOracle);
}
void set_structured_predict_hook(search_ptr sch, py::object run_object, py::object setup_object, py::object takedown_object)
{ verify_search_set_properly(sch);
HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
d->run_f = &search_run_fn;
delete (py::object*)d->run_object; d->run_object = NULL;
delete (py::object*)d->setup_object; d->setup_object = NULL;
delete (py::object*)d->takedown_object; d->takedown_object = NULL;
sch->set_force_oracle(false);
d->run_object = new py::object(run_object);
if (setup_object.ptr() != Py_None)
{ d->setup_object = new py::object(setup_object);
d->run_setup_f = &search_setup_fn;
}
if (takedown_object.ptr() != Py_None)
{ d->takedown_object = new py::object(takedown_object);
d->run_takedown_f = &search_takedown_fn;
}
d->delete_run_object = &py_delete_run_object;
}
void my_set_test_only(example_ptr ec, bool val) { ec->test_only = val; }
bool po_exists(search_ptr sch, string arg)
{ HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
return (*d->var_map).count(arg) > 0;
}
string po_get_string(search_ptr sch, string arg)
{ HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
return (*d->var_map)[arg].as<string>();
}
int32_t po_get_int(search_ptr sch, string arg)
{ HookTask::task_data* d = sch->get_task_data<HookTask::task_data>();
try { return (*d->var_map)[arg].as<int>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<size_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<uint32_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<uint64_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<uint16_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<int32_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<int64_t>(); }
catch (...) {}
try { return (*d->var_map)[arg].as<int16_t>(); }
catch (...) {}
// we know this'll fail but do it anyway to get the exception
return (*d->var_map)[arg].as<int>();
}
PyObject* po_get(search_ptr sch, string arg)
{ try
{ return py::incref(py::object(po_get_string(sch, arg)).ptr());
}
catch (...) {}
try
{ return py::incref(py::object(po_get_int(sch, arg)).ptr());
}
catch (...) {}
// return None
return py::incref(py::object().ptr());
}
void my_set_input(predictor_ptr P, example_ptr ec) { P->set_input(*ec); }
void my_set_input_at(predictor_ptr P, size_t posn, example_ptr ec) { P->set_input_at(posn, *ec); }
void my_add_oracle(predictor_ptr P, action a) { P->add_oracle(a); }
void my_add_oracles(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
void my_add_allowed(predictor_ptr P, action a) { P->add_allowed(a); }
void my_add_alloweds(predictor_ptr P, py::list& a) { for (size_t i=0; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
void my_add_condition(predictor_ptr P, ptag t, char c) { P->add_condition(t, c); }
void my_add_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->add_condition_range(hi, count, name0); }
void my_set_oracle(predictor_ptr P, action a) { P->set_oracle(a); }
void my_set_oracles(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_oracle(py::extract<action>(a[0])); else P->erase_oracles(); for (size_t i=1; i<len(a); i++) P->add_oracle(py::extract<action>(a[i])); }
void my_set_allowed(predictor_ptr P, action a) { P->set_allowed(a); }
void my_set_alloweds(predictor_ptr P, py::list& a) { if (len(a) > 0) P->set_allowed(py::extract<action>(a[0])); else P->erase_alloweds(); for (size_t i=1; i<len(a); i++) P->add_allowed(py::extract<action>(a[i])); }
void my_set_condition(predictor_ptr P, ptag t, char c) { P->set_condition(t, c); }
void my_set_condition_range(predictor_ptr P, ptag hi, ptag count, char name0) { P->set_condition_range(hi, count, name0); }
void my_set_learner_id(predictor_ptr P, size_t id) { P->set_learner_id(id); }
void my_set_tag(predictor_ptr P, ptag t) { P->set_tag(t); }
BOOST_PYTHON_MODULE(pylibvw)
{ // This will enable user-defined docstrings and python signatures,
// while disabling the C++ signatures
py::docstring_options local_docstring_options(true, true, false);
// define the vw class
py::class_<vw, vw_ptr>("vw", "the basic VW object that holds with weight vector, parser, etc.", py::no_init)
.def("__init__", py::make_constructor(my_initialize))
// .def("__del__", &my_finish, "deconstruct the VW object by calling finish")
.def("finish", &my_finish, "stop VW by calling finish (and, eg, write weights to disk)")
.def("learn", &my_learn, "given a pyvw example, learn (and predict) on that example")
.def("learn_string", &my_learn_string, "given an example specified as a string (as in a VW data file), learn on that example")
.def("predict", &my_predict_string, "given a pyvw example, predict on that example")
.def("predict_string", &my_predict_string, "given an example specified as a string (as in a VW data file), predict on that example")
.def("hash_space", &VW::hash_space, "given a namespace (as a string), compute the hash of that namespace")
.def("hash_feature", &VW::hash_feature, "given a feature string (arg2) and a hashed namespace (arg3), hash that feature")
.def("finish_example", &my_finish_example, "tell VW that you're done with a given example")
.def("setup_example", &my_setup_example, "given an example that you've created by hand, prepare it for learning (eg, compute quadratic feature)")
.def("unsetup_example", &unsetup_example, "reverse the process of setup, so that you can go back and modify this example")
.def("num_weights", &VW::num_weights, "how many weights are we learning?")
.def("get_weight", &VW::get_weight, "get the weight for a particular index")
.def("set_weight", &VW::set_weight, "set the weight for a particular index")
.def("get_stride", &VW::get_stride, "return the internal stride")
.def("get_sum_loss", &get_sum_loss, "return the total cumulative loss suffered so far")
.def("get_weighted_examples", &get_weighted_examples, "return the total weight of examples so far")
.def("get_search_ptr", &get_search_ptr, "return a pointer to the search data structure")
.def("audit_example", &my_audit_example, "print example audit information")
.def_readonly("lDefault", lDEFAULT, "Default label type (whatever vw was initialized with) -- used as input to the example() initializer")
.def_readonly("lBinary", lBINARY, "Binary label type -- used as input to the example() initializer")
.def_readonly("lMulticlass", lMULTICLASS, "Multiclass label type -- used as input to the example() initializer")
.def_readonly("lCostSensitive", lCOST_SENSITIVE, "Cost sensitive label type (for LDF!) -- used as input to the example() initializer")
.def_readonly("lContextualBandit", lCONTEXTUAL_BANDIT, "Contextual bandit label type -- used as input to the example() initializer")
;
// define the example class
py::class_<example, example_ptr>("example", py::no_init)
.def("__init__", py::make_constructor(my_read_example), "Given a string as an argument parse that into a VW example (and run setup on it) -- default to multiclass label type")
.def("__init__", py::make_constructor(my_empty_example), "Construct an empty (non setup) example; you must provide a label type (vw.lBinary, vw.lMulticlass, etc.)")
.def("set_test_only", &my_set_test_only, "Change the test-only bit on an example")
.def("get_tag", &my_get_tag, "Returns the tag associated with this example")
.def("get_topic_prediction", &VW::get_topic_prediction, "For LDA models, returns the topic prediction for the topic id given")
.def("get_feature_number", &VW::get_feature_number, "Returns the total number of features for this example")
.def("get_example_counter", &get_example_counter, "Returns the counter of total number of examples seen up to and including this one")
.def("get_ft_offset", &get_ft_offset, "Returns the feature offset for this example (used, eg, by multiclass classification to bulk offset all features)")
.def("get_partial_prediction", &get_partial_prediction, "Returns the partial prediction associated with this example")
.def("get_updated_prediction", &get_updated_prediction, "Returns the partial prediction as if we had updated it after learning")
.def("get_loss", &get_loss, "Returns the loss associated with this example")
.def("get_example_t", &get_example_t, "The total sum of importance weights up to and including this example")
.def("get_total_sum_feat_sq", &get_total_sum_feat_sq, "The total sum of feature-value squared for this example")
.def("num_namespaces", &ex_num_namespaces, "The total number of namespaces associated with this example")
.def("namespace", &ex_namespace, "Get the namespace id for namespace i (for i = 0.. num_namespaces); specifically returns the ord() of the corresponding character id")
.def("sum_feat_sq", &ex_sum_feat_sq, "Get the sum of feature-values squared for a given namespace id (id=character-ord)")
.def("num_features_in", &ex_num_features, "Get the number of features in a given namespace id (id=character-ord)")
.def("feature", &ex_feature, "Get the feature id for the ith feature in a given namespace id (id=character-ord)")
.def("feature_weight", &ex_feature_weight, "The the feature value (weight) per .feature(...)")
.def("push_hashed_feature", &ex_push_feature, "Add a hashed feature to a given namespace (id=character-ord)")
.def("push_feature_list", &ex_push_feature_list, "Add a (Python) list of features to a given namespace")
.def("push_feature_dict", &ex_push_dictionary, "Add a (Python) dictionary of namespace/feature-list pairs")
.def("pop_feature", &ex_pop_feature, "Remove the top feature from a given namespace; returns True iff the list was non-empty")
.def("push_namespace", &ex_push_namespace, "Add a new namespace")
.def("ensure_namespace_exists", &ex_ensure_namespace_exists, "Add a new namespace if it doesn't already exist")
.def("pop_namespace", &ex_pop_namespace, "Remove the top namespace off; returns True iff the list was non-empty")
.def("erase_namespace", &ex_erase_namespace, "Remove all the features from a given namespace")
.def("set_label_string", &ex_set_label_string, "(Re)assign the label of this example to this string")
.def("get_simplelabel_label", &ex_get_simplelabel_label, "Assuming a simple_label label type, return the corresponding label (class/regression target/etc.)")
.def("get_simplelabel_weight", &ex_get_simplelabel_weight, "Assuming a simple_label label type, return the importance weight")
.def("get_simplelabel_initial", &ex_get_simplelabel_initial, "Assuming a simple_label label type, return the initial (baseline) prediction")
.def("get_simplelabel_prediction", &ex_get_simplelabel_prediction, "Assuming a simple_label label type, return the final prediction")
.def("get_multiclass_label", &ex_get_multiclass_label, "Assuming a multiclass label type, get the true label")
.def("get_multiclass_weight", &ex_get_multiclass_weight, "Assuming a multiclass label type, get the importance weight")
.def("get_multiclass_prediction", &ex_get_multiclass_prediction, "Assuming a multiclass label type, get the prediction")
.def("get_costsensitive_prediction", &ex_get_costsensitive_prediction, "Assuming a cost_sensitive label type, get the prediction")
.def("get_costsensitive_num_costs", &ex_get_costsensitive_num_costs, "Assuming a cost_sensitive label type, get the total number of label/cost pairs")
.def("get_costsensitive_cost", &ex_get_costsensitive_cost, "Assuming a cost_sensitive label type, get the cost for a given pair (i=0.. get_costsensitive_num_costs)")
.def("get_costsensitive_class", &ex_get_costsensitive_class, "Assuming a cost_sensitive label type, get the label for a given pair (i=0.. get_costsensitive_num_costs)")
.def("get_costsensitive_partial_prediction", &ex_get_costsensitive_partial_prediction, "Assuming a cost_sensitive label type, get the partial prediction for a given pair (i=0.. get_costsensitive_num_costs)")
.def("get_costsensitive_wap_value", &ex_get_costsensitive_wap_value, "Assuming a cost_sensitive label type, get the weighted-all-pairs recomputed cost for a given pair (i=0.. get_costsensitive_num_costs)")
.def("get_cbandits_prediction", &ex_get_cbandits_prediction, "Assuming a contextual_bandits label type, get the prediction")
.def("get_cbandits_num_costs", &ex_get_cbandits_num_costs, "Assuming a contextual_bandits label type, get the total number of label/cost pairs")
.def("get_cbandits_cost", &ex_get_cbandits_cost, "Assuming a contextual_bandits label type, get the cost for a given pair (i=0.. get_cbandits_num_costs)")
.def("get_cbandits_class", &ex_get_cbandits_class, "Assuming a contextual_bandits label type, get the label for a given pair (i=0.. get_cbandits_num_costs)")
.def("get_cbandits_probability", &ex_get_cbandits_probability, "Assuming a contextual_bandits label type, get the bandits probability for a given pair (i=0.. get_cbandits_num_costs)")
.def("get_cbandits_partial_prediction", &ex_get_cbandits_partial_prediction, "Assuming a contextual_bandits label type, get the partial prediction for a given pair (i=0.. get_cbandits_num_costs)")
;
py::class_<Search::predictor, predictor_ptr>("predictor", py::no_init)
.def("set_input", &my_set_input, "set the input (an example) for this predictor (non-LDF mode only)")
//.def("set_input_ldf", &my_set_input_ldf, "set the inputs (a list of examples) for this predictor (LDF mode only)")
.def("set_input_length", &Search::predictor::set_input_length, "declare the length of an LDF-sequence of examples")
.def("set_input_at", &my_set_input_at, "put a given example at position in the LDF sequence (call after set_input_length)")
.def("add_oracle", &my_add_oracle, "add an action to the current list of oracle actions")
.def("add_oracles", &my_add_oracles, "add a list of actions to the current list of oracle actions")
.def("add_allowed", &my_add_allowed, "add an action to the current list of allowed actions")
.def("add_alloweds", &my_add_alloweds, "add a list of actions to the current list of allowed actions")
.def("add_condition", &my_add_condition, "add a (tag,char) pair to the list of variables on which to condition")
.def("add_condition_range", &my_add_condition_range, "given (tag,len,char), add (tag,char), (tag-1,char+1), ..., (tag-len,char+len) to the list of conditionings")
.def("set_oracle", &my_set_oracle, "set an action as the current list of oracle actions")
.def("set_oracles", &my_set_oracles, "set a list of actions as the current list of oracle actions")
.def("set_allowed", &my_set_allowed, "set an action as the current list of allowed actions")
.def("set_alloweds", &my_set_alloweds, "set a list of actions as the current list of allowed actions")
.def("set_condition", &my_set_condition, "set a (tag,char) pair as the list of variables on which to condition")
.def("set_condition_range", &my_set_condition_range, "given (tag,len,char), set (tag,char), (tag-1,char+1), ..., (tag-len,char+len) as the list of conditionings")
.def("set_learner_id", &my_set_learner_id, "select the learner with which to make this prediction")
.def("set_tag", &my_set_tag, "change the tag of this prediction")
.def("predict", &Search::predictor::predict, "make a prediction")
;
py::class_<Search::search, search_ptr>("search")
.def("set_options", &Search::search::set_options, "Set global search options (auto conditioning, etc.)")
.def("set_num_learners", &Search::search::set_num_learners, "Set the total number of learners you want to train")
.def("get_history_length", &Search::search::get_history_length, "Get the value specified by --search_history_length")
.def("loss", &Search::search::loss, "Declare a (possibly incremental) loss")
.def("should_output", &search_should_output, "Check whether search wants us to output (only happens if you have -p running)")
.def("predict_needs_example", &Search::search::predictNeedsExample, "Check whether a subsequent call to predict is actually going to use the example you pass---i.e., can you skip feature computation?")
.def("output", &search_output, "Add a string to the coutput (should only do if should_output returns True)")
.def("get_num_actions", &search_get_num_actions, "Return the total number of actions search was initialized with")
.def("set_structured_predict_hook", &set_structured_predict_hook, "Set the hook (function pointer) that search should use for structured prediction (you don't want to call this yourself!")
.def("set_force_oracle", &set_force_oracle, "For oracle decoding when .predict is run")
.def("is_ldf", &Search::search::is_ldf, "check whether this search task is running in LDF mode")
.def("po_exists", &po_exists, "For program (cmd line) options, check to see if a given option was specified; eg sch.po_exists(\"search\") should be True")
.def("po_get", &po_get, "For program (cmd line) options, if an option was specified, get its value; eg sch.po_get(\"search\") should return the # of actions (returns either int or string)")
.def("po_get_str", &po_get_string, "Same as po_get, but specialized for string return values.")
.def("po_get_int", &po_get_int, "Same as po_get, but specialized for integer return values.")
.def("get_predictor", &get_predictor, "Get a predictor object that can be used for making predictions; requires a tag argument to tag the prediction.")
.def_readonly("AUTO_CONDITION_FEATURES", Search::AUTO_CONDITION_FEATURES, "Tell search to automatically add features based on conditioned-on variables")
.def_readonly("AUTO_HAMMING_LOSS", Search::AUTO_HAMMING_LOSS, "Tell search to automatically compute hamming loss over predictions")
.def_readonly("EXAMPLES_DONT_CHANGE", Search::EXAMPLES_DONT_CHANGE, "Tell search that on a single structured 'run', you don't change the examples you pass to predict")
.def_readonly("IS_LDF", Search::IS_LDF, "Tell search that this is an LDF task")
;
}