Skip to content

Commit

Permalink
Addition of ElectraForMaskedLM
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed May 1, 2020
1 parent 5bec254 commit 4334fa1
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 9 deletions.
16 changes: 11 additions & 5 deletions examples/electra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

use rust_bert::resources::{LocalResource, Resource, download_resource};
use std::path::PathBuf;
use rust_bert::electra::electra::{ElectraConfig, ElectraModel};
use rust_bert::electra::electra::{ElectraConfig, ElectraForMaskedLM};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad};

fn main() -> failure::Fallible<()> {
Expand All @@ -38,7 +38,7 @@ fn main() -> failure::Fallible<()> {
let mut vs = nn::VarStore::new(device);
let tokenizer: BertTokenizer = BertTokenizer::from_file(vocab_path.to_str().unwrap(), true);
let config = ElectraConfig::from_file(config_path);
let electra_model = ElectraModel::new(&(&vs.root() / "electra"), &config);
let electra_model = ElectraForMaskedLM::new(&vs.root(), &config);
vs.load(weights_path)?;

// Define input
Expand Down Expand Up @@ -66,10 +66,16 @@ fn main() -> failure::Fallible<()> {
None,
None,
false)
.unwrap()
});

output.print();
// Print masked tokens
let index_1 = output.get(0).get(4).argmax(0, false);
let index_2 = output.get(1).get(7).argmax(0, false);
let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));

println!("{}", word_1); // Outputs "thing" : "Looks like one [thing] is missing"
println!("{}", word_2);// Outputs "sunny" : "It was a very nice and [sunny] day"

Ok(())
}
34 changes: 34 additions & 0 deletions src/electra/electra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,38 @@ impl ElectraGeneratorHead {
let output = (self.activation)(&output);
output.apply(&self.layer_norm)
}
}

pub struct ElectraForMaskedLM {
electra: ElectraModel,
generator_head: ElectraGeneratorHead,
lm_head: nn::Linear,
}

impl ElectraForMaskedLM {
pub fn new(p: &nn::Path, config: &ElectraConfig) -> ElectraForMaskedLM {
let electra = ElectraModel::new(&(p / "electra"), config);
let generator_head = ElectraGeneratorHead::new(&(p / "generator_predictions"), config);
let lm_head = nn::linear(&(p / "generator_lm_head"), config.embedding_size, config.vocab_size, Default::default());

ElectraForMaskedLM { electra, generator_head, lm_head }
}

pub fn forward_t(&self,
input_ids: Option<Tensor>,
mask: Option<Tensor>,
token_type_ids: Option<Tensor>,
position_ids: Option<Tensor>,
input_embeds: Option<Tensor>,
train: bool)
-> (Tensor, Option<Vec<Tensor>>, Option<Vec<Tensor>>) {
let (hidden_states,
all_hidden_states,
all_attentions) = self.electra
.forward_t(input_ids, mask, token_type_ids, position_ids, input_embeds, train)
.unwrap();
let hidden_states = self.generator_head.forward(&hidden_states);
let hidden_states = hidden_states.apply(&self.lm_head);
(hidden_states, all_hidden_states, all_attentions)
}
}
5 changes: 4 additions & 1 deletion src/electra/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ impl ElectraEmbeddings {
None => (input_value.apply_t(&self.word_embeddings, train), input_value.size())
}
None => match input_embeds {
Some(embeds) => (embeds.copy(), vec!(embeds.size()[0], embeds.size()[1])),
Some(embeds) => {
let size = vec!(embeds.size()[0], embeds.size()[1]);
(embeds, size)
},
None => { return Err("Only one of input ids or input embeddings may be set"); }
}
};
Expand Down
8 changes: 5 additions & 3 deletions utils/download-dependencies_electra.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from transformers import ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP, ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_electra import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
Expand All @@ -6,9 +8,9 @@
import torch
import subprocess

config_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/config.json"
vocab_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/vocab.txt"
weights_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/pytorch_model.bin"
config_path = ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP["google/electra-base-generator"]
vocab_path = PRETRAINED_VOCAB_FILES_MAP["vocab_file"]["google/electra-base-generator"]
weights_path = ELECTRA_PRETRAINED_MODEL_ARCHIVE_MAP["google/electra-base-generator"]

target_path = Path.home() / 'rustbert' / 'electra'

Expand Down
46 changes: 46 additions & 0 deletions utils/download-dependencies_electra_ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

config_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/config.json"
vocab_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/vocab.txt"
weights_path = "https://s3.amazonaws.com/models.huggingface.co/bert/dbmdz/electra-large-discriminator-finetuned-conll03-english/pytorch_model.bin"

target_path = Path.home() / 'rustbert' / 'electra-ner'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(str(target_path), exist_ok=True)

config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'vocab.txt')
model_path = str(target_path / 'model.bin')

shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)

weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

0 comments on commit 4334fa1

Please sign in to comment.