Skip to content

Commit 8e07978

Browse files
committed
Added a script (meant to be run by hand) that helps to generate code
for DeepLearning clients (since there are so many parameters.
1 parent ac53f3f commit 8e07978

File tree

1 file changed

+229
-0
lines changed

1 file changed

+229
-0
lines changed

scripts/gen_deeplearning.py

+229
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
#!/usr/bin/env python
2+
3+
#
4+
# The purpose of the script is to parse the DeepLearning.java file and emit
5+
# code related to parameters.
6+
#
7+
# Currently pieces of R code get emitted and need to be pasted in manually to the R file.
8+
#
9+
10+
import sys
11+
import os
12+
import shutil
13+
import signal
14+
import time
15+
import random
16+
import getpass
17+
import re
18+
import subprocess
19+
20+
def read_deeplearning_file(deeplearning_file):
21+
"""
22+
Read deep learning file and generate R parameter stub stuff.
23+
24+
@param deeplearning_file: Java source code file
25+
@return: none
26+
"""
27+
try:
28+
nlist = []
29+
in_api = False
30+
31+
f = open(deeplearning_file, "r")
32+
s = f.readline()
33+
lineno = 0
34+
while (len(s) != 0):
35+
lineno = lineno + 1
36+
stripped = s.strip()
37+
if (len(stripped) == 0):
38+
s = f.readline()
39+
continue
40+
if (stripped.startswith("@API")):
41+
# print("")
42+
if (in_api):
43+
assert(False)
44+
in_api = True
45+
46+
# match_groups = re.search("help\s*=\s*\"([^\"]*)\"", stripped)
47+
# if (match_groups == None):
48+
# print("Missing help")
49+
# sys.exit(1)
50+
# help = match_groups.group(1)
51+
# print(help)
52+
s = f.readline()
53+
continue
54+
if (in_api):
55+
skip = False
56+
if "checkpoint" in stripped:
57+
skip = True
58+
if "expert_mode" in stripped:
59+
skip = True
60+
# if "activation" in stripped:
61+
# skip = True
62+
# if "initial_weight_distribution" in stripped:
63+
# skip = True
64+
# if "loss" in stripped:
65+
# skip = True
66+
# if "score_validation_sampling" in stripped:
67+
# skip = True
68+
69+
if (skip):
70+
in_api = False
71+
s = f.readline()
72+
continue
73+
74+
match_groups = re.search("public boolean (\S+) = (\S+);", s)
75+
if (match_groups is not None):
76+
t = "boolean"
77+
n = match_groups.group(1)
78+
v = match_groups.group(2)
79+
print(" parms = .addBooleanParm(parms, k=\"{}\", v={})".format(n,n))
80+
nlist.append(n)
81+
# print(t, n, v)
82+
in_api = False
83+
s = f.readline()
84+
continue
85+
86+
match_groups = re.search("public Activation (\S+) = (\S+);", s)
87+
if (match_groups is not None):
88+
t = "string"
89+
n = match_groups.group(1)
90+
v = match_groups.group(2)
91+
print(" parms = .addStringParm(parms, k=\"{}\", v={})".format(n,n))
92+
nlist.append(n)
93+
# print(t, n, v)
94+
in_api = False
95+
s = f.readline()
96+
continue
97+
98+
match_groups = re.search("public int\[\] (\S+) = .*;", s)
99+
if (match_groups is not None):
100+
t = "int array"
101+
n = match_groups.group(1)
102+
print(" parms = .addIntArrayParm(parms, k=\"{}\", v={})".format(n,n))
103+
nlist.append(n)
104+
# print(t, n)
105+
in_api = False
106+
s = f.readline()
107+
continue
108+
109+
match_groups = re.search("public int (\S+) = .*;", s)
110+
if (match_groups is not None):
111+
t = "int"
112+
n = match_groups.group(1)
113+
print(" parms = .addIntParm(parms, k=\"{}\", v={})".format(n,n))
114+
nlist.append(n)
115+
# print(t, n)
116+
in_api = False
117+
s = f.readline()
118+
continue
119+
120+
match_groups = re.search("public double (\S+) = (\S+);", s)
121+
if (match_groups is not None):
122+
t = "double"
123+
n = match_groups.group(1)
124+
v = match_groups.group(2)
125+
print(" parms = .addDoubleParm(parms, k=\"{}\", v={})".format(n,n))
126+
nlist.append(n)
127+
# print(t, n, v)
128+
in_api = False
129+
s = f.readline()
130+
continue
131+
132+
match_groups = re.search("public float (\S+) = (\S+);", s)
133+
if (match_groups is not None):
134+
t = "float"
135+
n = match_groups.group(1)
136+
v = match_groups.group(2)
137+
print(" parms = .addFloatParm(parms, k=\"{}\", v={})".format(n,n))
138+
nlist.append(n)
139+
# print(t, n, v)
140+
in_api = False
141+
s = f.readline()
142+
continue
143+
144+
match_groups = re.search("public double\[\] (\S+);", s)
145+
if (match_groups is not None):
146+
t = "double array"
147+
n = match_groups.group(1)
148+
print(" parms = .addDoubleArrayParm(parms, k=\"{}\", v={})".format(n,n))
149+
nlist.append(n)
150+
# print(t, n)
151+
in_api = False
152+
s = f.readline()
153+
continue
154+
155+
match_groups = re.search("public long (\S+) = new Random.*;", s)
156+
if (match_groups is not None):
157+
t = "long"
158+
n = match_groups.group(1)
159+
v = -1
160+
print(" parms = .addLongParm(parms, k=\"{}\", v={})".format(n,n))
161+
nlist.append(n)
162+
# print(t, n, v)
163+
in_api = False
164+
s = f.readline()
165+
continue
166+
167+
match_groups = re.search("public long (\S+) = (\S+);", s)
168+
if (match_groups is not None):
169+
t = "long"
170+
n = match_groups.group(1)
171+
v = match_groups.group(2)
172+
print(" parms = .addLongParm(parms, k=\"{}\", v={})".format(n,n))
173+
nlist.append(n)
174+
# print(t, n, v)
175+
in_api = False
176+
s = f.readline()
177+
continue
178+
179+
if (stripped == "public InitialWeightDistribution initial_weight_distribution = InitialWeightDistribution.UniformAdaptive;"):
180+
t = "string"
181+
n = "initial_weight_distribution"
182+
print(" parms = .addStringParm(parms, k=\"{}\", v={})".format(n,n))
183+
nlist.append(n)
184+
# print(t, "initial_weight_distribution", "UniformAdaptive")
185+
in_api = False
186+
s = f.readline()
187+
continue
188+
189+
if (stripped == "public Loss loss = Loss.CrossEntropy;"):
190+
t = "string"
191+
n = "loss"
192+
print(" parms = .addStringParm(parms, k=\"{}\", v={})".format(n,n))
193+
nlist.append(n)
194+
# print(t, "loss", "CrossEntropy")
195+
in_api = False
196+
s = f.readline()
197+
continue
198+
199+
if (stripped == "public ClassSamplingMethod score_validation_sampling = ClassSamplingMethod.Uniform;"):
200+
t = "string"
201+
n = "score_validation_sampling"
202+
print(" parms = .addStringParm(parms, k=\"{}\", v={})".format(n,n))
203+
nlist.append(n)
204+
# print(t, "score_validation_sampling", "Uniform")
205+
in_api = False
206+
s = f.readline()
207+
continue
208+
209+
print("ERROR: No match group found on line ", lineno)
210+
sys.exit(1)
211+
212+
s = f.readline()
213+
f.close()
214+
215+
for n in nlist:
216+
print(" {},".format(n))
217+
218+
except IOError as e:
219+
print("")
220+
print("ERROR: Failure reading test list: " + deeplearning_file)
221+
print(" (errno {0}): {1}".format(e.errno, e.strerror))
222+
print("")
223+
sys.exit(1)
224+
225+
def main(argv):
226+
read_deeplearning_file("./src/main/java/hex/deeplearning/DeepLearning.java")
227+
228+
if __name__ == "__main__":
229+
main(sys.argv)

0 commit comments

Comments
 (0)