Skip to content

Commit

Permalink
Implement BART incremental decoding (facebookresearch#3231)
Browse files Browse the repository at this point in the history
* Implement BART incremental decoding

* Fix lint in random other spot.

* Reviewer comments.
  • Loading branch information
stephenroller authored Oct 28, 2020
1 parent 04f9884 commit 5e0ec44
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions parlai/agents/bart/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,18 @@ def reorder_decoder_incremental_state(
"""
Incremental state is weird to handle when we seed decoder with two inputs
initially.
We leave as a future exercise.
"""
return None
# we only have this method called when it's actually being used
assert incremental_state is not None
assert len(incremental_state) > 0

for incr_state_l in incremental_state.values():
assert 'self_attn' in incr_state_l
assert 'prev_mask' in incr_state_l['self_attn']
self_attn_mask = incr_state_l['self_attn']['prev_mask']
# check this is on the very first run with incremental state
if self_attn_mask.ndim == 3 and tuple(self_attn_mask.shape[1:]) == (2, 2):
# cut off the inappropriate incremental state
incr_state_l['self_attn']['prev_mask'] = self_attn_mask[:, -1:, :]

return super().reorder_decoder_incremental_state(incremental_state, inds)

0 comments on commit 5e0ec44

Please sign in to comment.