Skip to content

Commit

Permalink
enable showing UNK in play_with_model.py
Browse files Browse the repository at this point in the history
  • Loading branch information
ottokart committed Apr 4, 2017
1 parent 8e27b0b commit 95a941e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ You can play with a trained model with:

`python play_with_model.py <model_path>`

or with:

`python play_with_model.py <model_path> 1`

if you want to see, which words the model sees as UNKs (OOVs).


# Citing

Expand All @@ -122,4 +128,4 @@ The software is described in:
year = {2016}
}

We used the [release v1.0](https://github.com/ottokart/punctuator2/releases/tag/v1.0)
We used the [release v1.0](https://github.com/ottokart/punctuator2/releases/tag/v1.0) in the paper.
11 changes: 9 additions & 2 deletions play_with_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def convert_punctuation_to_readable(punct_token):
else:
return punct_token[0]

def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, text, f_out):
def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, reverse_word_vocabulary, text, f_out, show_unk):

if len(text) == 0:
sys.exit("Input text from stdin missing.")
Expand All @@ -39,6 +39,8 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat
break

converted_subsequence = [word_vocabulary.get(w, word_vocabulary[data.UNK]) for w in subsequence]
if show_unk:
subsequence = [reverse_word_vocabulary[w] for w in converted_subsequence]

y = predict(to_array(converted_subsequence))

Expand Down Expand Up @@ -81,6 +83,10 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat
else:
sys.exit("Model file path argument missing")

show_unk = False
if len(sys.argv) > 2:
show_unk = bool(int(sys.argv[2]))

x = T.imatrix('x')

print "Loading model parameters..."
Expand All @@ -90,9 +96,10 @@ def punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuat
predict = theano.function(inputs=[x], outputs=net.y)
word_vocabulary = net.x_vocabulary
punctuation_vocabulary = net.y_vocabulary
reverse_word_vocabulary = {v:k for k,v in net.x_vocabulary.items()}
reverse_punctuation_vocabulary = {v:k for k,v in net.y_vocabulary.items()}

with codecs.getwriter('utf-8')(sys.stdout) as f_out:
while True:
text = raw_input("\nTEXT: ").decode('utf-8')
punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, text, f_out)
punctuate(predict, word_vocabulary, punctuation_vocabulary, reverse_punctuation_vocabulary, reverse_word_vocabulary, text, f_out, show_unk)

0 comments on commit 95a941e

Please sign in to comment.