Skip to content

Commit 562ed8c

Browse files
author
Darius Morawiec
committedDec 2, 2017
Add new templates, examples and tests for the LinearSVC classifier
1 parent 2e98680 commit 562ed8c

30 files changed

+642
-16
lines changed
 

‎examples/estimator/classifier/LinearSVC/java/basics.ipynb

+81-5
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
},
2323
{
2424
"cell_type": "code",
25-
"execution_count": 2,
25+
"execution_count": 1,
2626
"metadata": {},
2727
"outputs": [
2828
{
@@ -85,7 +85,7 @@
8585
},
8686
{
8787
"cell_type": "code",
88-
"execution_count": 4,
88+
"execution_count": 5,
8989
"metadata": {
9090
"scrolled": false
9191
},
@@ -130,8 +130,8 @@
130130
" }\n",
131131
"\n",
132132
" // Parameters:\n",
133-
" double[][] coefficients = {{0.18424209458473811, 0.45123000025163923, -0.80794587716737576, -0.45071660033253858}, {0.052877455748516447, -0.89214995228605254, 0.40398084459610972, -0.9376821661447452}, {-0.85070784319293802, -0.98670214922204336, 1.381010448739191, 1.8654095662423917}};\n",
134-
" double[] intercepts = {0.10956266406702335, 1.6636707776739579, -1.7096109416521363};\n",
133+
" double[][] coefficients = {{0.184242094585, 0.451230000252, -0.807945877167, -0.450716600333}, {0.0528774557485, -0.892149952286, 0.403980844596, -0.937682166145}, {-0.850707843193, -0.986702149222, 1.38101044874, 1.86540956624}};\n",
134+
" double[] intercepts = {0.109562664067, 1.66367077767, -1.70961094165};\n",
135135
"\n",
136136
" // Prediction:\n",
137137
" LinearSVC clf = new LinearSVC(coefficients, intercepts);\n",
@@ -140,18 +140,94 @@
140140
"\n",
141141
" }\n",
142142
" }\n",
143-
"}\n"
143+
"}\n",
144+
"CPU times: user 957 µs, sys: 951 µs, total: 1.91 ms\n",
145+
"Wall time: 1.08 ms\n"
144146
]
145147
}
146148
],
147149
"source": [
150+
"%%time\n",
151+
"\n",
148152
"from sklearn_porter import Porter\n",
149153
"\n",
150154
"porter = Porter(clf)\n",
151155
"output = porter.export()\n",
152156
"\n",
153157
"print(output)"
154158
]
159+
},
160+
{
161+
"cell_type": "markdown",
162+
"metadata": {},
163+
"source": [
164+
"### Run classification in Java:"
165+
]
166+
},
167+
{
168+
"cell_type": "markdown",
169+
"metadata": {},
170+
"source": [
171+
"Save the transpiled estimator:"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": 6,
177+
"metadata": {
178+
"collapsed": true
179+
},
180+
"outputs": [],
181+
"source": [
182+
"with open('LinearSVC.java', 'w') as f:\n",
183+
" f.write(output)"
184+
]
185+
},
186+
{
187+
"cell_type": "markdown",
188+
"metadata": {},
189+
"source": [
190+
"Compiling:"
191+
]
192+
},
193+
{
194+
"cell_type": "code",
195+
"execution_count": 7,
196+
"metadata": {
197+
"collapsed": true
198+
},
199+
"outputs": [],
200+
"source": [
201+
"%%bash\n",
202+
"\n",
203+
"javac -cp . LinearSVC.java"
204+
]
205+
},
206+
{
207+
"cell_type": "markdown",
208+
"metadata": {},
209+
"source": [
210+
"Prediction:"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": 8,
216+
"metadata": {},
217+
"outputs": [
218+
{
219+
"name": "stdout",
220+
"output_type": "stream",
221+
"text": [
222+
"2\n"
223+
]
224+
}
225+
],
226+
"source": [
227+
"%%bash\n",
228+
"\n",
229+
"java -cp . LinearSVC 1 2 3 4"
230+
]
155231
}
156232
],
157233
"metadata": {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,333 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# sklearn-porter\n",
8+
"\n",
9+
"Repository: https://github.com/nok/sklearn-porter\n",
10+
"\n",
11+
"## LinearSVC\n",
12+
"\n",
13+
"Documentation: [sklearn.svm.LinearSVC](http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html)"
14+
]
15+
},
16+
{
17+
"cell_type": "markdown",
18+
"metadata": {},
19+
"source": [
20+
"### Loading data:"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": 1,
26+
"metadata": {},
27+
"outputs": [
28+
{
29+
"name": "stdout",
30+
"output_type": "stream",
31+
"text": [
32+
"((150, 4), (150,))\n"
33+
]
34+
}
35+
],
36+
"source": [
37+
"from sklearn.datasets import load_iris\n",
38+
"\n",
39+
"iris_data = load_iris()\n",
40+
"X = iris_data.data\n",
41+
"y = iris_data.target\n",
42+
"\n",
43+
"print(X.shape, y.shape)"
44+
]
45+
},
46+
{
47+
"cell_type": "markdown",
48+
"metadata": {},
49+
"source": [
50+
"### Train classifier:"
51+
]
52+
},
53+
{
54+
"cell_type": "code",
55+
"execution_count": 2,
56+
"metadata": {},
57+
"outputs": [
58+
{
59+
"data": {
60+
"text/plain": [
61+
"LinearSVC(C=1.0, class_weight=None, dual=True, fit_intercept=True,\n",
62+
" intercept_scaling=1, loss='squared_hinge', max_iter=1000,\n",
63+
" multi_class='ovr', penalty='l2', random_state=0, tol=0.0001,\n",
64+
" verbose=0)"
65+
]
66+
},
67+
"execution_count": 2,
68+
"metadata": {},
69+
"output_type": "execute_result"
70+
}
71+
],
72+
"source": [
73+
"from sklearn import svm\n",
74+
"\n",
75+
"clf = svm.LinearSVC(C=1., random_state=0)\n",
76+
"clf.fit(X, y)"
77+
]
78+
},
79+
{
80+
"cell_type": "markdown",
81+
"metadata": {},
82+
"source": [
83+
"### Transpile classifier:"
84+
]
85+
},
86+
{
87+
"cell_type": "code",
88+
"execution_count": 4,
89+
"metadata": {
90+
"scrolled": false
91+
},
92+
"outputs": [
93+
{
94+
"name": "stdout",
95+
"output_type": "stream",
96+
"text": [
97+
"import java.io.File;\n",
98+
"import java.io.FileNotFoundException;\n",
99+
"import java.util.*;\n",
100+
"import com.google.gson.Gson;\n",
101+
"\n",
102+
"\n",
103+
"class LinearSVC {\n",
104+
"\n",
105+
" private class Classifier {\n",
106+
" private double[][] coefficients;\n",
107+
" private double[] intercepts;\n",
108+
" }\n",
109+
"\n",
110+
" private Classifier clf;\n",
111+
"\n",
112+
" public LinearSVC(String file) throws FileNotFoundException {\n",
113+
" String jsonStr = new Scanner(new File(file)).useDelimiter(\"\\\\Z\").next();\n",
114+
" this.clf = new Gson().fromJson(jsonStr, Classifier.class);\n",
115+
" }\n",
116+
"\n",
117+
" public int predict(double[] features) {\n",
118+
" int classIdx = 0;\n",
119+
" double classVal = Double.NEGATIVE_INFINITY;\n",
120+
" for (int i = 0, il = this.clf.intercepts.length; i < il; i++) {\n",
121+
" double prob = 0.;\n",
122+
" for (int j = 0, jl = this.clf.coefficients[0].length; j < jl; j++) {\n",
123+
" prob += this.clf.coefficients[i][j] * features[j];\n",
124+
" }\n",
125+
" if (prob + this.clf.intercepts[i] > classVal) {\n",
126+
" classVal = prob + this.clf.intercepts[i];\n",
127+
" classIdx = i;\n",
128+
" }\n",
129+
" }\n",
130+
" return classIdx;\n",
131+
" }\n",
132+
"\n",
133+
" public static void main(String[] args) throws FileNotFoundException {\n",
134+
" if (args.length > 0 && args[0].endsWith(\".json\")) {\n",
135+
"\n",
136+
" // Features:\n",
137+
" double[] features = new double[args.length-1];\n",
138+
" for (int i = 1, l = args.length; i < l; i++) {\n",
139+
" features[i - 1] = Double.parseDouble(args[i]);\n",
140+
" }\n",
141+
"\n",
142+
" // Parameters:\n",
143+
" String modelData = args[0];\n",
144+
"\n",
145+
" // Estimators:\n",
146+
" LinearSVC clf = new LinearSVC(modelData);\n",
147+
"\n",
148+
" // Prediction:\n",
149+
" int prediction = clf.predict(features);\n",
150+
" System.out.println(prediction);\n",
151+
"\n",
152+
" }\n",
153+
" }\n",
154+
"}\n",
155+
"CPU times: user 1.1 ms, sys: 1.6 ms, total: 2.7 ms\n",
156+
"Wall time: 1.36 ms\n"
157+
]
158+
}
159+
],
160+
"source": [
161+
"%%time\n",
162+
"\n",
163+
"from sklearn_porter import Porter\n",
164+
"\n",
165+
"porter = Porter(clf)\n",
166+
"output = porter.export(export_data=True)\n",
167+
"\n",
168+
"print(output)"
169+
]
170+
},
171+
{
172+
"cell_type": "markdown",
173+
"metadata": {},
174+
"source": [
175+
"Parameters:"
176+
]
177+
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": 5,
181+
"metadata": {},
182+
"outputs": [
183+
{
184+
"name": "stdout",
185+
"output_type": "stream",
186+
"text": [
187+
"{\"coefficients\": [[0.184242094585, 0.451230000252, -0.807945877167, -0.450716600333], [0.0528774557485, -0.892149952286, 0.403980844596, -0.937682166145], [-0.850707843193, -0.986702149222, 1.38101044874, 1.86540956624]], \"intercepts\": [0.109562664067, 1.66367077767, -1.70961094165]}"
188+
]
189+
}
190+
],
191+
"source": [
192+
"%%bash\n",
193+
"\n",
194+
"cat data.json"
195+
]
196+
},
197+
{
198+
"cell_type": "markdown",
199+
"metadata": {
200+
"hideOutput": false
201+
},
202+
"source": [
203+
"### Run classification in Java:"
204+
]
205+
},
206+
{
207+
"cell_type": "markdown",
208+
"metadata": {},
209+
"source": [
210+
"Save the transpiled estimator:"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": 6,
216+
"metadata": {
217+
"collapsed": true
218+
},
219+
"outputs": [],
220+
"source": [
221+
"with open('LinearSVC.java', 'w') as f:\n",
222+
" f.write(output)"
223+
]
224+
},
225+
{
226+
"cell_type": "markdown",
227+
"metadata": {},
228+
"source": [
229+
"Download the dependencies:"
230+
]
231+
},
232+
{
233+
"cell_type": "code",
234+
"execution_count": 7,
235+
"metadata": {},
236+
"outputs": [
237+
{
238+
"name": "stderr",
239+
"output_type": "stream",
240+
"text": [
241+
"--2017-12-02 17:35:06-- http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar\n",
242+
"Resolving central.maven.org... 151.101.36.209\n",
243+
"Connecting to central.maven.org|151.101.36.209|:80... connected.\n",
244+
"HTTP request sent, awaiting response... 200 OK\n",
245+
"Length: 232932 (227K) [application/java-archive]\n",
246+
"Saving to: 'gson-2.8.2.jar'\n",
247+
"\n",
248+
" 0K .......... .......... .......... .......... .......... 21% 1.87M 0s\n",
249+
" 50K .......... .......... .......... .......... .......... 43% 3.28M 0s\n",
250+
" 100K .......... .......... .......... .......... .......... 65% 6.27M 0s\n",
251+
" 150K .......... .......... .......... .......... .......... 87% 3.97M 0s\n",
252+
" 200K .......... .......... ....... 100% 171M=0.06s\n",
253+
"\n",
254+
"2017-12-02 17:35:06 (3.62 MB/s) - 'gson-2.8.2.jar' saved [232932/232932]\n",
255+
"\n"
256+
]
257+
}
258+
],
259+
"source": [
260+
"%%bash\n",
261+
"\n",
262+
"wget http://central.maven.org/maven2/com/google/code/gson/gson/2.8.2/gson-2.8.2.jar"
263+
]
264+
},
265+
{
266+
"cell_type": "markdown",
267+
"metadata": {},
268+
"source": [
269+
"Compiling:"
270+
]
271+
},
272+
{
273+
"cell_type": "code",
274+
"execution_count": 8,
275+
"metadata": {
276+
"collapsed": true
277+
},
278+
"outputs": [],
279+
"source": [
280+
"%%bash\n",
281+
"\n",
282+
"javac -cp .:gson-2.8.2.jar LinearSVC.java"
283+
]
284+
},
285+
{
286+
"cell_type": "markdown",
287+
"metadata": {},
288+
"source": [
289+
"Prediction:"
290+
]
291+
},
292+
{
293+
"cell_type": "code",
294+
"execution_count": 9,
295+
"metadata": {},
296+
"outputs": [
297+
{
298+
"name": "stdout",
299+
"output_type": "stream",
300+
"text": [
301+
"2\n"
302+
]
303+
}
304+
],
305+
"source": [
306+
"%%bash\n",
307+
"\n",
308+
"java -cp .:gson-2.8.2.jar LinearSVC data.json 1 2 3 4"
309+
]
310+
}
311+
],
312+
"metadata": {
313+
"kernelspec": {
314+
"display_name": "Python 2",
315+
"language": "python",
316+
"name": "python2"
317+
},
318+
"language_info": {
319+
"codemirror_mode": {
320+
"name": "ipython",
321+
"version": 2
322+
},
323+
"file_extension": ".py",
324+
"mimetype": "text/x-python",
325+
"name": "python",
326+
"nbconvert_exporter": "python",
327+
"pygments_lexer": "ipython2",
328+
"version": "2.7.13"
329+
}
330+
},
331+
"nbformat": 4,
332+
"nbformat_minor": 2
333+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# -*- coding: utf-8 -*-
2+
3+
from sklearn import svm
4+
from sklearn.datasets import load_iris
5+
from sklearn_porter import Porter
6+
7+
8+
iris_data = load_iris()
9+
X = iris_data.data
10+
y = iris_data.target
11+
12+
clf = svm.LinearSVC(C=1., random_state=0)
13+
clf.fit(X, y)
14+
15+
porter = Porter(clf)
16+
output = porter.export(export_data=True)
17+
print(output)
18+
19+
"""
20+
import java.io.File;
21+
import java.io.FileNotFoundException;
22+
import java.util.*;
23+
import com.google.gson.Gson;
24+
25+
26+
class LinearSVC {
27+
28+
private class Classifier {
29+
private double[][] coefficients;
30+
private double[] intercepts;
31+
}
32+
33+
private Classifier clf;
34+
35+
public LinearSVC(String file) throws FileNotFoundException {
36+
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
37+
this.clf = new Gson().fromJson(jsonStr, Classifier.class);
38+
}
39+
40+
public int predict(double[] features) {
41+
int classIdx = 0;
42+
double classVal = Double.NEGATIVE_INFINITY;
43+
for (int i = 0, il = this.clf.intercepts.length; i < il; i++) {
44+
double prob = 0.;
45+
for (int j = 0, jl = this.clf.coefficients[0].length; j < jl; j++) {
46+
prob += this.clf.coefficients[i][j] * features[j];
47+
}
48+
if (prob + this.clf.intercepts[i] > classVal) {
49+
classVal = prob + this.clf.intercepts[i];
50+
classIdx = i;
51+
}
52+
}
53+
return classIdx;
54+
}
55+
56+
public static void main(String[] args) throws FileNotFoundException {
57+
if (args.length > 0 && args[0].endsWith(".json")) {
58+
59+
// Features:
60+
double[] features = new double[args.length-1];
61+
for (int i = 1, l = args.length; i < l; i++) {
62+
features[i - 1] = Double.parseDouble(args[i]);
63+
}
64+
65+
// Parameters:
66+
String modelData = args[0];
67+
68+
// Estimators:
69+
LinearSVC clf = new LinearSVC(modelData);
70+
71+
// Prediction:
72+
int prediction = clf.predict(features);
73+
System.out.println(prediction);
74+
75+
}
76+
}
77+
}
78+
"""

‎readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ Transpile trained [scikit-learn](https://github.com/scikit-learn/scikit-learn) e
4747
</tr>
4848
<tr>
4949
<td><a href="http://scikit-learn.org/stable/modules/generated/sklearn.svm.LinearSVC.html">svm.LinearSVC</a></td>
50-
<td align="center"><a href="examples/estimator/classifier/LinearSVC/java/basics.ipynb">✓</a></td>
50+
<td align="center"><a href="examples/estimator/classifier/LinearSVC/java/basics.ipynb">✓</a>, <a href="examples/estimator/classifier/LinearSVC/java/basics_imported.ipynb">✓ ᴵ</a></td>
5151
<td align="center"><a href="examples/estimator/classifier/LinearSVC/js/basics.ipynb">✓</a></td>
5252
<td align="center"><a href="examples/estimator/classifier/LinearSVC/c/basics.ipynb">✓</a></td>
5353
<td align="center"><a href="examples/estimator/classifier/LinearSVC/go/basics.ipynb">✓</a></td>

‎sklearn_porter/estimator/classifier/LinearSVC/__init__.py

+36-9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# -*- coding: utf-8 -*-
22

3+
import os
4+
import json
5+
from json import encoder
6+
37
from sklearn_porter.estimator.classifier.Classifier import Classifier
48

59

@@ -85,7 +89,9 @@ def __init__(self, estimator, target_language='java',
8589
target_method=target_method, **kwargs)
8690
self.estimator = estimator
8791

88-
def export(self, class_name, method_name, **kwargs):
92+
def export(self, class_name, method_name,
93+
export_data=False, export_dir='.',
94+
**kwargs):
8995
"""
9096
Port a trained estimator to the syntax of a chosen programming language.
9197
@@ -153,9 +159,14 @@ def export(self, class_name, method_name, **kwargs):
153159
self.intercepts = inters
154160

155161
if self.target_method == 'predict':
156-
return self.predict()
157-
158-
def predict(self):
162+
# Exported:
163+
if export_data and os.path.isdir(export_dir):
164+
self.export_data(export_dir)
165+
return self.predict('exported')
166+
# Separated:
167+
return self.predict('separated')
168+
169+
def predict(self, temp_type):
159170
"""
160171
Transpile the predict method.
161172
@@ -164,9 +175,25 @@ def predict(self):
164175
:return : string
165176
The transpiled predict method as string.
166177
"""
178+
# Exported:
179+
if temp_type == 'exported':
180+
temp = self.temp('exported.{}.class'.format(self.prefix))
181+
return temp.format(class_name=self.class_name,
182+
method_name=self.method_name)
183+
184+
# Separated
167185
self.method = self.create_method()
168-
output = self.create_class()
169-
return output
186+
return self.create_class()
187+
188+
def export_data(self, export_dir):
189+
model_data = {
190+
'coefficients': (self.estimator.coef_[0] if self.is_binary else self.estimator.coef_).tolist(),
191+
'intercepts': (self.estimator.intercept_[0] if self.is_binary else self.estimator.intercept_).tolist(),
192+
}
193+
encoder.FLOAT_REPR = lambda o: self.repr(o)
194+
path = os.path.join(export_dir, 'data.json')
195+
with open(path, 'w') as fp:
196+
json.dump(model_data, fp)
170197

171198
def create_method(self):
172199
"""
@@ -178,7 +205,7 @@ def create_method(self):
178205
The built method as string.
179206
"""
180207
n_indents = 0 if self.target_language in ['c', 'go'] else 1
181-
method_type = '{}.method'.format(self.prefix)
208+
method_type = 'separated.{}.method'.format(self.prefix)
182209
method_temp = self.temp(method_type, n_indents=n_indents, skipping=True)
183210
output = method_temp.format(**self.__dict__)
184211
return output
@@ -194,9 +221,9 @@ def create_class(self):
194221
"""
195222
if self.target_language in ['java', 'go']:
196223
n_indents = 1 if self.target_language == 'java' else 0
197-
class_head_temp = self.temp('{}.class'.format(self.prefix),
224+
class_head_temp = self.temp('separated.{}.class'.format(self.prefix),
198225
n_indents=n_indents, skipping=True)
199226
self.class_head = class_head_temp.format(**self.__dict__)
200227

201-
output = self.temp('class').format(**self.__dict__)
228+
output = self.temp('separated.class').format(**self.__dict__)
202229
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import java.io.File;
2+
import java.io.FileNotFoundException;
3+
import java.util.*;
4+
import com.google.gson.Gson;
5+
6+
7+
class {class_name} {{
8+
9+
private class Classifier {{
10+
private double[] coefficients;
11+
private double intercepts;
12+
}}
13+
14+
private Classifier clf;
15+
16+
public {class_name}(String file) throws FileNotFoundException {{
17+
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
18+
this.clf = new Gson().fromJson(jsonStr, Classifier.class);
19+
}}
20+
21+
public int {method_name}(double[] features) {{
22+
double prob = 0.;
23+
for (int i = 0, il = this.clf.coefficients.length; i < il; i++) {{
24+
prob += this.clf.coefficients[i] * features[i];
25+
}}
26+
if (prob + this.clf.intercepts > 0) {{
27+
return 1;
28+
}}
29+
return 0;
30+
}}
31+
32+
public static void main(String[] args) throws FileNotFoundException {{
33+
if (args.length > 0 && args[0].endsWith(".json")) {{
34+
35+
// Features:
36+
double[] features = new double[args.length-1];
37+
for (int i = 1, l = args.length; i < l; i++) {{
38+
features[i - 1] = Double.parseDouble(args[i]);
39+
}}
40+
41+
// Parameters:
42+
String modelData = args[0];
43+
44+
// Estimators:
45+
{class_name} clf = new {class_name}(modelData);
46+
47+
// Prediction:
48+
int prediction = clf.{method_name}(features);
49+
System.out.println(prediction);
50+
51+
}}
52+
}}
53+
}}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import java.io.File;
2+
import java.io.FileNotFoundException;
3+
import java.util.*;
4+
import com.google.gson.Gson;
5+
6+
7+
class {class_name} {{
8+
9+
private class Classifier {{
10+
private double[][] coefficients;
11+
private double[] intercepts;
12+
}}
13+
14+
private Classifier clf;
15+
16+
public {class_name}(String file) throws FileNotFoundException {{
17+
String jsonStr = new Scanner(new File(file)).useDelimiter("\\Z").next();
18+
this.clf = new Gson().fromJson(jsonStr, Classifier.class);
19+
}}
20+
21+
public int {method_name}(double[] features) {{
22+
int classIdx = 0;
23+
double classVal = Double.NEGATIVE_INFINITY;
24+
for (int i = 0, il = this.clf.intercepts.length; i < il; i++) {{
25+
double prob = 0.;
26+
for (int j = 0, jl = this.clf.coefficients[0].length; j < jl; j++) {{
27+
prob += this.clf.coefficients[i][j] * features[j];
28+
}}
29+
if (prob + this.clf.intercepts[i] > classVal) {{
30+
classVal = prob + this.clf.intercepts[i];
31+
classIdx = i;
32+
}}
33+
}}
34+
return classIdx;
35+
}}
36+
37+
public static void main(String[] args) throws FileNotFoundException {{
38+
if (args.length > 0 && args[0].endsWith(".json")) {{
39+
40+
// Features:
41+
double[] features = new double[args.length-1];
42+
for (int i = 1, l = args.length; i < l; i++) {{
43+
features[i - 1] = Double.parseDouble(args[i]);
44+
}}
45+
46+
// Parameters:
47+
String modelData = args[0];
48+
49+
// Estimators:
50+
{class_name} clf = new {class_name}(modelData);
51+
52+
// Prediction:
53+
int prediction = clf.{method_name}(features);
54+
System.out.println(prediction);
55+
56+
}}
57+
}}
58+
}}

‎tests/estimator/classifier/LinearSVC/LinearSVCJavaTest.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
from sklearn_porter import Porter
1414

1515
from tests.estimator.classifier.Classifier import Classifier
16+
from tests.estimator.classifier.ExportedData import ExportedData
1617
from tests.language.Java import Java
1718

1819

19-
class LinearSVCJavaTest(Java, Classifier, TestCase):
20+
class LinearSVCJavaTest(Java, Classifier, ExportedData, TestCase):
2021

2122
def setUp(self):
2223
super(LinearSVCJavaTest, self).setUp()

0 commit comments

Comments
 (0)
Please sign in to comment.