Skip to content

Commit

Permalink
Summarization without beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Apr 4, 2020
1 parent dbf6841 commit fa87fce
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 12 deletions.
8 changes: 6 additions & 2 deletions examples/summarization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ fn main() -> failure::Fallible<()> {
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 142,
do_sample: false,
num_beams: 1,
do_sample: true,
num_beams: 3,
temperature: 1.0,
top_k: 50,
top_p: 1.0,
length_penalty: 2.0,
min_length: 56,
num_return_sequences: 1,
..Default::default()
};
Expand Down
3 changes: 1 addition & 2 deletions src/bart/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ impl LMHeadModel for BartForConditionalGeneration {
train);

let lm_logits = decoder_output.linear::<Tensor>(&self.base_model.encoder.embed_tokens.ws, None);

Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
Ok((lm_logits, Some(encoder_hidden_states), None, None, None))
}
}
9 changes: 4 additions & 5 deletions src/bart/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,20 @@ impl BartDecoder {
};

let positions = self.embed_positions.forward(input_ids, self.generation_mode);

let (input_ids, positions) = if self.generation_mode {
let end = input_ids.size()[1];
(input_ids.slice(1, end - 1, end, 1), positions.slice(1, end - 1, end, 1))
let end_inputs = input_ids.size()[1];
let end_positions = positions.size()[1];
(input_ids.slice(1, end_inputs - 1, end_inputs, 1),
positions.slice(1, end_positions - 1, end_positions, 1))
} else {
(input_ids.copy(), positions)
};

let x: Tensor = input_ids.as_ref().apply(&self.embed_tokens) + positions;

let x = x
.apply(&self.layer_norm_embedding)
.apply_t(&self.dropout, train)
.transpose(0, 1);

let mut all_hidden_states: Option<Vec<Tensor>> = if self.output_hidden_states { Some(vec!()) } else { None };
let mut all_attentions: Option<Vec<Tensor>> = if self.output_attentions { Some(vec!()) } else { None };
let mut next_decoder_cache: Option<Vec<(&LayerState, &LayerState)>> = if self.output_past { Some(vec!()) } else { None };
Expand Down
6 changes: 3 additions & 3 deletions src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ impl BartGenerator {
///# use std::path::PathBuf;
///# use tch::Device;
///# fn main() -> failure::Fallible<()> {
/// use rust_bert::pipelines::generation::{GenerateConfig, OpenAIGenerator, BartGenerator};
/// use rust_bert::pipelines::generation::{GenerateConfig, BartGenerator};
///# let mut home: PathBuf = dirs::home_dir().unwrap();
///# home.push("rustbert");
///# home.push("openai-gpt");
Expand Down Expand Up @@ -665,7 +665,7 @@ mod private_generation_utils {

// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&next_token_logits.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()), std::f64::NEG_INFINITY);
&next_token_logits.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(next_token_logits.device()), std::f64::NEG_INFINITY);
}

// Top-k and top-p sampling
Expand Down Expand Up @@ -783,7 +783,7 @@ mod private_generation_utils {
let mut scores = next_token_logits.log_softmax(-1, Float);
// Do not allow eos token if min length is not reached
if (&eos_token_ids.is_some()) & (current_length < min_length) {
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()), std::f64::NEG_INFINITY);
&scores.index_fill_(1, &Tensor::of_slice(eos_token_ids.as_ref().unwrap()).to(scores.device()), std::f64::NEG_INFINITY);
}
// Get banned tokens and set their probability to 0
let banned_tokens = self.get_banned_tokens(&input_ids, no_repeat_ngram_size as i64, current_length as i64);
Expand Down

0 comments on commit fa87fce

Please sign in to comment.