Skip to content

Commit

Permalink
Grid search launches one job at a time, fixed NN MR trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
cypof committed Sep 28, 2013
1 parent 74c38ef commit b13938f
Show file tree
Hide file tree
Showing 18 changed files with 295 additions and 215 deletions.
14 changes: 9 additions & 5 deletions .settings/org.eclipse.jdt.core.prefs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
eclipse.preferences.version=1
org.eclipse.jdt.core.compiler.annotation.inheritNullAnnotations=disabled
org.eclipse.jdt.core.compiler.annotation.missingNonNullByDefaultAnnotation=ignore
org.eclipse.jdt.core.compiler.annotation.nonnull=org.eclipse.jdt.annotation.NonNull
org.eclipse.jdt.core.compiler.annotation.nonnullbydefault=org.eclipse.jdt.annotation.NonNullByDefault
Expand All @@ -23,7 +24,7 @@ org.eclipse.jdt.core.compiler.problem.discouragedReference=warning
org.eclipse.jdt.core.compiler.problem.emptyStatement=ignore
org.eclipse.jdt.core.compiler.problem.enumIdentifier=error
org.eclipse.jdt.core.compiler.problem.explicitlyClosedAutoCloseable=ignore
org.eclipse.jdt.core.compiler.problem.fallthroughCase=ignore
org.eclipse.jdt.core.compiler.problem.fallthroughCase=warning
org.eclipse.jdt.core.compiler.problem.fatalOptionalError=disabled
org.eclipse.jdt.core.compiler.problem.fieldHiding=ignore
org.eclipse.jdt.core.compiler.problem.finalParameterBound=warning
Expand All @@ -36,17 +37,18 @@ org.eclipse.jdt.core.compiler.problem.incompleteEnumSwitch=warning
org.eclipse.jdt.core.compiler.problem.indirectStaticAccess=ignore
org.eclipse.jdt.core.compiler.problem.localVariableHiding=ignore
org.eclipse.jdt.core.compiler.problem.methodWithConstructorName=warning
org.eclipse.jdt.core.compiler.problem.missingDefaultCase=ignore
org.eclipse.jdt.core.compiler.problem.missingDefaultCase=warning
org.eclipse.jdt.core.compiler.problem.missingDeprecatedAnnotation=ignore
org.eclipse.jdt.core.compiler.problem.missingEnumCaseDespiteDefault=disabled
org.eclipse.jdt.core.compiler.problem.missingHashCodeMethod=ignore
org.eclipse.jdt.core.compiler.problem.missingOverrideAnnotation=ignore
org.eclipse.jdt.core.compiler.problem.missingHashCodeMethod=warning
org.eclipse.jdt.core.compiler.problem.missingOverrideAnnotation=warning
org.eclipse.jdt.core.compiler.problem.missingOverrideAnnotationForInterfaceMethodImplementation=enabled
org.eclipse.jdt.core.compiler.problem.missingSerialVersion=ignore
org.eclipse.jdt.core.compiler.problem.missingSynchronizedOnInheritedMethod=ignore
org.eclipse.jdt.core.compiler.problem.missingSynchronizedOnInheritedMethod=warning
org.eclipse.jdt.core.compiler.problem.noEffectAssignment=warning
org.eclipse.jdt.core.compiler.problem.noImplicitStringConversion=warning
org.eclipse.jdt.core.compiler.problem.nonExternalizedStringLiteral=ignore
org.eclipse.jdt.core.compiler.problem.nonnullParameterAnnotationDropped=warning
org.eclipse.jdt.core.compiler.problem.nullAnnotationInferenceConflict=error
org.eclipse.jdt.core.compiler.problem.nullReference=warning
org.eclipse.jdt.core.compiler.problem.nullSpecViolation=error
Expand All @@ -67,6 +69,7 @@ org.eclipse.jdt.core.compiler.problem.specialParameterHidingField=disabled
org.eclipse.jdt.core.compiler.problem.staticAccessReceiver=warning
org.eclipse.jdt.core.compiler.problem.suppressOptionalErrors=disabled
org.eclipse.jdt.core.compiler.problem.suppressWarnings=enabled
org.eclipse.jdt.core.compiler.problem.syntacticNullAnalysisForFields=disabled
org.eclipse.jdt.core.compiler.problem.syntheticAccessEmulation=ignore
org.eclipse.jdt.core.compiler.problem.typeParameterHiding=warning
org.eclipse.jdt.core.compiler.problem.unavoidableGenericTypeProblems=enabled
Expand All @@ -90,6 +93,7 @@ org.eclipse.jdt.core.compiler.problem.unusedParameterIncludeDocCommentReference=
org.eclipse.jdt.core.compiler.problem.unusedParameterWhenImplementingAbstract=disabled
org.eclipse.jdt.core.compiler.problem.unusedParameterWhenOverridingConcrete=disabled
org.eclipse.jdt.core.compiler.problem.unusedPrivateMember=warning
org.eclipse.jdt.core.compiler.problem.unusedTypeParameter=ignore
org.eclipse.jdt.core.compiler.problem.unusedWarningToken=warning
org.eclipse.jdt.core.compiler.problem.varargsArgumentNeedCast=warning
org.eclipse.jdt.core.compiler.source=1.6
Expand Down
25 changes: 16 additions & 9 deletions src/main/java/hex/GridSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public class GridSearch extends Job {
@Override protected void run() {
UKV.put(destination_key, this);
for( Job job : jobs )
job.startFJ();
job.startFJ().join();
}

@Override public float progress() {
Expand All @@ -25,7 +25,7 @@ public class GridSearch extends Job {
return Math.min(1f, (float) (d / jobs.length));
}

@Override protected Response redirect() {
@Override public Response redirect() {
String n = GridSearchProgress.class.getSimpleName();
return new Response(Response.Status.redirect, this, -1, -1, n, "job", job_key, "dst_key", destination_key);
}
Expand Down Expand Up @@ -98,15 +98,22 @@ public static class GridSearchProgress extends Progress2 {
throw new RuntimeException(e);
}
}
sb.append("<td>").append((info._job.runTimeMs()) / 1000).append("</td>");
String runTime = "Pending", speed = "";
if( info._job.start_time != 0 ) {
runTime = "" + (info._job.runTimeMs()) / 1000;
speed = perf != null ? info._job.speedValue() : "";
}
sb.append("<td>").append(runTime).append("</td>");
if( perf != null )
sb.append("<td>").append(info._job.speedValue()).append("</td>");
sb.append("<td>").append(speed).append("</td>");

String link = info._job.destination_key.toString();
if( info._model instanceof GBMModel )
link = GBMModelView.link(link, info._job.destination_key);
else
link = Inspect.link(link, info._job.destination_key);
if( info._job.start_time != 0 ) {
if( info._model instanceof GBMModel )
link = GBMModelView.link(link, info._job.destination_key);
else
link = Inspect.link(link, info._job.destination_key);
}
sb.append("<td>").append(link).append("</td>");

String pct = "", f1 = "";
Expand All @@ -128,7 +135,7 @@ static class JobInfo {
Job _job;
Model _model;
ConfusionMatrix _cm;
double _error;
double _error = Double.POSITIVE_INFINITY;
}

static void filter(ArrayList<Argument> args, String... names) {
Expand Down
1 change: 0 additions & 1 deletion src/main/java/hex/KMeans2.java
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,5 @@ public KMeans2() {
@Override protected void run() {
Log.info(DOC_GET + source);
UKV.put(destination_key, source);
remove();
}
}
1 change: 0 additions & 1 deletion src/main/java/hex/KMeansGrid.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,5 @@ public static String link(Key k, String content) {
} finally {
UKV.remove(temp);
}
remove();
}
}
13 changes: 6 additions & 7 deletions src/main/java/hex/Layer.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import java.util.Random;

import water.Iced;
import water.Model;
import water.fvec.Chunk;
import water.fvec.Vec;

Expand Down Expand Up @@ -243,6 +242,7 @@ public static class ChunksInput extends Input {

public ChunksInput(Chunk[] chunks, VecsInput stats) {
super(stats._subs.length);
assert stats == null || (chunks.length == stats._subs.length && chunks.length == stats._muls.length);
_chunks = chunks;
_subs = stats._subs;
_muls = stats._muls;
Expand All @@ -252,7 +252,7 @@ public ChunksInput(Chunk[] chunks, VecsInput stats) {
for( int i = 0; i < _a.length; i++ ) {
double d = _chunks[i].at0((int) _pos);
d -= _subs[i];
d = _muls[i] > 1e-4 ? d / _muls[i] : d;
d *= _muls[i];
_a[i] = (float) d;
}
}
Expand Down Expand Up @@ -319,7 +319,7 @@ public static class VecSoftmax extends Softmax {
Vec _vec;

public VecSoftmax(Vec vec) {
super(Model.responseDomain(vec).length);
super(vec.domain().length);
_vec = vec;
}

Expand All @@ -329,10 +329,10 @@ public VecSoftmax(Vec vec) {
}

public static class ChunkSoftmax extends Softmax {
Chunk _chunk;
transient Chunk _chunk;

public ChunkSoftmax(Chunk chunk) {
super(Model.responseDomain(chunk._vec).length);
super(chunk._vec.domain().length);
_chunk = chunk;
}

Expand Down Expand Up @@ -429,8 +429,7 @@ public Rectifier(int units) {
}

public static void copyWeights(Layer[] src, Layer[] dst) {
for( int y = 1; y < src.length - 1; y++ ) {
assert dst[y]._w == null && dst[y]._b == null;
for( int y = 1; y < src.length; y++ ) {
dst[y]._w = src[y]._w;
dst[y]._b = src[y]._b;
}
Expand Down
99 changes: 57 additions & 42 deletions src/main/java/hex/NeuralNet.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ public NeuralNet() {

@Override protected void run() {
selectCols();
_train = reChunk(_train);
Vec[] vecs = Utils.add(_train, response);
reChunk(vecs);
System.arraycopy(vecs, 0, _train, 0, _train.length);
response = vecs[vecs.length - 1];

final Layer[] ls = new Layer[hidden.length + 2];
ls[0] = new VecsInput(_train);
Expand All @@ -82,7 +85,6 @@ public NeuralNet() {
UKV.put(destination_key, model);

final Trainer trainer = new Trainer.MapReduce(ls, epochs, self());
trainer.start();

// Use a separate thread for monitoring (blocked most of the time)
Thread thread = new Thread() {
Expand All @@ -99,6 +101,22 @@ public NeuralNet() {

NeuralNetModel model = new NeuralNetModel(destination_key, sourceKey, frame, ls);
long[][] cm = new long[model.classNames().length][model.classNames().length];

VecsInput stats = (VecsInput) ls[0];
Layer[] clones = new Layer[ls.length];
if( _valid != null ) {
clones[0] = new VecsInput(_valid, stats);
clones[clones.length - 1] = new VecSoftmax(_validResponse);
} else {
clones[0] = new VecsInput(_train, stats);
clones[clones.length - 1] = new VecSoftmax(response);
}
for( int y = 1; y < clones.length - 1; y++ )
clones[y] = ls[y].clone();
for( int y = 0; y < clones.length; y++ )
clones[y].init(clones, y, false, 0);
Layer.copyWeights(ls, clones);

Error train = NeuralNetScore.run(ls, EVAL_ROW_COUNT, cm);
model.items = items;
model.items_per_second = ps;
Expand All @@ -116,7 +134,8 @@ public NeuralNet() {
}
};
thread.start();
trainer.join();
//trainer.join();
trainer.start();
}

@Override public float progress() {
Expand Down Expand Up @@ -286,7 +305,7 @@ public NeuralNetScore() {
}

@Override protected void run() {
selectCols();
initResponse();
Layer[] clones = new Layer[model.layers.length];
clones[0] = new VecsInput(selectVecs(source));
for( int y = 1; y < clones.length - 1; y++ )
Expand All @@ -297,7 +316,7 @@ public NeuralNetScore() {
clones[y]._b = model.bs[y];
clones[y].init(clones, y, false, 0);
}
int classes = Model.responseDomain(response).length;
int classes = response.domain().length;
confusion_matrix = new long[classes][classes];
Error error = run(clones, max_rows, confusion_matrix);
classification_error = error.Value;
Expand Down Expand Up @@ -345,10 +364,7 @@ private static boolean correct(Layer[] ls, Error error, long[][] confusion) {
@Override public boolean toHTML(StringBuilder sb) {
DocGen.HTML.section(sb, "Classification error: " + String.format("%5.2f %%", 100 * classification_error));
DocGen.HTML.section(sb, "Square error: " + sqr_error);
String[] classes = source.vecs()[source.numCols() - 1].domain();
if( classes == null )
classes = Model.responseDomain(source);
confusion(sb, "Confusion Matrix", classes, confusion_matrix);
confusion(sb, "Confusion Matrix", response.domain(), confusion_matrix);
return true;
}

Expand Down Expand Up @@ -406,43 +422,42 @@ static int cores() {
* Makes sure small datasets are spread over enough chunks to parallelize training. Neural nets
* can require lots of processing even for small data.
*/
public static Vec[] reChunk(Vec[] vecs) {
public static void reChunk(Vec[] vecs) {
final int splits = cores() * 2; // More in case of unbalance
if( vecs[0].nChunks() >= splits )
return vecs;
for( int v = 0; v < vecs.length; v++ ) {
AppendableVec vec = new AppendableVec(UUID.randomUUID().toString());
long rows = vecs[0].length();
Chunk cache = null;
for( int split = 0; split < splits; split++ ) {
long off = rows * (split + 0) / splits;
long lim = rows * (split + 1) / splits;
NewChunk chunk = new NewChunk(vec, split);
for( long r = off; r < lim; r++ ) {
if( cache == null || r < cache._start || r >= cache._start + cache._len )
cache = vecs[v].chunk(r);
if( !cache.isNA(r) ) {
if( vecs[v]._domain != null )
chunk.addEnum((int) cache.at8(r));
else if( vecs[v].isInt() )
chunk.addNum(cache.at8(r), 0);
else
chunk.addNum(cache.at(r));
} else {
if( vecs[v].isInt() )
chunk.addNA();
else {
// Don't use addNA() for doubles, as NewChunk uses separate array
chunk.addNum(Double.NaN);
if( vecs[0].nChunks() < splits ) {
for( int v = 0; v < vecs.length; v++ ) {
AppendableVec vec = new AppendableVec(UUID.randomUUID().toString());
long rows = vecs[0].length();
Chunk cache = null;
for( int split = 0; split < splits; split++ ) {
long off = rows * (split + 0) / splits;
long lim = rows * (split + 1) / splits;
NewChunk chunk = new NewChunk(vec, split);
for( long r = off; r < lim; r++ ) {
if( cache == null || r < cache._start || r >= cache._start + cache._len )
cache = vecs[v].chunk(r);
if( !cache.isNA(r) ) {
if( vecs[v]._domain != null )
chunk.addEnum((int) cache.at8(r));
else if( vecs[v].isInt() )
chunk.addNum(cache.at8(r), 0);
else
chunk.addNum(cache.at(r));
} else {
if( vecs[v].isInt() )
chunk.addNA();
else {
// Don't use addNA() for doubles, as NewChunk uses separate array
chunk.addNum(Double.NaN);
}
}
}
chunk.close(split, null);
}
chunk.close(split, null);
Vec t = vec.close(null);
t._domain = vecs[v]._domain;
vecs[v] = t;
}
Vec t = vec.close(null);
t._domain = vecs[v]._domain;
vecs[v] = t;
}
return vecs;
}
}
}
Loading

0 comments on commit b13938f

Please sign in to comment.