Skip to content

Commit

Permalink
Added possibility to reload a conversation from a snapshot
Browse files Browse the repository at this point in the history
  • Loading branch information
guillaume-be committed Oct 17, 2020
1 parent f7da9dc commit 7b84ff5
Showing 1 changed file with 101 additions and 13 deletions.
114 changes: 101 additions & 13 deletions src/pipelines/conversation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ pub struct Conversation {
/// New user input that needs to be processed
pub new_user_input: Option<String>,
/// History of the tokens passed as an input and generated so far used as context for next turn generation
pub history: Vec<i64>,
pub history: Vec<Vec<i64>>,
}

impl Conversation {
Expand Down Expand Up @@ -336,6 +336,71 @@ impl Conversation {
None
}
}

fn append(&mut self, text: &str, ids: &[i64]) {
match &self.new_user_input {
Some(_) => {
if self.past_user_inputs.len() >= self.generated_responses.len() {
self.generated_responses.push(text.to_string());
} else {
self.mark_processed();
let _ = self.add_user_input(text);
}
}
None => {
let _ = self.add_user_input(text);
}
}
self.history.push(ids.to_vec());
}

/// Initializes a conversation form a prior state. It is assumed that a conversation always
/// start from a user interaction.
///
/// # Arguments
/// - texts: sequence of strings, alternating between past user inputs and past generated responses.
/// - ids: sequence of sequence of ids, alternating between past user inputs and past generated responses.
/// These can be generated via a `ConversationModel`'s `encode_prompts`.
///
/// # Example:
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// let model = ConversationModel::new(Default::default())?;
///
/// let mut conversation_manager = ConversationManager::new();
/// let history = [
/// "Going to the movies tonight - any suggestions?",
/// "The Big Lebowski",
/// "Is it an action movie?",
/// ];
/// let encoded_history = model.encode_prompts(&history);
///
/// let conversation_1_id = conversation_manager.create_empty();
/// let _ = conversation_manager
/// .get(&conversation_1_id)
/// .unwrap()
/// .load_from_history(history, encoded_history);
/// # Ok(())
/// # }
/// ```
pub fn load_from_history<'a, ST, SI, STR, SIN>(&mut self, texts: ST, ids: SI)
where
ST: AsRef<[STR]>,
SI: AsRef<[SIN]>,
STR: AsRef<str>,
SIN: AsRef<[i64]>,
{
for (round_text, round_ids) in texts.as_ref().iter().zip(ids.as_ref().iter()) {
self.append(round_text.as_ref(), round_ids.as_ref());
}

if texts.as_ref().len() / 2 == 1 {
self.history.pop();
}
}
}

/// Data structure allowing the management of conversations and main input to the dialogue model.
Expand Down Expand Up @@ -668,32 +733,36 @@ impl ConversationModel {

let history = active_conversations
.iter()
.map(|c| &c.history)
.map(|c| c.history.iter().flatten().map(|&v| v).collect())
.collect_vec();

let prompt_ids = self.encode_prompts(texts.as_slice());
let input_tensor = self.concat_input_history(prompt_ids, history);
let input_tensor = self.concat_input_history(&prompt_ids, history);
let input_length = *input_tensor.size().last().unwrap() as usize;
let mut generated = self.model.generate_from_ids_and_past(input_tensor, None);
let removed_padding_quantities = self.clean_padding_indices(&mut generated);

let mut output = HashMap::with_capacity(active_uuid.len());

for (((conversation, generated_sequence), uuid), removed_padding) in
active_conversations
.into_iter()
.zip(generated.into_iter())
.zip(active_uuid.into_iter())
.zip(removed_padding_quantities.into_iter())
for (
((conversation, (generated_sequence, conversation_promp_ids)), uuid),
removed_padding,
) in active_conversations
.into_iter()
.zip(generated.into_iter().zip(prompt_ids.into_iter()))
.zip(active_uuid.into_iter())
.zip(removed_padding_quantities.into_iter())
{
let generated_response = &generated_sequence[input_length - removed_padding.0..];
conversation
.generated_responses
.push(self.model.get_tokenizer().decode(
generated_sequence[input_length - removed_padding.0..].to_vec(),
generated_response.to_vec(),
true,
true,
));
conversation.history = generated_sequence;
conversation.history.push(conversation_promp_ids);
conversation.history.push(generated_response.to_vec());
conversation.mark_processed();
output.insert(uuid, conversation.get_last_response().unwrap());
}
Expand Down Expand Up @@ -726,7 +795,7 @@ impl ConversationModel {
removed_tokens
}

fn concat_input_history(&self, inputs: Vec<Vec<i64>>, history: Vec<&Vec<i64>>) -> Tensor {
fn concat_input_history(&self, inputs: &Vec<Vec<i64>>, history: Vec<Vec<i64>>) -> Tensor {
// Concatenates the history token indices with new user input
let pad_token = self.model.get_pad_id().unwrap_or(self.eos_token_id);

Expand Down Expand Up @@ -791,7 +860,26 @@ impl ConversationModel {
*eos_indices.first().unwrap_or(&0usize)
}

fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
/// Encodes prompts into Vectors of indices to be processed by the model. This method may be used to
/// initialize the history of a conversation with a prior state.
///
/// # Example:
///
/// ```no_run
/// # fn main() -> anyhow::Result<()> {
/// use rust_bert::pipelines::conversation::{ConversationManager, ConversationModel};
/// use rust_bert::pipelines::generation::LanguageGenerator;
/// let model = ConversationModel::new(Default::default())?;
/// let history = [
/// "Going to the movies tonight - any suggestions?",
/// "The Big Lebowski",
/// "Is it an action movie?",
/// ];
/// let encoded_history = model.encode_prompts(&history);
/// # Ok(())
/// # }
/// ```
pub fn encode_prompts(&self, texts: &[&str]) -> Vec<Vec<i64>> {
// Encode the user prompt into token ids
let tokens = self.model.get_tokenizer().tokenize_list(texts.to_vec());

Expand Down

0 comments on commit 7b84ff5

Please sign in to comment.