Skip to content

Commit 0669645

Browse files
author
Darius Morawiec
committedDec 3, 2017
Add '--export' attribute and file handling
1 parent 5baa0f2 commit 0669645

File tree

4 files changed

+28
-23
lines changed

4 files changed

+28
-23
lines changed
 

‎readme.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ First of all have a quick view on the available arguments:
233233
$ python -m sklearn_porter [-h] --input <PICKLE_FILE> [--output <DEST_DIR>] \
234234
[--class_name <CLASS_NAME>] [--method_name <METHOD_NAME>] \
235235
[--c] [--java] [--js] [--go] [--php] [--ruby] \
236-
[--pipe]
236+
[--export] [--pipe]
237237
```
238238

239239
The following example shows how you can save an trained estimator to the [pickle format](http://scikit-learn.org/stable/modules/model_persistence.html#persistence-example):

‎sklearn_porter/Porter.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,10 @@ def export(self, class_name=None, method_name=None,
176176
with further information.
177177
"""
178178

179-
if class_name is None:
179+
if class_name is None or class_name == '':
180180
class_name = self.estimator_name
181181

182-
if method_name is None:
182+
if method_name is None or method_name == '':
183183
method_name = self.target_method
184184

185185
if isinstance(num_format, types.LambdaType):
@@ -196,7 +196,7 @@ def export(self, class_name=None, method_name=None,
196196
class_name,
197197
language)
198198
output = {
199-
'model': str(output),
199+
'estimator': str(output),
200200
'filename': filename,
201201
'class_name': class_name,
202202
'method_name': method_name,
@@ -495,12 +495,12 @@ def _get_filename(class_name, language):
495495
filename : str
496496
The generated filename.
497497
"""
498-
name = str(class_name).lower()
498+
name = str(class_name).strip()
499499
lang = str(language)
500500

501501
# Name:
502502
if language in ['java', 'php']:
503-
name = name.capitalize()
503+
name = "".join([name[0].upper() + name[1:]])
504504

505505
# Suffix:
506506
suffix = {

‎sklearn_porter/__main__.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,20 @@ def parse_args(args):
3333
'stored.'))
3434
optional.add_argument(
3535
'--class_name',
36-
default='Brain',
36+
default=None,
3737
required=False,
3838
help='Define the class name in the final output.')
3939
optional.add_argument(
4040
'--method_name',
4141
default='predict',
4242
required=False,
4343
help='Define the method name in the final output.')
44+
optional.add_argument(
45+
'--export', '-e',
46+
required=False,
47+
default=False,
48+
action='store_true',
49+
help='Whether to export the model data or not.')
4450
optional.add_argument(
4551
'--pipe', '-p',
4652
required=False,
@@ -91,37 +97,36 @@ def main():
9197
language = key
9298
break
9399

100+
# Define destination path:
101+
dest_dir = str(args.get('output'))
102+
if dest_dir == '' or not os.path.isdir(dest_dir):
103+
dest_dir = input_path.split(os.sep)
104+
del dest_dir[-1]
105+
dest_dir = os.sep.join(dest_dir)
106+
94107
# Port estimator:
95108
try:
96109
porter = Porter(estimator, language=language)
97-
class_name = str(args.get('class_name'))
98-
method_name = str(args.get('method_name'))
110+
class_name = args.get('class_name')
111+
method_name = args.get('method_name')
99112
output = porter.export(class_name=class_name,
100113
method_name=method_name,
101-
output=str(args.get('output')),
114+
export_dir=dest_dir,
115+
export_data=bool(args.get('export')),
102116
details=True)
103117
except Exception as e:
104118
sys.exit('Error: {}'.format(str(e)))
105119
else:
106120
# Print transpiled estimator to the console:
107121
if bool(args.get('pipe', False)):
108-
print(output.get('model'))
122+
print(output.get('estimator'))
109123
sys.exit(0)
110124

111-
# Define destination path:
112-
dest_dir = str(args.get('output'))
113125
filename = output.get('filename')
114-
if dest_dir != '' and os.path.isdir(dest_dir):
115-
dest_path = os.path.join(dest_dir, filename)
116-
else:
117-
dest_dir = input_path.split(os.sep)
118-
del dest_dir[-1]
119-
dest_dir += [filename]
120-
dest_path = os.sep.join(dest_dir)
121-
126+
dest_path = dest_dir + os.sep + filename
122127
# Save transpiled estimator:
123128
with open(dest_path, 'w') as file_:
124-
file_.write(output.get('model'))
129+
file_.write(output.get('estimator'))
125130

126131

127132
if __name__ == "__main__":

‎tests/PorterTest.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_python_command_execution(self):
6565
joblib.dump(self.estimator, pkl_path)
6666

6767
# Port estimator:
68-
cmd = 'python -m sklearn_porter -i {}'.format(pkl_path).split()
68+
cmd = 'python -m sklearn_porter -i {} --class_name Brain'.format(pkl_path).split()
6969
subp.call(cmd)
7070
# Compare file contents:
7171
equal = filecmp.cmp(cp_src, cp_dest)

0 commit comments

Comments
 (0)
Please sign in to comment.