Skip to content

Commit

Permalink
multi-label zero-shot classification implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Sep 5, 2020
1 parent 7914c01 commit 0ea9148
Showing 1 changed file with 84 additions and 24 deletions.
108 changes: 84 additions & 24 deletions src/pipelines/zero_shot_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -395,17 +395,17 @@ impl ZeroShotClassificationModel {
.map(|input| Tensor::of_slice(&(input)))
.collect::<Vec<_>>();

let mask = input_tensor
let tokenized_input_tensors =
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device());

let mask = tokenized_input_tensors
.ne(self
.tokenizer
.get_pad_id()
.expect("The Tokenizer used for zero shot classification should contain a PAD id"))
.to_kind(Bool);

(
Tensor::stack(tokenized_input_tensors.as_slice(), 0).to(self.var_store.device()),
mask,
)
(tokenized_input_tensors, mask)
}

/// Zero shot classification with 1 (and exactly 1) true label.
Expand Down Expand Up @@ -511,7 +511,7 @@ impl ZeroShotClassificationModel {
/// let input_sequence_2 = "The central bank is meeting today to discuss monetary policy.";
/// let candidate_labels = &["politics", "public health", "economics", "sports"];
///
/// let output = sequence_classification_model.predict(
/// let output = sequence_classification_model.predict_multilabel(
/// &[input_sentence, input_sequence_2],
/// candidate_labels,
/// None,
Expand All @@ -520,6 +520,64 @@ impl ZeroShotClassificationModel {
/// # Ok(())
/// # }
/// ```
/// outputs:
/// ```no_run
/// # use rust_bert::pipelines::sequence_classification::Label;
/// let output = [
/// [
/// Label {
/// text: "politics".to_string(),
/// score: 0.972,
/// id: 0,
/// sentence: 0,
/// },
/// Label {
/// text: "public health".to_string(),
/// score: 0.032,
/// id: 1,
/// sentence: 0,
/// },
/// Label {
/// text: "economics".to_string(),
/// score: 0.006,
/// id: 2,
/// sentence: 0,
/// },
/// Label {
/// text: "sports".to_string(),
/// score: 0.004,
/// id: 3,
/// sentence: 0,
/// },
/// ],
/// [
/// Label {
/// text: "politics".to_string(),
/// score: 0.975,
/// id: 0,
/// sentence: 1,
/// },
/// Label {
/// text: "economics".to_string(),
/// score: 0.852,
/// id: 2,
/// sentence: 1,
/// },
/// Label {
/// text: "public health".to_string(),
/// score: 0.0818,
/// id: 1,
/// sentence: 1,
/// },
/// Label {
/// text: "sports".to_string(),
/// score: 0.001,
/// id: 3,
/// sentence: 1,
/// },
/// ],
/// ];
/// ```
pub fn predict_multilabel(
&self,
inputs: &[&str],
Expand All @@ -540,26 +598,28 @@ impl ZeroShotClassificationModel {
);
output.view((num_inputs as i64, labels.len() as i64, -1i64))
});
let scores = output.slice(-1, 0, 3, 2).softmax(-1, Float).select(-1, -1);

let scores = output.softmax(1, Float).select(-1, -1);
let label_indices = scores.as_ref().argmax(-1, true).squeeze1(1);
label_indices.print();
let scores = scores
.gather(1, &label_indices.unsqueeze(-1), false)
.squeeze1(1);
let label_indices = label_indices.iter::<i64>().unwrap().collect::<Vec<i64>>();
let scores = scores.iter::<f64>().unwrap().collect::<Vec<f64>>();
let mut output_labels = vec![];
for sentence_idx in 0..num_inputs {
let mut sentence_labels = vec![];
let sentence_scores = scores
.select(0, sentence_idx as i64)
.iter::<f64>()
.unwrap()
.collect::<Vec<f64>>();

let mut output_labels: Vec<Label> = vec![];
for sentence_idx in 0..label_indices.len() {
let label_string = labels[label_indices[sentence_idx] as usize].to_string();
let label = Label {
text: label_string,
score: scores[sentence_idx],
id: label_indices[sentence_idx],
sentence: sentence_idx,
};
output_labels.push(label)
for (label_index, score) in sentence_scores.into_iter().enumerate() {
let label_string = labels[label_index].to_string();
let label = Label {
text: label_string,
score,
id: label_index as i64,
sentence: sentence_idx,
};
sentence_labels.push(label);
}
output_labels.push(sentence_labels);
}
output_labels
}
Expand Down

0 comments on commit 0ea9148

Please sign in to comment.