Skip to content

Commit

Permalink
Delete unused viz folder and subcommand
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Feb 5, 2021
1 parent 148b6fc commit 3cee9f1
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 434 deletions.
59 changes: 0 additions & 59 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,58 +502,6 @@ def run_parse(args):
output_file.write("{}\n".format(tree.linearize()))
print("Output written to:", args.output_path)

#%%
def run_viz(args):
assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported"

print("Loading test trees from {}...".format(args.viz_path))
viz_treebank = trees.load_trees(args.viz_path)
print("Loaded {:,} test examples.".format(len(viz_treebank)))

print("Loading model from {}...".format(args.model_path_base))

info = torch_load(args.model_path_base)

assert 'hparams' in info['spec'], "Only self-attentive models are supported"
parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict'])

from viz import viz_attention

stowed_values = {}
orig_multihead_forward = parse_nk.MultiHeadAttention.forward
def wrapped_multihead_forward(self, inp, batch_idxs, **kwargs):
res, attns = orig_multihead_forward(self, inp, batch_idxs, **kwargs)
stowed_values[f'attns{stowed_values["stack"]}'] = attns.cpu().data.numpy()
stowed_values['stack'] += 1
return res, attns

parse_nk.MultiHeadAttention.forward = wrapped_multihead_forward

# Select the sentences we will actually be visualizing
max_len_viz = 15
if max_len_viz > 0:
viz_treebank = [tree for tree in viz_treebank if len(list(tree.leaves())) <= max_len_viz]
viz_treebank = viz_treebank[:1]

print("Parsing viz sentences...")

for start_index in range(0, len(viz_treebank), args.eval_batch_size):
subbatch_trees = viz_treebank[start_index:start_index+args.eval_batch_size]
subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees]
stowed_values = dict(stack=0)
predicted, _ = parser.parse_batch(subbatch_sentences)
del _
predicted = [p.convert() for p in predicted]
stowed_values['predicted'] = predicted

for snum, sentence in enumerate(subbatch_sentences):
sentence_words = [tokens.START] + [x[1] for x in sentence] + [tokens.STOP]

for stacknum in range(stowed_values['stack']):
attns_padded = stowed_values[f'attns{stacknum}']
attns = attns_padded[snum::len(subbatch_sentences), :len(sentence_words), :len(sentence_words)]
viz_attention(sentence_words, attns)


def main():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -597,13 +545,6 @@ def main():
subparser.add_argument("--output-path", type=str, default="-")
subparser.add_argument("--eval-batch-size", type=int, default=100)

subparser = subparsers.add_parser("viz")
subparser.set_defaults(callback=run_viz)
subparser.add_argument("--model-path-base", required=True)
subparser.add_argument("--evalb-dir", default="EVALB/")
subparser.add_argument("--viz-path", default="data/22.auto.clean")
subparser.add_argument("--eval-batch-size", type=int, default=100)

args = parser.parse_args()
args.callback(args)

Expand Down
52 changes: 0 additions & 52 deletions src/viz.py

This file was deleted.

Loading

0 comments on commit 3cee9f1

Please sign in to comment.