Skip to content

Commit

Permalink
fixed bug in grammar induction
Browse files Browse the repository at this point in the history
  • Loading branch information
Kevin M Ellis committed Apr 22, 2018
1 parent ebde3ed commit 0c16b95
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion fragmentGrammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def restrictFrontiers():
def grammarScore(g):
g = g.makeUniform().insideOutside(restrictedFrontiers, pseudoCounts)
likelihood = g.jointFrontiersMDL(restrictedFrontiers)
structure = sum(fragmentSize(p) for p in g.primitives)
structure = sum(primitiveSize(p) for p in g.primitives)
score = likelihood - aic*len(g) - structurePenalty*structure
g.clearCache()
if invalid(score):
Expand All @@ -287,6 +287,7 @@ def grammarScore(g):
if aic is not POSITIVEINFINITY:
restrictedFrontiers = restrictFrontiers()
bestScore, _ = grammarScore(bestGrammar)
eprint("Starting score",bestScore)
while True:
restrictedFrontiers = restrictFrontiers()
fragments = [ f
Expand Down
4 changes: 4 additions & 0 deletions fragmentUtilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def fragmentSize(f, boundVariableCost = 0.1, freeVariableCost = 0.01):
assert not isinstance(e,FragmentVariable)
return leaves + boundVariableCost*boundVariables + freeVariableCost*freeVariables

def primitiveSize(e):
if e.isInvented: e = e.body
return fragmentSize(e)

def defragment(expression):
'''Converts a fragment into an invented primitive'''
if isinstance(expression, (Primitive,Invented)): return expression
Expand Down
6 changes: 6 additions & 0 deletions frontier.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def summarize(self):
best = self.bestPosterior
return "HIT %s w/ %s ; log prior = %f ; log likelihood = %f"%(self.task.name, best.program, best.logPrior, best.logLikelihood)

def summarizeFull(self):
if self.empty: return "MISS " + self.task.name
return "\n".join([self.task.name] + \
[ "%f\t%s"%(e.logPosterior, e.program)
for e in self.normalize() ])

@staticmethod
def describe(frontiers):
numberOfHits = sum(not f.empty for f in frontiers)
Expand Down
6 changes: 5 additions & 1 deletion makeTextTasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,12 @@ def problem(n, examples, needToTrain = False):
[ ((x,), x.replace(d1,d2))
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWords(d1)] ],
needToTrain=True)
needToTrain=False)
for d in delimiters:
problem("drop first were delimited by '%s'"%d,
[ ((x,), d.join(x.split(d)[1:]))
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWords(d)] ])
for n in [0,1,-1]:
problem("nth (n=%d) word delimited by '%s'"%(n,d),
[ ((x,), x.split(d)[n])
Expand Down

0 comments on commit 0c16b95

Please sign in to comment.