Skip to content

Commit

Permalink
Addition of get_tokenizer and get_tokenizer_mut methods for pipel…
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be authored May 1, 2023
1 parent c37eb32 commit 02967b1
Show file tree
Hide file tree
Showing 30 changed files with 243 additions and 1 deletion.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ All notable changes to this project will be documented in this file. The format
## [Unreleased]
## Added
- Addition of the [LongT5](https://arxiv.org/abs/2112.07916) model architecture and pretrained weights.
- Addition of `add_tokens` and `add_extra_ids` interafce methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
- Addition of `add_tokens` and `add_extra_ids` interface methods to the `TokenizerOption`. Allow building most pipeline with custom tokenizer via `new_with_tokenizer`.
- Addition of `get_tokenizer` and `get_tokenizer_mut` methods to all pipelines allowing to get a (mutable) reference to the pipeline tokenizer.

## Changed
- Bumped the tokenizers dependency from 7.x to 8.x, exposing additional options for special token mapping and adding the NLLBTokenizer.
Expand Down
3 changes: 3 additions & 0 deletions src/bart/bart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1041,6 +1041,9 @@ impl PrivateLanguageGenerator for BartGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/gpt2/gpt2_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,9 @@ impl PrivateLanguageGenerator for GPT2Generator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/gpt_j/gpt_j_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,9 @@ impl PrivateLanguageGenerator for GptJGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/gpt_neo/gpt_neo_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -697,6 +697,9 @@ impl PrivateLanguageGenerator for GptNeoGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/longt5/longt5_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,9 @@ impl PrivateLanguageGenerator for LongT5Generator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/m2m_100/m2m_100_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,9 @@ impl PrivateLanguageGenerator for M2M100Generator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/marian/marian_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ impl PrivateLanguageGenerator for MarianGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/mbart/mbart_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,9 @@ impl PrivateLanguageGenerator for MBartGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/openai_gpt/openai_gpt_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ impl PrivateLanguageGenerator for OpenAIGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
3 changes: 3 additions & 0 deletions src/pegasus/pegasus_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,9 @@ impl PrivateLanguageGenerator for PegasusConditionalGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}
fn get_var_store(&self) -> &nn::VarStore {
&self.var_store
}
Expand Down
8 changes: 8 additions & 0 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -732,12 +732,20 @@ impl ConversationOption {
}
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer(),
}
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &TokenizerOption {
match self {
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
}
}

/// Returns the `ModelType` for this ConversationOption
pub fn model_type(&self) -> ModelType {
match *self {
Expand Down
5 changes: 5 additions & 0 deletions src/pipelines/generation_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ pub(crate) mod private_generation_utils {

pub trait PrivateLanguageGenerator {
fn _get_tokenizer(&self) -> &TokenizerOption;
fn _get_tokenizer_mut(&mut self) -> &mut TokenizerOption;
fn get_var_store(&self) -> &nn::VarStore;
fn get_var_store_mut(&mut self) -> &mut nn::VarStore;
fn get_config(&self) -> &GenerateConfig;
Expand Down Expand Up @@ -2140,6 +2141,10 @@ pub trait LanguageGenerator: PrivateLanguageGenerator {
self._get_tokenizer()
}

fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self._get_tokenizer_mut()
}

fn half(&mut self) {
self.get_var_store_mut().half();
}
Expand Down
11 changes: 11 additions & 0 deletions src/pipelines/keywords_extraction/pipeline.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::pipelines::common::TokenizerOption;
/// Derived from https://github.com/MaartenGr/KeyBERT, shared under MIT License
///
/// Copyright (c) 2020, Maarten P. Grootendorst
Expand Down Expand Up @@ -178,6 +179,16 @@ impl<'a> KeywordExtractionModel<'a> {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.sentence_embeddings_model.get_tokenizer()
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.sentence_embeddings_model.get_tokenizer_mut()
}

/// Extract keywords from a list of input texts.
///
/// # Arguments
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/masked_language.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,16 @@ impl MaskedLanguageModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}

/// Replace custom user-provided mask token by language model mask token.
fn replace_mask_token<'a, S>(
&self,
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/ner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,16 @@ impl NERModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.token_classification_model.get_tokenizer()
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.token_classification_model.get_tokenizer_mut()
}

/// Extract entities from a text
///
/// # Arguments
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/pos_tagging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,16 @@ impl POSModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.token_classification_model.get_tokenizer()
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.token_classification_model.get_tokenizer_mut()
}

/// Extract entities from a text
///
/// # Arguments
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,6 +684,16 @@ impl QuestionAnsweringModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}

/// Perform extractive question answering given a list of `QaInputs`
///
/// # Arguments
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/sentence_embeddings/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,16 @@ impl SentenceEmbeddingsModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}

/// Sets the tokenizer's truncation strategy
pub fn set_tokenizer_truncation(&mut self, truncation_strategy: TruncationStrategy) {
self.tokenizer_truncation_strategy = truncation_strategy;
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/sentiment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ impl SentimentModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.sequence_classification_model.get_tokenizer()
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.sequence_classification_model.get_tokenizer_mut()
}

/// Extract sentiment form an array of text inputs
///
/// # Arguments
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/sequence_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,16 @@ impl SequenceClassificationModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}

fn prepare_for_model<'a, S>(&self, input: S) -> Tensor
where
S: AsRef<[&'a str]>,
Expand Down
33 changes: 33 additions & 0 deletions src/pipelines/summarization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ use crate::resources::ResourceProvider;
use crate::t5::T5Generator;

use crate::longt5::LongT5Generator;
use crate::pipelines::generation_utils::private_generation_utils::PrivateLanguageGenerator;
#[cfg(feature = "remote")]
use crate::{
bart::{BartConfigResources, BartMergesResources, BartModelResources, BartVocabResources},
Expand Down Expand Up @@ -290,6 +291,28 @@ impl SummarizationOption {
}
}

/// Interface method to access tokenizer
pub fn get_tokenizer(&self) -> &TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref._get_tokenizer(),
Self::T5(model_ref) => model_ref._get_tokenizer(),
Self::LongT5(model_ref) => model_ref._get_tokenizer(),
Self::ProphetNet(model_ref) => model_ref._get_tokenizer(),
Self::Pegasus(model_ref) => model_ref._get_tokenizer(),
}
}

/// Interface method to access tokenizer
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::Bart(model_ref) => model_ref._get_tokenizer_mut(),
Self::T5(model_ref) => model_ref._get_tokenizer_mut(),
Self::LongT5(model_ref) => model_ref._get_tokenizer_mut(),
Self::ProphetNet(model_ref) => model_ref._get_tokenizer_mut(),
Self::Pegasus(model_ref) => model_ref._get_tokenizer_mut(),
}
}

/// Interface method to generate() of the particular models.
pub fn generate<S>(&self, prompt_texts: Option<&[S]>) -> Vec<String>
where
Expand Down Expand Up @@ -399,6 +422,16 @@ impl SummarizationModel {
Ok(SummarizationModel { model, prefix })
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
self.model.get_tokenizer()
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.model.get_tokenizer_mut()
}

/// Summarize texts provided
///
/// # Arguments
Expand Down
20 changes: 20 additions & 0 deletions src/pipelines/text_generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,18 @@ impl TextGenerationOption {
}
}

/// Interface method to access tokenizer
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
match self {
Self::GPT(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPT2(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPTNeo(model_ref) => model_ref._get_tokenizer_mut(),
Self::GPTJ(model_ref) => model_ref._get_tokenizer_mut(),
Self::XLNet(model_ref) => model_ref._get_tokenizer_mut(),
Self::Reformer(model_ref) => model_ref._get_tokenizer_mut(),
}
}

/// Interface method to generate() of the particular models.
pub fn generate_indices<S>(
&self,
Expand Down Expand Up @@ -481,6 +493,14 @@ with people, even a bishop, begging for his blessing. <eod> </s> <eos>"
(prefix, min_length, max_length)
}

pub fn get_tokenizer(&self) -> &TokenizerOption {
self.model.get_tokenizer()
}

pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
self.model.get_tokenizer_mut()
}

pub fn half(&mut self) {
self.model.half();
}
Expand Down
10 changes: 10 additions & 0 deletions src/pipelines/token_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -746,6 +746,16 @@ impl TokenClassificationModel {
})
}

/// Get a reference to the model tokenizer.
pub fn get_tokenizer(&self) -> &TokenizerOption {
&self.tokenizer
}

/// Get a mutable reference to the model tokenizer.
pub fn get_tokenizer_mut(&mut self) -> &mut TokenizerOption {
&mut self.tokenizer
}

fn generate_features<S>(&self, input: S, example_index: usize) -> Vec<InputFeature>
where
S: AsRef<str>,
Expand Down
Loading

0 comments on commit 02967b1

Please sign in to comment.