Skip to content

Commit

Permalink
Started gibbs sampling for pr(Z|Yi)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaynagesh committed Nov 19, 2013
1 parent aa0a28d commit 49a3701
Showing 1 changed file with 40 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.kbp.slotfilling.KBPTrainer;
import edu.stanford.nlp.kbp.slotfilling.classify.HoffmannExtractor.LabelWeights;
import edu.stanford.nlp.kbp.slotfilling.common.Constants;
import edu.stanford.nlp.kbp.slotfilling.common.Log;
import edu.stanford.nlp.kbp.slotfilling.common.Props;
import edu.stanford.nlp.sequences.SequenceGibbsSampler;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.ErasureUtils;
Expand Down Expand Up @@ -487,6 +484,18 @@ void generateYZTPredicted(List<Counter<Integer>> ys, List<Counter<Integer>> zs,
List<Counter<Integer>> ComputePrZ(int [][] crtGroup) {
List<Counter<Integer>> prZs = estimateZ(crtGroup);

for(Counter<Integer> pr_z : prZs){
double scoreTotal = 0.0;
for(double score : pr_z.values())
scoreTotal += Math.exp(score);

for(int z : pr_z.keySet()){
double score = Math.exp(pr_z.getCount(z));
pr_z.setCount(z, score/scoreTotal);
}

}

return prZs;
}

Expand All @@ -507,20 +516,18 @@ private void trainJointly(
// 1. estimate Pr(Z) .. for now estimating \hat{Z}
List<Counter<Integer>> pr_z = ComputePrZ(crtGroup);
zPredicted = generateZPredicted(pr_z);

// 2. estimate Pr(Y|Z,T)
Counter<Integer> pr_y = ComputePrY_ZiTi(zPredicted, arg1Type, arg2Type, goldPos, yLabels);

Counter<Integer> yPredicted = null;

yPredicted = generateYPredicted(pr_y, yLabels, 0.01); // TODO: temporary hack .. need to parameterize this
Counter<Integer> yPredicted = generateYPredicted(pr_y, yLabels, 0.01); // TODO: temporary hack .. need to parameterize this

Set<Integer> [] zUpdate;

if(updateCondition(yPredicted.keySet(), goldPos)){

zUpdate = generateZUpdate(goldPos, pr_z);
updateZModel(zUpdate, zPredicted, crtGroup, posUpdateStats, negUpdateStats);
}
//updateZModel(zUpdate, zPredicted, crtGroup, posUpdateStats, negUpdateStats);
}
}

else if(ALGO_TYPE == 2){
Expand Down Expand Up @@ -627,11 +634,31 @@ private static boolean updateCondition(Set<Integer> y, Set<Integer> yPos) {
*/
private Set<Integer> [] generateZUpdate(
Set<Integer> goldPos,
List<Counter<Integer>> zs) {
Set<Integer> [] zUpdate = ErasureUtils.uncheckedCast(new Set[zs.size()]);

List<Counter<Integer>> pr_z) {

/**
* while (not converged) {
* choose a random permutation for Z variable Pi
* for (j = 1 to |Z|) {
* update Z_{Pi_j} given the rest is fixed
* }
* }
*/

Set<Integer> [] zUpdate = null; //ErasureUtils.uncheckedCast(new Set[zs.size()]);

Counter<Integer> yLabels = new ClassicCounter<Integer>();
for(int y : goldPos){
yLabels.incrementCount(y);
}



for(int i = 0; i < epochsInf; i++){

}

return zUpdate;
return zUpdate;
}

/**
Expand Down

0 comments on commit 49a3701

Please sign in to comment.