Skip to content

Commit

Permalink
Pr (y, z | t, x) is done...
Browse files Browse the repository at this point in the history
  • Loading branch information
ajaynagesh committed Nov 19, 2013
1 parent 2af878b commit aa0a28d
Showing 1 changed file with 64 additions and 48 deletions.
112 changes: 64 additions & 48 deletions src/edu/stanford/nlp/kbp/slotfilling/classify/SelPrefORExtractor.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package edu.stanford.nlp.kbp.slotfilling.classify;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.kbp.slotfilling.KBPTrainer;
import edu.stanford.nlp.kbp.slotfilling.classify.HoffmannExtractor.LabelWeights;
Expand Down Expand Up @@ -213,7 +214,7 @@ public void train(MultiLabelDataset<String, String> dataset) {
Set<Integer> arg1Type = dataset.arg1TypeArray()[i];
Set<Integer> arg2Type = dataset.arg2TypeArray()[i];

trainJointly(crtGroup, goldPos, arg1Type, arg2Type, posUpdateStats, negUpdateStats);
trainJointly(crtGroup, goldPos, arg1Type, arg2Type, posUpdateStats, negUpdateStats, labelIndex);

// update the number of iterations an weight vector has survived
for(LabelWeights zw: zWeights) zw.updateSurvivalIterations();
Expand Down Expand Up @@ -296,24 +297,7 @@ public static void main(String[] args) throws Exception{

train(props);

}

/*
* while (not_converged){
Step 1: // update Z vars assuming Y vars to be fixed ..
done by simple enumeration
Step 2: // update Y assuming Z to be fixed .. done by Gibbs Sampling
while (not_converged){
choose a random permutation of Y variables → \Pi
for ( i = 1 to |Y| ) {
update Y_{\Pi_i} assuming the rest is fixed
}
}
}
*
*/

}

private Counter<Integer> estimateZ(int [] datum) {
Counter<Integer> vector = new ClassicCounter<Integer>();
Expand All @@ -333,7 +317,7 @@ private Counter<Integer> estimateZ(int [] datum) {
* @param Yi
* @return
*/
private Counter<Integer> estimateZ(int [] datum, Set<Integer> Yi) {
/*private Counter<Integer> estimateZ(int [] datum, Set<Integer> Yi) {
Counter<Integer> vector = new ClassicCounter<Integer>();
for(int d: datum) vector.incrementCount(d);
Expand All @@ -349,20 +333,20 @@ private Counter<Integer> estimateZ(int [] datum, Set<Integer> Yi) {
}
return scores;
}
}*/

private List<Counter<Integer>> ComputePrZ_Yi (int [][] datums, Set<Integer> Yi) {
/*private List<Counter<Integer>> ComputePrZ_Yi (int [][] datums, Set<Integer> Yi) {
List<Counter<Integer>> zs = new ArrayList<Counter<Integer>>();
for(int [] datum: datums) {
zs.add(estimateZ(datum, Yi));
}
return zs;
}
}*/

//\hat{Y,Z} = argmax_{Y,Z} Pr_{\theta} (Y, Z | T_i, x_i)
private void ComputePrYZ_Ti (int [][] datums, int szY, Set<Integer> arg1Type, Set<Integer> arg2Type) {
/* private void ComputePrYZ_Ti (int [][] datums, int szY, Set<Integer> arg1Type, Set<Integer> arg2Type) {
Set<Integer> yPredicted = null; // TODO: Initialize yPredicted. How ?
for (int i = 0; i < epochsInf; i ++){ // TODO: What is the stopping criterion
Expand All @@ -382,19 +366,19 @@ private void ComputePrYZ_Ti (int [][] datums, int szY, Set<Integer> arg1Type, Se
}
}
}
}*/

private void ComputePrZT_Yi () {
/*private void ComputePrZT_Yi () {
}
}*/

private void ComputePrYZT (List<Counter<Integer>> pr_y, List<Counter<Integer>> pr_z, List<Counter<Integer>> pr_t) {

}

private void computeFactor() {
/*private void computeFactor() {
}
}*/

// TODO: Need to implement the randomization routine
private Counter<Integer> randomizeVar(){
Expand Down Expand Up @@ -427,42 +411,72 @@ private Counter<Integer> randomizeVar(){
return yPredicted;
}*/

Counter<String> constructFeatureVector(Set<Integer> arg1Type, Set<Integer> arg2Type){
Counter<String> yFeats_select = new ClassicCounter<String>();
Counter<Integer> constructArgFeatureVector(Set<Integer> arg1Type, Set<Integer> arg2Type){
Counter<Integer> yFeats_select = new ClassicCounter<Integer>();

for(int type1 : arg1Type){
for(int type2 : arg2Type){
String type = "arg1=" + type1 + "_" + "arg2=" + type2;
for(int type2 : arg2Type){
for(int type1 : arg1Type){
int type = type2*10 + type1;
yFeats_select.incrementCount(type);
}
}

for(int typ : arg1Type){
String type = "arg1=" + typ;
yFeats_select.incrementCount(type);
yFeats_select.incrementCount(typ);
}
for(int typ : arg2Type){
String type = "arg2=" + typ;
yFeats_select.incrementCount(type);
yFeats_select.incrementCount(typ);
}


return yFeats_select;
}

List<Counter<Integer>> ComputePrY_ZiTi(int [] zPredicted, Set<Integer> arg1Type,
Set<Integer> arg2Type, Set<Integer> goldPos){
Counter<Integer> ComputePrY_ZiTi(int [] zPredicted, Set<Integer> arg1Type,
Set<Integer> arg2Type, Set<Integer> goldPos, Index<String> yLabels){

Counter<Integer> ys = new ClassicCounter<Integer>();

Counter<String> yFeats_select = constructFeatureVector(arg1Type, arg2Type);
yWeights_select[1].dotProduct(yFeats_select);
List<Counter<Integer>> ys = new ArrayList<Counter<Integer>>();
Counter<Integer> yFeats_select = constructArgFeatureVector(arg1Type, arg2Type);

Counter<Integer> yFeats_mention = new ClassicCounter<Integer>();
for(int z : zPredicted){
yFeats_mention.incrementCount(z);
}

double totalScore = 0.0;
for(String label : yLabels){
int indx = yLabels.indexOf(label);
double score = yWeights_select[indx].dotProduct(yFeats_select);

score += yWeights_mention[indx].dotProduct(yFeats_mention);

score = Math.exp(score);
// System.out.println("score : " + score);
totalScore += score;
ys.setCount(indx, score);
}

for(String label : yLabels){
int indx = yLabels.indexOf(label);
double prob = ys.getCount(indx)/totalScore;
ys.setCount(indx, prob);
}

return ys;
}

Counter<Integer> generateYPredicted(List<Counter<Integer>> ys) {
Counter<Integer> generateYPredicted(Counter<Integer> ys, Index<String> yLabels, double threshold) {
Counter<Integer> yPredicted = new ClassicCounter<Integer>();

for(String label : yLabels){
int indx = yLabels.indexOf(label);

double score = ys.getCount(indx);
if(score > threshold)
yPredicted.setCount(indx, 1);
}

return yPredicted;
}

Expand All @@ -482,7 +496,8 @@ private void trainJointly(
Set<Integer> arg1Type,
Set<Integer> arg2Type,
Counter<Integer> posUpdateStats,
Counter<Integer> negUpdateStats) {
Counter<Integer> negUpdateStats,
Index<String> yLabels) {

int [] zPredicted = null;
int [] tPredicted = null;
Expand All @@ -493,11 +508,11 @@ private void trainJointly(
List<Counter<Integer>> pr_z = ComputePrZ(crtGroup);
zPredicted = generateZPredicted(pr_z);
// 2. estimate Pr(Y|Z,T)
List<Counter<Integer>> pr_y = ComputePrY_ZiTi(zPredicted, arg1Type, arg2Type, goldPos);
Counter<Integer> pr_y = ComputePrY_ZiTi(zPredicted, arg1Type, arg2Type, goldPos, yLabels);

Counter<Integer> yPredicted = null;

yPredicted = generateYPredicted(pr_y);
yPredicted = generateYPredicted(pr_y, yLabels, 0.01); // TODO: temporary hack .. need to parameterize this

Set<Integer> [] zUpdate;

Expand Down Expand Up @@ -623,6 +638,7 @@ private static boolean updateCondition(Set<Integer> y, Set<Integer> yPos) {
* TODO: Need to code this up correctly
*
*/
/*
private Counter<Integer> estimateY(int [] zPredicted, int ySz, Set<Integer> arg1Type, Set<Integer> arg2Type) {
Counter<Integer> ys = new ClassicCounter<Integer>();
Expand All @@ -641,7 +657,7 @@ private Counter<Integer> estimateY(int [] zPredicted, int ySz, Set<Integer> arg1
for(int z : zPredicted) zPredVector.incrementCount(z);
return ys;
}
}*/

private List<Counter<Integer>> estimateZ(int [][] datums) {
List<Counter<Integer>> zs = new ArrayList<Counter<Integer>>();
Expand Down

0 comments on commit aa0a28d

Please sign in to comment.