Skip to content

Commit

Permalink
Addition of BERT option for QA
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Jul 8, 2020
1 parent 2dad825 commit e6938e1
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
57 changes: 57 additions & 0 deletions examples/question_answering_bert.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright 2019-present, the HuggingFace Inc. team, The Google AI Language Team and Facebook, Inc.
// Copyright 2019 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate failure;

use rust_bert::bert::{BertConfigResources, BertModelResources, BertVocabResources};
use rust_bert::pipelines::common::ModelType;
use rust_bert::pipelines::question_answering::{
QaInput, QuestionAnsweringConfig, QuestionAnsweringModel,
};
use rust_bert::resources::{RemoteResource, Resource};

fn main() -> failure::Fallible<()> {
// Set-up Question Answering model
let config = QuestionAnsweringConfig::new(
ModelType::Bert,
Resource::Remote(RemoteResource::from_pretrained(BertModelResources::BERT_QA)),
Resource::Remote(RemoteResource::from_pretrained(
BertConfigResources::BERT_NER,
)),
Resource::Remote(RemoteResource::from_pretrained(
BertVocabResources::BERT_NER,
)),
None, //merges resource only relevant with ModelType::Roberta
false, //lowercase
);

let qa_model = QuestionAnsweringModel::new(config)?;

// Define input
let question_1 = String::from("Where does Amy live ?");
let context_1 = String::from("Amy lives in Amsterdam");
let question_2 = String::from("Where does Eric live");
let context_2 = String::from("While Amy lives in Amsterdam, Eric is in The Hague.");
let qa_input_1 = QaInput {
question: question_1,
context: context_1,
};
let qa_input_2 = QaInput {
question: question_2,
context: context_2,
};

// Get answer
let answers = qa_model.predict(&vec![qa_input_1, qa_input_2], 1, 32);
println!("{:?}", answers);
Ok(())
}
15 changes: 15 additions & 0 deletions src/bert/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ impl BertModelResources {
"bert-ner/model.ot",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/rust_model.ot",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/model.ot",
"https://cdn.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad/rust_model.ot",
);
}

impl BertConfigResources {
Expand All @@ -57,6 +62,11 @@ impl BertConfigResources {
"bert-ner/config.json",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/config.json",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/config.json",
"https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
);
}

impl BertVocabResources {
Expand All @@ -70,6 +80,11 @@ impl BertVocabResources {
"bert-ner/vocab.txt",
"https://cdn.huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english/vocab.txt",
);
/// Shared under Apache 2.0 license by Hugging Face Inc at https://github.com/huggingface/transformers/tree/master/examples/question-answering. Modified with conversion to C-array format.
pub const BERT_QA: (&'static str, &'static str) = (
"bert-qa/vocab.txt",
"https://cdn.huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad-vocab.txt",
);
}

#[allow(non_camel_case_types)]
Expand Down
31 changes: 31 additions & 0 deletions src/pipelines/question_answering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,37 @@ pub struct QuestionAnsweringConfig {
pub lower_case: bool,
}

impl QuestionAnsweringConfig {
/// Instantiate a new question answering configuration of the supplied type.
///
/// # Arguments
///
/// * `model_type` - `ModelType` indicating the model type to load (must match with the actual data to be loaded!)
/// * model_resource - The `Resource` pointing to the model to load (e.g. model.ot)
/// * config_resource - The `Resource' pointing to the model configuration to load (e.g. config.json)
/// * vocab_resource - The `Resource' pointing to the tokenizer's vocabulary to load (e.g. vocab.txt/vocab.json)
/// * merges_resource - An optional `Resource` tuple (`Option<Resource>`) pointing to the tokenizer's merge file to load (e.g. merges.txt), needed only for Roberta.
/// * lower_case - A `bool' indicating whether the tokeniser should lower case all input (in case of a lower-cased model)
pub fn new(
model_type: ModelType,
model_resource: Resource,
config_resource: Resource,
vocab_resource: Resource,
merges_resource: Option<Resource>,
lower_case: bool,
) -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
model_type,
model_resource,
config_resource,
vocab_resource,
merges_resource,
lower_case,
device: Device::cuda_if_available(),
}
}
}

impl Default for QuestionAnsweringConfig {
fn default() -> QuestionAnsweringConfig {
QuestionAnsweringConfig {
Expand Down
48 changes: 48 additions & 0 deletions utils/download-dependencies_bert_qa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from transformers import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_MODEL_ARCHIVE_MAP
from transformers.tokenization_bert import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

config_path = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP["bert-large-cased-whole-word-masking-finetuned-squad"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["bert-large-cased-whole-word-masking-finetuned-squad"]
weights_path = BERT_PRETRAINED_MODEL_ARCHIVE_MAP["bert-large-cased-whole-word-masking-finetuned-squad"]

target_path = Path.home() / 'rustbert' / 'bert-qa'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(str(target_path), exist_ok=True)

config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')

shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)

weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

0 comments on commit e6938e1

Please sign in to comment.