From 340be36ed9617ee2b8e82ad2780a4bb327c6236e Mon Sep 17 00:00:00 2001 From: guillaume-be Date: Sun, 30 Oct 2022 07:39:52 +0000 Subject: [PATCH] Mixed resources (#291) * - made `merges` resource optional for all pipelines - allow mixing local and remote resources for pipelines * Updated changelog * Fixed Clippy warnings --- CHANGELOG.md | 2 + benches/generation_benchmark.rs | 4 +- examples/generation_gpt_neo.rs | 2 +- examples/generation_reformer.rs | 5 +- examples/generation_xlnet.rs | 5 +- examples/summarization_bart.rs | 2 +- examples/summarization_pegasus.rs | 4 +- examples/summarization_prophetnet.rs | 4 +- examples/summarization_t5.rs | 2 +- examples/translation_m2m100.rs | 2 +- examples/translation_marian.rs | 2 +- examples/translation_mbart.rs | 4 +- examples/translation_t5.rs | 3 +- src/bart/bart_model.rs | 10 +++- src/gpt2/gpt2_model.rs | 10 +++- src/gpt_neo/gpt_neo_model.rs | 10 +++- src/gpt_neo/mod.rs | 2 +- src/m2m_100/m2m_100_model.rs | 10 +++- src/marian/marian_model.rs | 11 ++++- src/openai_gpt/openai_gpt_model.rs | 10 +++- src/pipelines/conversation.rs | 6 +-- src/pipelines/generation_utils.rs | 6 ++- src/pipelines/question_answering.rs | 28 ++++++----- src/pipelines/sequence_classification.rs | 14 +++--- src/pipelines/summarization.rs | 22 +++++---- src/pipelines/text_generation.rs | 22 +++++---- src/pipelines/token_classification.rs | 14 +++--- src/pipelines/translation/mod.rs | 2 +- .../translation/translation_builder.rs | 2 +- .../translation/translation_pipeline.rs | 47 +++++++++---------- src/pipelines/zero_shot_classification.rs | 14 +++--- src/prophetnet/mod.rs | 4 +- src/xlnet/mod.rs | 5 +- tests/bart.rs | 4 +- tests/gpt2.rs | 22 ++++----- tests/gpt_neo.rs | 2 +- tests/m2m100.rs | 2 +- tests/marian.rs | 2 +- tests/openai_gpt.rs | 8 ++-- tests/pegasus.rs | 4 +- tests/prophetnet.rs | 4 +- tests/reformer.rs | 5 +- tests/t5.rs | 5 +- tests/xlnet.rs | 5 +- 44 files changed, 203 insertions(+), 150 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index eb92c3b50..aa43e7f10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file. The format ## Changed - Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`) +- (BREAKING) `merges_resource` now optional for all pipelines +- Allow mixing local and remote resources in pipelines ## Fixed - Fixed configuration check for RoBERTa models for sentence classification. diff --git a/benches/generation_benchmark.rs b/benches/generation_benchmark.rs index 197dffb8b..451f9c9eb 100644 --- a/benches/generation_benchmark.rs +++ b/benches/generation_benchmark.rs @@ -17,7 +17,9 @@ fn create_text_generation_model() -> TextGenerationModel { model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)), config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)), vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)), - merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)), + merges_resource: Some(Box::new(RemoteResource::from_pretrained( + Gpt2MergesResources::GPT2, + ))), min_length: 0, max_length: 30, do_sample: true, diff --git a/examples/generation_gpt_neo.rs b/examples/generation_gpt_neo.rs index 3cd13cea0..184ff35d0 100644 --- a/examples/generation_gpt_neo.rs +++ b/examples/generation_gpt_neo.rs @@ -41,7 +41,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), min_length: 10, max_length: 32, do_sample: false, diff --git a/examples/generation_reformer.rs b/examples/generation_reformer.rs index 70e814124..b37fa9613 100644 --- a/examples/generation_reformer.rs +++ b/examples/generation_reformer.rs @@ -30,9 +30,6 @@ fn main() -> anyhow::Result<()> { let vocab_resource = Box::new(RemoteResource::from_pretrained( ReformerVocabResources::CRIME_AND_PUNISHMENT, )); - let merges_resource = Box::new(RemoteResource::from_pretrained( - ReformerVocabResources::CRIME_AND_PUNISHMENT, - )); let model_resource = Box::new(RemoteResource::from_pretrained( ReformerModelResources::CRIME_AND_PUNISHMENT, )); @@ -41,7 +38,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: None, min_length: 100, max_length: 100, do_sample: true, diff --git a/examples/generation_xlnet.rs b/examples/generation_xlnet.rs index 0faa1d085..cc4e12a72 100644 --- a/examples/generation_xlnet.rs +++ b/examples/generation_xlnet.rs @@ -27,9 +27,6 @@ fn main() -> anyhow::Result<()> { let vocab_resource = Box::new(RemoteResource::from_pretrained( XLNetVocabResources::XLNET_BASE_CASED, )); - let merges_resource = Box::new(RemoteResource::from_pretrained( - XLNetVocabResources::XLNET_BASE_CASED, - )); let model_resource = Box::new(RemoteResource::from_pretrained( XLNetModelResources::XLNET_BASE_CASED, )); @@ -39,7 +36,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: None, max_length: 32, do_sample: false, num_beams: 3, diff --git a/examples/summarization_bart.rs b/examples/summarization_bart.rs index b829eef69..f35da0176 100644 --- a/examples/summarization_bart.rs +++ b/examples/summarization_bart.rs @@ -37,7 +37,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), num_beams: 1, length_penalty: 1.0, min_length: 56, diff --git a/examples/summarization_pegasus.rs b/examples/summarization_pegasus.rs index 8bfcd9769..48564aad8 100644 --- a/examples/summarization_pegasus.rs +++ b/examples/summarization_pegasus.rs @@ -33,8 +33,8 @@ fn main() -> anyhow::Result<()> { model_type: ModelType::Pegasus, model_resource: weights_resource, config_resource, - vocab_resource: vocab_resource.clone(), - merges_resource: vocab_resource, + vocab_resource, + merges_resource: None, length_penalty: 1.0, num_beams: 4, no_repeat_ngram_size: 3, diff --git a/examples/summarization_prophetnet.rs b/examples/summarization_prophetnet.rs index 2b332f687..d4d3a2a9f 100644 --- a/examples/summarization_prophetnet.rs +++ b/examples/summarization_prophetnet.rs @@ -35,8 +35,8 @@ fn main() -> anyhow::Result<()> { model_type: ModelType::ProphetNet, model_resource: weights_resource, config_resource, - vocab_resource: vocab_resource.clone(), - merges_resource: vocab_resource, + vocab_resource, + merges_resource: None, length_penalty: 1.2, num_beams: 4, no_repeat_ngram_size: 3, diff --git a/examples/summarization_t5.rs b/examples/summarization_t5.rs index 7d78a15b8..643d8c6df 100644 --- a/examples/summarization_t5.rs +++ b/examples/summarization_t5.rs @@ -26,8 +26,8 @@ fn main() -> anyhow::Result<()> { ModelType::T5, weights_resource, config_resource, - vocab_resource.clone(), vocab_resource, + None, ); let summarization_model = SummarizationModel::new(summarization_config)?; diff --git a/examples/translation_m2m100.rs b/examples/translation_m2m100.rs index cbcd289a1..dade34cd6 100644 --- a/examples/translation_m2m100.rs +++ b/examples/translation_m2m100.rs @@ -35,7 +35,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + Some(merges_resource), source_languages, target_languages, Device::cuda_if_available(), diff --git a/examples/translation_marian.rs b/examples/translation_marian.rs index f9640902b..9a8156aa2 100644 --- a/examples/translation_marian.rs +++ b/examples/translation_marian.rs @@ -36,7 +36,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + Some(merges_resource), source_languages, target_languages, Device::cuda_if_available(), diff --git a/examples/translation_mbart.rs b/examples/translation_mbart.rs index 7286713cc..c7d89a19d 100644 --- a/examples/translation_mbart.rs +++ b/examples/translation_mbart.rs @@ -26,8 +26,6 @@ fn main() -> anyhow::Result<()> { let config_resource = RemoteResource::from_pretrained(MBartConfigResources::MBART50_MANY_TO_MANY); let vocab_resource = RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY); - let merges_resource = - RemoteResource::from_pretrained(MBartVocabResources::MBART50_MANY_TO_MANY); let source_languages = MBartSourceLanguages::MBART50_MANY_TO_MANY; let target_languages = MBartTargetLanguages::MBART50_MANY_TO_MANY; @@ -37,7 +35,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + None, source_languages, target_languages, Device::cuda_if_available(), diff --git a/examples/translation_t5.rs b/examples/translation_t5.rs index 913cbe9de..aeca88dcd 100644 --- a/examples/translation_t5.rs +++ b/examples/translation_t5.rs @@ -22,7 +22,6 @@ fn main() -> anyhow::Result<()> { let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_BASE); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_BASE); let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE); - let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_BASE); let source_languages = [ Language::English, @@ -42,7 +41,7 @@ fn main() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + None, source_languages, target_languages, Device::cuda_if_available(), diff --git a/src/bart/bart_model.rs b/src/bart/bart_model.rs index 555dbb5fd..fc2eb33de 100644 --- a/src/bart/bart_model.rs +++ b/src/bart/bart_model.rs @@ -1067,7 +1067,15 @@ impl BartGenerator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let merges_path = generate_config.merges_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "BART expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::Bart, diff --git a/src/gpt2/gpt2_model.rs b/src/gpt2/gpt2_model.rs index b53614e8c..2718f82fe 100644 --- a/src/gpt2/gpt2_model.rs +++ b/src/gpt2/gpt2_model.rs @@ -708,7 +708,15 @@ impl GPT2Generator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let merges_path = generate_config.merges_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "GPT2 expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::GPT2, diff --git a/src/gpt_neo/gpt_neo_model.rs b/src/gpt_neo/gpt_neo_model.rs index f959b993b..388a5dc92 100644 --- a/src/gpt_neo/gpt_neo_model.rs +++ b/src/gpt_neo/gpt_neo_model.rs @@ -683,7 +683,15 @@ impl GptNeoGenerator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let merges_path = generate_config.merges_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "GPT-Neo expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::GPTNeo, diff --git a/src/gpt_neo/mod.rs b/src/gpt_neo/mod.rs index cf326884a..b535ba16b 100644 --- a/src/gpt_neo/mod.rs +++ b/src/gpt_neo/mod.rs @@ -44,7 +44,7 @@ //! model_resource, //! config_resource, //! vocab_resource, -//! merges_resource, +//! merges_resource: Some(merges_resource), //! num_beams: 4, //! no_repeat_ngram_size: 3, //! device: Device::cuda_if_available(), diff --git a/src/m2m_100/m2m_100_model.rs b/src/m2m_100/m2m_100_model.rs index 6ba9c4e4c..8d223f3ab 100644 --- a/src/m2m_100/m2m_100_model.rs +++ b/src/m2m_100/m2m_100_model.rs @@ -617,7 +617,15 @@ impl M2M100Generator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let merges_path = generate_config.merges_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "M2M100 expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::M2M100, diff --git a/src/marian/marian_model.rs b/src/marian/marian_model.rs index 8370c8011..a1bd826df 100644 --- a/src/marian/marian_model.rs +++ b/src/marian/marian_model.rs @@ -837,7 +837,16 @@ impl MarianGenerator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let sentence_piece_path = generate_config.merges_resource.get_local_path()?; + let sentence_piece_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "Marian expects a merges (SentencePiece model) resources to be provided" + .to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::Marian, diff --git a/src/openai_gpt/openai_gpt_model.rs b/src/openai_gpt/openai_gpt_model.rs index 12242b416..4da0d4dec 100644 --- a/src/openai_gpt/openai_gpt_model.rs +++ b/src/openai_gpt/openai_gpt_model.rs @@ -470,7 +470,15 @@ impl OpenAIGenerator { /// ``` pub fn new(generate_config: GenerateConfig) -> Result { let vocab_path = generate_config.vocab_resource.get_local_path()?; - let merges_path = generate_config.merges_resource.get_local_path()?; + let merges_path = generate_config + .merges_resource + .as_ref() + .ok_or_else(|| { + RustBertError::InvalidConfigurationError( + "GPT expects a merges resources to be provided".to_string(), + ) + })? + .get_local_path()?; let tokenizer = TokenizerOption::from_file( ModelType::OpenAiGpt, diff --git a/src/pipelines/conversation.rs b/src/pipelines/conversation.rs index 7af017a5a..f9b33a773 100644 --- a/src/pipelines/conversation.rs +++ b/src/pipelines/conversation.rs @@ -82,7 +82,7 @@ pub struct ConversationConfig { /// Vocab resource (default: DialoGPT-medium) pub vocab_resource: Box, /// Merges resource (default: DialoGPT-medium) - pub merges_resource: Box, + pub merges_resource: Option>, /// Minimum sequence length (default: 0) pub min_length: i64, /// Maximum sequence length (default: 20) @@ -131,9 +131,9 @@ impl Default for ConversationConfig { vocab_resource: Box::new(RemoteResource::from_pretrained( Gpt2VocabResources::DIALOGPT_MEDIUM, )), - merges_resource: Box::new(RemoteResource::from_pretrained( + merges_resource: Some(Box::new(RemoteResource::from_pretrained( Gpt2MergesResources::DIALOGPT_MEDIUM, - )), + ))), min_length: 0, max_length: 1000, min_length_for_response: 64, diff --git a/src/pipelines/generation_utils.rs b/src/pipelines/generation_utils.rs index 249ffc938..6fd21a78c 100644 --- a/src/pipelines/generation_utils.rs +++ b/src/pipelines/generation_utils.rs @@ -103,7 +103,7 @@ pub struct GenerateConfig { /// Vocab resource (default: pretrained GPT2 model) pub vocab_resource: Box, /// Merges resource (default: pretrained GPT2 model) - pub merges_resource: Box, + pub merges_resource: Option>, /// Minimum sequence length (default: 0) pub min_length: i64, /// Maximum sequence length (default: 20) @@ -143,7 +143,9 @@ impl Default for GenerateConfig { model_resource: Box::new(RemoteResource::from_pretrained(Gpt2ModelResources::GPT2)), config_resource: Box::new(RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2)), vocab_resource: Box::new(RemoteResource::from_pretrained(Gpt2VocabResources::GPT2)), - merges_resource: Box::new(RemoteResource::from_pretrained(Gpt2MergesResources::GPT2)), + merges_resource: Some(Box::new(RemoteResource::from_pretrained( + Gpt2MergesResources::GPT2, + ))), min_length: 0, max_length: 20, do_sample: true, diff --git a/src/pipelines/question_answering.rs b/src/pipelines/question_answering.rs index fff1e489e..742501e21 100644 --- a/src/pipelines/question_answering.rs +++ b/src/pipelines/question_answering.rs @@ -166,18 +166,20 @@ impl QuestionAnsweringConfig { /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * merges_resource - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: Option, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, lower_case: bool, strip_accents: impl Into>, add_prefix_space: impl Into>, ) -> QuestionAnsweringConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { QuestionAnsweringConfig { model_type, @@ -210,12 +212,12 @@ impl QuestionAnsweringConfig { /// * max_query_length - Optional maximum question token length. Defaults to 64. /// * doc_stride - Optional stride to apply if a sliding window is required to process the input context. Represents the number of overlapping tokens between sliding windows. This should be lower than the max_seq_length minus max_query_length (otherwise there is a risk for the sliding window not to progress). Defaults to 128. /// * max_answer_length - Optional maximum token length for the extracted answer. Defaults to 15. - pub fn custom_new( + pub fn custom_new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: Option, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, lower_case: bool, strip_accents: impl Into>, add_prefix_space: impl Into>, @@ -225,7 +227,9 @@ impl QuestionAnsweringConfig { max_answer_length: impl Into>, ) -> QuestionAnsweringConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { QuestionAnsweringConfig { model_type, diff --git a/src/pipelines/sequence_classification.rs b/src/pipelines/sequence_classification.rs index 472655670..700fb27f3 100644 --- a/src/pipelines/sequence_classification.rs +++ b/src/pipelines/sequence_classification.rs @@ -134,18 +134,20 @@ impl SequenceClassificationConfig { /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: Option, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, lower_case: bool, strip_accents: impl Into>, add_prefix_space: impl Into>, ) -> SequenceClassificationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { SequenceClassificationConfig { model_type, diff --git a/src/pipelines/summarization.rs b/src/pipelines/summarization.rs index 27c17e05b..856cb3d68 100644 --- a/src/pipelines/summarization.rs +++ b/src/pipelines/summarization.rs @@ -92,7 +92,7 @@ pub struct SummarizationConfig { /// Vocab resource (default: pretrained BART model on CNN-DM) pub vocab_resource: Box, /// Merges resource (default: pretrained BART model on CNN-DM) - pub merges_resource: Box, + pub merges_resource: Option>, /// Minimum sequence length (default: 0) pub min_length: i64, /// Maximum sequence length (default: 20) @@ -135,22 +135,24 @@ impl SummarizationConfig { /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: R, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, ) -> SummarizationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { SummarizationConfig { model_type, model_resource: Box::new(model_resource), config_resource: Box::new(config_resource), vocab_resource: Box::new(vocab_resource), - merges_resource: Box::new(merges_resource), + merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), min_length: 56, max_length: 142, do_sample: false, @@ -178,7 +180,9 @@ impl Default for SummarizationConfig { RemoteResource::from_pretrained(BartModelResources::BART_CNN), RemoteResource::from_pretrained(BartConfigResources::BART_CNN), RemoteResource::from_pretrained(BartVocabResources::BART_CNN), - RemoteResource::from_pretrained(BartMergesResources::BART_CNN), + Some(RemoteResource::from_pretrained( + BartMergesResources::BART_CNN, + )), ) } } diff --git a/src/pipelines/text_generation.rs b/src/pipelines/text_generation.rs index 2e7d22255..b793c0a5d 100644 --- a/src/pipelines/text_generation.rs +++ b/src/pipelines/text_generation.rs @@ -63,7 +63,7 @@ pub struct TextGenerationConfig { /// Vocab resource (default: pretrained BART model on CNN-DM) pub vocab_resource: Box, /// Merges resource (default: pretrained BART model on CNN-DM) - pub merges_resource: Box, + pub merges_resource: Option>, /// Minimum sequence length (default: 0) pub min_length: i64, /// Maximum sequence length (default: 20) @@ -106,22 +106,24 @@ impl TextGenerationConfig { /// * config_resource - The `ResourceProvider` pointing to the model configuration to load (e.g. config.json) /// * vocab_resource - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * merges_resource - The `ResourceProvider` pointing to the tokenizer's merge file or SentencePiece model to load (e.g. merges.txt). - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: R, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, ) -> TextGenerationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { TextGenerationConfig { model_type, model_resource: Box::new(model_resource), config_resource: Box::new(config_resource), vocab_resource: Box::new(vocab_resource), - merges_resource: Box::new(merges_resource), + merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), min_length: 0, max_length: 20, do_sample: true, @@ -149,7 +151,9 @@ impl Default for TextGenerationConfig { RemoteResource::from_pretrained(Gpt2ModelResources::GPT2_MEDIUM), RemoteResource::from_pretrained(Gpt2ConfigResources::GPT2_MEDIUM), RemoteResource::from_pretrained(Gpt2VocabResources::GPT2_MEDIUM), - RemoteResource::from_pretrained(Gpt2MergesResources::GPT2_MEDIUM), + Some(RemoteResource::from_pretrained( + Gpt2MergesResources::GPT2_MEDIUM, + )), ) } } diff --git a/src/pipelines/token_classification.rs b/src/pipelines/token_classification.rs index 2c1e99431..685038356 100644 --- a/src/pipelines/token_classification.rs +++ b/src/pipelines/token_classification.rs @@ -254,19 +254,21 @@ impl TokenClassificationConfig { /// * vocab - The `ResourceProvider` pointing to the tokenizers' vocabulary to load (e.g. vocab.txt/vocab.json) /// * vocab - An optional `ResourceProvider` pointing to the tokenizers' merge file to load (e.g. merges.txt), needed only for Roberta. /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: Option, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, lower_case: bool, strip_accents: impl Into>, add_prefix_space: impl Into>, label_aggregation_function: LabelAggregationOption, ) -> TokenClassificationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { TokenClassificationConfig { model_type, diff --git a/src/pipelines/translation/mod.rs b/src/pipelines/translation/mod.rs index a48e0508e..5d856ce57 100644 --- a/src/pipelines/translation/mod.rs +++ b/src/pipelines/translation/mod.rs @@ -38,7 +38,7 @@ //! model_resource, //! config_resource, //! vocab_resource, -//! merges_resource, +//! Some(merges_resource), //! source_languages, //! target_languages, //! Device::cuda_if_available(), diff --git a/src/pipelines/translation/translation_builder.rs b/src/pipelines/translation/translation_builder.rs index 5f49af2b8..f02d4618b 100644 --- a/src/pipelines/translation/translation_builder.rs +++ b/src/pipelines/translation/translation_builder.rs @@ -383,7 +383,7 @@ impl TranslationModelBuilder { translation_resources.model_resource, translation_resources.config_resource, translation_resources.vocab_resource, - translation_resources.merges_resource, + Some(translation_resources.merges_resource), translation_resources.source_languages, translation_resources.target_languages, device, diff --git a/src/pipelines/translation/translation_pipeline.rs b/src/pipelines/translation/translation_pipeline.rs index ee906f749..6f8598232 100644 --- a/src/pipelines/translation/translation_pipeline.rs +++ b/src/pipelines/translation/translation_pipeline.rs @@ -380,7 +380,7 @@ pub struct TranslationConfig { /// Vocab resource pub vocab_resource: Box, /// Merges resource - pub merges_resource: Box, + pub merges_resource: Option>, /// Supported source languages pub source_languages: HashSet, /// Supported target languages @@ -428,11 +428,8 @@ impl TranslationConfig { /// # Example /// /// ```no_run - /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::marian::{ - /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, - /// MarianVocabResources, - /// }; + /// # fn main() -> anyhow::Result<()> { /// + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources}; /// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::translation::TranslationConfig; /// use rust_bert::resources::RemoteResource; @@ -441,6 +438,7 @@ impl TranslationConfig { /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH); /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); + /// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH); /// /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; @@ -449,8 +447,8 @@ impl TranslationConfig { /// ModelType::Marian, /// model_resource, /// config_resource, - /// vocab_resource.clone(), /// vocab_resource, + /// Some(spm_resource), /// source_languages, /// target_languages, /// Device::cuda_if_available(), @@ -458,18 +456,20 @@ impl TranslationConfig { /// # Ok(()) /// # } /// ``` - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: R, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, source_languages: S, target_languages: T, device: impl Into>, ) -> TranslationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, S: AsRef<[Language]>, T: AsRef<[Language]>, { @@ -480,7 +480,7 @@ impl TranslationConfig { model_resource: Box::new(model_resource), config_resource: Box::new(config_resource), vocab_resource: Box::new(vocab_resource), - merges_resource: Box::new(merges_resource), + merges_resource: merges_resource.map(|r| Box::new(r) as Box<_>), source_languages: source_languages.as_ref().iter().cloned().collect(), target_languages: target_languages.as_ref().iter().cloned().collect(), device, @@ -786,11 +786,8 @@ impl TranslationModel { /// # Example /// /// ```no_run - /// # fn main() -> anyhow::Result<()> { - /// use rust_bert::marian::{ - /// MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, - /// MarianVocabResources, - /// }; + /// # fn main() -> anyhow::Result<()> { /// + /// use rust_bert::marian::{MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, MarianTargetLanguages, MarianVocabResources}; /// use rust_bert::pipelines::common::ModelType; /// use rust_bert::pipelines::translation::{TranslationConfig, TranslationModel}; /// use rust_bert::resources::RemoteResource; @@ -799,6 +796,7 @@ impl TranslationModel { /// let model_resource = RemoteResource::from_pretrained(MarianModelResources::ROMANCE2ENGLISH); /// let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); /// let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); + /// let spm_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH); /// /// let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; /// let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; @@ -807,8 +805,8 @@ impl TranslationModel { /// ModelType::Marian, /// model_resource, /// config_resource, - /// vocab_resource.clone(), /// vocab_resource, + /// Some(spm_resource), /// source_languages, /// target_languages, /// Device::cuda_if_available(), @@ -863,7 +861,7 @@ impl TranslationModel { /// model_resource, /// config_resource, /// vocab_resource, - /// merges_resource, + /// Some(merges_resource), /// source_languages, /// target_languages, /// Device::cuda_if_available(), @@ -911,8 +909,8 @@ impl TranslationModel { mod test { use super::*; use crate::marian::{ - MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianTargetLanguages, - MarianVocabResources, + MarianConfigResources, MarianModelResources, MarianSourceLanguages, MarianSpmResources, + MarianTargetLanguages, MarianVocabResources, }; use crate::resources::RemoteResource; @@ -923,6 +921,7 @@ mod test { let config_resource = RemoteResource::from_pretrained(MarianConfigResources::ROMANCE2ENGLISH); let vocab_resource = RemoteResource::from_pretrained(MarianVocabResources::ROMANCE2ENGLISH); + let merges_resource = RemoteResource::from_pretrained(MarianSpmResources::ROMANCE2ENGLISH); let source_languages = MarianSourceLanguages::ROMANCE2ENGLISH; let target_languages = MarianTargetLanguages::ROMANCE2ENGLISH; @@ -931,8 +930,8 @@ mod test { ModelType::Marian, model_resource, config_resource, - vocab_resource.clone(), vocab_resource, + Some(merges_resource), source_languages, target_languages, Device::cuda_if_available(), diff --git a/src/pipelines/zero_shot_classification.rs b/src/pipelines/zero_shot_classification.rs index 6d1f26b09..b0a62be16 100644 --- a/src/pipelines/zero_shot_classification.rs +++ b/src/pipelines/zero_shot_classification.rs @@ -159,18 +159,20 @@ impl ZeroShotClassificationConfig { /// * vocab - The `ResourceProvider` pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json) /// * merges - An optional `ResourceProvider` pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta. /// * lower_case - A `bool` indicating whether the tokenizer should lower case all input (in case of a lower-cased model) - pub fn new( + pub fn new( model_type: ModelType, - model_resource: R, - config_resource: R, - vocab_resource: R, - merges_resource: Option, + model_resource: RM, + config_resource: RC, + vocab_resource: RV, + merges_resource: Option, lower_case: bool, strip_accents: impl Into>, add_prefix_space: impl Into>, ) -> ZeroShotClassificationConfig where - R: ResourceProvider + Send + 'static, + RM: ResourceProvider + Send + 'static, + RC: ResourceProvider + Send + 'static, + RV: ResourceProvider + Send + 'static, { ZeroShotClassificationConfig { model_type, diff --git a/src/prophetnet/mod.rs b/src/prophetnet/mod.rs index 2d3ac3691..c46e5791e 100644 --- a/src/prophetnet/mod.rs +++ b/src/prophetnet/mod.rs @@ -38,8 +38,8 @@ //! model_type: ModelType::ProphetNet, //! model_resource: weights_resource, //! config_resource, -//! vocab_resource: vocab_resource.clone(), -//! merges_resource: vocab_resource, +//! vocab_resource, +//! merges_resource: None, //! length_penalty: 1.2, //! num_beams: 4, //! no_repeat_ngram_size: 3, diff --git a/src/xlnet/mod.rs b/src/xlnet/mod.rs index 06f9cff1f..74bf644c2 100644 --- a/src/xlnet/mod.rs +++ b/src/xlnet/mod.rs @@ -30,9 +30,6 @@ //! let vocab_resource = Box::new(RemoteResource::from_pretrained( //! XLNetVocabResources::XLNET_BASE_CASED, //! )); -//! let merges_resource = Box::new(RemoteResource::from_pretrained( -//! XLNetVocabResources::XLNET_BASE_CASED, -//! )); //! let model_resource = Box::new(RemoteResource::from_pretrained( //! XLNetModelResources::XLNET_BASE_CASED, //! )); @@ -41,7 +38,7 @@ //! model_resource, //! config_resource, //! vocab_resource, -//! merges_resource, +//! merges_resource: None, //! max_length: 56, //! do_sample: true, //! num_beams: 3, diff --git a/tests/bart.rs b/tests/bart.rs index a9c01d86e..fc009141b 100644 --- a/tests/bart.rs +++ b/tests/bart.rs @@ -93,7 +93,7 @@ fn bart_summarization_greedy() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), num_beams: 1, length_penalty: 1.0, min_length: 56, @@ -154,7 +154,7 @@ fn bart_summarization_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), num_beams: 4, min_length: 56, max_length: 142, diff --git a/tests/gpt2.rs b/tests/gpt2.rs index 07eff2d18..c772a05c9 100644 --- a/tests/gpt2.rs +++ b/tests/gpt2.rs @@ -120,7 +120,7 @@ fn gpt2_generation_greedy() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 40, do_sample: false, num_beams: 1, @@ -152,7 +152,7 @@ fn gpt2_generation_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, num_beams: 5, @@ -196,7 +196,7 @@ fn gpt2_generation_beam_search_multiple_prompts_without_padding() -> anyhow::Res model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, num_beams: 5, @@ -253,7 +253,7 @@ fn gpt2_generation_beam_search_multiple_prompts_with_padding() -> anyhow::Result model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, num_beams: 5, @@ -309,7 +309,7 @@ fn gpt2_diverse_beam_search_multiple_prompts_with_padding() -> anyhow::Result<() model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), min_length: 10, max_length: 20, do_sample: false, @@ -382,7 +382,7 @@ fn gpt2_prefix_allowed_token_greedy() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 1, device: Device::Cpu, @@ -432,7 +432,7 @@ fn gpt2_bad_tokens_greedy() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 1, device: Device::Cpu, @@ -498,7 +498,7 @@ fn gpt2_bad_tokens_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 3, device: Device::Cpu, @@ -579,7 +579,7 @@ fn gpt2_prefix_allowed_token_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 3, device: Device::Cpu, @@ -629,7 +629,7 @@ fn gpt2_greedy_token_scores() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 1, device: Device::Cpu, @@ -685,7 +685,7 @@ fn gpt2_beam_search_token_scores() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), do_sample: false, num_beams: 2, device: Device::Cpu, diff --git a/tests/gpt_neo.rs b/tests/gpt_neo.rs index c10e79489..c3ab89391 100644 --- a/tests/gpt_neo.rs +++ b/tests/gpt_neo.rs @@ -128,7 +128,7 @@ fn test_generation_gpt_neo() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), min_length: 10, max_length: 32, do_sample: false, diff --git a/tests/m2m100.rs b/tests/m2m100.rs index bf90d6470..8a15b3d52 100644 --- a/tests/m2m100.rs +++ b/tests/m2m100.rs @@ -81,7 +81,7 @@ fn m2m100_translation() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + Some(merges_resource), source_languages, target_languages, Device::cuda_if_available(), diff --git a/tests/marian.rs b/tests/marian.rs index faf292c34..f2a1e3c27 100644 --- a/tests/marian.rs +++ b/tests/marian.rs @@ -26,7 +26,7 @@ fn test_translation() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + Some(merges_resource), source_languages, target_languages, Device::cuda_if_available(), diff --git a/tests/openai_gpt.rs b/tests/openai_gpt.rs index d1ebe51b1..15bc20fd5 100644 --- a/tests/openai_gpt.rs +++ b/tests/openai_gpt.rs @@ -122,7 +122,7 @@ fn openai_gpt_generation_greedy() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 40, do_sample: false, num_beams: 1, @@ -164,7 +164,7 @@ fn openai_gpt_generation_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, early_stopping: true, @@ -217,7 +217,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_without_padding() -> anyho model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, early_stopping: true, @@ -286,7 +286,7 @@ fn openai_gpt_generation_beam_search_multiple_prompts_with_padding() -> anyhow:: model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: Some(merges_resource), max_length: 20, do_sample: false, num_beams: 5, diff --git a/tests/pegasus.rs b/tests/pegasus.rs index 3b159f898..b99ebbace 100644 --- a/tests/pegasus.rs +++ b/tests/pegasus.rs @@ -22,8 +22,8 @@ fn pegasus_summarization_greedy() -> anyhow::Result<()> { model_type: ModelType::Pegasus, model_resource, config_resource, - vocab_resource: vocab_resource.clone(), - merges_resource: vocab_resource, + vocab_resource, + merges_resource: None, num_beams: 4, no_repeat_ngram_size: 3, device: Device::cuda_if_available(), diff --git a/tests/prophetnet.rs b/tests/prophetnet.rs index 2725e1468..9d2415366 100644 --- a/tests/prophetnet.rs +++ b/tests/prophetnet.rs @@ -24,8 +24,8 @@ fn prophetnet_summarization_greedy() -> anyhow::Result<()> { model_type: ModelType::ProphetNet, model_resource: weights_resource, config_resource, - vocab_resource: vocab_resource.clone(), - merges_resource: vocab_resource, + vocab_resource, + merges_resource: None, length_penalty: 1.2, num_beams: 4, no_repeat_ngram_size: 3, diff --git a/tests/reformer.rs b/tests/reformer.rs index 9cceaf44f..5749444c8 100644 --- a/tests/reformer.rs +++ b/tests/reformer.rs @@ -39,9 +39,6 @@ fn test_generation_reformer() -> anyhow::Result<()> { let vocab_resource = Box::new(RemoteResource::from_pretrained( ReformerVocabResources::CRIME_AND_PUNISHMENT, )); - let merges_resource = Box::new(RemoteResource::from_pretrained( - ReformerVocabResources::CRIME_AND_PUNISHMENT, - )); let model_resource = Box::new(RemoteResource::from_pretrained( ReformerModelResources::CRIME_AND_PUNISHMENT, )); @@ -51,7 +48,7 @@ fn test_generation_reformer() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: None, min_length: 100, max_length: 100, do_sample: false, diff --git a/tests/t5.rs b/tests/t5.rs index 33dc5376b..f7f0a4fed 100644 --- a/tests/t5.rs +++ b/tests/t5.rs @@ -10,7 +10,6 @@ fn test_translation_t5() -> anyhow::Result<()> { let model_resource = RemoteResource::from_pretrained(T5ModelResources::T5_SMALL); let config_resource = RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL); let vocab_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL); - let merges_resource = RemoteResource::from_pretrained(T5VocabResources::T5_SMALL); let source_languages = [ Language::English, @@ -30,7 +29,7 @@ fn test_translation_t5() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + None, source_languages, target_languages, Device::cuda_if_available(), @@ -69,7 +68,7 @@ fn test_summarization_t5() -> anyhow::Result<()> { model_resource: Box::new(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)), config_resource: Box::new(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)), vocab_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), - merges_resource: Box::new(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)), + merges_resource: None, min_length: 30, max_length: 200, early_stopping: true, diff --git a/tests/xlnet.rs b/tests/xlnet.rs index 516c791b3..cda84c76b 100644 --- a/tests/xlnet.rs +++ b/tests/xlnet.rs @@ -202,9 +202,6 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> { let vocab_resource = Box::new(RemoteResource::from_pretrained( XLNetVocabResources::XLNET_BASE_CASED, )); - let merges_resource = Box::new(RemoteResource::from_pretrained( - XLNetVocabResources::XLNET_BASE_CASED, - )); let model_resource = Box::new(RemoteResource::from_pretrained( XLNetModelResources::XLNET_BASE_CASED, )); @@ -214,7 +211,7 @@ fn xlnet_generation_beam_search() -> anyhow::Result<()> { model_resource, config_resource, vocab_resource, - merges_resource, + merges_resource: None, max_length: 32, do_sample: false, num_beams: 3,