-
Notifications
You must be signed in to change notification settings - Fork 169
/
Copy pathbasics_imported.pct.py
75 lines (55 loc) · 1.71 KB
/
basics_imported.pct.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# %% [markdown]
# # sklearn-porter
#
# Repository: [https://github.com/nok/sklearn-porter](https://github.com/nok/sklearn-porter)
#
# ## MLPClassifier
#
# Documentation: [sklearn.neural_network.MLPClassifier](http://scikit-learn.org/stable/modules/generated/sklearn.neural_network.MLPClassifier.html)
# %%
import sys
sys.path.append('../../../../..')
# %% [markdown]
# ### Load data
# %%
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle
iris_data = load_iris()
X = iris_data.data
y = iris_data.target
X = shuffle(X, random_state=0)
y = shuffle(y, random_state=0)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.4, random_state=5)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
# %% [markdown]
# ### Train classifier
# %%
from sklearn.neural_network import MLPClassifier
clf = MLPClassifier(activation='relu', hidden_layer_sizes=50,
max_iter=500, alpha=1e-4, solver='sgd',
tol=1e-4, random_state=1, learning_rate_init=.1)
clf.fit(X_train, y_train)
# %% [markdown]
# ### Transpile classifier
# %%
from sklearn_porter import Porter
porter = Porter(clf, language='java')
output = porter.export(export_data=True)
print(output)
# %% [markdown]
# ### Run classification in Java
# %%
# Save classifier:
# with open('MLPClassifier.java', 'w') as f:
# f.write(output)
# Check model data:
# $ cat data.json
# Download dependencies:
# $ wget -O gson.jar http://central.maven.org/maven2/com/google/code/gson/gson/2.8.5/gson-2.8.5.jar
# Compile model:
# $ javac -cp .:gson.jar MLPClassifier.java
# Run classification:
# $ java -cp .:gson.jar MLPClassifier data.json 1 2 3 4