From 95a941ebca4a551bad4a35e9fd46cc03efc6c1ad Mon Sep 17 00:00:00 2001 From: Ottokar Tilk Date: Wed, 5 Apr 2017 02:04:48 +0300 Subject: [PATCH] enable showing UNK in play_with_model.py --- README.md | 8 +++++++- play_with_model.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 54062a5..39aa119 100644 --- a/README.md +++ b/README.md @@ -110,6 +110,12 @@ You can play with a trained model with: `python play_with_model.py ` +or with: + +`python play_with_model.py 1` + +if you want to see, which words the model sees as UNKs (OOVs). + # Citing @@ -122,4 +128,4 @@ The software is described in: year = {2016} } -We used the [release v1.0](https://github.com/ottokart/punctuator2/releases/tag/v1.0) +We used the [release v1.0](https://github.com/ottokart/punctuator2/releases/tag/v1.0) in the paper. diff --git a/play_with_model.py b/play_with_model.py index 4b6a879..bfc1e8f 100644 --- a/play_with_model.py +++ b/play_with_model.py @@ -22,7 +22,7 @@ def convert_punctuation_to_readable(punct_token): else: return punct_token[0] -def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, text, f_out): +def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, reverse_word_vocabulary, text, f_out, show_unk): if len(text) == 0: sys.exit("Input text from stdin missing.") @@ -39,6 +39,8 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat break converted_subsequence = [word_vocabulary.get(w, word_vocabulary[data.UNK]) for w in subsequence] + if show_unk: + subsequence = [reverse_word_vocabulary[w] for w in converted_subsequence] y = predict(to_array(converted_subsequence)) @@ -81,6 +83,10 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat else: sys.exit("Model file path argument missing") + show_unk = False + if len(sys.argv) > 2: + show_unk = bool(int(sys.argv[2])) + x = T.imatrix('x') print "Loading model parameters..." @@ -90,9 +96,10 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat predict = theano.function(inputs=[x], outputs=net.y) word_vocabulary = net.x_vocabulary punctuation_vocabulary = net.y_vocabulary + reverse_word_vocabulary = {v:k for k,v in net.x_vocabulary.items()} reverse_punctuation_vocabulary = {v:k for k,v in net.y_vocabulary.items()} with codecs.getwriter('utf-8')(sys.stdout) as f_out: while True: text = raw_input("\nTEXT: ").decode('utf-8') - punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, text, f_out) + punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, reverse_word_vocabulary, text, f_out, show_unk)