Skip to content

Commit

Permalink
Add decoding option to not label the root node
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Feb 5, 2021
1 parent d6407a6 commit 8198cf4
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/decode_chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,14 @@ def uncollapse_unary(tree, ensure_top=False):
class ChartDecoder:
"""A chart decoder for parsing formulated as span classification."""

def __init__(self, label_vocab):
def __init__(self, label_vocab, force_root_constituent=True):
"""Constructs a new ChartDecoder object.
Args:
label_vocab: A mapping from span labels to integer indices.
"""
self.label_vocab = label_vocab
self.label_from_index = {i: label for label, i in label_vocab.items()}
self.force_root_constituent = force_root_constituent

@staticmethod
def build_vocab(trees):
Expand Down Expand Up @@ -169,10 +170,9 @@ def tree_from_scores(self, scores, leaves):
label_scores = scores[left, right - 1]
label_scores = label_scores - label_scores[0]

# TODO(nikita): add option to not label the root node
argmax_label_index = int(
label_scores.argmax()
if length < len(leaves)
if length < len(leaves) or not self.force_root_constituent
else label_scores[1:].argmax() + 1
)
argmax_label = self.label_from_index[argmax_label_index]
Expand Down

0 comments on commit 8198cf4

Please sign in to comment.