Skip to content

Commit

Permalink
Merge pull request guillaume-be#65 from guillaume-be/additional_ner_m…
Browse files Browse the repository at this point in the history
…odels

Additional ner models
  • Loading branch information
guillaume-be authored Jul 26, 2020
2 parents b27918a + d3cda4a commit 873fa35
Show file tree
Hide file tree
Showing 40 changed files with 637 additions and 138 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]

[dependencies]
rust_tokenizers = "~3.1.5"
rust_tokenizers = "~3.1.6"
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,8 @@ Output:
```

#### 7. Named Entity Recognition
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz).
Models are currently available for English, German, Spanish and Dutch.
```rust
let ner_model = NERModel::new(default::default())?;

Expand Down
18 changes: 18 additions & 0 deletions examples/download_all_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,23 @@ fn download_bert_qa() -> failure::Fallible<()> {
Ok(())
}

fn download_xlm_roberta_ner_german() -> failure::Fallible<()> {
// Shared under Apache 2.0 license by the HuggingFace Inc. team at https://huggingface.co/models.
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaConfigResources::XLM_ROBERTA_NER_DE,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaVocabResources::XLM_ROBERTA_NER_DE,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
RobertaModelResources::XLM_ROBERTA_NER_DE,
));
let _ = download_resource(&config_resource)?;
let _ = download_resource(&vocab_resource)?;
let _ = download_resource(&weights_resource)?;
Ok(())
}

fn main() -> failure::Fallible<()> {
let _ = download_distil_gpt2();
let _ = download_distilbert_sst2();
Expand All @@ -366,6 +383,7 @@ fn main() -> failure::Fallible<()> {
let _ = download_t5_small();
let _ = download_roberta_qa();
let _ = download_bert_qa();
let _ = download_xlm_roberta_ner_german();

Ok(())
}
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch == 1.5.0
transformers == 2.8.0
transformers == 2.10.0
43 changes: 41 additions & 2 deletions src/pipelines/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{
};
use rust_tokenizers::preprocessing::tokenizer::marian_tokenizer::MarianTokenizer;
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;
use rust_tokenizers::preprocessing::tokenizer::xlm_roberta_tokenizer::XLMRobertaTokenizer;
use rust_tokenizers::preprocessing::vocab::albert_vocab::AlbertVocab;
use rust_tokenizers::preprocessing::vocab::marian_vocab::MarianVocab;
use rust_tokenizers::preprocessing::vocab::t5_vocab::T5Vocab;
use rust_tokenizers::{
AlbertTokenizer, BertTokenizer, BertVocab, RobertaTokenizer, RobertaVocab, TokenizedInput,
TruncationStrategy,
TruncationStrategy, XLMRobertaVocab,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand All @@ -46,6 +47,7 @@ pub enum ModelType {
Bert,
DistilBert,
Roberta,
XLMRoberta,
Electra,
Marian,
T5,
Expand Down Expand Up @@ -74,6 +76,8 @@ pub enum TokenizerOption {
Bert(BertTokenizer),
/// Roberta Tokenizer
Roberta(RobertaTokenizer),
/// Roberta Tokenizer
XLMRoberta(XLMRobertaTokenizer),
/// Marian Tokenizer
Marian(MarianTokenizer),
/// T5 Tokenizer
Expand All @@ -86,7 +90,9 @@ impl ConfigOption {
/// Interface method to load a configuration from file
pub fn from_file(model_type: ModelType, path: &Path) -> Self {
match model_type {
ModelType::Bert | ModelType::Roberta => ConfigOption::Bert(BertConfig::from_file(path)),
ModelType::Bert | ModelType::Roberta | ModelType::XLMRoberta => {
ConfigOption::Bert(BertConfig::from_file(path))
}
ModelType::DistilBert => ConfigOption::DistilBert(DistilBertConfig::from_file(path)),
ModelType::Electra => ConfigOption::Electra(ElectraConfig::from_file(path)),
ModelType::Marian => ConfigOption::Marian(BartConfig::from_file(path)),
Expand Down Expand Up @@ -140,6 +146,9 @@ impl TokenizerOption {
lower_case,
)),
ModelType::T5 => TokenizerOption::T5(T5Tokenizer::from_file(vocab_path, lower_case)),
ModelType::XLMRoberta => {
TokenizerOption::XLMRoberta(XLMRobertaTokenizer::from_file(vocab_path, lower_case))
}
ModelType::Albert => TokenizerOption::Albert(AlbertTokenizer::from_file(
vocab_path,
lower_case,
Expand All @@ -153,6 +162,7 @@ impl TokenizerOption {
match *self {
Self::Bert(_) => ModelType::Bert,
Self::Roberta(_) => ModelType::Roberta,
Self::XLMRoberta(_) => ModelType::XLMRoberta,
Self::Marian(_) => ModelType::Marian,
Self::T5(_) => ModelType::T5,
Self::Albert(_) => ModelType::Albert,
Expand Down Expand Up @@ -180,6 +190,9 @@ impl TokenizerOption {
Self::T5(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::XLMRoberta(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Self::Albert(ref tokenizer) => {
tokenizer.encode_list(text_list, max_len, truncation_strategy, stride)
}
Expand All @@ -193,6 +206,7 @@ impl TokenizerOption {
Self::Roberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Marian(ref tokenizer) => tokenizer.tokenize(text),
Self::T5(ref tokenizer) => tokenizer.tokenize(text),
Self::XLMRoberta(ref tokenizer) => tokenizer.tokenize(text),
Self::Albert(ref tokenizer) => tokenizer.tokenize(text),
}
}
Expand Down Expand Up @@ -237,6 +251,16 @@ impl TokenizerOption {
mask_1,
mask_2,
),
Self::XLMRoberta(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
offsets_1,
offsets_2,
original_offsets_1,
original_offsets_2,
mask_1,
mask_2,
),
Self::Marian(ref tokenizer) => tokenizer.build_input_with_special_tokens(
tokens_1,
tokens_2,
Expand Down Expand Up @@ -277,6 +301,7 @@ impl TokenizerOption {
Self::Roberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Marian(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::T5(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::XLMRoberta(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
Self::Albert(ref tokenizer) => tokenizer.convert_tokens_to_ids(tokens),
}
}
Expand All @@ -298,6 +323,13 @@ impl TokenizerOption {
.get(RobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*tokenizer
.vocab()
.special_values
.get(XLMRobertaVocab::pad_value())
.expect("PAD token not found in vocabulary"),
),
Self::Marian(ref tokenizer) => Some(
*tokenizer
.vocab()
Expand Down Expand Up @@ -339,6 +371,13 @@ impl TokenizerOption {
.get(RobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::XLMRoberta(ref tokenizer) => Some(
*tokenizer
.vocab()
.special_values
.get(XLMRobertaVocab::sep_value())
.expect("SEP token not found in vocabulary"),
),
Self::Albert(ref tokenizer) => Some(
*tokenizer
.vocab()
Expand Down
83 changes: 37 additions & 46 deletions src/pipelines/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,21 +795,19 @@ impl PrivateLanguageGenerator<BartForConditionalGeneration, RobertaVocab, Robert
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
Expand Down Expand Up @@ -1072,21 +1070,19 @@ impl PrivateLanguageGenerator<MarianForConditionalGeneration, MarianVocab, Maria
match past {
Cache::BARTCache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
Expand Down Expand Up @@ -1306,21 +1302,19 @@ impl PrivateLanguageGenerator<T5ForConditionalGeneration, T5Vocab, T5Tokenizer>
match past {
Cache::T5Cache(old_cache_option) => match old_cache_option {
Some(old_cache) => {
let mut new_past = vec![];
for (self_layer_state, encoder_layer_state) in old_cache.into_iter() {
let new_self_layer_state = match self_layer_state {
Some(self_layer_state) => {
Some(self_layer_state.reorder_cache(beam_indices))
}
None => None,
if self_layer_state.is_some() {
self_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
let new_encoder_layer_state = match encoder_layer_state {
Some(encoder_layer_state) => {
Some(encoder_layer_state.reorder_cache(beam_indices))
}
None => None,
if encoder_layer_state.is_some() {
encoder_layer_state
.as_mut()
.unwrap()
.reorder_cache(beam_indices)
};
new_past.push((new_self_layer_state, new_encoder_layer_state));
}
}
None => {}
Expand Down Expand Up @@ -2147,18 +2141,15 @@ pub(crate) mod private_generation_utils {
) -> Option<Tensor> {
match past {
Cache::None => None,
Cache::GPT2Cache(cached_decoder_state) => {
match cached_decoder_state {
Some(value) => {
// let mut reordered_past = vec!();
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None
Cache::GPT2Cache(cached_decoder_state) => match cached_decoder_state {
Some(value) => {
for layer_past in value.iter_mut() {
*layer_past = layer_past.index_select(1, beam_indices);
}
None => None,
None
}
}
None => None,
},
Cache::BARTCache(_) => {
panic!("Not implemented");
}
Expand Down
3 changes: 2 additions & 1 deletion src/pipelines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@
//! ```
//!
//! #### 7. Named Entity Recognition
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text. The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Additional pre-trained models are available for English, German, Spanish and Dutch.
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//! # fn main() -> failure::Fallible<()> {
Expand Down
59 changes: 58 additions & 1 deletion src/pipelines/ner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,19 @@

//! # Named Entity Recognition pipeline
//! Extracts entities (Person, Location, Organization, Miscellaneous) from text.
//! BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! Pretrained models are available for the following languages:
//! - English
//! - German
//! - Spanish
//! - Dutch
//!
//! The default NER mode is an English BERT cased large model finetuned on CoNNL03, contributed by the [MDZ Digital Library team at the Bavarian State Library](https://github.com/dbmdz)
//! All resources for this model can be downloaded using the Python utility script included in this repository.
//! 1. Set-up a Python virtual environment and install dependencies (in ./requirements.txt)
//! 2. Run the conversion script python /utils/download-dependencies_bert_ner.py.
//! The dependencies will be downloaded to the user's home directory, under ~/rustbert/bert-ner
//!
//! The example below illustrate how to run the model for the default English NER model
//! ```no_run
//! use rust_bert::pipelines::ner::NERModel;
//! # fn main() -> failure::Fallible<()> {
Expand Down Expand Up @@ -60,6 +67,56 @@
//! ]
//! # ;
//! ```
//!
//! To run the pipeline for another language, change the NERModel configuration from its default:
//!
//! ```no_run
//! use rust_bert::pipelines::common::ModelType;
//! use rust_bert::pipelines::ner::NERModel;
//! use rust_bert::pipelines::token_classification::TokenClassificationConfig;
//! use rust_bert::resources::{RemoteResource, Resource};
//! use rust_bert::roberta::{
//! RobertaConfigResources, RobertaModelResources, RobertaVocabResources,
//! };
//! use tch::Device;
//!
//! # fn main() -> failure::Fallible<()> {
//! let ner_config = TokenClassificationConfig {
//! model_type: ModelType::XLMRoberta,
//! model_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaModelResources::XLM_ROBERTA_NER_DE,
//! )),
//! config_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaConfigResources::XLM_ROBERTA_NER_DE,
//! )),
//! vocab_resource: Resource::Remote(RemoteResource::from_pretrained(
//! RobertaVocabResources::XLM_ROBERTA_NER_DE,
//! )),
//! lower_case: false,
//! device: Device::cuda_if_available(),
//! ..Default::default()
//! };
//!
//! let ner_model = NERModel::new(ner_config)?;
//!
//! // Define input
//! let input = [
//! "Mein Name ist Amélie. Ich lebe in Paris.",
//! "Paris ist eine Stadt in Frankreich.",
//! ];
//! let output = ner_model.predict(&input);
//! # Ok(())
//! # }
//! ```
//! The XLMRoberta models for the languages are defined as follows:
//!
//! | **Language** |**Model name**|
//! :-----:|:----:
//! English| XLM_ROBERTA_NER_EN |
//! German| XLM_ROBERTA_NER_DE |
//! Spanish| XLM_ROBERTA_NER_ES |
//! Dutch| XLM_ROBERTA_NER_NL |
//!
use crate::pipelines::token_classification::{TokenClassificationConfig, TokenClassificationModel};

Expand Down
Loading

0 comments on commit 873fa35

Please sign in to comment.