Skip to content

Commit

Permalink
Naive Bayes scoring on continuous variables assumes normal pdf with m…
Browse files Browse the repository at this point in the history
…ean and variance calculated from training data.
  • Loading branch information
anqif committed Apr 1, 2014
1 parent 3f623e9 commit 00f9250
Showing 4 changed files with 50 additions and 39 deletions.
77 changes: 39 additions & 38 deletions .classpath
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="src" path="src/test/java"/>
<classpathentry kind="src" path="src/main/resources"/>
<classpathentry kind="lib" path="lib/apache/commons-configuration-1.6.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-lang-2.4.jar" sourcepath="lib/apache/commons-lang-2.4-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-logging-1.1.1.jar" sourcepath="lib/apache/commons-logging-1.1.1-sources.zip"/>
<classpathentry kind="lib" path="lib/apache/httpclient-4.1.1.jar" sourcepath="lib/apache/httpclient-4.1.1-sources.zip"/>
<classpathentry kind="lib" path="lib/apache/httpcore-4.1.jar" sourcepath="lib/apache/httpcore-4.1-sources.jar"/>
<classpathentry kind="lib" path="lib/junit/junit-4.11.jar" sourcepath="lib/junit/junit-4.11-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/guava-12.0.1.jar" sourcepath="lib/apache/guava-12.0.1-sources.jar"/>
<classpathentry kind="lib" path="lib/gson/gson-2.2.2.jar" sourcepath="lib/gson/gson-2.2.2-sources.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-3.8-20120326.jar" sourcepath="lib/poi/poi-3.8-sources.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-ooxml-3.8-20120326.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-ooxml-schemas-3.8-20120326.jar"/>
<classpathentry kind="lib" path="lib/s3/aws-java-sdk-1.3.27.jar" sourcepath="lib/s3/aws-java-sdk-1.3.27-sources.jar"/>
<classpathentry kind="lib" path="lib/jama/Jama.jar"/>
<classpathentry kind="lib" path="lib/javassist.jar" sourcepath="lib/javassist-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-codec-1.4.jar" sourcepath="lib/apache/commons-codec-1.4-sources.zip"/>
<classpathentry kind="lib" path="lib/mockito/mockito-all-1.9.5.jar" sourcepath="lib/mockito/mockito-all-1.9.5-sources.jar"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="lib/jets3t/commons-httpclient-3.1.jar" sourcepath="lib/jets3t/commons-httpclient-3.1-sources.jar"/>
<classpathentry kind="lib" path="lib/jets3t/jets3t-0.6.1.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/commons-cli-1.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/guava-r09-jarjar.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/jackson-core-asl-1.5.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/jackson-mapper-asl-1.5.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/log4j-1.2.15.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/hadoop-core-0.20.2-cdh3u6.jar" sourcepath="lib/hadoop/cdh3/hadoop-core-0.20.2-cdh3u6-sources.jar"/>
<classpathentry kind="lib" path="lib/log4j/log4j-1.2.15.jar"/>
<classpathentry kind="lib" path="lib/jogamp/jocl-natives-linux-amd64.jar"/>
<classpathentry kind="lib" path="lib/jogamp/jocl.jar" sourcepath="lib/jogamp/jocl-sources.jar"/>
<classpathentry kind="lib" path="lib/jogamp/gluegen-rt-natives-linux-amd64.jar"/>
<classpathentry kind="lib" path="lib/jogamp/gluegen-rt.jar" sourcepath="lib/jogamp/gluegen-rt-sources.jar"/>
<classpathentry kind="lib" path="lib/joda/joda-time-2.3.jar"/>
<classpathentry kind="output" path="target/classes"/>
</classpath>
<?xml version="1.0" encoding="UTF-8"?>
<classpath>
<classpathentry kind="src" path="src/main/java"/>
<classpathentry kind="src" path="src/test/java"/>
<classpathentry kind="src" path="src/main/resources"/>
<classpathentry kind="lib" path="lib/apache/commons-configuration-1.6.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-lang-2.4.jar" sourcepath="lib/apache/commons-lang-2.4-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-logging-1.1.1.jar" sourcepath="lib/apache/commons-logging-1.1.1-sources.zip"/>
<classpathentry kind="lib" path="lib/apache/httpclient-4.1.1.jar" sourcepath="lib/apache/httpclient-4.1.1-sources.zip"/>
<classpathentry kind="lib" path="lib/apache/httpcore-4.1.jar" sourcepath="lib/apache/httpcore-4.1-sources.jar"/>
<classpathentry kind="lib" path="lib/junit/junit-4.11.jar" sourcepath="lib/junit/junit-4.11-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/guava-12.0.1.jar" sourcepath="lib/apache/guava-12.0.1-sources.jar"/>
<classpathentry kind="lib" path="lib/gson/gson-2.2.2.jar" sourcepath="lib/gson/gson-2.2.2-sources.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-3.8-20120326.jar" sourcepath="lib/poi/poi-3.8-sources.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-ooxml-3.8-20120326.jar"/>
<classpathentry kind="lib" path="lib/poi/poi-ooxml-schemas-3.8-20120326.jar"/>
<classpathentry kind="lib" path="lib/s3/aws-java-sdk-1.3.27.jar" sourcepath="lib/s3/aws-java-sdk-1.3.27-sources.jar"/>
<classpathentry kind="lib" path="lib/jama/Jama.jar"/>
<classpathentry kind="lib" path="lib/javassist.jar" sourcepath="lib/javassist-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-codec-1.4.jar" sourcepath="lib/apache/commons-codec-1.4-sources.zip"/>
<classpathentry kind="lib" path="lib/mockito/mockito-all-1.9.5.jar" sourcepath="lib/mockito/mockito-all-1.9.5-sources.jar"/>
<classpathentry kind="lib" path="lib/apache/commons-math3-3.2.jar"/>
<classpathentry kind="con" path="org.eclipse.jdt.launching.JRE_CONTAINER"/>
<classpathentry kind="lib" path="lib/jets3t/commons-httpclient-3.1.jar" sourcepath="lib/jets3t/commons-httpclient-3.1-sources.jar"/>
<classpathentry kind="lib" path="lib/jets3t/jets3t-0.6.1.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/commons-cli-1.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/guava-r09-jarjar.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/jackson-core-asl-1.5.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/jackson-mapper-asl-1.5.2.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/log4j-1.2.15.jar"/>
<classpathentry kind="lib" path="lib/hadoop/cdh3/hadoop-core-0.20.2-cdh3u6.jar" sourcepath="lib/hadoop/cdh3/hadoop-core-0.20.2-cdh3u6-sources.jar"/>
<classpathentry kind="lib" path="lib/log4j/log4j-1.2.15.jar"/>
<classpathentry kind="lib" path="lib/jogamp/jocl-natives-linux-amd64.jar"/>
<classpathentry kind="lib" path="lib/jogamp/jocl.jar" sourcepath="lib/jogamp/jocl-sources.jar"/>
<classpathentry kind="lib" path="lib/jogamp/gluegen-rt-natives-linux-amd64.jar"/>
<classpathentry kind="lib" path="lib/jogamp/gluegen-rt.jar" sourcepath="lib/jogamp/gluegen-rt-sources.jar"/>
<classpathentry kind="lib" path="lib/joda/joda-time-2.3.jar"/>
<classpathentry kind="output" path="target/classes"/>
</classpath>
Binary file added lib/apache/commons-math3-3.2-sources.jar
Binary file not shown.
Binary file added lib/apache/commons-math3-3.2.jar
Binary file not shown.
12 changes: 11 additions & 1 deletion src/main/java/hex/nb/NBModel.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package hex.nb;

import org.apache.commons.math3.distribution.NormalDistribution;

import hex.FrameTask.DataInfo;
import hex.nb.NaiveBayes.NBTask;
import water.Key;
@@ -48,11 +50,19 @@ public NBModel(Key selfKey, Key dataKey, DataInfo dinfo, NBTask tsk, double[] pp
// Compute joint probability of predictors for every response class
for(int rlevel = 0; rlevel < pprior.length; rlevel++) {
double num = 1;
for(int col = 0; col < data.length; col++) {
for(int col = 0; col < ncats; col++) {
if(Double.isNaN(data[col])) continue; // Skip predictor in joint x_1,...,x_m if NA
int plevel = (int)data[col];
num *= pcond[col][rlevel][plevel]; // p(x|y) = \Pi_{j = 1}^m p(x_j|y)
}

// For numeric predictors, assume Gaussian distribution with sample mean and variance from model
for(int col = ncats; col < data.length; col++) {
if(Double.isNaN(data[col])) continue;
NormalDistribution nd = new NormalDistribution(pcond[col][rlevel][0], pcond[col][rlevel][1]);
num *= nd.density(data[col]);
}

num *= pprior[rlevel]; // p(x,y) = p(x|y)*p(y)
denom += num; // p(x) = \Sum_{levels of y} p(x,y)
preds[rlevel+1] = (float)num;

0 comments on commit 00f9250

Please sign in to comment.