Skip to content

Commit

Permalink
Merge pull request guillaume-be#19 from guillaume-be/pipeline_optimiz…
Browse files Browse the repository at this point in the history
…ations

BART caching update and pipelines optimization
  • Loading branch information
guillaume-be authored Apr 6, 2020
2 parents 03f642f + 35fcd0b commit 2bb284c
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 75 deletions.
5 changes: 3 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.6.0"
version = "0.6.1"
authors = ["Guillaume Becquin <[email protected]>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
Expand Down Expand Up @@ -37,4 +37,5 @@ serde = {version = "1.0.104", features = ["derive"]}
failure = "0.1.6"
dirs = "2.0"
itertools = "0.9.0"
ordered-float = "1.0.2"
ordered-float = "1.0.2"
csv = "1.1.3"
2 changes: 1 addition & 1 deletion examples/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ fn main() -> failure::Fallible<()> {
// Set-up masked LM model
let device = Device::cuda_if_available();
let generate_config = GenerateConfig {
max_length: 20,
max_length: 30,
do_sample: true,
num_beams: 5,
temperature: 1.1,
Expand Down
64 changes: 64 additions & 0 deletions examples/sst2.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate failure;
extern crate dirs;

use std::path::PathBuf;
use tch::Device;
use failure::err_msg;
use rust_bert::pipelines::sentiment::{SentimentClassifier, ss2_processor};
use std::env;


fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("distilbert_sst2");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

if !config_path.is_file() | !vocab_path.is_file() | !weights_path.is_file() {
return Err(
err_msg("Could not find required resources to run example. \
Please run ../utils/download_dependencies_sst2_sentiment.py \
in a Python environment with dependencies listed in ../requirements.txt"));
}

// Set-up classifier
let device = Device::cuda_if_available();
let sentiment_classifier = SentimentClassifier::new(vocab_path,
config_path,
weights_path, device)?;

// Define input
let mut sst2_path = PathBuf::from(env::var("SST2_PATH")
.expect("Please set the \"squad_dataset\" environment variable pointing to the SQuAD dataset folder"));
sst2_path.push("train.tsv");
let inputs = ss2_processor(sst2_path).unwrap();

// Run model
let batch_size = 64;
let mut output = vec!();
for batch in inputs.chunks(batch_size) {
output.push(sentiment_classifier.predict(batch.iter().map(|v| v.as_str()).collect::<Vec<&str>>().as_slice()));
}
let mut flat_outputs = vec!();
for batch_output in output.iter_mut() {
flat_outputs.append(batch_output);
}
println!("{:?}", flat_outputs.len());

Ok(())
}
1 change: 0 additions & 1 deletion examples/summarization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ about exoplanets like K2-18b."];

// Credits: WikiNews, CC BY 2.5 license (https://en.wikinews.org/wiki/Astronomers_find_water_vapour_in_atmosphere_of_exoplanet_K2-18b)
let output = summarization_model.summarize(&input);

for sentence in output {
println!("{:?}", sentence);
}
Expand Down
6 changes: 6 additions & 0 deletions src/bart/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ impl LayerState {
self.prev_key_padding_mask = Some(self.prev_key_padding_mask.as_ref().unwrap().index_select(0, new_indices));
}
}

pub(crate) fn reset_cache(&mut self) {
self.prev_key = None;
self.prev_value = None;
self.prev_key_padding_mask = None;
}
}


Expand Down
26 changes: 23 additions & 3 deletions src/bart/bart.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ impl BartModel {
};
let embedding_config = EmbeddingConfig { padding_idx: pad_token_id, ..Default::default() };
let embeddings: nn::Embedding = embedding(p / "shared",
config.vocab_size,
config.d_model,
embedding_config);
config.vocab_size,
config.d_model,
embedding_config);

let encoder = BartEncoder::new(p / "encoder", config);
let decoder = BartDecoder::new(p / "decoder", config, generation_mode);
Expand Down Expand Up @@ -292,6 +292,14 @@ impl BartModel {
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
}

/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
for layer in self.get_decoder().get_layers() {
layer.get_self_attention().prev_state.as_mut().unwrap().reset_cache();
layer.get_encoder_attention().prev_state.as_mut().unwrap().reset_cache();
};
}
}

/// # BART Model for conditional generation
Expand Down Expand Up @@ -416,6 +424,11 @@ impl BartForConditionalGeneration {
let (encoder_hidden_states, _, _) = self.base_model.encoder.forward_t(input_ids, attention_mask, &self.base_model.embeddings, false);
encoder_hidden_states
}

/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
self.get_base_model().reset_cache()
}
}

pub struct BartClassificationHead {
Expand Down Expand Up @@ -571,6 +584,13 @@ impl BartForSequenceClassification {
all_decoder_hidden_states, all_decoder_attentions,
all_encoder_hidden_states, all_encoder_attentions)
}

pub(crate) fn get_base_model(&mut self) -> &mut BartModel { &mut self.base_model }

/// Resets the decoder cached keys and values. Should be run for every new generation using the model.
pub fn reset_cache(&mut self) {
self.get_base_model().reset_cache()
}
}

impl LMHeadModel for BartForConditionalGeneration {
Expand Down
8 changes: 7 additions & 1 deletion src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,10 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
};
(None, encoder_outputs)
}

fn reset_cache(&mut self) {
self.get_model().reset_cache();
}
}

impl LanguageGenerator<BartForConditionalGeneration, RobertaVocab, RobertaTokenizer> for BartGenerator {}
Expand Down Expand Up @@ -1013,6 +1017,8 @@ mod private_generation_utils {
None => (None, None)
}
}

fn reset_cache(&mut self) {}
}
}

Expand Down Expand Up @@ -1165,7 +1171,7 @@ pub trait LanguageGenerator<T: LMHeadModel, V: Vocab, U: Tokenizer<V>>: PrivateL
(input_ids, attention_mask)
};


self.reset_cache();
let decoded = no_grad(|| {
if num_beams > 1 {
self.generate_beam_search(input_ids, encoder_outputs, cur_len, min_length as i64, max_length as i64, do_sample, early_stopping, temperature, top_k as i64, top_p, repetition_penalty,
Expand Down
113 changes: 57 additions & 56 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,63 +311,66 @@ impl QuestionAnsweringModel {

for (start, end) in batch_indices {
let batch_features = &features[start..end];
let mut input_ids: Vec<Tensor> = vec!();
let mut attention_masks: Vec<Tensor> = vec!();
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}
let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());
let mut input_ids = Vec::with_capacity(batch_features.len());
let mut attention_masks = Vec::with_capacity(batch_features.len());

let (start_logits, end_logits, _, _) = no_grad(|| {
self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap()
});
let start_logits = start_logits.to(Device::Cpu);
let end_logits = end_logits.to(Device::Cpu);

let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.collect();

let mut feature_id_start = 0;
for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs();

let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;

let (starts, ends, scores) = self.decode(&start, &end, top_k);

for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");

let start = example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();

let end = example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();

answers.push(Answer { score: scores[idx], start, end, answer });
no_grad(|| {
for feature in batch_features {
input_ids.push(Tensor::of_slice(&feature.input_ids));
attention_masks.push(Tensor::of_slice(&feature.attention_mask));
}

let input_ids = Tensor::stack(&input_ids, 0).to(self.var_store.device());
let attention_masks = Tensor::stack(&attention_masks, 0).to(self.var_store.device());

let (start_logits, end_logits, _, _) = self.distilbert_qa.forward_t(Some(input_ids), Some(attention_masks), None, false).unwrap();

let start_logits = start_logits.detach();
let end_logits = end_logits.detach();
let example_index_to_feature_end_position: Vec<(usize, i64)> = batch_features
.iter()
.enumerate()
.map(|(feature_index, feature)| (feature.example_index as usize, feature_index as i64 + 1))
.collect();

let mut feature_id_start = 0;

for (example_id, max_feature_id) in example_index_to_feature_end_position {
let mut answers: Vec<Answer> = vec!();
let example = &examples[example_id];
for feature_idx in feature_id_start..max_feature_id {
let feature = &batch_features[feature_idx as usize];
let start = start_logits.get(feature_idx);
let end = end_logits.get(feature_idx);
let p_mask = (Tensor::of_slice(&feature.p_mask) - 1).abs().to_device(start.device());

let start: Tensor = start.exp() / start.exp().sum(Float) * &p_mask;
let end: Tensor = end.exp() / end.exp().sum(Float) * &p_mask;

let (starts, ends, scores) = self.decode(&start, &end, top_k);

for idx in 0..starts.len() {
let start_pos = feature.token_to_orig_map[&starts[idx]] as usize;
let end_pos = feature.token_to_orig_map[&ends[idx]] as usize;
let answer = example.doc_tokens[start_pos..end_pos + 1].join(" ");

let start = example.char_to_word_offset
.iter()
.position(|&v| v as usize == start_pos)
.unwrap();

let end = example.char_to_word_offset
.iter()
.rposition(|&v| v as usize == end_pos)
.unwrap();

answers.push(Answer { score: scores[idx], start, end, answer });
}
}
feature_id_start = max_feature_id;
all_answers.push(answers[..(top_k as usize)].to_vec());
}
feature_id_start = max_feature_id;
all_answers.push(answers[..(top_k as usize)].to_vec());
}
});
}
all_answers
}
Expand All @@ -384,11 +387,9 @@ impl QuestionAnsweringModel {
} else {
candidates.argsort(0, true).slice(0, 0, top_k, 1)
};

let mut start: Vec<i64> = vec!();
let mut end: Vec<i64> = vec!();
let mut scores: Vec<f64> = vec!();

for flat_index_position in 0..idx_sort.size()[0] {
let flat_index = idx_sort.int64_value(&[flat_index_position]);
scores.push(candidates.double_value(&[flat_index]));
Expand Down
Loading

0 comments on commit 2bb284c

Please sign in to comment.