Skip to content

Commit

Permalink
flatten ApiResponseOrError to avoid nested Results and Errors
Browse files Browse the repository at this point in the history
implement From traits for std:io and reqwest Errors

Added missing_file test
  • Loading branch information
kcberg committed Jun 4, 2023
1 parent ec78d1b commit efd2d36
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 57 deletions.
1 change: 0 additions & 1 deletion examples/chat_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ async fn main() {
let chat_completion = ChatCompletion::builder("gpt-3.5-turbo", messages.clone())
.create()
.await
.unwrap()
.unwrap();
let returned_message = chat_completion.choices.first().unwrap().message.clone();

Expand Down
1 change: 0 additions & 1 deletion examples/completions_cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ async fn main() {
.max_tokens(1024)
.create()
.await
.unwrap()
.unwrap();

let response = &completion.choices.first().unwrap().text;
Expand Down
1 change: 0 additions & 1 deletion src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ mod tests {
.temperature(0.0)
.create()
.await
.unwrap()
.unwrap();

assert_eq!(
Expand Down
1 change: 0 additions & 1 deletion src/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ mod tests {
.temperature(0.0)
.create()
.await
.unwrap()
.unwrap();

assert_eq!(
Expand Down
5 changes: 2 additions & 3 deletions src/edits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ impl Edit {
edit.choices.push(choice.text.clone());
}

Ok(Ok(edit))
Ok(edit)
}
Err(_) => Ok(response),
Err(_) => response,
}
}

Expand Down Expand Up @@ -102,7 +102,6 @@ mod tests {
.temperature(0.0)
.create()
.await
.unwrap()
.unwrap();

assert_eq!(
Expand Down
10 changes: 2 additions & 8 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,8 @@ impl Embeddings {

impl Embedding {
pub async fn create(model: &str, input: &str, user: &str) -> ApiResponseOrError<Self> {
let response = Embeddings::create(model, vec![input], user).await?;

match response {
Ok(mut embeddings) => Ok(Ok(embeddings.data.swap_remove(0))),
Err(error) => Ok(Err(error)),
}
let mut embeddings = Embeddings::create(model, vec![input], user).await?;
Ok(embeddings.data.swap_remove(0))
}

pub fn distance(&self, other: &Self) -> f64 {
Expand Down Expand Up @@ -111,7 +107,6 @@ mod tests {
"",
)
.await
.unwrap()
.unwrap();

assert!(!embeddings.data.first().unwrap().vec.is_empty());
Expand All @@ -128,7 +123,6 @@ mod tests {
"",
)
.await
.unwrap()
.unwrap();

assert!(!embedding.vec.is_empty());
Expand Down
63 changes: 27 additions & 36 deletions src/files.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::path::{Path, PathBuf};
use std::path::Path;

use derive_builder::Builder;
use reqwest::multipart::{Form, Part};
use serde::{Deserialize, Serialize};

use crate::{openai_delete, openai_get, openai_post_multipart, OpenAiError};
use crate::{openai_delete, openai_get, openai_post_multipart};

use super::ApiResponseOrError;

Expand Down Expand Up @@ -44,23 +44,15 @@ impl File {
async fn create(request: &FileUploadRequest) -> ApiResponseOrError<Self> {
let purpose = request.purpose.clone();
let upload_file_path = Path::new(request.file_name.as_str());
let upload_file_path = upload_file_path.canonicalize().unwrap();
if !upload_file_path.exists() {
return Ok(Err(file_not_found_error(&upload_file_path)));
}
let upload_file_path = upload_file_path.canonicalize()?;
let simple_name = upload_file_path
.file_name()
.unwrap()
.to_str()
.unwrap()
.to_string()
.clone();
let async_file = match tokio::fs::File::open(upload_file_path).await {
Ok(f) => f,
Err(e) => {
return Ok(Err(io_error(e)));
}
};
let async_file = tokio::fs::File::open(upload_file_path).await?;
let file_part = Part::stream(async_file)
.file_name(simple_name)
.mime_str("application/jsonl")?;
Expand All @@ -77,24 +69,6 @@ impl File {
}
}

fn file_not_found_error(file_path: &PathBuf) -> OpenAiError {
OpenAiError {
message: format!("File {} not found", file_path.display()),
error_type: "internal".to_string(),
param: None,
code: None,
}
}

fn io_error(err: std::io::Error) -> OpenAiError {
OpenAiError {
message: format!("IO Error {}", err.to_string()),
error_type: "internal".to_string(),
param: None,
code: None,
}
}

impl FileUploadBuilder {
pub async fn create(self) -> ApiResponseOrError<File> {
File::create(&self.build().unwrap()).await
Expand Down Expand Up @@ -132,21 +106,38 @@ mod tests {
async fn upload_file() {
dotenv().ok();
set_key(env::var("OPENAI_KEY").unwrap());
let file_upload = test_upload_builder().create().await.unwrap().unwrap();
let file_upload = test_upload_builder().create().await.unwrap();
println!(
"upload: {}",
serde_json::to_string_pretty(&file_upload).unwrap()
);
assert_eq!(file_upload.id.as_bytes()[..5], *"file-".as_bytes())
}

#[tokio::test]
async fn missing_file() {
dotenv().ok();
set_key(env::var("OPENAI_KEY").unwrap());
let test_builder = File::builder()
.file_name("test_data/missing_file.jsonl")
.purpose("fine-tune");
let response = test_builder.create().await;
assert!(response.is_err());
let openapi_err = response.err().unwrap();
assert_eq!(openapi_err.error_type, "io");
assert_eq!(
openapi_err.message,
"No such file or directory (os error 2)"
)
}

#[tokio::test]
async fn list_files() {
dotenv().ok();
set_key(env::var("OPENAI_KEY").unwrap());
// ensure at least one file exists
test_upload_builder().create().await.unwrap().unwrap();
let openai_files = Files::list().await.unwrap().unwrap();
test_upload_builder().create().await.unwrap();
let openai_files = Files::list().await.unwrap();
let file_count = openai_files.data.len();
assert!(file_count > 0);
for openai_file in &openai_files.data {
Expand All @@ -164,15 +155,15 @@ mod tests {
dotenv().ok();
set_key(env::var("OPENAI_KEY").unwrap());
// ensure at least one file exists
test_upload_builder().create().await.unwrap().unwrap();
test_upload_builder().create().await.unwrap();
// wait to avoid recent upload still processing error
tokio::time::sleep(Duration::from_secs(5)).await;
let openai_files = Files::list().await.unwrap().unwrap();
let openai_files = Files::list().await.unwrap();
assert!(openai_files.data.len() > 0);
let mut files = openai_files.data;
files.sort_by(|a, b| a.created_at.cmp(&b.created_at));
for file in files {
let deleted_file = File::delete(file.id.as_str()).await.unwrap().unwrap();
let deleted_file = File::delete(file.id.as_str()).await.unwrap();
assert!(deleted_file.deleted);
println!("deleted: {} {}", deleted_file.id, deleted_file.deleted)
}
Expand Down
32 changes: 28 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Mutex;

use reqwest::multipart::Form;
use reqwest::{header::AUTHORIZATION, Client, Method, RequestBuilder};
use reqwest_eventsource::{CannotCloneRequestError, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::sync::Mutex;

pub mod chat;
pub mod completions;
Expand All @@ -25,6 +26,17 @@ pub struct OpenAiError {
pub code: Option<String>,
}

impl OpenAiError {
fn new(message: String, error_type: String) -> OpenAiError {
OpenAiError {
message,
error_type,
param: None,
code: None,
}
}
}

impl std::fmt::Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
Expand All @@ -47,7 +59,19 @@ pub struct Usage {
pub total_tokens: u32,
}

type ApiResponseOrError<T> = Result<Result<T, OpenAiError>, reqwest::Error>;
type ApiResponseOrError<T> = Result<T, OpenAiError>;

impl From<reqwest::Error> for OpenAiError {
fn from(value: reqwest::Error) -> Self {
OpenAiError::new(value.to_string(), "reqwest".to_string())
}
}

impl From<std::io::Error> for OpenAiError {
fn from(value: std::io::Error) -> Self {
OpenAiError::new(value.to_string(), "io".to_string())
}
}

async fn openai_request<F, T>(method: Method, route: &str, builder: F) -> ApiResponseOrError<T>
where
Expand All @@ -67,8 +91,8 @@ where
.await?;

match api_response {
ApiResponse::Ok(t) => Ok(Ok(t)),
ApiResponse::Err { error } => Ok(Err(error)),
ApiResponse::Ok(t) => Ok(t),
ApiResponse::Err { error } => Err(error),
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ mod tests {
dotenv().ok();
set_key(env::var("OPENAI_KEY").unwrap());

let model = Model::from("text-davinci-003").await.unwrap().unwrap();
let model = Model::from("text-davinci-003").await.unwrap();

assert_eq!(model.id, "text-davinci-003");
}
Expand Down
1 change: 0 additions & 1 deletion src/moderations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ mod tests {
.model("text-moderation-latest")
.create()
.await
.unwrap()
.unwrap();

assert_eq!(
Expand Down

0 comments on commit efd2d36

Please sign in to comment.