forked from guillaume-be/rust-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added pre-trained NER model to pipelines
- Loading branch information
1 parent
6282536
commit 6a3bfee
Showing
11 changed files
with
231 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
.idea/* | ||
# Generated by Cargo | ||
# will have compiled files and executables | ||
/target/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
extern crate failure; | ||
extern crate dirs; | ||
|
||
use std::path::PathBuf; | ||
use rust_bert::pipelines::ner::NERModel; | ||
use tch::Device; | ||
|
||
|
||
fn main() -> failure::Fallible<()> { | ||
// Resources paths | ||
let mut home: PathBuf = dirs::home_dir().unwrap(); | ||
home.push("rustbert"); | ||
home.push("bert-ner"); | ||
let config_path = &home.as_path().join("config.json"); | ||
let vocab_path = &home.as_path().join("vocab.txt"); | ||
let weights_path = &home.as_path().join("model.ot"); | ||
|
||
// Set-up model | ||
let device = Device::cuda_if_available(); | ||
let ner_model = NERModel::new(vocab_path, | ||
config_path, | ||
weights_path, device)?; | ||
|
||
// Define input | ||
let input = [ | ||
"My name is Amy. I live in Paris.", | ||
"Paris is a city in France." | ||
]; | ||
|
||
// Run model | ||
let output = ner_model.predict(input.to_vec()); | ||
for entity in output { | ||
println!("{:?}", entity); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
extern crate failure; | ||
extern crate dirs; | ||
|
||
use std::path::PathBuf; | ||
use rust_bert::pipelines::ner::NERModel; | ||
use tch::Device; | ||
|
||
|
||
fn main() -> failure::Fallible<()> { | ||
// Resources paths | ||
let mut home: PathBuf = dirs::home_dir().unwrap(); | ||
home.push("rustbert"); | ||
home.push("bert-ner"); | ||
let config_path = &home.as_path().join("config.json"); | ||
let vocab_path = &home.as_path().join("vocab.txt"); | ||
let weights_path = &home.as_path().join("model.ot"); | ||
|
||
// Set-up model | ||
let device = Device::cuda_if_available(); | ||
let ner_model = NERModel::new(vocab_path, | ||
config_path, | ||
weights_path, device)?; | ||
|
||
// Define input | ||
let input = [ | ||
"My name is Amy. I live in Paris.", | ||
"Paris is a city in France." | ||
]; | ||
|
||
// Run model | ||
let output = ner_model.predict(input.to_vec()); | ||
for entity in output { | ||
println!("{:?}", entity); | ||
} | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,4 +2,3 @@ pub mod distilbert; | |
mod embeddings; | ||
mod attention; | ||
mod transformer; | ||
pub mod sentiment; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pub mod sentiment; | ||
pub mod ner; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc. | ||
// Copyright 2019 Guillaume Becquin | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
|
||
use rust_tokenizers::bert_tokenizer::BertTokenizer; | ||
use std::path::Path; | ||
use tch::nn::VarStore; | ||
use rust_tokenizers::preprocessing::tokenizer::base_tokenizer::{TruncationStrategy, MultiThreadedTokenizer}; | ||
use crate::{BertForTokenClassification, BertConfig}; | ||
use std::collections::HashMap; | ||
use crate::common::config::Config; | ||
use tch::{Tensor, no_grad, Device}; | ||
use tch::kind::Kind::Float; | ||
|
||
|
||
#[derive(Debug)] | ||
pub struct Entity { | ||
pub word: String, | ||
pub score: f64, | ||
pub label: String, | ||
} | ||
|
||
pub struct NERModel { | ||
tokenizer: BertTokenizer, | ||
bert_sequence_classifier: BertForTokenClassification, | ||
label_mapping: HashMap<i64, String>, | ||
var_store: VarStore, | ||
} | ||
|
||
impl NERModel { | ||
pub fn new(vocab_path: &Path, model_config_path: &Path, model_weight_path: &Path, device: Device) | ||
-> failure::Fallible<NERModel> { | ||
let tokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), false); | ||
let mut var_store = VarStore::new(device); | ||
let config = BertConfig::from_file(model_config_path); | ||
let bert_sequence_classifier = BertForTokenClassification::new(&var_store.root(), &config); | ||
let label_mapping = config.id2label.expect("No label dictionary (id2label) provided in configuration file"); | ||
var_store.load(model_weight_path)?; | ||
Ok(NERModel { tokenizer, bert_sequence_classifier, label_mapping, var_store }) | ||
} | ||
|
||
fn prepare_for_model(&self, input: Vec<&str>) -> Tensor { | ||
let tokenized_input = self.tokenizer.encode_list(input.to_vec(), | ||
128, | ||
&TruncationStrategy::LongestFirst, | ||
0); | ||
let max_len = tokenized_input.iter().map(|input| input.token_ids.len()).max().unwrap(); | ||
let tokenized_input = tokenized_input. | ||
iter(). | ||
map(|input| input.token_ids.clone()). | ||
map(|mut input| { | ||
input.extend(vec![0; max_len - input.len()]); | ||
input | ||
}). | ||
map(|input| | ||
Tensor::of_slice(&(input))). | ||
collect::<Vec<_>>(); | ||
Tensor::stack(tokenized_input.as_slice(), 0).to(self.var_store.device()) | ||
} | ||
|
||
pub fn predict(&self, input: Vec<&str>) -> Vec<Entity> { | ||
let input_tensor = self.prepare_for_model(input); | ||
let (output, _, _) = no_grad(|| { | ||
self.bert_sequence_classifier | ||
.forward_t(Some(input_tensor.copy()), | ||
None, | ||
None, | ||
None, | ||
None, | ||
false) | ||
}); | ||
let output = output.detach().to(Device::Cpu); | ||
let score: Tensor = output.exp() / output.exp().sum1(&[-1], true, Float); | ||
let labels_idx = &score.argmax(-1, true); | ||
|
||
let mut entities: Vec<Entity> = vec!(); | ||
for sentence_idx in 0..labels_idx.size()[0] { | ||
let labels = labels_idx.get(sentence_idx); | ||
for position_idx in 0..labels.size()[0] { | ||
let label = labels.int64_value(&[position_idx]); | ||
if label != 0 { | ||
entities.push(Entity { | ||
word: rust_tokenizers::preprocessing::tokenizer::base_tokenizer::Tokenizer::decode(&self.tokenizer, vec!(input_tensor.int64_value(&[sentence_idx, position_idx])), true, true), | ||
score: score.double_value(&[sentence_idx, position_idx, label]), | ||
label: self.label_mapping.get(&label).expect("Index out of vocabulary bounds.").to_owned(), | ||
}); | ||
} | ||
} | ||
} | ||
entities | ||
} | ||
} |
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from transformers.file_utils import get_from_cache, S3_BUCKET_PREFIX | ||
from transformers.pipelines import SUPPORTED_TASKS | ||
from pathlib import Path | ||
import shutil | ||
import os | ||
import numpy as np | ||
import torch | ||
import subprocess | ||
|
||
ROOT_PATH = S3_BUCKET_PREFIX + '/' + SUPPORTED_TASKS['ner']['default']['model']['pt'] | ||
|
||
config_path = ROOT_PATH + '/config.json' | ||
vocab_path = ROOT_PATH + '/vocab.txt' | ||
weights_path = ROOT_PATH + '/pytorch_model.bin' | ||
|
||
target_path = Path.home() / 'rustbert' / 'bert-ner' | ||
|
||
temp_config = get_from_cache(config_path) | ||
temp_vocab = get_from_cache(vocab_path) | ||
temp_weights = get_from_cache(weights_path) | ||
|
||
os.makedirs(str(target_path), exist_ok=True) | ||
|
||
config_path = str(target_path / 'config.json') | ||
vocab_path = str(target_path / 'vocab.txt') | ||
model_path = str(target_path / 'model.bin') | ||
|
||
shutil.copy(temp_config, config_path) | ||
shutil.copy(temp_vocab, vocab_path) | ||
shutil.copy(temp_weights, model_path) | ||
|
||
weights = torch.load(temp_weights, map_location='cpu') | ||
nps = {} | ||
for k, v in weights.items(): | ||
k = k.replace("gamma", "weight").replace("beta", "bias") | ||
nps[k] = np.ascontiguousarray(v.cpu().numpy()) | ||
|
||
np.savez(target_path / 'model.npz', **nps) | ||
|
||
source = str(target_path / 'model.npz') | ||
target = str(target_path / 'model.ot') | ||
|
||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() | ||
|
||
subprocess.call( | ||
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target]) |