Skip to content

Commit

Permalink
Add tree_from_str for loading a single tree
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Jun 28, 2019
1 parent 589825e commit 9864130
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions src/trees.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,55 @@ def leaves(self):
def convert(self):
return LeafTreebankNode(self.tag, self.word)


def tree_from_str(treebank, strip_top=True, strip_spmrl_features=True):
# Features bounded by `##` may contain spaces, so if we strip the features
# we need to do so prior to tokenization
if strip_spmrl_features:
treebank = "".join(treebank.split("##")[::2])

tokens = treebank.replace("(", " ( ").replace(")", " ) ").split()

def helper(index):
trees = []

while index < len(tokens) and tokens[index] == "(":
paren_count = 0
while tokens[index] == "(":
index += 1
paren_count += 1

label = tokens[index]
index += 1

if tokens[index] == "(":
children, index = helper(index)
trees.append(InternalTreebankNode(label, children))
else:
word = tokens[index]
index += 1
trees.append(LeafTreebankNode(label, word))

while paren_count > 0:
assert tokens[index] == ")"
index += 1
paren_count -= 1

return trees, index

trees, index = helper(0)
assert index == len(tokens)

if strip_top:
for i, tree in enumerate(trees):
if tree.label in ("TOP", "ROOT"):
assert len(tree.children) == 1
trees[i] = tree.children[0]

assert len(trees) == 1

return trees[0]

def load_trees(path, strip_top=True, strip_spmrl_features=True):
with open(path) as infile:
treebank = infile.read()
Expand Down

0 comments on commit 9864130

Please sign in to comment.