Skip to content

Commit

Permalink
add runifSplit to Frame
Browse files Browse the repository at this point in the history
  • Loading branch information
spennihana committed Jan 8, 2015
1 parent 62401b0 commit 8a3b7d5
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 3 deletions.
56 changes: 53 additions & 3 deletions src/main/java/water/fvec/Frame.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package water.fvec;

import jsr166y.CountedCompleter;
import jsr166y.ForkJoinTask;
import jsr166y.ForkJoinWorkerThread;
import jsr166y.RecursiveAction;
import water.*;
import water.H2O.H2OCountedCompleter;
import water.exec.Flow;
Expand All @@ -14,6 +11,7 @@
import java.util.Arrays;
import java.util.HashMap;
import java.util.IllegalFormatException;
import java.util.Random;

/**
* A collection of named Vecs. Essentially an R-like data-frame. Multiple
Expand Down Expand Up @@ -1074,6 +1072,33 @@ private static class DeepSlice extends MRTask2<DeepSlice> {
}
}


public static Frame[] runifSplit(Frame f, float threshold, long seed) {
if (seed == -1) seed = new Random().nextLong();
Vec rv = new Vec(f.anyVec().group().addVecs(1)[0],f.anyVec()._espc);
Futures fs = new Futures();
DKV.put(rv._key,rv, fs);
for(int i = 0; i < rv._espc.length-1; ++i)
DKV.put(rv.chunkKey(i),new C0DChunk(0,(int)(rv._espc[i+1]-rv._espc[i])),fs);
fs.blockForPending();
final long zeed = seed;
new MRTask2() {
@Override public void map(Chunk c){
Random rng = new Random(zeed*c.cidx());
for(int i = 0; i < c._len; ++i)
c.set0(i, (float)rng.nextDouble());
}
}.doAll(rv);
Vec[] vecs = new Vec[f.numCols()+1];
System.arraycopy(f.vecs(), 0, vecs,0, f.numCols());
vecs[f.numCols()] = rv;
Frame doAllFr = new Frame(null, vecs);
Frame left = new DeepSelectThresh(threshold, true).doAll(f.numCols(),doAllFr).outputFrame(null, doAllFr.domains());
Frame rite = new DeepSelectThresh(threshold, false).doAll(f.numCols(),doAllFr).outputFrame(null, doAllFr.domains());
UKV.remove(rv._key);
return new Frame[]{left,rite};
}

private static class DeepSelect extends MRTask2<DeepSelect> {
@Override public void map( Chunk chks[], NewChunk nchks[] ) {
Chunk pred = chks[chks.length-1];
Expand All @@ -1089,6 +1114,31 @@ private static class DeepSelect extends MRTask2<DeepSelect> {
}
}

private static class DeepSelectThresh extends MRTask2<DeepSelectThresh> {
private final float _threshold;
private final boolean _left;
DeepSelectThresh(float threshold, boolean left) { _threshold = threshold; _left = left; }

private void addRow(Chunk[] cs, NewChunk[] ncs, int i) {
for (int j = 0; j < cs.length -1; ++j) {
Chunk c = cs[j];
if (c._vec.isUUID()) ncs[j].addUUID(c,i);
else ncs[j].addNum(c.at0(i)); // NewChunk will compress later ... not set0s
}
}

@Override public void map(Chunk cs[], NewChunk ncs[]) {
Chunk rv = cs[cs.length-1];
for (int i = 0; i < rv._len; ++i) {
if (_left) {
if (rv.at0(i) <= _threshold) addRow(cs, ncs, i);
} else {
if (rv.at0(i) > _threshold) addRow(cs, ncs, i);
}
}
}
}

private Frame copyRollups( Frame fr, boolean isACopy ) {
if( !isACopy ) return fr; // Not a clean copy, do not copy rollups (will do rollups "the hard way" on first ask)
Vec vecs0[] = vecs();
Expand Down
24 changes: 24 additions & 0 deletions src/test/java/water/util/RunifSplitTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package water.util;

import org.junit.Assert;
import org.junit.Test;
import water.Key;
import water.TestUtil;
import water.fvec.Frame;
import water.fvec.NFSFileVec;
import water.fvec.ParseDataset2;

public class RunifSplitTest extends TestUtil {
final static String PATH = "smalldata/iris/iris.csv";

@Test
public void test1() {
Key file = NFSFileVec.make(find_test_file(PATH));
Frame fr = ParseDataset2.parse(Key.make("iris_nn2"), new Key[]{file});
Frame[] split = Frame.runifSplit(fr, .70f, -1);
Assert.assertTrue(split[0].numRows() + split[1].numRows() == fr.numRows());
fr.delete();
split[0].delete();
split[1].delete();
}
}

0 comments on commit 8a3b7d5

Please sign in to comment.