Skip to content

Commit

Permalink
Merge pull request guillaume-be#83 from guillaume-be/t5_minor_fixes
Browse files Browse the repository at this point in the history
T5 minor fixes
  • Loading branch information
guillaume-be authored Oct 3, 2020
2 parents 2bf14ec + ce7333c commit 73e20aa
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 24 deletions.
2 changes: 1 addition & 1 deletion src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1239,7 +1239,7 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
None,
Some(attention_mask),
encoder_outputs,
Some(input_ids),
Some(input_ids.narrow(1, -1, 1)),
Cache::T5Cache(past),
),
Cache::None => (
Expand Down
7 changes: 6 additions & 1 deletion src/t5/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ impl LayerState {
#[derive(Debug)]
pub struct T5Attention {
is_decoder: bool,
is_bidirectional: bool,
has_relative_attention_bias: bool,
relative_attention_num_buckets: i64,
d_model: i64,
Expand All @@ -67,6 +68,7 @@ impl T5Attention {
p: P,
config: &T5Config,
is_decoder: bool,
is_bidirectional: bool,
store_cache: bool,
output_attentions: bool,
has_relative_attention_bias: bool,
Expand Down Expand Up @@ -101,6 +103,7 @@ impl T5Attention {

T5Attention {
is_decoder,
is_bidirectional,
has_relative_attention_bias,
relative_attention_num_buckets: config.relative_attention_num_buckets,
d_model: config.d_model,
Expand Down Expand Up @@ -275,7 +278,7 @@ impl T5Attention {

let rp_bucket = self.get_relative_position_bucket(
&relative_position,
!self.is_decoder,
self.is_bidirectional,
self.relative_attention_num_buckets,
128,
);
Expand Down Expand Up @@ -310,6 +313,7 @@ impl T5LayerSelfAttention {
p / "SelfAttention",
config,
is_decoder,
!is_decoder,
store_cache,
output_attentions,
has_relative_attention_bias,
Expand Down Expand Up @@ -375,6 +379,7 @@ impl T5LayerCrossAttention {
p / "EncDecAttention",
config,
is_decoder,
true,
store_cache,
output_attentions,
has_relative_attention_bias,
Expand Down
22 changes: 0 additions & 22 deletions src/t5/t5_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,28 +324,6 @@ impl T5Model {
)
.unwrap(),
};
let (calculated_decoder_input_ids, calculated_decoder_input_embeds) =
if old_layer_states.is_some() {
let decoder_input_ids = match decoder_input_ids {
Some(value) => Some(value.narrow(1, -1, 1)),
None => None,
};
let decoder_input_embeds = match &decoder_input_embeds {
Some(value) => Some(value.narrow(1, -1, 1)),
None => None,
};
(decoder_input_ids, decoder_input_embeds)
} else {
(None, None)
};
let (decoder_input_ids, decoder_input_embeds) = if old_layer_states.is_some() {
(
calculated_decoder_input_ids.as_ref(),
calculated_decoder_input_embeds,
)
} else {
(decoder_input_ids, decoder_input_embeds)
};

let decoder_output = self
.decoder
Expand Down

0 comments on commit 73e20aa

Please sign in to comment.