Skip to content

Commit

Permalink
babi_rnn bugfix: QA19 requires vocab from the answer
Browse files Browse the repository at this point in the history
For all other questions, the full vocab is in the stories and the queries
  • Loading branch information
Smerity committed Aug 6, 2015
1 parent e42f738 commit 63284a4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion examples/babi_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def vectorize_stories(data):
train = get_stories(tar.extractfile(challenge.format('train')))
test = get_stories(tar.extractfile(challenge.format('test')))

vocab = sorted(reduce(lambda x, y: x | y, (set(story + q) for story, q, answer in train + test)))
vocab = sorted(reduce(lambda x, y: x | y, (set(story + q + [answer]) for story, q, answer in train + test)))
# Reserve 0 for masking via pad_sequences
vocab_size = len(vocab) + 1
word_idx = dict((c, i + 1) for i, c in enumerate(vocab))
Expand Down

0 comments on commit 63284a4

Please sign in to comment.