Skip to content

Commit

Permalink
Keyword extraction pipeline fixes (guillaume-be#393)
Browse files Browse the repository at this point in the history
* - Fix keyword extraction pipeline skipping 1-character tokens
- Fix keyword extraction pipeline excluding all n-grams longer than 2
- Expose new optional `tokenizer_forbidden_ngram_chars` to filter forbidden characters in n-grams. Defaults to punctuation (excluding hyphen)

* expanded punctuation defaults with parentheses
  • Loading branch information
guillaume-be authored Jun 11, 2023
1 parent 75cf52d commit 14b0bb4
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [Unreleased]
## Fixed
- (BREAKING) Fixed the keyword extraction pipeline for n-gram sizes > 2. Add new configuration option `tokenizer_forbidden_ngram_chars` to specify characters that should be excluded from n-grams (allows filtering m-grams spanning multiple sentences).

## [0.21.0] - 2023-06-03
## Added
Expand Down
4 changes: 4 additions & 0 deletions src/pipelines/keywords_extraction/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub struct KeywordExtractionConfig<'a> {
pub tokenizer_stopwords: Option<HashSet<&'a str>>,
/// Optional tokenization regex pattern. Defaults to sequence of word characters.
pub tokenizer_pattern: Option<Regex>,
/// Optional list of characters that should not be included in ngrams (useful to filter ngrams spanning over punctuation marks).
pub tokenizer_forbidden_ngram_chars: Option<&'a [char]>,
/// `KeywordScorerType` used to rank keywords.
pub scorer_type: KeywordScorerType,
/// N-gram range (inclusive) for keywords. (1, 2) would consider all 1 and 2 word gram for keyword candidates.
Expand All @@ -99,6 +101,7 @@ impl Default for KeywordExtractionConfig<'_> {
sentence_embeddings_config,
tokenizer_stopwords: None,
tokenizer_pattern: None,
tokenizer_forbidden_ngram_chars: None,
scorer_type: KeywordScorerType::CosineSimilarity,
ngram_range: (1, 1),
num_keywords: 5,
Expand Down Expand Up @@ -167,6 +170,7 @@ impl<'a> KeywordExtractionModel<'a> {
config.tokenizer_stopwords,
config.tokenizer_pattern,
do_lower_case,
config.tokenizer_forbidden_ngram_chars,
);
Ok(Self {
sentence_embeddings_model,
Expand Down
13 changes: 9 additions & 4 deletions src/pipelines/keywords_extraction/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,32 @@ use rust_tokenizers::{Offset, OffsetSize};
use std::borrow::Cow;
use std::collections::{HashMap, HashSet};

const DEFAULT_REGEX_PATTERN: &str = r"(?u)\b\w\w+\b";
const DEFAULT_REGEX_PATTERN: &str = r"(?u)\b\w+\b";
const PUNCTUATION: [char; 12] = ['.', '!', '?', ',', ':', ';', '(', ')', '[', ']', '[', ']'];

pub struct StopWordsTokenizer<'a> {
stopwords: HashSet<&'a str>,
pattern: Regex,
do_lower_case: bool,
forbidden_ngram_chars: &'a [char],
}

impl<'a> StopWordsTokenizer<'a> {
pub fn new(
stopwords: Option<HashSet<&'a str>>,
pattern: Option<Regex>,
do_lower_case: bool,
forbidden_ngram_patterns: Option<&'a [char]>,
) -> Self {
let stopwords = stopwords.unwrap_or_else(|| HashSet::from(ENGLISH_STOPWORDS));
let pattern = pattern.unwrap_or_else(|| Regex::new(DEFAULT_REGEX_PATTERN).unwrap());
let forbidden_ngram_chars = forbidden_ngram_patterns.unwrap_or(&PUNCTUATION);

Self {
stopwords,
pattern,
do_lower_case,
forbidden_ngram_chars,
}
}

Expand All @@ -50,6 +55,9 @@ impl<'a> StopWordsTokenizer<'a> {
end: ngram.last().unwrap().end,
};
let mut ngram_text = Cow::from(&text[pos.begin as usize..pos.end as usize]);
if ngram_text.contains(self.forbidden_ngram_chars) {
continue 'ngram_loop;
}
if self.do_lower_case {
ngram_text = Cow::from(ngram_text.to_lowercase());
}
Expand All @@ -66,9 +74,6 @@ impl<'a> StopWordsTokenizer<'a> {
continue 'ngram_loop;
}
}
if ngram.last().unwrap().begin > ngram[0].end + 1 {
continue;
}
}
tokenized_text
.entry(ngram_text)
Expand Down
7 changes: 7 additions & 0 deletions src/pipelines/sentence_embeddings/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,13 @@ impl SentenceEmbeddingsModel {
tokens_ids,
tokens_masks,
} = self.tokenize(inputs);
if tokens_ids.is_empty() {
return Err(RustBertError::ValueError(
"No n-gram found in the document. \
Try allowing smaller n-gram sizes or relax stopword/forbidden characters criteria."
.to_string(),
));
}
let tokens_ids = Tensor::stack(&tokens_ids, 0).to(self.var_store.device());
let tokens_masks = Tensor::stack(&tokens_masks, 0).to(self.var_store.device());

Expand Down

0 comments on commit 14b0bb4

Please sign in to comment.