From 49a370192641fd18c7a2222f5aa349c99f62a414 Mon Sep 17 00:00:00 2001 From: ajaynagesh Date: Tue, 19 Nov 2013 22:26:28 +0530 Subject: [PATCH] Started gibbs sampling for pr(Z|Yi) --- .../classify/SelPrefORExtractor.java | 53 ++++++++++++++----- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/src/edu/stanford/nlp/kbp/slotfilling/classify/SelPrefORExtractor.java b/src/edu/stanford/nlp/kbp/slotfilling/classify/SelPrefORExtractor.java index a93f122..a76bb61 100644 --- a/src/edu/stanford/nlp/kbp/slotfilling/classify/SelPrefORExtractor.java +++ b/src/edu/stanford/nlp/kbp/slotfilling/classify/SelPrefORExtractor.java @@ -1,13 +1,10 @@ package edu.stanford.nlp.kbp.slotfilling.classify; -import edu.stanford.nlp.classify.Dataset; import edu.stanford.nlp.ie.machinereading.structure.RelationMention; import edu.stanford.nlp.kbp.slotfilling.KBPTrainer; -import edu.stanford.nlp.kbp.slotfilling.classify.HoffmannExtractor.LabelWeights; import edu.stanford.nlp.kbp.slotfilling.common.Constants; import edu.stanford.nlp.kbp.slotfilling.common.Log; import edu.stanford.nlp.kbp.slotfilling.common.Props; -import edu.stanford.nlp.sequences.SequenceGibbsSampler; import edu.stanford.nlp.stats.ClassicCounter; import edu.stanford.nlp.stats.Counter; import edu.stanford.nlp.util.ErasureUtils; @@ -487,6 +484,18 @@ void generateYZTPredicted(List> ys, List> zs, List> ComputePrZ(int [][] crtGroup) { List> prZs = estimateZ(crtGroup); + for(Counter pr_z : prZs){ + double scoreTotal = 0.0; + for(double score : pr_z.values()) + scoreTotal += Math.exp(score); + + for(int z : pr_z.keySet()){ + double score = Math.exp(pr_z.getCount(z)); + pr_z.setCount(z, score/scoreTotal); + } + + } + return prZs; } @@ -507,20 +516,18 @@ private void trainJointly( // 1. estimate Pr(Z) .. for now estimating \hat{Z} List> pr_z = ComputePrZ(crtGroup); zPredicted = generateZPredicted(pr_z); + // 2. estimate Pr(Y|Z,T) Counter pr_y = ComputePrY_ZiTi(zPredicted, arg1Type, arg2Type, goldPos, yLabels); - - Counter yPredicted = null; - - yPredicted = generateYPredicted(pr_y, yLabels, 0.01); // TODO: temporary hack .. need to parameterize this + Counter yPredicted = generateYPredicted(pr_y, yLabels, 0.01); // TODO: temporary hack .. need to parameterize this Set [] zUpdate; if(updateCondition(yPredicted.keySet(), goldPos)){ zUpdate = generateZUpdate(goldPos, pr_z); - updateZModel(zUpdate, zPredicted, crtGroup, posUpdateStats, negUpdateStats); - } + //updateZModel(zUpdate, zPredicted, crtGroup, posUpdateStats, negUpdateStats); + } } else if(ALGO_TYPE == 2){ @@ -627,11 +634,31 @@ private static boolean updateCondition(Set y, Set yPos) { */ private Set [] generateZUpdate( Set goldPos, - List> zs) { - Set [] zUpdate = ErasureUtils.uncheckedCast(new Set[zs.size()]); - + List> pr_z) { + + /** + * while (not converged) { + * choose a random permutation for Z variable Pi + * for (j = 1 to |Z|) { + * update Z_{Pi_j} given the rest is fixed + * } + * } + */ + + Set [] zUpdate = null; //ErasureUtils.uncheckedCast(new Set[zs.size()]); + + Counter yLabels = new ClassicCounter(); + for(int y : goldPos){ + yLabels.incrementCount(y); + } + + + + for(int i = 0; i < epochsInf; i++){ + + } - return zUpdate; + return zUpdate; } /**