From 3df5ea5d37b8ba6e84abd98a9b8127ac1e4dfbc7 Mon Sep 17 00:00:00 2001 From: sftse <92270426+sftse@users.noreply.github.com> Date: Wed, 18 May 2022 19:44:50 +0000 Subject: [PATCH] Add token offset information to entities (#255) * Add token offset information to entities * replace unwrap by error propagation Co-authored-by: guillaume-be --- src/lib.rs | 5 +++++ src/pipelines/mod.rs | 5 +++++ src/pipelines/ner.rs | 16 ++++++++++++++++ 3 files changed, 26 insertions(+) diff --git a/src/lib.rs b/src/lib.rs index bb21e36bf..88c632197 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -466,6 +466,7 @@ //! Output: \ //! ```no_run //! # use rust_bert::pipelines::ner::Entity; +//! # use rust_tokenizers::Offset; //! # let output = //! [ //! [ @@ -473,11 +474,13 @@ //! word: String::from("Amy"), //! score: 0.9986, //! label: String::from("I-PER"), +//! offset: Offset { begin: 11, end: 14 }, //! }, //! Entity { //! word: String::from("Paris"), //! score: 0.9985, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 26, end: 31 }, //! }, //! ], //! [ @@ -485,11 +488,13 @@ //! word: String::from("Paris"), //! score: 0.9988, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 0, end: 5 }, //! }, //! Entity { //! word: String::from("France"), //! score: 0.9993, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 19, end: 25 }, //! }, //! ], //! ] diff --git a/src/pipelines/mod.rs b/src/pipelines/mod.rs index a95b71b5e..ee84bdc48 100644 --- a/src/pipelines/mod.rs +++ b/src/pipelines/mod.rs @@ -335,6 +335,7 @@ //! Output: \ //! ```no_run //! # use rust_bert::pipelines::ner::Entity; +//! # use rust_tokenizers::Offset; //! # let output = //! [ //! [ @@ -342,11 +343,13 @@ //! word: String::from("Amy"), //! score: 0.9986, //! label: String::from("I-PER"), +//! offset: Offset { begin: 11, end: 14 }, //! }, //! Entity { //! word: String::from("Paris"), //! score: 0.9985, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 26, end: 31 }, //! }, //! ], //! [ @@ -354,11 +357,13 @@ //! word: String::from("Paris"), //! score: 0.9988, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 0, end: 5 }, //! }, //! Entity { //! word: String::from("France"), //! score: 0.9993, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 19, end: 25 }, //! }, //! ], //! ] diff --git a/src/pipelines/ner.rs b/src/pipelines/ner.rs index 209fc9e62..690350f59 100644 --- a/src/pipelines/ner.rs +++ b/src/pipelines/ner.rs @@ -42,6 +42,7 @@ //! Output: \ //! ```no_run //! # use rust_bert::pipelines::ner::Entity; +//! # use rust_tokenizers::Offset; //! # let output = //! [ //! [ @@ -49,11 +50,13 @@ //! word: String::from("Amy"), //! score: 0.9986, //! label: String::from("I-PER"), +//! offset: Offset { begin: 11, end: 14 }, //! }, //! Entity { //! word: String::from("Paris"), //! score: 0.9985, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 26, end: 31 }, //! }, //! ], //! [ @@ -61,11 +64,13 @@ //! word: String::from("Paris"), //! score: 0.9988, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 0, end: 5 }, //! }, //! Entity { //! word: String::from("France"), //! score: 0.9993, //! label: String::from("I-LOC"), +//! offset: Offset { begin: 19, end: 25 }, //! }, //! ], //! ] @@ -125,6 +130,7 @@ use crate::common::error::RustBertError; use crate::pipelines::token_classification::{ Token, TokenClassificationConfig, TokenClassificationModel, }; +use rust_tokenizers::Offset; use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -136,6 +142,8 @@ pub struct Entity { pub score: f64, /// Entity label (e.g. ORG, LOC...) pub label: String, + /// Token offsets + pub offset: Offset, } //type alias for some backward compatibility @@ -207,6 +215,7 @@ impl NERModel { .into_iter() .filter(|token| token.label != "O") .map(|token| Entity { + offset: token.offset.unwrap(), word: token.text, score: token.score, label: token.label, @@ -247,17 +256,20 @@ impl NERModel { /// ```no_run /// # use rust_bert::pipelines::question_answering::Answer; /// # use rust_bert::pipelines::ner::Entity; + /// # use rust_tokenizers::Offset; /// # let output = /// [[ /// Entity { /// word: String::from("John Smith"), /// score: 0.9747, /// label: String::from("PER"), + /// offset: Offset { begin: 6, end: 16 }, /// }, /// Entity { /// word: String::from("Acme Corp"), /// score: 0.8847, /// label: String::from("I-LOC"), + /// offset: Offset { begin: 23, end: 32 }, /// }, /// ]] /// # ; @@ -346,6 +358,10 @@ impl<'a> EntityBuilder<'a> { .join(" "), score: entity_tokens.iter().map(|token| token.score).product(), label: label.to_string(), + offset: Offset { + begin: entity_tokens.first()?.offset?.begin, + end: entity_tokens.last()?.offset?.end, + }, }) } else { None