diff --git a/src/pipelines/generation.rs b/src/pipelines/generation.rs index c9de1341d..5163e93b0 100644 --- a/src/pipelines/generation.rs +++ b/src/pipelines/generation.rs @@ -1239,7 +1239,7 @@ impl PrivateLanguageGenerator None, Some(attention_mask), encoder_outputs, - Some(input_ids), + Some(input_ids.narrow(1, -1, 1)), Cache::T5Cache(past), ), Cache::None => ( diff --git a/src/t5/attention.rs b/src/t5/attention.rs index 5740a70ab..120b382b1 100644 --- a/src/t5/attention.rs +++ b/src/t5/attention.rs @@ -46,6 +46,7 @@ impl LayerState { #[derive(Debug)] pub struct T5Attention { is_decoder: bool, + is_bidirectional: bool, has_relative_attention_bias: bool, relative_attention_num_buckets: i64, d_model: i64, @@ -67,6 +68,7 @@ impl T5Attention { p: P, config: &T5Config, is_decoder: bool, + is_bidirectional: bool, store_cache: bool, output_attentions: bool, has_relative_attention_bias: bool, @@ -101,6 +103,7 @@ impl T5Attention { T5Attention { is_decoder, + is_bidirectional, has_relative_attention_bias, relative_attention_num_buckets: config.relative_attention_num_buckets, d_model: config.d_model, @@ -275,7 +278,7 @@ impl T5Attention { let rp_bucket = self.get_relative_position_bucket( &relative_position, - !self.is_decoder, + self.is_bidirectional, self.relative_attention_num_buckets, 128, ); @@ -310,6 +313,7 @@ impl T5LayerSelfAttention { p / "SelfAttention", config, is_decoder, + !is_decoder, store_cache, output_attentions, has_relative_attention_bias, @@ -375,6 +379,7 @@ impl T5LayerCrossAttention { p / "EncDecAttention", config, is_decoder, + true, store_cache, output_attentions, has_relative_attention_bias, diff --git a/src/t5/t5_model.rs b/src/t5/t5_model.rs index bba00b6f3..5c1a18d1e 100644 --- a/src/t5/t5_model.rs +++ b/src/t5/t5_model.rs @@ -324,28 +324,6 @@ impl T5Model { ) .unwrap(), }; - let (calculated_decoder_input_ids, calculated_decoder_input_embeds) = - if old_layer_states.is_some() { - let decoder_input_ids = match decoder_input_ids { - Some(value) => Some(value.narrow(1, -1, 1)), - None => None, - }; - let decoder_input_embeds = match &decoder_input_embeds { - Some(value) => Some(value.narrow(1, -1, 1)), - None => None, - }; - (decoder_input_ids, decoder_input_embeds) - } else { - (None, None) - }; - let (decoder_input_ids, decoder_input_embeds) = if old_layer_states.is_some() { - ( - calculated_decoder_input_ids.as_ref(), - calculated_decoder_input_embeds, - ) - } else { - (decoder_input_ids, decoder_input_embeds) - }; let decoder_output = self .decoder