Skip to content

Commit

Permalink
Reroute all API interactions to handle_api()
Browse files Browse the repository at this point in the history
(Which handles api errors aswell)
  • Loading branch information
valentinegb committed Feb 17, 2023
1 parent dacfde2 commit 1c13b77
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 56 deletions.
27 changes: 16 additions & 11 deletions src/completions.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
//! Given a prompt, the model will return one or more predicted completions,
//! and can also return the probabilities of alternative tokens at each position.
use serde::{ Deserialize, Serialize };
use super::{ models::ModelID, Usage };
use std::collections::HashMap;
use super::{handle_api, models::ModelID, ModifiedApiResponse, Usage};
use openai_utils::{authorization, BASE_URL};
use reqwest::Client;
use openai_utils::{ BASE_URL, authorization };
use serde::{Deserialize, Serialize};
use std::collections::HashMap;

#[derive(Deserialize)]
pub struct Completion {
Expand All @@ -18,18 +18,17 @@ pub struct Completion {

impl Completion {
/// Creates a completion for the provided prompt and parameters
pub async fn new(body: &CreateCompletionRequestBody<'_>) -> Result<Self, reqwest::Error> {
pub async fn new(body: &CreateCompletionRequestBody<'_>) -> ModifiedApiResponse<Self> {
if let Some(enabled) = body.stream {
if enabled {
todo!("the `stream` field is not yet implemented");
}
}

let client = Client::builder().build()?;
let request = authorization!(client.post(format!("{BASE_URL}/completions"))).json(body);

authorization!(client.post(format!("{BASE_URL}/completions")))
.json(body)
.send().await?.json().await
handle_api(request).await
}
}

Expand Down Expand Up @@ -154,15 +153,21 @@ mod tests {
#[tokio::test]
async fn completion() {
dotenv().ok();

let completion = Completion::new(&CreateCompletionRequestBody {
model: ModelID::TextDavinci003,
prompt: "Say this is a test",
max_tokens: Some(7),
temperature: Some(0.0),
..Default::default()
}).await.unwrap();
})
.await
.unwrap()
.unwrap();

assert_eq!(completion.choices.first().unwrap().text, "\n\nThis is indeed a test")
assert_eq!(
completion.choices.first().unwrap().text,
"\n\nThis is indeed a test"
);
}
}
35 changes: 22 additions & 13 deletions src/edits.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
//! Given a prompt and an instruction, the model will return an edited version of the prompt.
use serde::{ Deserialize, Serialize };
use super::{ Usage, models::ModelID };
use super::{handle_api, models::ModelID, ModifiedApiResponse, OpenAiError, Usage};
use openai_utils::{authorization, BASE_URL};
use reqwest::Client;
use openai_utils::{ BASE_URL, authorization };
use serde::{Deserialize, Serialize};

#[derive(Deserialize)]
pub struct Edit {
Expand All @@ -16,18 +16,21 @@ pub struct Edit {
}

impl Edit {
pub async fn new(body: &CreateEditRequestBody<'_>) -> Result<Self, reqwest::Error> {
pub async fn new(body: &CreateEditRequestBody<'_>) -> ModifiedApiResponse<Self> {
let client = Client::builder().build()?;
let request = authorization!(client.post(format!("{BASE_URL}/edits"))).json(body);
let response: Result<Self, OpenAiError> = handle_api(request).await?;

let mut edit: Self = authorization!(client.post(format!("{BASE_URL}/edits")))
.json(body)
.send().await?.json().await?;
match response {
Ok(mut edit) => {
for choice in &edit.choices_bad {
edit.choices.push(choice.text.clone());
}

for choice in &edit.choices_bad {
edit.choices.push(choice.text.clone());
Ok(Ok(edit))
}
Err(_) => Ok(response),
}

Ok(edit)
}
}

Expand Down Expand Up @@ -80,8 +83,14 @@ mod tests {
instruction: "Fix the spelling mistakes",
temperature: Some(0.0),
..Default::default()
}).await.unwrap();
})
.await
.unwrap()
.unwrap();

assert_eq!(edit.choices.first().unwrap(), "What day of the week is it?\n")
assert_eq!(
edit.choices.first().unwrap(),
"What day of the week is it?\n"
)
}
}
39 changes: 22 additions & 17 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
//!
//! Related guide: [Embeddings](https://beta.openai.com/docs/guides/embeddings)
use serde::{ Deserialize, Serialize };
use super::{handle_api, models::ModelID, ModifiedApiResponse, Usage};
use openai_utils::{authorization, BASE_URL};
use reqwest::Client;
use super::{ models::ModelID, Usage };
use openai_utils::{ BASE_URL, authorization };
use serde::{Deserialize, Serialize};

#[derive(Serialize)]
struct CreateEmbeddingsRequestBody<'a> {
Expand Down Expand Up @@ -36,12 +36,12 @@ impl Embeddings {
/// Each input must not exceed 8192 tokens in length.
/// * `user` - A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse.
/// [Learn more](https://beta.openai.com/docs/guides/safety-best-practices/end-user-ids).
pub async fn new(model: ModelID, input: Vec<&str>, user: &str) -> Result<Self, reqwest::Error> {
pub async fn new(model: ModelID, input: Vec<&str>, user: &str) -> ModifiedApiResponse<Self> {
let client = Client::builder().build()?;
let request = authorization!(client.post(format!("{BASE_URL}/embeddings")))
.json(&CreateEmbeddingsRequestBody { model, input, user });

authorization!(client.post(format!("{BASE_URL}/embeddings")))
.json(&CreateEmbeddingsRequestBody { model, input, user })
.send().await?.json().await
handle_api(request).await
}
}

Expand All @@ -52,14 +52,13 @@ pub struct Embedding {
}

impl Embedding {
pub async fn new(model: ModelID, input: &str, user: &str) -> Result<Self, reqwest::Error> {
let embeddings = Embeddings::new(model, vec![input], user);
pub async fn new(model: ModelID, input: &str, user: &str) -> ModifiedApiResponse<Self> {
let response = Embeddings::new(model, vec![input], user).await?;

Ok(
embeddings
.await.expect("should create embeddings")
.data.swap_remove(0)
)
match response {
Ok(mut embeddings) => Ok(Ok(embeddings.data.swap_remove(0))),
Err(error) => Ok(Err(error)),
}
}
}

Expand All @@ -76,7 +75,10 @@ mod tests {
ModelID::TextEmbeddingAda002,
vec!["The food was delicious and the waiter..."],
"",
).await.unwrap();
)
.await
.unwrap()
.unwrap();

assert!(!embeddings.data.first().unwrap().vec.is_empty())
}
Expand All @@ -89,8 +91,11 @@ mod tests {
ModelID::TextEmbeddingAda002,
"The food was delicious and the waiter...",
"",
).await.unwrap();
)
.await
.unwrap()
.unwrap();

assert!(!embedding.vec.is_empty())
}
}
}
46 changes: 45 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use serde::Deserialize;
use reqwest::RequestBuilder;
use serde::{de::DeserializeOwned, Deserialize};

pub mod completions;
pub mod edits;
Expand All @@ -11,3 +12,46 @@ pub struct Usage {
pub completion_tokens: Option<u16>,
pub total_tokens: u32,
}

#[derive(Deserialize)]
struct ErrorResponse {
error: OpenAiError,
}

#[derive(Deserialize, Debug)]
pub struct OpenAiError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub param: Option<String>,
pub code: Option<String>,
}

impl std::fmt::Display for OpenAiError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}

impl std::error::Error for OpenAiError {}

#[derive(Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ErrorResponse),
}

type ModifiedApiResponse<T> = Result<Result<T, OpenAiError>, reqwest::Error>;

async fn handle_api<T>(request: RequestBuilder) -> ModifiedApiResponse<T>
where
T: DeserializeOwned,
{
let api_response: ApiResponse<T> = request.send().await?.json().await?;

match api_response {
ApiResponse::Ok(t) => Ok(Ok(t)),
ApiResponse::Err(error) => Ok(Err(error.error)),
}
}
29 changes: 15 additions & 14 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
//! You can refer to the [Models](https://beta.openai.com/docs/models)
//! documentation to understand what models are available and the differences between them.
use serde::Deserialize;
use reqwest::Client;
use super::{handle_api, ModifiedApiResponse};
use openai_proc_macros::generate_model_id_enum;
use openai_utils::{ BASE_URL, authorization };
use openai_utils::{authorization, BASE_URL};
use reqwest::Client;
use serde::Deserialize;

#[derive(Deserialize)]
pub struct Model {
Expand All @@ -20,11 +21,11 @@ pub struct Model {
impl Model {
//! Retrieves a model instance,
//! providing basic information about the model such as the owner and permissioning.
pub async fn new(id: ModelID) -> Result<Model, reqwest::Error> {
pub async fn new(id: ModelID) -> ModifiedApiResponse<Self> {
let client = Client::builder().build()?;
let request = authorization!(client.get(format!("{BASE_URL}/models/{id}")));

authorization!(client.get(format!("{BASE_URL}/models/{id}")))
.send().await?.json().await
handle_api(request).await
}
}

Expand Down Expand Up @@ -96,21 +97,21 @@ mod tests {
async fn model() {
dotenv().ok();

let model = Model::new(ModelID::TextDavinci003).await.unwrap();
let model = Model::new(ModelID::TextDavinci003).await.unwrap().unwrap();

assert_eq!(
model.id,
ModelID::TextDavinci003,
)
assert_eq!(model.id, ModelID::TextDavinci003,)
}

#[tokio::test]
async fn custom_model() {
dotenv().ok();

let model = Model::new(
ModelID::Custom("davinci:ft-personal-2022-12-12-04-49-51".to_string())
).await.unwrap();
let model = Model::new(ModelID::Custom(
"davinci:ft-personal-2022-12-12-04-49-51".to_string(),
))
.await
.unwrap()
.unwrap();

assert_eq!(
model.id,
Expand Down

0 comments on commit 1c13b77

Please sign in to comment.