Skip to content

Commit

Permalink
Addition of Electra resources
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed May 3, 2020
1 parent 5a1c1ae commit 029d4bd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 21 deletions.
15 changes: 5 additions & 10 deletions examples/electra_discriminator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@
// limitations under the License.


use rust_bert::resources::{LocalResource, Resource, download_resource};
use std::path::PathBuf;
use rust_bert::electra::electra::{ElectraConfig, ElectraDiscriminator};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::electra::{ElectraConfig, ElectraDiscriminator, ElectraConfigResources, ElectraVocabResources, ElectraModelResources};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy};
use tch::{Tensor, Device, nn, no_grad};

fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("electra-discriminator");

let config_resource = Resource::Local(LocalResource { local_path: home.as_path().join("config.json") });
let vocab_resource = Resource::Local(LocalResource { local_path: home.as_path().join("vocab.txt") });
let weights_resource = Resource::Local(LocalResource { local_path: home.as_path().join("model.ot") });
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_DISCRIMINATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_DISCRIMINATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_DISCRIMINATOR));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
Expand Down
15 changes: 5 additions & 10 deletions examples/electra_masked_lm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,17 @@
// limitations under the License.


use rust_bert::resources::{LocalResource, Resource, download_resource};
use std::path::PathBuf;
use rust_bert::electra::electra::{ElectraConfig, ElectraForMaskedLM};
use rust_bert::resources::{Resource, download_resource, RemoteResource};
use rust_bert::electra::electra::{ElectraConfig, ElectraForMaskedLM, ElectraModelResources, ElectraConfigResources, ElectraVocabResources};
use rust_bert::Config;
use rust_tokenizers::{BertTokenizer, Tokenizer, TruncationStrategy, Vocab};
use tch::{Tensor, Device, nn, no_grad};

fn main() -> failure::Fallible<()> {
// Resources paths
let mut home: PathBuf = dirs::home_dir().unwrap();
home.push("rustbert");
home.push("electra-generator");

let config_resource = Resource::Local(LocalResource {local_path: home.as_path().join("config.json")});
let vocab_resource = Resource::Local(LocalResource {local_path: home.as_path().join("vocab.txt")});
let weights_resource = Resource::Local(LocalResource {local_path: home.as_path().join("model.ot")});
let config_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraConfigResources::BASE_GENERATOR));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraVocabResources::BASE_GENERATOR));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(ElectraModelResources::BASE_GENERATOR));
let config_path = download_resource(&config_resource)?;
let vocab_path = download_resource(&vocab_resource)?;
let weights_path = download_resource(&weights_resource)?;
Expand Down
30 changes: 30 additions & 0 deletions src/electra/electra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,36 @@ use crate::bert::encoder::BertEncoder;
use crate::common::activations::{_gelu, _relu, _mish};
use crate::common::dropout::Dropout;

/// # Electra Pretrained model weight files
pub struct ElectraModelResources;

/// # Electra Pretrained model config files
pub struct ElectraConfigResources;

/// # Electra Pretrained model vocab files
pub struct ElectraVocabResources;

impl ElectraModelResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/model.ot", "https://cdn.huggingface.co/google/electra-base-generator/rust_model.ot");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/model.ot", "https://cdn.huggingface.co/google/electra-base-discriminator/rust_model.ot");
}

impl ElectraConfigResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/config.json", "https://cdn.huggingface.co/google/electra-base-generator/config.json");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/config.json", "https://cdn.huggingface.co/google/electra-base-discriminator/config.json");
}

impl ElectraVocabResources {
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_GENERATOR: (&'static str, &'static str) = ("electra-base-generator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-generator/vocab.txt");
/// Shared under Apache 2.0 license by the Google team at https://github.com/google-research/electra. Modified with conversion to C-array format.
pub const BASE_DISCRIMINATOR: (&'static str, &'static str) = ("electra-base-discriminator/vocab.txt", "https://cdn.huggingface.co/google/electra-base-discriminator/vocab.txt");
}

#[derive(Debug, Serialize, Deserialize)]
/// # Electra model configuration
/// Defines the Electra model architecture (e.g. number of layers, hidden layer size, label mapping...)
Expand Down
8 changes: 8 additions & 0 deletions tests/electra.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@




//#[test]
//fn electra_masked_lm() -> failure::Fallible<()> {
//
//}
1 change: 0 additions & 1 deletion utils/download-dependencies_electra-discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
nps = {}
for k, v in weights.items():
k = k.replace("gamma", "weight").replace("beta", "bias")
print(k)
nps[k] = np.ascontiguousarray(v.cpu().numpy())

np.savez(target_path / 'model.npz', **nps)
Expand Down

0 comments on commit 029d4bd

Please sign in to comment.