Skip to content

Commit

Permalink
Merge pull request #85 from B-urb/fix/api-paging
Browse files Browse the repository at this point in the history
feat(paperless): add pagination support for processing documents
  • Loading branch information
B-urb authored Aug 13, 2024
2 parents bbad21b + 7fd5d0f commit ab010f5
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 72 deletions.
95 changes: 57 additions & 38 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use serde::{Deserialize, Serialize};
use serde_json::{Value};
use std::env;
use crate::llm_api::generate_response;
use crate::paperless::{get_data_from_paperless, query_custom_fields, update_document_fields};
use crate::paperless::{get_data_from_paperless, get_next_data_from_paperless, query_custom_fields, update_document_fields};
use substring::Substring;

#[derive(Serialize, Deserialize, Debug, Clone)]
Expand Down Expand Up @@ -101,12 +101,11 @@ fn init_ollama_client(host: &str, port: u16, secure_endpoint: bool) -> Ollama {

// Refactor the main process into a function for better readability
async fn process_documents(client: &Client, ollama: &Ollama, model: &str, base_url: &str, filter: &str) -> Result<(), Box<dyn std::error::Error>> {

let language= env::var("LANGUAGE").unwrap_or_else(|_| "EN".to_string()).to_uppercase();
let language = env::var("LANGUAGE").unwrap_or_else(|_| "EN".to_string()).to_uppercase();
let base_prompt;

match language.as_ref(){
"DE"=> base_prompt = "Bitte ziehe die Metadaten aus dem bereitgestelltem Dokument \
match language.as_ref() {
"DE" => base_prompt = "Bitte ziehe die Metadaten aus dem bereitgestelltem Dokument \
und antworte im JSON format. \
Die Felder, welche ich brauche sind:\
title,topic,sender,recipient,urgency(mit werten entweder n/a oder low oder medium oder high),\
Expand All @@ -117,7 +116,7 @@ async fn process_documents(client: &Client, ollama: &Ollama, model: &str, base_u
(keine verschachtelten Objekte) vorliegen, um von einem anderen Programm direkt analysiert werden zu können. \
Also keine zusätzlichen Texte oder Erklärungen, der Antworttext sollte mit eckigen Klammern beginnen und enden, \
die das JSON-Objekt umfassen ".to_string(),
_=> base_prompt = "Please extract metadata\
_ => base_prompt = "Please extract metadata\
from the provided document and return it in JSON format.\
The fields I need are:\
title,topic,sender,recipient,urgency(with value either n/a or low or medium or high),\
Expand All @@ -130,52 +129,72 @@ async fn process_documents(client: &Client, ollama: &Ollama, model: &str, base_u
delimiting the json object ".to_string()
};

let prompt_base = env::var("BASE_PROMPT").unwrap_or_else(|_| base_prompt.to_string());
let prompt_base = env::var("BASE_PROMPT").unwrap_or_else(|_| base_prompt.to_string());

let mode_env = env::var("MODE").unwrap_or_else(|_| "0".to_string());
let mode_int = mode_env.parse::<i32>().unwrap_or(0);
let mode = Mode::from_int(mode_int);
let fields = query_custom_fields(client, base_url).await?;
match get_data_from_paperless(&client, &base_url, filter).await {
Ok(data) => {
for document in data {
slog_scope::trace!("Document Content: {}", document.content);
slog_scope::info!("Generate Response with LLM {}", "model");
slog_scope::debug!("with Prompt: {}", prompt_base);

match generate_response(ollama, &model.to_string(), &prompt_base.to_string(), &document).await {
Ok(res) => {
// Log the response from the generate_response call
slog_scope::debug!("LLM Response: {}", res.response);

match extract_json_object(&res.response) {
Ok(json_str) => {
// Log successful JSON extraction
slog_scope::debug!("Extracted JSON Object: {}", json_str);

match serde_json::from_str(&json_str) {
Ok(json) => update_document_fields(client, document.id, &fields, &json, base_url, mode).await?,
Err(e) => {
slog_scope::error!("Error parsing llm response json {}", e.to_string());
slog_scope::debug!("JSON String was: {}", &json_str);
}
}
}
Err(e) => slog_scope::error!("{}", e),
Ok(mut data) => {
loop {
process_documents_batch(&data.results, ollama, model, &prompt_base, client, &fields, base_url, mode).await?;

if let Some(url) = data.next {
match get_next_data_from_paperless(&client, url.as_str()).await {
Ok(next_data) => {
data = next_data;
}
Err(e) => {
slog_scope::error!("Error while interacting with paperless: {}", e);
break;
}
},
Err(e) => {
slog_scope::error!("Error generating llm response: {}", e);
continue;
}
} else {
break;
}
}
},
}
Err(e) => slog_scope::error!("Error while interacting with paperless: {}", e),
}
Ok(())
}

async fn process_documents_batch(documents: &Vec<Document>, ollama: &Ollama, model: &str, prompt_base: &String, client: &Client, fields: &Vec<Field>, base_url: &str, mode: Mode) -> Result<(), Box<dyn std::error::Error>> {
Ok(for document in documents {
slog_scope::trace!("Document Content: {}", document.content);
slog_scope::info!("Generate Response with LLM {}", "model");
slog_scope::debug!("with Prompt: {}", prompt_base);

match generate_response(ollama, &model.to_string(), &prompt_base.to_string(), &document).await {
Ok(res) => {
// Log the response from the generate_response call
slog_scope::debug!("LLM Response: {}", res.response);

match extract_json_object(&res.response) {
Ok(json_str) => {
// Log successful JSON extraction
slog_scope::debug!("Extracted JSON Object: {}", json_str);

match serde_json::from_str(&json_str) {
Ok(json) => update_document_fields(client, document.id, &fields, &json, base_url, mode).await?,
Err(e) => {
slog_scope::error!("Error parsing llm response json {}", e.to_string());
slog_scope::debug!("JSON String was: {}", &json_str);
}
}
}
Err(e) => slog_scope::error!("{}", e),
}
}
Err(e) => {
slog_scope::error!("Error generating llm response: {}", e);
continue;
}
}
})
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
logger::init(); // Initializes the global logger
Expand Down
106 changes: 72 additions & 34 deletions src/paperless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub async fn get_data_from_paperless(
client: &Client,
url: &str,
filter: &str,
) -> Result<Vec<Document>, Box<dyn StdError + Send + Sync>> {
) -> Result<Response<Document>, Box<dyn StdError + Send + Sync>> {
// Read token from environment
//Define filter string
let filter = filter;
Expand All @@ -34,22 +34,7 @@ pub async fn get_data_from_paperless(
//let error_part = value.pointer("/results/0").unwrap();
//println!("Error part: {}", error_part);
// Parse the JSON string into the Response struct
let data: Result<Response<Document>, _> = serde_json::from_str(json);
match data {
Ok(data) => {
slog_scope::info!("Successfully retrieved {} Documents", data.results.len());
Ok(data.results)
}
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
slog_scope::error!("Error while creating json of document response from paperless {}", e);
slog_scope::error!("Error at column {}: {}", column, &json[start..end]);
slog_scope::trace!("Error occured in json {}", &json);
Err(e.into()) // Remove the semicolon here
}
}
return parse_document_response(json);
}
Err(e) => {
slog_scope::error!("Error while fetching documents from paperless: {}",e);
Expand All @@ -58,6 +43,60 @@ pub async fn get_data_from_paperless(
}
}

pub async fn get_next_data_from_paperless(client: &Client,
url: &str,
) -> Result<Response<Document>, Box<dyn StdError + Send + Sync>> {
// Read token from environment
//Define filter string
slog_scope::info!("Retrieve next page {}", url);
let response = client.get(format!("{}", url)).send().await?;


let response_result = response.error_for_status();
match response_result {
Ok(data) => {
let body = data.text().await?;
slog_scope::trace!("Response from server while fetching documents: {}", body);

// Remove the "Document content: " prefix
let json = body.trim_start_matches("Document content: ");
//println!("{}",json);
// Parse the JSON string into a generic JSON structure
//let value: serde_json::Value = serde_json::from_str(json).unwrap();

// Print the part of the JSON structure that's causing the error
//let error_part = value.pointer("/results/0").unwrap();
//println!("Error part: {}", error_part);
// Parse the JSON string into the Response struct
return parse_document_response(json);
}
Err(e) => {
slog_scope::error!("Error while fetching documents from paperless: {}",e);
Err(e.into())
}
}
}


pub fn parse_document_response(json: &str) -> Result<Response<Document>, Box<dyn StdError + Send + Sync>> {
let data: Result<Response<Document>, _> = serde_json::from_str(json);
match data {
Ok(data) => {
slog_scope::info!("Successfully retrieved {} Documents", data.results.len());
Ok(data)
}
Err(e) => {
let column = e.column();
let start = (column as isize - 30).max(0) as usize;
let end = (column + 30).min(json.len());
slog_scope::error!("Error while creating json of document response from paperless {}", e);
slog_scope::error!("Error at column {}: {}", column, &json[start..end]);
slog_scope::trace!("Error occured in json {}", &json);
Err(e.into()) // Remove the semicolon here
}
}
}

pub async fn query_custom_fields(
client: &Client,
base_url: &str,
Expand Down Expand Up @@ -110,7 +149,7 @@ pub async fn update_document_fields(
) -> Result<(), Box<dyn std::error::Error>> {
let mut custom_fields = Vec::new();

// Use `if let` to conditionally execute code if the 'tagged' field is found.
// Use `if let` to conditionally execute code if the 'tagged' field is found.
let field = match fields.iter().find(|&f| f.name == "tagged") {
Some(field) => field,
None => {
Expand All @@ -124,7 +163,7 @@ pub async fn update_document_fields(
value: Some(serde_json::json!(true)),
};

// Add this tagged_field to your custom_fields collection or use it as needed.
// Add this tagged_field to your custom_fields collection or use it as needed.
custom_fields.push(tagged_field);

for (key, value) in metadata {
Expand All @@ -135,29 +174,28 @@ pub async fn update_document_fields(
if let Some(field) = fields.iter().find(|&f| f.name == *key) {
let custom_field = convert_field_to_custom_field(value, field);
custom_fields.push(custom_field);
}
else {
} else {
if matches!(mode, Mode::Create) {
slog_scope::info!("Creating field: {}", key);
let create_field = CreateField {
name: key.clone(),
data_type: "Text".to_string(),
default_value: None,
};
match create_custom_field(client, &create_field, base_url).await
{
Ok(new_field) => {
let custom_field = convert_field_to_custom_field(value, &new_field);
custom_fields.push(custom_field)
},
Err(e) => {
slog_scope::error!("Error: {} creating custom field: {}, skipping...",e, key)
}
}
match create_custom_field(client, &create_field, base_url).await
{
Ok(new_field) => {
let custom_field = convert_field_to_custom_field(value, &new_field);
custom_fields.push(custom_field)
}
Err(e) => {
slog_scope::error!("Error: {} creating custom field: {}, skipping...",e, key)
}
}
}
}
}
// Check if tagged_field_id has a value and then proceed.
// Check if tagged_field_id has a value and then proceed.

let mut payload = serde_json::Map::new();

Expand All @@ -172,7 +210,7 @@ pub async fn update_document_fields(
let url = format!("{}/api/documents/{}/", base_url, document_id);
slog_scope::info!("Updating document with ID: {}", document_id);
slog_scope::debug!("Request Payload: {}", map_to_string(&payload));

for (key, value) in &payload {
slog_scope::debug!("{}: {}", key, value);
}
Expand Down Expand Up @@ -227,7 +265,7 @@ pub async fn create_custom_field(
match field {
Ok(field) => {
Ok(field.results[0].clone()) // TODO: improve
},
}
Err(e) => {
slog_scope::debug!("Creating field response: {}", body);
slog_scope::error!("Error parsing response from new field: {}", e);
Expand Down

0 comments on commit ab010f5

Please sign in to comment.