Skip to content

Commit

Permalink
text primitives and problems targeting syntax guided synthesis
Browse files Browse the repository at this point in the history
  • Loading branch information
ellisk42 committed Apr 19, 2018
1 parent 1bbb57e commit 96b58b6
Show file tree
Hide file tree
Showing 7 changed files with 86 additions and 28 deletions.
1 change: 1 addition & 0 deletions listPrimitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ def bootstrapTarget():
Primitive("range", arrow(tint, tlist(tint)), range),
Primitive("index", arrow(tint, tlist(t0), t0), _index),
Primitive("fold", arrow(tlist(t0), t1, arrow(t0,t1,t1), t1), _fold),
Primitive("length", arrow(tlist(t0),tint), len),

# built-ins
Primitive("if", arrow(tbool, t0, t0, t0), _if),
Expand Down
67 changes: 64 additions & 3 deletions makeTextTasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random


delimiters = ['.',',',' ','<','>','/','@','-','|']
delimiters = ['.',',',' ','(',')','-']

def randomDelimiter():
return random.choice(delimiters)
Expand Down Expand Up @@ -63,7 +63,7 @@ def randomWords(d):
("drop first character","drop first character")
}

def makeTasks():
def makeOldTasks():
NUMBEROFEXAMPLES = 4
problems = []
def toList(s): return [c for c in s]
Expand Down Expand Up @@ -241,6 +241,66 @@ def problem(n, examples, needToTrain = False):
for x in [randomWord() + d1 + y + d2 + randomWord()] ])
return problems

def makeTasks():
NUMBEROFEXAMPLES = 4

problems = []
def toList(s): return [c for c in s]
# Converts strings into a list of characters depending on the type
def preprocess(x):
if isinstance(x,tuple): return tuple( preprocess(z) for z in x)
if isinstance(x,list): return [ preprocess(z) for z in x ]
if isinstance(x,str): return [ c for c in x ]
assert False

def problem(n, examples, needToTrain = False):
task = Task(n, guess_arrow_type(examples),
[(preprocess(x),
preprocess(y))
for x,y in examples ])
if needToTrain: task.mustTrain = True
problems.append(task)

for d1 in delimiters:
for d2 in delimiters:
if d1 != d2:
problem("Replace '%s' w/ '%s'"%(d1,d2),
[ ((x,), x.replace(d1,d2))
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWords(d1)] ],
needToTrain=True)
for d in delimiters:
for n in [0,1,-1]:
problem("nth (n=%d) word delimited by '%s'"%(n,d),
[ ((x,), x.split(d)[n])
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWords(d)] ],
needToTrain=True)
for d1 in delimiters:
problem("Append two words delimited by '%s'"%(d1),
[ ((x,y), x + d1 + y)
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWord()]
for y in [randomWord()] ])
for d2 in delimiters:
problem("Append two words delimited by '%s%s'"%(d1,d2),
[ ((x,y), x + d1 + d2 + y)
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWord()]
for y in [randomWord()] ])
for n in xrange(1,4):
problem("Drop last %d characters"%n,
[ ((x,), x[:-n])
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWord() + randomWord()] ])
for d in delimiters:
problem("Take first character and append '%s'"%d,
[ ((x,), x[0] + d)
for _ in range(NUMBEROFEXAMPLES)
for x in [randomWord()] ])
return problems



def loadPBETasks(directory="PBE_Strings_Track"):
"""
Expand Down Expand Up @@ -302,9 +362,10 @@ def findStrings(s):
if __name__ == "__main__":
import sys
loadPBETasks()
assert False

tasks = makeTasks()
for t in tasks: print t.describe()
assert False
# def maximumLength(x):
# if isinstance(x,list):
# return max([len(x)] + map(maximumLength,x))
Expand Down
Binary file modified solver
Binary file not shown.
6 changes: 4 additions & 2 deletions solvers/program.ml
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,8 @@ let primitive_constant_strings = [primitive "','" tcharacter ',';
primitive "'/'" tcharacter '/';
primitive "'|'" tcharacter '|';
primitive "'-'" tcharacter '-';
primitive "LPAREN" tcharacter '(';
primitive "RPAREN" tcharacter ')';
];;
(* let primitive_slice_string = primitive "slice-string" (tint @> tint @> tstring @> tstring)
* (fun i j s ->
Expand Down Expand Up @@ -342,8 +344,8 @@ let primitive_reverse = primitive "reverse" (tlist tint @> tlist tint) (List.rev
let primitive_append = primitive "append" (tlist tint @> tlist tint @> tlist tint) (@);;
let primitive_singleton = primitive "singleton" (tint @> tlist tint) (fun x -> [x]);;
let primitive_slice = primitive "slice" (tint @> tint @> tlist tint @> tlist tint) slice;;
let primitive_length = primitive "length" (tlist tint @> tint) (List.length);;
let primitive_map = primitive "map" ((tint @> tint) @> (tlist tint) @> (tlist tint)) (fun f l -> List.map ~f:f l);;
let primitive_length = primitive "length" (tlist t0 @> tint) (List.length);;
let primitive_map = primitive "map" ((t0 @> t1) @> (tlist t0) @> (tlist t1)) (fun f l -> List.map ~f:f l);;
let primitive_fold_right = primitive "fold_right" ((tint @> tint @> tint) @> tint @> (tlist tint) @> tint) (fun f x0 l -> List.fold_right ~f:f ~init:x0 l);;
let primitive_mapi = primitive "mapi" ((tint @> t0 @> t1) @> (tlist t0) @> (tlist t1)) (fun f l ->
List.mapi l ~f:f);;
Expand Down
5 changes: 3 additions & 2 deletions text.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from utilities import eprint, testTrainSplit, numberOfCPUs
from makeTextTasks import makeTasks, delimiters
from textPrimitives import primitives
from listPrimitives import bootstrapTarget
from program import *
from recognition import *

Expand Down Expand Up @@ -40,10 +41,10 @@ def __init__(self, tasks):
tasks = makeTasks()
eprint("Generated",len(tasks),"tasks")

test, train = testTrainSplit(tasks, 0.2)
test, train = testTrainSplit(tasks, 0.9)
eprint("Split tasks into %d/%d test/train"%(len(test),len(train)))

baseGrammar = Grammar.uniform(primitives)
baseGrammar = Grammar.uniform(primitives + bootstrapTarget())

explorationCompression(baseGrammar, train,
testingTasks = test,
Expand Down
28 changes: 7 additions & 21 deletions textPrimitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,11 @@ def _identity(x): return x
def _strip(x): return x.strip()
def _eq(x): return lambda y: x == y

specialCharacters = {' ': 'SPACE',
')': 'RPAREN',
'(': 'LPAREN'}

primitives = [
# Primitive("0",tint,0),
#Primitive("len",arrow(tstr,tint),len),
# Primitive("+1",arrow(tint,tint),_increment),
# Primitive("-1",arrow(tint,tint),_decrement),
# Primitive("emptyString",tstr,""),
Primitive("char-eq?",arrow(tcharacter,tcharacter,tboolean),_eq),
Primitive("caseLower",arrow(tcharacter,tcharacter), _lower),
Primitive("caseUpper",arrow(tcharacter,tcharacter), _upper),
#Primitive("caseCapitalize",arrow(tstr,tstr), _capitalize),
# Primitive("concatenate",arrow(tstr,tstr,tstr), _append),
# Primitive("slice-string", arrow(tint,tint,tstr,tstr),_slice),
# Primitive("nth", arrow(tint, tlist(tstr), tstr),_index),
# Primitive("map-string", arrow(arrow(tstr,tstr), tlist(tstr), tlist(tstr)),_map),
#Primitive("find", arrow(tcharacter, tstr, tint),_find),
#Primitive("replace", arrow(tstr, tstr, tstr, tstr),_replace),
# Primitive("strip", arrow(tstr,tstr),_strip),
# Primitive("split", arrow(tcharacter, tstr, tlist(tstr)),_split),
# Primitive("join", arrow(tstr, tlist(tstr), tstr),_join),
# Primitive("chr2str", arrow(tcharacter, tstr), _identity),
] + [ Primitive("'%s'"%d, tcharacter, d) for d in delimiters if d != ' '] + \
[ Primitive("SPACE", tcharacter, ' ')]
Primitive("char-eq?",arrow(tcharacter,tcharacter,tboolean),_eq)
] + [ Primitive("'%s'"%d, tcharacter, d) for d in delimiters if d not in specialCharacters] + \
[ Primitive(name, tcharacter, value) for value, name in specialCharacters.iteritems() ]
7 changes: 7 additions & 0 deletions type.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,3 +199,10 @@ def guess_type(xs):
return tlist(guess_type([y for ys in xs for y in ys]))
else:
raise ValueError("cannot guess type from {}".format(xs))
def guess_arrow_type(examples):
a = len(examples[0][0])
input_types = []
for n in xrange(a):
input_types.append(guess_type([ xs[n] for xs,_ in examples ]))
output_type = guess_type([ y for _,y in examples ])
return arrow(*(input_types + [output_type]))

0 comments on commit 96b58b6

Please sign in to comment.