Skip to content

Commit

Permalink
updated summarization example
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Apr 6, 2020
1 parent fa4cdc0 commit 2d7f4c6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
25 changes: 16 additions & 9 deletions examples/summarization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::summarization::{SummarizationModel, SummarizationConfig};
use std::time::Instant;


fn main() -> failure::Fallible<()> {
Expand All @@ -40,7 +41,7 @@ fn main() -> failure::Fallible<()> {
let device = Device::cuda_if_available();

let summarization_config = SummarizationConfig {
num_beams: 1,
num_beams: 3,
..Default::default()
};

Expand All @@ -62,21 +63,27 @@ but previous discoveries were made on planets with high temperatures or other pr
said UCL astronomer Angelos Tsiaras. \"It's the best candidate for habitability right now.\" \"It's a good sign\", \
said Ryan Cloutier of the Harvard–Smithsonian Center for Astrophysics, who was not one of either study's authors. \
\"Overall,\" he continued, \"the presence of water in its atmosphere certainly improves the prospect of K2-18b being \
a potentially habitable planet, but further observations will be required to say for sure. \"
a potentially habitable planet, but further observations will be required to say for sure. \" \
K2-18b was first identified in 2015 by the Kepler space telescope. It is about 110 light-years from Earth and larger \
but less dense. Its star, a red dwarf, is cooler than the Sun, but the planet's orbit is much closer, such that a year \
on K2-18b lasts 33 Earth days. According to The Guardian, astronomers were optimistic that NASA's James Webb space \
telescope — scheduled for launch in 2021 — and the European Space Agency's 2028 ARIEL program, could reveal more \
about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{:?}", sentence);
}
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{:?}", sentence);

let num_iterations = 3;
let mut all_iteration_times = Vec::with_capacity(num_iterations);

for _ in 0..num_iterations {
let start_iteration = Instant::now();
let output = summarization_model.summarize(&input);
for sentence in output {
println!("{:?}", sentence);
}
println!("iteration total time {:?}", start_iteration.elapsed().as_millis());
all_iteration_times.push(start_iteration.elapsed().as_millis())
}
println!("average total time {:?}", all_iteration_times.iter().sum::<u128>() / all_iteration_times.len() as u128);
Ok(())
}
39 changes: 38 additions & 1 deletion src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,7 @@ mod private_generation_utils {
use crate::pipelines::generation::{BeamHypotheses, GenerateConfig, LMHeadModel};
use itertools::Itertools;
use super::ordered_float::OrderedFloat;
use std::time::Instant;

pub trait PrivateLanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>> {
fn get_model(&mut self) -> &mut T;
Expand Down Expand Up @@ -806,7 +807,16 @@ mod private_generation_utils {
let mut encoder_outputs = encoder_outputs;
let mut current_length = cur_len;

let mut all_forwards = Vec::with_capacity(max_length as usize);
let mut all_repetition_penalty_time = Vec::with_capacity(max_length as usize);
let mut all_banned_tokens_time = Vec::with_capacity(max_length as usize);
let mut all_sample_time = Vec::with_capacity(max_length as usize);
let mut all_beam_decoding_time = Vec::with_capacity(max_length as usize);
let mut all_cache_reorder_time = Vec::with_capacity(max_length as usize);


while current_length < max_length {
let start_iter = Instant::now();
let (prepared_input,
prepared_encoder_output,
prepared_decoder_input,
Expand All @@ -826,13 +836,20 @@ mod private_generation_utils {
outputs = temp.0;
encoder_outputs = temp.1;
past = temp.2;

let forward_time = Instant::now();
let elapsed_forward_time = (forward_time - start_iter).as_millis();
all_forwards.push(elapsed_forward_time);

let mut next_token_logits = outputs.select(1, -1);

// Reduce probability for repeated inputs
if repetition_penalty > 1f64 {
self.enforce_repetition_penalty(&mut next_token_logits, batch_size, 1, &input_ids, repetition_penalty)
}

let repetition_penalty_time = Instant::now();
let elapsed_repetition_penalty_time = (repetition_penalty_time - forward_time).as_millis();
all_repetition_penalty_time.push(elapsed_repetition_penalty_time);
if temperature > 1f64 {
next_token_logits = next_token_logits / temperature;
}
Expand All @@ -849,6 +866,9 @@ mod private_generation_utils {
for (batch_index, index_banned_token) in (0..banned_tokens.len() as i64).zip(banned_tokens) {
&scores.get(batch_index).index_fill_(0, &Tensor::of_slice(&index_banned_token).to_device(next_token_logits.device()), std::f64::NEG_INFINITY);
}
let banned_tokens_time = Instant::now();
let elapsed_banned_tokens_time = (banned_tokens_time - repetition_penalty_time).as_millis();
all_banned_tokens_time.push(elapsed_banned_tokens_time);
let (next_scores, next_tokens) = if do_sample {
let mut _scores: Tensor = &scores + &beam_scores.unsqueeze(-1).expand_as(&scores);
self.top_k_top_p_filtering(&mut _scores, top_k as i64, top_p, 2);
Expand All @@ -865,6 +885,9 @@ mod private_generation_utils {
let next_scores = next_scores.contiguous().view((batch_size, num_beams * vocab_size));
next_scores.topk(2 * num_beams, 1, true, true)
};
let sample_time = Instant::now();
let elapsed_sample_time = (sample_time - banned_tokens_time).as_millis();
all_sample_time.push(elapsed_sample_time);
let mut next_batch_beam: Vec<(f64, i64, i64)> = vec!();
for batch_index in 0..batch_size {
if done[batch_index as usize] {
Expand Down Expand Up @@ -923,6 +946,10 @@ mod private_generation_utils {
if done.iter().all(|&x| x) {
break;
}
let beam_decoding_time = Instant::now();
let elapsed_beam_decoding_time = (beam_decoding_time - sample_time).as_millis();
all_beam_decoding_time.push(elapsed_beam_decoding_time);

beam_scores = Tensor::of_slice(&next_batch_beam.iter().map(|(score, _, _)| *score).collect_vec()).to(input_ids.device());
beam_tokens = Tensor::of_slice(&next_batch_beam.iter().map(|(_, token, _)| *token).collect_vec()).to(input_ids.device());
beam_indices = Tensor::of_slice(&next_batch_beam.iter().map(|(_, _, index)| *index).collect_vec()).to(input_ids.device());
Expand All @@ -936,6 +963,9 @@ mod private_generation_utils {
attention_mask = Tensor::cat(&[attention_mask.as_ref(), Tensor::ones(&[*attention_mask.size().first().unwrap(), 1],
(Int64, attention_mask.device())).as_ref()], -1);
}
let cache_reorder_time = Instant::now();
let elapsed_cache_reorder_time = (cache_reorder_time - beam_decoding_time).as_millis();
all_cache_reorder_time.push(elapsed_cache_reorder_time);
current_length += 1;
}

Expand Down Expand Up @@ -1002,6 +1032,13 @@ mod private_generation_utils {
} else {
Tensor::stack(&best_ids, 0).to_kind(Int64).to(input_ids.device())
};

println!(" Total forward time {:?}", all_forwards.iter().sum::<u128>());
println!(" Total repetition time {:?}", all_repetition_penalty_time.iter().sum::<u128>());
println!(" Total banned tokens time {:?}", all_banned_tokens_time.iter().sum::<u128>());
println!(" Total sample time {:?}", all_sample_time.iter().sum::<u128>());
println!(" Total beam decoding time {:?}", all_beam_decoding_time.iter().sum::<u128>());
println!(" Total cache reorder time {:?}", all_cache_reorder_time.iter().sum::<u128>());
decoded
}

Expand Down

0 comments on commit 2d7f4c6

Please sign in to comment.