Skip to content

Commit

Permalink
update python-package
Browse files Browse the repository at this point in the history
  • Loading branch information
aksnzhy committed Nov 20, 2017
1 parent 7404088 commit 5deb6d9
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 20 deletions.
44 changes: 24 additions & 20 deletions python-package/xlearn/xlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
import os
import ctypes
from base import _LIB
from base import _LIB, XLearnHandle
from base import _check_call, c_str

class XLearn(object):
Expand All @@ -21,7 +21,7 @@ def __init__(self, handle):
def __del__(self):
_check_call(_LIB.XLearnHandleFree(self.handle))

def __set_param(self, param):
def _set_param(self, param):
"""Set hyper-parameter for xlearn handle
Parameters
Expand All @@ -31,34 +31,38 @@ def __set_param(self, param):
"""
for (key, value) in param.items():
if key == 'task':
__check_call(_LIB.XLearnSetStr(self.handle,
_check_call(_LIB.XLearnSetStr(self.handle,
c_str(key), c_str(value)))
elif key == 'metric':
__check_call(_LIB.XLearnSetStr(self.handle,
_check_call(_LIB.XLearnSetStr(self.handle,
c_str(key), c_str(value)))
elif key == 'log':
__check_call(_LIB.XLearnSetStr(self.handle,
_check_call(_LIB.XLearnSetStr(self.handle,
c_str(key), c_str(value)))
elif key == 'lr':
__check_call(_LIB.XLearnSetFloat(self.handle,
_check_call(_LIB.XLearnSetFloat(self.handle,
c_str(key), c_float(value)))
elif key == 'k':
__check_call(_LIB.XLearnSetInt(self.handle,
_check_call(_LIB.XLearnSetInt(self.handle,
c_str(key), c_uint(value)))
elif key == 'lambda':
__check_call(_LIB.XLearnSetFloat(self.handle,
_check_call(_LIB.XLearnSetFloat(self.handle,
c_str(key), c_float(value)))
elif key == 'init':
__check_call(_LIB.XLearnSetFloat(self.handle,
_check_call(_LIB.XLearnSetFloat(self.handle,
c_str(key), c_float(value)))
elif key == 'epoch':
__check_call(_LIB(XLearnSetInt(self.handle,
_check_call(_LIB(XLearnSetInt(self.handle,
c_str(key), c_uint(value))))
elif key == 'fold':
__check_call(_LIB(XLearnSetInt(self.handle,
_check_call(_LIB(XLearnSetInt(self.handle,
c_str(key), c_uint(value))))
else:
raise Exception("Invalid key!", key)
def show(self):
"""Show model information
"""
_check_call(_LIB.XLearnShow(self.handle))

def setTrain(self, train_path):
"""Set file path of training data.
Expand Down Expand Up @@ -93,43 +97,43 @@ def setValidate(self, val_path):
def setQuiet(self):
"""Set xlearn to quiet model"""
key = 'quiet'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(True)))

def setOnDisk(self):
"""Set xlearn to use on-disk training"""
key = 'on_disk'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(True)))

def disableNorm(self):
"""Disable instance-wise normalization"""
key = 'norm'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(False)))

def setLockFree(self):
"""Set xlearn to use lock free training"""
key = 'lock_free'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(True)))

def disableEarlyStop(self):
"""Disable early-stopping"""
key = 'early_stop'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(False)))

def setSign(self):
"""Convert output to 0 and 1"""
key = 'sign'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(True)))

def setSigmoid(self):
"""Convert output by using sigmoid"""
key = 'sigmoid'
__check_call(_LIB.XLearnSetBool(self.handle,
_check_call(_LIB.XLearnSetBool(self.handle,
c_str(key), c_bool(True)))

def fit(self, param, model_path):
Expand All @@ -142,7 +146,7 @@ def fit(self, param, model_path):
model_path : str
path of model checkpoint.
"""
__set_Param(param)
_set_Param(param)
_check_call(_LIB.XLearnFit(self.handle, c_str(model_path)))

def cv(self, param):
Expand All @@ -153,7 +157,7 @@ def cv(self, param):
param : dict
hyper-parameter used by xlearn
"""
__set_Param(param)
_set_Param(param)
_check_call(_LIB.XLearnCV(self.handle))

def predict(self, model_path, out_path):
Expand Down
10 changes: 10 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,16 @@ XL_DLL int XLearnHandleFree(XLearnHandle *out) {
API_END();
}

// Show the mode information
XL_DLL int XLearnShow(XLearnHandle *out) {
API_BEGIN()
XLearn* xl = reinterpret_cast<XLearn*>(*out);
printf("Info: \n Model: %s\n Loss: %s\n",
xl->GetHyperParam().score_func.c_str(),
xl->GetHyperParam().loss_func.c_str());
API_END()
}

// Set file path of the training data
XL_DLL int XLearnSetTrain(XLearnHandle *out,
const char *train_path) {
Expand Down
3 changes: 3 additions & 0 deletions src/c_api/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ XL_DLL int XLearnCreate(const char *model_type,
// Free the xLearn handle
XL_DLL int XLearnHandleFree(XLearnHandle *out);

// Show the model information
XL_DLL int XLearnShow(XLearnHandle *out);

// Set file path of the training data
XL_DLL int XLearnSetTrain(XLearnHandle *out,
const char *train_path);
Expand Down
1 change: 1 addition & 0 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ TEST(C_API_TEST, Initialize) {
EXPECT_EQ(XLearnSetBool(&xlearn, "early_stop", false), 0);
EXPECT_EQ(XLearnSetBool(&xlearn, "sign", true), 0);
EXPECT_EQ(XLearnSetBool(&xlearn, "sigmoid", true), 0);
EXPECT_EQ(XLearnShow(&xlearn), 0);
// Test
XLearn* xl = reinterpret_cast<XLearn*>(xlearn);
EXPECT_EQ(xl->GetHyperParam().score_func, "linear");
Expand Down

0 comments on commit 5deb6d9

Please sign in to comment.