Skip to content

Commit

Permalink
sklearn vw improvements: clearing params when constructing model, bet…
Browse files Browse the repository at this point in the history
…ter support for sparse matrices, fixing bug in tovw conversion for values ending in 0, adding passes argument to constructor
  • Loading branch information
Scott Graham committed Dec 17, 2015
1 parent 2f844c6 commit b35c1db
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
30 changes: 22 additions & 8 deletions python/sklearn_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class VW(BaseEstimator, vw):
"""

params = dict()
passes = 1

def __init__(self,
random_seed=None,
Expand Down Expand Up @@ -91,6 +92,7 @@ def __init__(self,
f=None,
readable_model=None,
invert_hash=None,
passes=None,
save_resume=None,
output_feature_regularizer_binary=None,
output_feature_regularizer_text=None):
Expand Down Expand Up @@ -165,6 +167,7 @@ def __init__(self,
final_regressor,f (str): Final regressor
readable_model (str): Output human-readable final regressor with numeric features
invert_hash (str): Output human-readable final regressor with feature names. Computationally expensive.
passes (int): Number of training passes
save_resume (bool): save extra state so learning can be resumed later with new data
output_feature_regularizer_binary (str): Per feature regularization output file
output_feature_regularizer_text (str): Per feature regularization output file, in text
Expand All @@ -179,14 +182,17 @@ def __init__(self,
if hasattr(self, 'fit_'):
del self.fit_

# quiet models by default
if 'quiet' not in self.params:
self.params['quiet'] = True
# reset params and quiet models by default
self.params = {'quiet': True}

# assign all valid args to params dict
for k, v in locals().iteritems():
if k != 'self' and v is not None:
self.params[k] = v

# store passes separately to be used in fit
self.passes = self.params.pop('passes', 1)

super(VW, self).__init__(**self.params)

def fit(self, X, y=None, sample_weight=None, convert_to_vw=True):
Expand All @@ -207,8 +213,11 @@ def fit(self, X, y=None, sample_weight=None, convert_to_vw=True):
"""

# add examples to model
for ex in X if not convert_to_vw else tovw(x=X, y=y, sample_weight=sample_weight):
self.learn(ex)
for _ in xrange(self.passes):
for idx, x in enumerate(X):
if convert_to_vw:
x = tovw(x=x, y=y[idx], sample_weight=sample_weight)[0]
self.learn(x)
self.fit_ = True

def predict(self, X, convert_to_vw=True):
Expand Down Expand Up @@ -238,7 +247,9 @@ def predict(self, X, convert_to_vw=True):

# add test examples to model
y = np.empty([num_samples])
for idx, x in enumerate(X if not convert_to_vw else tovw(x=X)):
for idx, x in enumerate(X):
if convert_to_vw:
x = tovw(x)[0]
ex = self.example(x)
# need to set test bit to skip learning
ex.set_test_only(True)
Expand Down Expand Up @@ -319,6 +330,9 @@ class would be predicted.

return VW.predict(self, X=X)

def __del__(self):
VW.__del__(self)


class VWRegressor(VW, RegressorMixin):
""" Vowpal Wabbit Regressor model """
Expand Down Expand Up @@ -349,7 +363,7 @@ def tovw(x, y=None, sample_weight=None):
use_weight = sample_weight is not None

# convert to numpy array if needed
if not isinstance(x, np.ndarray):
if not isinstance(x, (np.ndarray, csr_matrix)):
x = np.array(x)
if not isinstance(y, np.ndarray):
y = np.array(y)
Expand Down Expand Up @@ -378,7 +392,7 @@ def tovw(x, y=None, sample_weight=None):
for idx, row in enumerate(rows):
truth = y[idx] if use_truth else 1
weight = sample_weight[idx] if use_weight else 1
features = row.split('0 ')[1]
features = row.split('0 ', 1)[1]
# only using a single namespace and no tags
out.append(('{y} {w} |{ns} {x}'.format(y=truth, w=weight, ns=DEFAULT_NS, x=features)))

Expand Down
36 changes: 29 additions & 7 deletions python/test_sklearn_vw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sklearn_vw import VW, VWClassifier, VWRegressor, tovw
from sklearn import datasets
from sklearn.utils.validation import NotFittedError
from scipy.sparse import csr_matrix


"""
Expand Down Expand Up @@ -32,6 +33,23 @@ def test_fit(self, data):
model.fit(data.x, data.y)
assert model.fit_

def test_passes(self, data):
n_passes = 2
model = VW(loss_function='logistic', passes=n_passes)
assert model.passes == n_passes

model.fit(data.x, data.y)
weights = model.get_coefs()

model = VW(loss_function='logistic')
# first pass weights should not be the same
model.fit(data.x, data.y)
assert not np.allclose(weights.data, model.get_coefs().data)

# second pass weights should match
model.fit(data.x, data.y)
assert np.allclose(weights.data, model.get_coefs().data)

def test_predict_not_fit(self, data):
model = VW(loss_function='logistic')
with pytest.raises(NotFittedError):
Expand All @@ -54,20 +72,22 @@ def test_set_params(self):
model.set_params(l=0.1)
assert model.params['l'] == 0.1

# confirm model params reset with new construction
model = VW()
assert 'l' not in model.params

def test_get_coefs(self, data):
model = VW()
model.fit(data.x, data.y)
weights = model.get_coefs()
print weights.data
assert np.allclose(weights.indices, [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 116060])
assert np.allclose(weights.data, [0.11553502, -0.0166647, -0.00349924, 0.06911729, 0.00252684,
-0.00826817, 0.01991862, -0.02473332, 0.00483846, -0.04616702, -0.00744559])

def test_get_intercept(self, data):
model = VW()
model.fit(data.x, data.y)
intercept = model.get_intercept()
assert np.isclose(intercept, -0.00744559)
assert isinstance(intercept, float)


class TestVWClassifier:
Expand Down Expand Up @@ -105,11 +125,13 @@ def test_predict(self, data):


def test_tovw():
x = np.array([[1.2, 3.4, 5.6], [7.8, 9.10, 11.]])
x = np.array([[1.2, 3.4, 5.6, 1.0, 10], [7.8, 9.10, 11, 0, 20]])
y = np.array([1, -1])
w = [1, 2]

expected = ['1 1 | 0:1.2 1:3.4 2:5.6',
'-1 2 | 0:7.8 1:9.1 2:11']
expected = ['1 1 | 0:1.2 1:3.4 2:5.6 3:1 4:10',
'-1 2 | 0:7.8 1:9.1 2:11 4:20']

assert tovw(x=x, y=y, sample_weight=w) == expected

assert tovw(x=x, y=y, sample_weight=w) == expected
assert tovw(x=csr_matrix(x), y=y, sample_weight=w) == expected

0 comments on commit b35c1db

Please sign in to comment.