forked from guillaume-be/rust-bert
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
T5 dependencies download and configuration setup
- Loading branch information
1 parent
d076ec6
commit 8e7696f
Showing
6 changed files
with
246 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[package] | ||
name = "rust-bert" | ||
version = "0.7.9" | ||
version = "0.7.10" | ||
authors = ["Guillaume Becquin <[email protected]>"] | ||
edition = "2018" | ||
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)" | ||
|
@@ -30,7 +30,7 @@ all-tests = [] | |
features = [ "doc-only" ] | ||
|
||
[dependencies] | ||
rust_tokenizers = "~3.1.4" | ||
rust_tokenizers = {version = "~3.1.5", path = "E:/Coding/backup-rust/rust-tokenizers/main"} | ||
tch = "~0.1.7" | ||
serde_json = "1.0.51" | ||
serde = {version = "1.0.106", features = ["derive"]} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. | ||
// Copyright 2020 Guillaume Becquin | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
extern crate failure; | ||
|
||
use rust_bert::resources::{download_resource, RemoteResource, Resource}; | ||
use rust_bert::t5::{T5Config, T5ConfigResources, T5ModelResources, T5VocabResources}; | ||
use rust_bert::Config; | ||
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer; | ||
|
||
fn main() -> failure::Fallible<()> { | ||
// Resources paths | ||
let config_resource = | ||
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL)); | ||
let vocab_resource = | ||
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL)); | ||
let weights_resource = | ||
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL)); | ||
let config_path = download_resource(&config_resource)?; | ||
let vocab_path = download_resource(&vocab_resource)?; | ||
let _weights_path = download_resource(&weights_resource)?; | ||
|
||
// Set-up masked LM model | ||
// let device = Device::Cpu; | ||
// let mut vs = nn::VarStore::new(device); | ||
let _tokenizer: T5Tokenizer = T5Tokenizer::from_file(vocab_path.to_str().unwrap(), true); | ||
let _config = T5Config::from_file(config_path); | ||
|
||
// let albert_model = AlbertForMaskedLM::new(&vs.root(), &config); | ||
// vs.load(weights_path)?; | ||
|
||
// Define input | ||
// let input = [ | ||
// "Looks like one [MASK] is missing", | ||
// "It was a very nice and [MASK] day", | ||
// ]; | ||
// let tokenized_input = | ||
// tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0); | ||
// let max_len = tokenized_input | ||
// .iter() | ||
// .map(|input| input.token_ids.len()) | ||
// .max() | ||
// .unwrap(); | ||
// let tokenized_input = tokenized_input | ||
// .iter() | ||
// .map(|input| input.token_ids.clone()) | ||
// .map(|mut input| { | ||
// input.extend(vec![0; max_len - input.len()]); | ||
// input | ||
// }) | ||
// .map(|input| Tensor::of_slice(&(input))) | ||
// .collect::<Vec<_>>(); | ||
// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device); | ||
// | ||
// // Forward pass | ||
// let (output, _, _) = | ||
// no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false)); | ||
// println!("{:?}", output.double_value(&[0, 0, 0])); | ||
// // Print masked tokens | ||
// let index_1 = output.get(0).get(4).argmax(0, false); | ||
// let index_2 = output.get(1).get(7).argmax(0, false); | ||
// let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[])); | ||
// let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[])); | ||
// | ||
// println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing" | ||
// println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day" | ||
|
||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
mod t5; | ||
|
||
pub use t5::{T5Config, T5ConfigResources, T5ModelResources, T5VocabResources}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team. | ||
// Copyright 2020 Guillaume Becquin | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
use crate::Config; | ||
use serde::{Deserialize, Serialize}; | ||
|
||
/// # T5 Pretrained model weight files | ||
pub struct T5ModelResources; | ||
|
||
/// # T5 Pretrained model config files | ||
pub struct T5ConfigResources; | ||
|
||
/// # T5 Pretrained model vocab files | ||
pub struct T5VocabResources; | ||
|
||
impl T5ModelResources { | ||
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format. | ||
pub const T5_SMALL: (&'static str, &'static str) = ( | ||
"t5-small/model.ot", | ||
"https://cdn.huggingface.co/t5-small/rust_model.ot", | ||
); | ||
} | ||
|
||
impl T5ConfigResources { | ||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. | ||
pub const T5_SMALL: (&'static str, &'static str) = ( | ||
"t5-small/config.json", | ||
"https://cdn.huggingface.co/t5-small/config.json", | ||
); | ||
} | ||
|
||
impl T5VocabResources { | ||
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer. | ||
pub const T5_SMALL: (&'static str, &'static str) = ( | ||
"t5-small/spiece.model", | ||
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model", | ||
); | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
/// # T5 model configuration | ||
/// Defines the T5 model architecture (e.g. number of layers, hidden layer size, label mapping...) | ||
pub struct T5Config { | ||
pub dropout_rate: f64, | ||
pub d_model: i64, | ||
pub d_ff: i64, | ||
pub d_kv: i64, | ||
pub decoder_start_token_id: Option<i64>, | ||
pub eos_token_id: Option<i64>, | ||
pub initializer_factor: f64, | ||
pub is_encoder_decoder: Option<bool>, | ||
pub layer_norm_epsilon: f64, | ||
pub n_positions: i64, | ||
pub num_heads: i64, | ||
pub num_layers: i64, | ||
pub output_past: Option<bool>, | ||
pub pad_token_id: Option<i64>, | ||
pub relative_attention_num_buckets: i64, | ||
pub vocab_size: i64, | ||
task_specific_params: TaskSpecificParams, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TaskSpecificParams { | ||
summarization: Summarization, | ||
translation_en_to_de: TranslationEnToDe, | ||
translation_en_to_fr: TranslationEnToFr, | ||
translation_en_to_ro: TranslationEnToRo, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct Summarization { | ||
early_stopping: bool, | ||
length_penalty: f64, | ||
max_length: i64, | ||
min_length: i64, | ||
no_repeat_ngram_size: i64, | ||
num_beams: i64, | ||
prefix: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TranslationEnToDe { | ||
early_stopping: bool, | ||
max_length: i64, | ||
num_beams: i64, | ||
prefix: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TranslationEnToFr { | ||
early_stopping: bool, | ||
max_length: i64, | ||
num_beams: i64, | ||
prefix: String, | ||
} | ||
|
||
#[derive(Debug, Serialize, Deserialize)] | ||
pub struct TranslationEnToRo { | ||
early_stopping: bool, | ||
max_length: i64, | ||
num_beams: i64, | ||
prefix: String, | ||
} | ||
|
||
impl Config<T5Config> for T5Config {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
from transformers import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP | ||
from transformers.tokenization_t5 import PRETRAINED_VOCAB_FILES_MAP | ||
from transformers.file_utils import get_from_cache | ||
from pathlib import Path | ||
import shutil | ||
import os | ||
import numpy as np | ||
import torch | ||
import subprocess | ||
|
||
config_path = T5_PRETRAINED_CONFIG_ARCHIVE_MAP['t5-small'] | ||
vocab_path = PRETRAINED_VOCAB_FILES_MAP['vocab_file']['t5-small'] | ||
weights_path = T5_PRETRAINED_MODEL_ARCHIVE_MAP['t5-small'] | ||
|
||
target_path = Path.home() / 'rustbert' / 't5-small' | ||
|
||
temp_config = get_from_cache(config_path) | ||
temp_vocab = get_from_cache(vocab_path) | ||
temp_weights = get_from_cache(weights_path) | ||
|
||
os.makedirs(str(target_path), exist_ok=True) | ||
|
||
config_path = str(target_path / 'config.json') | ||
vocab_path = str(target_path / 'spiece.model') | ||
model_path = str(target_path / 'model.bin') | ||
|
||
shutil.copy(temp_config, config_path) | ||
shutil.copy(temp_vocab, vocab_path) | ||
shutil.copy(temp_weights, model_path) | ||
|
||
weights = torch.load(temp_weights, map_location='cpu') | ||
nps = {} | ||
for k, v in weights.items(): | ||
k = k.replace("gamma", "weight").replace("beta", "bias") | ||
nps[k] = np.ascontiguousarray(v.cpu().numpy()) | ||
|
||
np.savez(target_path / 'model.npz', **nps) | ||
|
||
source = str(target_path / 'model.npz') | ||
target = str(target_path / 'model.ot') | ||
|
||
toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve() | ||
|
||
subprocess.call( | ||
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target]) | ||
|
||
os.remove(str(target_path / 'model.bin')) | ||
os.remove(str(target_path / 'model.npz')) |