Skip to content

Commit

Permalink
Added pre-trained NER model to pipelines
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Feb 24, 2020
1 parent 6282536 commit 6a3bfee
Show file tree
Hide file tree
Showing 11 changed files with 231 additions and 6 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
.idea/*
# Generated by Cargo
# will have compiled files and executables
/target/
Expand Down
37 changes: 37 additions & 0 deletions examples/ner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
extern crate failure;
extern crate dirs;

use std::path::PathBuf;
use rust_bert::pipelines::ner::NERModel;
use tch::Device;


fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert-ner");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;

// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
];

// Run model
let output = ner_model.predict(input.to_vec());
for entity in output {
println!("{:?}", entity);
}

Ok(())
}
37 changes: 37 additions & 0 deletions examples/sentiment.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
extern crate failure;
extern crate dirs;

use std::path::PathBuf;
use rust_bert::pipelines::ner::NERModel;
use tch::Device;


fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("bert-ner");
let config_path = &home.as_path().join("config.json");
let vocab_path = &home.as_path().join("vocab.txt");
let weights_path = &home.as_path().join("model.ot");

// Set-up model
let device = Device::cuda_if_available();
let ner_model = NERModel::new(vocab_path,
config_path,
weights_path, device)?;

// Define input
let input = [
"My name is Amy. I live in Paris.",
"Paris is a city in France."
];

// Run model
let output = ner_model.predict(input.to_vec());
for entity in output {
println!("{:?}", entity);
}

Ok(())
}
4 changes: 2 additions & 2 deletions src/bert/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ pub struct BertConfig {
pub output_attentions: Option<bool>,
pub output_hidden_states: Option<bool>,
pub is_decoder: Option<bool>,
pub id2label: Option<HashMap<i32, String>>,
pub label2id: Option<HashMap<String, i32>>,
pub id2label: Option<HashMap<i64, String>>,
pub label2id: Option<HashMap<String, i64>>,
pub num_labels: Option<i64>,
}

Expand Down
1 change: 0 additions & 1 deletion src/distilbert/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,3 @@ pub mod distilbert;
mod embeddings;
mod attention;
mod transformer;
pub mod sentiment;
6 changes: 4 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ pub mod distilbert;
pub mod bert;
pub mod roberta;
pub mod common;
pub mod pipelines;

pub use distilbert::distilbert::{DistilBertConfig, DistilBertModel, DistilBertModelClassifier, DistilBertModelMaskedLM, DistilBertForTokenClassification, DistilBertForQuestionAnswering};
pub use distilbert::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};

pub use bert::bert::BertConfig;
pub use bert::bert::{BertModel, BertForSequenceClassification, BertForMaskedLM, BertForQuestionAnswering, BertForTokenClassification, BertForMultipleChoice};

pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};
pub use roberta::roberta::{RobertaForSequenceClassification, RobertaForMaskedLM, RobertaForQuestionAnswering, RobertaForTokenClassification, RobertaForMultipleChoice};

pub use pipelines::sentiment::{Sentiment, SentimentPolarity, SentimentClassifier};
2 changes: 1 addition & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::path::PathBuf;
use tch::Device;
use rust_bert::distilbert::sentiment::SentimentClassifier;
use rust_bert::pipelines::sentiment::SentimentClassifier;

extern crate failure;
extern crate dirs;
Expand Down
2 changes: 2 additions & 0 deletions src/pipelines/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pub mod sentiment;
pub mod ner;
101 changes: 101 additions & 0 deletions src/pipelines/ner.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


use rust_tokenizers::bert_tokenizer::BertTokenizer;
use std::path::Path;
use tch::nn::VarStore;
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer};
use crate::{BertForTokenClassification, BertConfig};
use std::collections::HashMap;
use crate::common::config::Config;
use tch::{Tensor, no_grad, Device};
use tch::kind::Kind::Float;


#[derive(Debug)]
pub struct Entity {
pub word: String,
pub score: f64,
pub label: String,
}

pub struct NERModel {
tokenizer: BertTokenizer,
bert_sequence_classifier: BertForTokenClassification,
label_mapping: HashMap<i64, String>,
var_store: VarStore,
}

impl NERModel {
pub fn new(vocab_path: &Path, model_config_path: &Path, model_weight_path: &Path, device: Device)
-> failure::Fallible<NERModel> {
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false);
let mut var_store = VarStore::new(device);
let config = BertConfig::from_file(model_config_path);
let bert_sequence_classifier = BertForTokenClassification::new(&var_store.root(), &config);
let label_mapping = config.id2label.expect("No label dictionary (id2label) provided in configuration file");
var_store.load(model_weight_path)?;
Ok(NERModel { tokenizer, bert_sequence_classifier, label_mapping, var_store })
}

fn prepare_for_model(&self, input: Vec<&str>) -> Tensor {
let tokenized_input = self.tokenizer.encode_list(input.to_vec(),
128,
&TruncationStrategy::LongestFirst,
0);
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap();
let tokenized_input = tokenized_input.
iter().
map(|input| input.token_ids.clone()).
map(|mut input| {
input.extend(vec![0; max_len - input.len()]);
input
}).
map(|input|
Tensor::of_slice(&(input))).
collect::<Vec<_>>();
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device())
}

pub fn predict(&self, input: Vec<&str>) -> Vec<Entity> {
let input_tensor = self.prepare_for_model(input);
let (output, _, _) = no_grad(|| {
self.bert_sequence_classifier
.forward_t(Some(input_tensor.copy()),
None,
None,
None,
None,
false)
});
let output = output.detach().to(Device::Cpu);
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float);
let labels_idx = &score.argmax(-1, true);

let mut entities: Vec<Entity> = vec!();
for sentence_idx in 0..labels_idx.size()[0] {
let labels = labels_idx.get(sentence_idx);
for position_idx in 0..labels.size()[0] {
let label = labels.int64_value(&[position_idx]);
if label != 0 {
entities.push(Entity {
word: rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer::decode(&self.tokenizer, vec!(input_tensor.int64_value(&[sentence_idx, position_idx])), true, true),
score: score.double_value(&[sentence_idx, position_idx, label]),
label: self.label_mapping.get(&label).expect("Index out of vocabulary bounds.").to_owned(),
});
}
}
}
entities
}
}
File renamed without changes.
46 changes: 46 additions & 0 deletions utils/download-dependencies_bert_ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX
from transformers.pipelines import SUPPORTED_TASKS
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

ROOT_PATH = S3_BUCKET_PREFIX + '/' + SUPPORTED_TASKS['ner']['default']['model']['pt']

config_path = ROOT_PATH + '/config.json'
vocab_path = ROOT_PATH + '/vocab.txt'
weights_path = ROOT_PATH + '/pytorch_model.bin'

target_path = Path.home() / 'rustbert' / 'bert-ner'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(str(target_path), exist_ok=True)

config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')

shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)

weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

0 comments on commit 6a3bfee

Please sign in to comment.