Skip to content

Commit

Permalink
- Addition of all-mini-lm-l6-v2 (guillaume-be#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be authored Nov 9, 2022
1 parent c6771d3 commit 2ffb600
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ All notable changes to this project will be documented in this file. The format

## [Unreleased]

## Added
- Addition of All-MiniLM-L6-V2 model weights

## Changed
- Addition of type aliases for the controlled generation (`PrefixAllowedFunction`) and zero-shot classification (`ZeroShotTemplate`)
- (BREAKING) `merges_resource` now optional for all pipelines
Expand Down
15 changes: 15 additions & 0 deletions src/bert/bert_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ impl BertModelResources {
"all-mini-lm-l12-v2/model",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/rust_model.ot",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/model",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/rust_model.ot",
);
}

impl BertConfigResources {
Expand Down Expand Up @@ -90,6 +95,11 @@ impl BertConfigResources {
"all-mini-lm-l12-v2/config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/config.json",
);
}

impl BertVocabResources {
Expand Down Expand Up @@ -118,6 +128,11 @@ impl BertVocabResources {
"all-mini-lm-l12-v2/vocab",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/vocab.txt",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/vocab",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/vocab.txt",
);
}

#[derive(Debug, Serialize, Deserialize, Clone)]
Expand Down
29 changes: 29 additions & 0 deletions src/pipelines/sentence_embeddings/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,35 @@ impl From<SentenceEmbeddingsModelType> for SentenceEmbeddingsConfig {
device: Device::cuda_if_available(),
},

SentenceEmbeddingsModelType::AllMiniLmL6V2 => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::ALL_MINI_LM_L6_V2,
)),
transformer_type: ModelType::Bert,
transformer_config_resource: Box::new(RemoteResource::from_pretrained(
BertConfigResources::ALL_MINI_LM_L6_V2,
)),
transformer_weights_resource: Box::new(RemoteResource::from_pretrained(
BertModelResources::ALL_MINI_LM_L6_V2,
)),
pooling_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsPoolingConfigResources::ALL_MINI_LM_L6_V2,
)),
dense_config_resource: None,
dense_weights_resource: None,
sentence_bert_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsConfigResources::ALL_MINI_LM_L6_V2,
)),
tokenizer_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsTokenizerConfigResources::ALL_MINI_LM_L6_V2,
)),
tokenizer_vocab_resource: Box::new(RemoteResource::from_pretrained(
BertVocabResources::ALL_MINI_LM_L6_V2,
)),
tokenizer_merges_resource: None,
device: Device::cuda_if_available(),
},

SentenceEmbeddingsModelType::AllDistilrobertaV1 => SentenceEmbeddingsConfig {
modules_config_resource: Box::new(RemoteResource::from_pretrained(
SentenceEmbeddingsModulesConfigResources::ALL_DISTILROBERTA_V1,
Expand Down
25 changes: 23 additions & 2 deletions src/pipelines/sentence_embeddings/resources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub enum SentenceEmbeddingsModelType {
DistiluseBaseMultilingualCased,
BertBaseNliMeanTokens,
AllMiniLmL12V2,
AllMiniLmL6V2,
AllDistilrobertaV1,
ParaphraseAlbertSmallV2,
SentenceT5Base,
Expand All @@ -41,6 +42,11 @@ impl SentenceEmbeddingsModulesConfigResources {
"all-mini-lm-l12-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/modules.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/modules.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-config",
Expand Down Expand Up @@ -100,6 +106,11 @@ impl SentenceEmbeddingsPoolingConfigResources {
"all-mini-lm-l12-v2/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/sbert-pooling-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/1_Pooling/config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-pooling-config",
Expand Down Expand Up @@ -128,11 +139,16 @@ impl SentenceEmbeddingsConfigResources {
"bert-base-nli-mean-tokens/sbert-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/sbert-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/sentence_bert_config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/sbert-config",
Expand Down Expand Up @@ -161,11 +177,16 @@ impl SentenceEmbeddingsTokenizerConfigResources {
"bert-base-nli-mean-tokens/tokenizer-config",
"https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/bert-base-nli-mean-tokens>. Modified with conversion to C-array format.
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L12_V2: (&'static str, &'static str) = (
"all-mini-lm-l12-v2/tokenizer-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L12-v2/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 license at <https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2>. Modified with conversion to C-array format.
pub const ALL_MINI_LM_L6_V2: (&'static str, &'static str) = (
"all-mini-lm-l6-v2/tokenizer-config",
"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/tokenizer_config.json",
);
/// Shared under Apache 2.0 licenseat <https://huggingface.co/sentence-transformers/all-distilroberta-v1>. Modified with conversion to C-array format.
pub const ALL_DISTILROBERTA_V1: (&'static str, &'static str) = (
"all-distilroberta-v1/tokenizer-config",
Expand Down

0 comments on commit 2ffb600

Please sign in to comment.