Skip to content

Commit

Permalink
Merge pull request scikit-learn#4966 from lazywei/multilabel-dump-svm…
Browse files Browse the repository at this point in the history
…light

ENH Add multilabel support to dump_svmlight_file
  • Loading branch information
larsmans committed Jul 12, 2015
2 parents 2d4cd28 + 82ea224 commit 193b3c8
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 10 deletions.
33 changes: 23 additions & 10 deletions sklearn/datasets/svmlight_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
dtype : numpy data type, default np.float64
Data type of dataset to be loaded. This will be the data type of the
output numpy arrays ``X`` and ``y``.
Returns
-------
[X1, y1, ..., Xn, yn]
Expand Down Expand Up @@ -275,18 +275,19 @@ def load_svmlight_files(files, n_features=None, dtype=np.float64,
return result


def _dump_svmlight(X, y, f, one_based, comment, query_id):
def _dump_svmlight(X, y, f, multilabel, one_based, comment, query_id):
is_sp = int(hasattr(X, "tocsr"))
if X.dtype.kind == 'i':
value_pattern = u("%d:%d")
else:
value_pattern = u("%d:%.16g")

if y.dtype.kind == 'i':
line_pattern = u("%d")
label_pattern = u("%d")
else:
line_pattern = u("%.16g")
label_pattern = u("%.16g")

line_pattern = u("%s")
if query_id is not None:
line_pattern += u(" qid:%d")
line_pattern += u(" %s\n")
Expand All @@ -309,14 +310,22 @@ def _dump_svmlight(X, y, f, one_based, comment, query_id):
row = zip(np.where(nz)[0], X[i, nz])

s = " ".join(value_pattern % (j + one_based, x) for j, x in row)

if multilabel:
nz_labels = np.where(y[i] != 0)[0]
labels_str = ",".join(label_pattern % j for j in nz_labels)
else:
labels_str = label_pattern % y[i]

if query_id is not None:
feat = (y[i], query_id[i], s)
feat = (labels_str, query_id[i], s)
else:
feat = (y[i], s)
feat = (labels_str, s)

f.write((line_pattern % feat).encode('ascii'))


def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None):
def dump_svmlight_file(X, y, f, multilabel=False, zero_based=True, comment=None, query_id=None):
"""Dump the dataset in svmlight / libsvm file format.
This format is a text-based format, with one sample per line. It does
Expand All @@ -339,6 +348,10 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None):
If file-like, data will be written to f. f should be opened in binary
mode.
multilabel: boolean, optional
Samples may have several labels each (see
http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel.html)
zero_based : boolean, optional
Whether column indices should be written zero-based (True) or one-based
(False).
Expand Down Expand Up @@ -368,7 +381,7 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None):
raise ValueError("comment string contains NUL byte")

y = np.asarray(y)
if y.ndim != 1:
if y.ndim != 1 and not multilabel:
raise ValueError("expected y of shape (n_samples,), got %r"
% (y.shape,))

Expand Down Expand Up @@ -396,7 +409,7 @@ def dump_svmlight_file(X, y, f, zero_based=True, comment=None, query_id=None):
one_based = not zero_based

if hasattr(f, "write"):
_dump_svmlight(X, y, f, one_based, comment, query_id)
_dump_svmlight(X, y, f, multilabel, one_based, comment, query_id)
else:
with open(f, "wb") as f:
_dump_svmlight(X, y, f, one_based, comment, query_id)
_dump_svmlight(X, y, f, multilabel, one_based, comment, query_id)
14 changes: 14 additions & 0 deletions sklearn/datasets/tests/test_svmlight_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,20 @@ def test_dump():
assert_array_equal(y, y2)


def test_dump_multilabel():
X = [[1, 0, 3, 0, 5],
[0, 0, 0, 0, 0],
[0, 5, 0, 1, 0]]
y = [[0, 1, 0], [1, 0, 1], [1, 1, 0]]
f = BytesIO()
dump_svmlight_file(X, y, f, multilabel=True)
f.seek(0)
# make sure it dumps multilabel correctly
assert_equal(f.readline(), b("1 0:1 2:3 4:5\n"))
assert_equal(f.readline(), b("0,2 \n"))
assert_equal(f.readline(), b("0,1 1:5 3:1\n"))


def test_dump_concise():
one = 1
two = 2.1
Expand Down

0 comments on commit 193b3c8

Please sign in to comment.