Skip to content

Commit

Permalink
T5 dependencies download and configuration setup
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Jun 30, 2020
1 parent d076ec6 commit 8e7696f
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "rust-bert"
version = "0.7.9"
version = "0.7.10"
authors = ["Guillaume Becquin <[email protected]>"]
edition = "2018"
description = "Ready-to-use NLP pipelines and transformer-based models (BERT, DistilBERT, GPT2,...)"
Expand Down Expand Up @@ -30,7 +30,7 @@ all-tests = []
features = [ "doc-only" ]

[dependencies]
rust_tokenizers = "~3.1.4"
rust_tokenizers = {version = "~3.1.5", path = "E:/Coding/backup-rust/rust-tokenizers/main"}
tch = "~0.1.7"
serde_json = "1.0.51"
serde = {version = "1.0.106", features = ["derive"]}
Expand Down
78 changes: 78 additions & 0 deletions examples/t5.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

extern crate failure;

use rust_bert::resources::{download_resource, RemoteResource, Resource};
use rust_bert::t5::{T5Config, T5ConfigResources, T5ModelResources, T5VocabResources};
use rust_bert::Config;
use rust_tokenizers::preprocessing::tokenizer::t5_tokenizer::T5Tokenizer;

fn main() -> failure::Fallible<()> {
// Resources paths
let config_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ConfigResources::T5_SMALL));
let vocab_resource =
Resource::Remote(RemoteResource::from_pretrained(T5VocabResources::T5_SMALL));
let weights_resource =
Resource::Remote(RemoteResource::from_pretrained(T5ModelResources::T5_SMALL));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let _weights_path = download_resource(&weights_resource)?;

// Set-up masked LM model
// let device = Device::Cpu;
// let mut vs = nn::VarStore::new(device);
let _tokenizer: T5Tokenizer = T5Tokenizer::from_file(vocab_path.to_str().unwrap(), true);
let _config = T5Config::from_file(config_path);

// let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
// vs.load(weights_path)?;

// Define input
// let input = [
// "Looks like one [MASK] is missing",
// "It was a very nice and [MASK] day",
// ];
// let tokenized_input =
// tokenizer.encode_list(input.to_vec(), 128, &TruncationStrategy::LongestFirst, 0);
// let max_len = tokenized_input
// .iter()
// .map(|input| input.token_ids.len())
// .max()
// .unwrap();
// let tokenized_input = tokenized_input
// .iter()
// .map(|input| input.token_ids.clone())
// .map(|mut input| {
// input.extend(vec![0; max_len - input.len()]);
// input
// })
// .map(|input| Tensor::of_slice(&(input)))
// .collect::<Vec<_>>();
// let input_tensor = Tensor::stack(tokenized_input.as_slice(), 0).to(device);
//
// // Forward pass
// let (output, _, _) =
// no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
// println!("{:?}", output.double_value(&[0, 0, 0]));
// // Print masked tokens
// let index_1 = output.get(0).get(4).argmax(0, false);
// let index_2 = output.get(1).get(7).argmax(0, false);
// let word_1 = tokenizer.vocab().id_to_token(&index_1.int64_value(&[]));
// let word_2 = tokenizer.vocab().id_to_token(&index_2.int64_value(&[]));
//
// println!("{} - {}", &index_1.int64_value(&[]), word_1); // Outputs "_them" : "Looks like one [them] is missing"
// println!("{} - {}", &index_2.int64_value(&[]), word_2); // Outputs "_enjoyable" : "It was a very nice and [enjoyable] day"

Ok(())
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ pub mod marian;
pub mod openai_gpt;
pub mod pipelines;
pub mod roberta;
pub mod t5;

pub use common::resources;
pub use common::Config;
3 changes: 3 additions & 0 deletions src/t5/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod t5;

pub use t5::{T5Config, T5ConfigResources, T5ModelResources, T5VocabResources};
114 changes: 114 additions & 0 deletions src/t5/t5.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use crate::Config;
use serde::{Deserialize, Serialize};

/// # T5 Pretrained model weight files
pub struct T5ModelResources;

/// # T5 Pretrained model config files
pub struct T5ConfigResources;

/// # T5 Pretrained model vocab files
pub struct T5VocabResources;

impl T5ModelResources {
/// Shared under Apache 2.0 license by the T5 Authors at https://github.com/google-research/text-to-text-transfer-transformer. Modified with conversion to C-array format.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/model.ot",
"https://cdn.huggingface.co/t5-small/rust_model.ot",
);
}

impl T5ConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/config.json",
"https://cdn.huggingface.co/t5-small/config.json",
);
}

impl T5VocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/text-to-text-transfer-transformer.
pub const T5_SMALL: (&'static str, &'static str) = (
"t5-small/spiece.model",
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-spiece.model",
);
}

#[derive(Debug, Serialize, Deserialize)]
/// # T5 model configuration
/// Defines the T5 model architecture (e.g. number of layers, hidden layer size, label mapping...)
pub struct T5Config {
pub dropout_rate: f64,
pub d_model: i64,
pub d_ff: i64,
pub d_kv: i64,
pub decoder_start_token_id: Option<i64>,
pub eos_token_id: Option<i64>,
pub initializer_factor: f64,
pub is_encoder_decoder: Option<bool>,
pub layer_norm_epsilon: f64,
pub n_positions: i64,
pub num_heads: i64,
pub num_layers: i64,
pub output_past: Option<bool>,
pub pad_token_id: Option<i64>,
pub relative_attention_num_buckets: i64,
pub vocab_size: i64,
task_specific_params: TaskSpecificParams,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TaskSpecificParams {
summarization: Summarization,
translation_en_to_de: TranslationEnToDe,
translation_en_to_fr: TranslationEnToFr,
translation_en_to_ro: TranslationEnToRo,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Summarization {
early_stopping: bool,
length_penalty: f64,
max_length: i64,
min_length: i64,
no_repeat_ngram_size: i64,
num_beams: i64,
prefix: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TranslationEnToDe {
early_stopping: bool,
max_length: i64,
num_beams: i64,
prefix: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TranslationEnToFr {
early_stopping: bool,
max_length: i64,
num_beams: i64,
prefix: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct TranslationEnToRo {
early_stopping: bool,
max_length: i64,
num_beams: i64,
prefix: String,
}

impl Config<T5Config> for T5Config {}
48 changes: 48 additions & 0 deletions utils/download-dependencies_t5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from transformers import T5_PRETRAINED_MODEL_ARCHIVE_MAP, T5_PRETRAINED_CONFIG_ARCHIVE_MAP
from transformers.tokenization_t5 import PRETRAINED_VOCAB_FILES_MAP
from transformers.file_utils import get_from_cache
from pathlib import Path
import shutil
import os
import numpy as np
import torch
import subprocess

config_path = T5_PRETRAINED_CONFIG_ARCHIVE_MAP['t5-small']
vocab_path = PRETRAINED_VOCAB_FILES_MAP['vocab_file']['t5-small']
weights_path = T5_PRETRAINED_MODEL_ARCHIVE_MAP['t5-small']

target_path = Path.home() / 'rustbert' / 't5-small'

temp_config = get_from_cache(config_path)
temp_vocab = get_from_cache(vocab_path)
temp_weights = get_from_cache(weights_path)

os.makedirs(str(target_path), exist_ok=True)

config_path = str(target_path / 'config.json')
vocab_path = str(target_path / 'spiece.model')
model_path = str(target_path / 'model.bin')

shutil.copy(temp_config, config_path)
shutil.copy(temp_vocab, vocab_path)
shutil.copy(temp_weights, model_path)

weights = torch.load(temp_weights, map_location='cpu')
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)

source = str(target_path / 'model.npz')
target = str(target_path / 'model.ot')

toml_location = (Path(__file__).resolve() / '..' / '..' / 'Cargo.toml').resolve()

subprocess.call(
['cargo', 'run', '--bin=convert-tensor', '--manifest-path=%s' % toml_location, '--', source, target])

os.remove(str(target_path / 'model.bin'))
os.remove(str(target_path / 'model.npz'))

0 comments on commit 8e7696f

Please sign in to comment.