Skip to content

Commit

Permalink
Corrected formulation of Pr(Z|Y,X). Now the accuracies match that of
Browse files Browse the repository at this point in the history
Hoffmann (only the inference algorithms differ in both of these)
  • Loading branch information
ajaynagesh committed Jan 17, 2014
1 parent 5dafb10 commit 51eb734
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ private Counter<Integer> calScore(int [] mentionFeatures, Set<Integer> goldPos)
}
}

if(goldPos != null && goldPos.size() == 0){ //goldPos is empty that means relation label is -nil-
if(goldPos != null){ //Also calculate the nil score when goldPos is given
int zLabel = nilIndex;
// score of Xj taking on label i \in Y' = Sj_i
double score = zWeights[zLabel].dotProduct(vector);
Expand Down Expand Up @@ -822,10 +822,10 @@ private void trainJointly(
//TODO: Do we need to differentiate between nil labels and non-nil labels (as in updateZModel) ? Verify during small dataset runs
//zUpdate = generateZUpdate(goldPos, crtGroup);

if(goldPos.size() - crtGroup.length > 0){
System.out.println("How come ? " + "data : " + crtGroup + " goldPos : " + goldPos + " EgId : " + egId + " ... skipping ...");
return;
}
// if(goldPos.size() - crtGroup.length > 0){
// System.out.println("How come ? " + "data : " + crtGroup + " goldPos : " + goldPos + " EgId : " + egId + " ... skipping ...");
// return;
// }

List<Counter<Integer>> scoresWithYgiven = computeScores(crtGroup, goldPos);
zUpdate = ilpInfHandle.generateZUpdateILP(scoresWithYgiven, crtGroup.length, goldPos, nilIndex);
Expand Down
64 changes: 16 additions & 48 deletions src/ilpInference/InferenceWrappers.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

public class InferenceWrappers {

public Set<Integer> [] generateZUpdateILP(List<Counter<Integer>> scores,
public Set<Integer> [] generateZUpdateILP(List<Counter<Integer>> scoresYGiven,
int numOfMentions,
Set<Integer> goldPos,
int nilIndex){
Expand All @@ -43,7 +43,7 @@ public class InferenceWrappers {
if(goldPos.size() > numOfMentions){
//////////////Objective --------------------------------------
for(int mentionIdx = 0; mentionIdx < numOfMentions; mentionIdx ++){
Counter<Integer> score = scores.get(mentionIdx);
Counter<Integer> score = scoresYGiven.get(mentionIdx);
for(int label : score.keySet()){
if(label == nilIndex)
continue;
Expand Down Expand Up @@ -87,7 +87,7 @@ public class InferenceWrappers {
else {
//////////////Objective --------------------------------------
for(int mentionIdx = 0; mentionIdx < numOfMentions; mentionIdx ++){
Counter<Integer> score = scores.get(mentionIdx);
Counter<Integer> score = scoresYGiven.get(mentionIdx);
for(int label : score.keySet()){
String var = "z"+mentionIdx+"_"+"y"+label;
double coeff = score.getCount(label);
Expand All @@ -102,43 +102,21 @@ public class InferenceWrappers {
/// 1. equality constraints \Sum_{i \in Y'} z_ji = 1 \forall j
for(int mentionIdx = 0; mentionIdx < numOfMentions; mentionIdx ++){
constraint = new Linear();
if(goldPos.size() == 0) { // if goldPos is [] ==> -nil- index
String var = "z"+mentionIdx+"_"+"y"+nilIndex;
for(int y : goldPos){
String var = "z"+mentionIdx+"_"+"y"+y;
constraint.add(1, var);

//System.out.print("z"+mentionIdx+"_"+"y"+y + " + ");
}
else {
for(int y : goldPos){
String var = "z"+mentionIdx+"_"+"y"+y;
constraint.add(1, var);

//System.out.print("z"+mentionIdx+"_"+"y"+y + " + ");
}
//System.out.print("z"+mentionIdx+"_"+"y"+y + " + ");
}
constraint.add(1, "z"+mentionIdx+"_"+"y"+nilIndex); //nil index added to constraint

problem.add(constraint, "=", 1);
//System.out.println(" 0 = "+ "1");
}

}



//System.out.println("\n-----------------");
/// 2. inequality constraint ===> 1 <= \Sum_j z_ji \forall i \in Y' {lhs=1, since we consider only Y' i.e goldPos}
/////////// ------------------------------------------------------
if(goldPos.size() == 0){ // if goldPos is [] ==> -nil- index
constraint = new Linear();
for(int mentionIdx = 0; mentionIdx < numOfMentions; mentionIdx ++){
String var = "z"+mentionIdx+"_"+"y"+nilIndex;
constraint.add(1, var);
//System.out.print("z"+mentionIdx+"_"+"y"+y + " + ");
}
problem.add(constraint, ">=", 1);
//System.out.println(" 0 - " + "y"+y +" >= 0" );
}
else {

//System.out.println("\n-----------------");
/// 2. inequality constraint ===> 1 <= \Sum_j z_ji \forall i \in Y' {lhs=1, since we consider only Y' i.e goldPos}
/////////// ------------------------------------------------------
for(int y : goldPos){
constraint = new Linear();
for(int mentionIdx = 0; mentionIdx < numOfMentions; mentionIdx ++){
Expand All @@ -149,9 +127,9 @@ public class InferenceWrappers {
problem.add(constraint, ">=", 1);
//System.out.println(" 0 - " + "y"+y +" >= 0" );
}
/////////// ------------------------------------------------------
}
/////////// ------------------------------------------------------


// Set the types of all variables to Binary
for(Object var : problem.getVariables())
problem.setVarType(var, Boolean.class);
Expand Down Expand Up @@ -185,28 +163,18 @@ public class InferenceWrappers {

}

int numOfUpdates = 0;
for(Object var : problem.getVariables()) {
if(result.containsVar(var) && (result.get(var).intValue() == 1)){
String [] split = var.toString().split("_");
//System.out.println(split[0]);
int mentionIdx = Integer.parseInt(split[0].toString().substring(1));
//System.out.println(split[1]);
int ylabel = Integer.parseInt(split[1].toString().substring(1));
zUpdate[mentionIdx].add(ylabel);
numOfUpdates++;
if(ylabel != nilIndex)
zUpdate[mentionIdx].add(ylabel);
}
}

if (numOfUpdates != numOfMentions)
{
System.out.println(result);
System.out.println("----------ERROR-----------");
System.out.println("GOLDPOS : " + goldPos);
}

//assert (numOfUpdates == numOfMentions);


return zUpdate;
}

Expand Down

0 comments on commit 51eb734

Please sign in to comment.