Skip to content

Commit be3ba47

Browse files
author
cypof
committed
NN fixes and todos
1 parent 742313a commit be3ba47

17 files changed

+328
-142
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,4 @@ h2o-samples/h2o-samples.iml
3131
src/main/java/water/BuildVersion.java
3232
hadoop/build.log
3333
.DS_Store
34+
/.idea

h2o-samples/src/main/java/samples/CloudRemote.java

+2-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
*/
1515
public class CloudRemote {
1616
public static void main(String[] args) throws Exception {
17-
launchEC2(null);
17+
// launchEC2(null);
18+
launchIPs(null);
1819
}
1920

2021
/**

h2o-samples/src/main/java/samples/NeuralNetMnist.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Layer[] build(Vec[] data, Vec labels, VecsInput inputStats, VecSoftmax outputSta
4545
ls[2].rate = .02f;
4646
for( int i = 0; i < ls.length; i++ ) {
4747
ls[i].l2 = .0001f;
48-
ls[i].rateAnnealing = 1 / 2e6f;
48+
ls[i].rate_annealing = 1 / 2e6f;
4949
ls[i].init(ls, i);
5050
}
5151
return ls;
@@ -86,12 +86,12 @@ Trainer startTraining(Layer[] ls) {
8686

8787
// Build separate nets for scoring purposes, use same normalization stats as for training
8888
Layer[] temp = build(train, trainLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
89-
Layer.copyWeights(ls, temp);
89+
Layer.shareWeights(ls, temp);
9090
Error error = NeuralNet.eval(temp, NeuralNet.EVAL_ROW_COUNT, null);
9191
text += "train: " + error;
9292

9393
temp = build(test, testLabels, (VecsInput) ls[0], (VecSoftmax) ls[ls.length - 1]);
94-
Layer.copyWeights(ls, temp);
94+
Layer.shareWeights(ls, temp);
9595
error = NeuralNet.eval(temp, NeuralNet.EVAL_ROW_COUNT, null);
9696
text += ", test: " + error;
9797
text += ", rates: ";

h2o-samples/src/main/java/samples/NeuralNetMnistDeep.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
*/
1313
public class NeuralNetMnistDeep extends NeuralNetMnist {
1414
public static void main(String[] args) throws Exception {
15-
// CloudLocal.launch(1, NeuralNetMnistDeep.class);
15+
CloudLocal.launch(1, NeuralNetMnistDeep.class);
1616
// CloudProcess.launch(4, NeuralNetMnistDeep.class);
1717
// CloudConnect.launch("localhost:54321", NeuralNetMnistDeep.class);
18-
CloudRemote.launchIPs(NeuralNetMnistDeep.class);
18+
// CloudRemote.launchIPs(NeuralNetMnistDeep.class);
1919
}
2020

2121
@Override public Layer[] build(Vec[] data, Vec labels, VecsInput inputStats, VecSoftmax outputStats) {
@@ -29,7 +29,7 @@ public static void main(String[] args) throws Exception {
2929
ls[ls.length - 1].rate = .02f;
3030
for( int i = 0; i < ls.length; i++ ) {
3131
ls[i].l2 = .0001f;
32-
ls[i].rateAnnealing = 1 / 1e5f;
32+
ls[i].rate_annealing = 1 / 1e5f;
3333
ls[i].init(ls, i);
3434
}
3535
return ls;

h2o-samples/src/main/java/samples/WebAPI.java

+64-4
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,37 @@
11
package samples;
22

3-
import java.io.InputStreamReader;
3+
import java.io.*;
44

55
import org.apache.commons.httpclient.HttpClient;
66
import org.apache.commons.httpclient.methods.GetMethod;
7+
import org.apache.commons.httpclient.methods.PostMethod;
8+
import org.apache.commons.httpclient.methods.multipart.*;
79

8-
import com.google.gson.Gson;
10+
import water.util.Utils;
11+
12+
import com.google.gson.*;
13+
import com.google.gson.internal.Streams;
14+
import com.google.gson.stream.JsonWriter;
915

1016
/**
11-
* Invokes an H2O functionality through the Web API.
17+
* Invokes H2O functionality through the Web API.
1218
*/
1319
public class WebAPI {
20+
static final String URL = "http://127.0.0.1:54321";
21+
static final File JSON_FILE = new File(Utils.tmp(), "model.json");
22+
1423
public static void main(String[] args) throws Exception {
24+
listJobs();
25+
exportModel();
26+
importModel();
27+
}
28+
29+
/**
30+
* Lists jobs currently running.
31+
*/
32+
static void listJobs() throws Exception {
1533
HttpClient client = new HttpClient();
16-
GetMethod get = new GetMethod("http://127.0.0.1:54321/Jobs.json");
34+
GetMethod get = new GetMethod(URL + "/Jobs.json");
1735
int status = client.executeMethod(get);
1836
if( status != 200 )
1937
throw new Exception(get.getStatusText());
@@ -36,4 +54,46 @@ public static class Job {
3654
String end_time;
3755
String exception;
3856
}
57+
58+
/**
59+
* Exports a model to a JSON file.
60+
*/
61+
static void exportModel() throws Exception {
62+
HttpClient client = new HttpClient();
63+
GetMethod get = new GetMethod(URL + "/2/ExportModel.json?model=MyInitialNeuralNet");
64+
int status = client.executeMethod(get);
65+
if( status != 200 )
66+
throw new Exception(get.getStatusText());
67+
JsonObject response = (JsonObject) new JsonParser().parse(new InputStreamReader(get.getResponseBodyAsStream()));
68+
JsonElement model = response.get("model");
69+
JsonWriter writer = new JsonWriter(new FileWriter(JSON_FILE));
70+
writer.setLenient(true);
71+
writer.setIndent(" ");
72+
Streams.write(model, writer);
73+
writer.close();
74+
get.releaseConnection();
75+
}
76+
77+
/**
78+
* Imports a model from a JSON file.
79+
*/
80+
public static void importModel() throws Exception {
81+
// Upload file to H2O
82+
HttpClient client = new HttpClient();
83+
PostMethod post = new PostMethod(URL + "/Upload.json?key=" + JSON_FILE.getName());
84+
Part[] parts = { new FilePart(JSON_FILE.getName(), JSON_FILE) };
85+
post.setRequestEntity(new MultipartRequestEntity(parts, post.getParams()));
86+
if( 200 != client.executeMethod(post) )
87+
throw new RuntimeException("Request failed: " + post.getStatusLine());
88+
post.releaseConnection();
89+
90+
// Parse the key into a model
91+
GetMethod get = new GetMethod(URL + "/2/ImportModel.json?" //
92+
+ "destination_key=MyImportedNeuralNet&" //
93+
+ "type=NeuralNetModel&" //
94+
+ "json=" + JSON_FILE.getName());
95+
if( 200 != client.executeMethod(get) )
96+
throw new RuntimeException("Request failed: " + get.getStatusLine());
97+
get.releaseConnection();
98+
}
3999
}

src/main/java/hex/KMeans2.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ public static class KMeans2Progress extends Progress2 {
146146
static public DocGen.FieldDoc[] DOC_FIELDS;
147147

148148
@Override protected Response jobDone(Job job, Key dst) {
149-
return new Response(Response.Status.redirect, this, 0, 0, new KMeans2ModelView().href(), "model", destination_key);
149+
return new Response(Response.Status.redirect, this, 0, 0, new KMeans2ModelView().href(), "destination_key", destination_key);
150150
}
151151
}
152152

0 commit comments

Comments
 (0)