Skip to content

Commit

Permalink
Merge branch 'main' into 18-handle-errors-given-from-api-responses
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinegb authored Feb 17, 2023
2 parents c814827 + 3a50807 commit 02c9816
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 48 deletions.
12 changes: 6 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ features = ["derive"]
[dependencies.openai_proc_macros]
path = "openai_proc_macros"

[dependencies.openai_utils]
path = "openai_utils"
[dependencies.openai_bootstrap]
path = "openai_bootstrap"

[dev-dependencies.tokio]
version = "1.25"
Expand Down
2 changes: 1 addition & 1 deletion openai_utils/Cargo.toml → openai_bootstrap/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "openai_utils"
name = "openai_bootstrap"
version.workspace = true
authors.workspace = true
edition.workspace = true
Expand Down
17 changes: 17 additions & 0 deletions openai_bootstrap/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
pub const BASE_URL: &str = "https://api.openai.com/v1";

#[macro_export]
macro_rules! authorization {
($request:expr) => {{
use dotenvy::dotenv;
use reqwest::{header::AUTHORIZATION, RequestBuilder};
use std::env;

dotenv().ok();

let token =
env::var("OPENAI_KEY").expect("environment variable `OPENAI_KEY` should be defined");

$request.header(AUTHORIZATION, format!("Bearer {token}"))
}};
}
4 changes: 2 additions & 2 deletions openai_proc_macros/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ convert_case = "0.6"
version = "0.11"
features = [ "blocking", "json" ]

[dependencies.openai_utils]
path = "../openai_utils"
[dependencies.openai_bootstrap]
path = "../openai_bootstrap"

[dependencies.serde]
version = "1.0"
Expand Down
16 changes: 8 additions & 8 deletions openai_proc_macros/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use convert_case::{Case, Casing};
use openai_bootstrap::{authorization, BASE_URL};
use proc_macro::TokenStream;
use quote::{ quote, format_ident };
use quote::{format_ident, quote};
use reqwest::blocking::Client;
use openai_utils::{ BASE_URL, authorization };
use serde::Deserialize;
use convert_case::{ Case, Casing };

#[derive(Deserialize)]
struct Models {
Expand All @@ -20,18 +20,18 @@ pub fn generate_model_id_enum(_input: TokenStream) -> TokenStream {
let client = Client::new();

let response: Models = authorization!(client.get(format!("{BASE_URL}/models")))
.send().unwrap().json().unwrap();
.send()
.unwrap()
.json()
.unwrap();

let mut model_id_idents = Vec::new();
let mut model_ids = Vec::new();
let mut model_indexes = Vec::new();
let mut index: u32 = 0;

for model in response.data {
if model.id.contains(':')
|| model.id.contains('.')
|| model.id.contains("deprecated")
{
if model.id.contains(':') || model.id.contains('.') || model.id.contains("deprecated") {
continue;
}

Expand Down
19 changes: 0 additions & 19 deletions openai_utils/src/lib.rs

This file was deleted.

2 changes: 1 addition & 1 deletion src/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
//! and can also return the probabilities of alternative tokens at each position.
use super::{handle_api, models::ModelID, ModifiedApiResponse, Usage};
use openai_utils::{authorization, BASE_URL};
use openai_bootstrap::{ BASE_URL, authorization };
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
Expand Down
4 changes: 2 additions & 2 deletions src/edits.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Given a prompt and an instruction, the model will return an edited version of the prompt.
use super::{handle_api, models::ModelID, ModifiedApiResponse, OpenAiError, Usage};
use openai_utils::{authorization, BASE_URL};
use openai_bootstrap::{ BASE_URL, authorization };
use reqwest::Client;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -91,6 +91,6 @@ mod tests {
assert_eq!(
edit.choices.first().unwrap(),
"What day of the week is it?\n"
)
);
}
}
79 changes: 75 additions & 4 deletions src/embeddings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
//! Related guide: [Embeddings](https://beta.openai.com/docs/guides/embeddings)
use super::{handle_api, models::ModelID, ModifiedApiResponse, Usage};
use openai_utils::{authorization, BASE_URL};
use openai_bootstrap::{authorization, BASE_URL};
use reqwest::Client;
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -43,12 +43,27 @@ impl Embeddings {

handle_api(request).await
}

pub fn distances(&self) -> Vec<f64> {
let mut distances = Vec::new();
let mut last_embedding: Option<&Embedding> = None;

for embedding in &self.data {
if let Some(other) = last_embedding {
distances.push(embedding.distance(other));
}

last_embedding = Some(embedding);
}

distances
}
}

#[derive(Deserialize)]
pub struct Embedding {
#[serde(rename = "embedding")]
pub vec: Vec<f32>,
pub vec: Vec<f64>,
}

impl Embedding {
Expand All @@ -60,6 +75,18 @@ impl Embedding {
Err(error) => Ok(Err(error)),
}
}

pub fn distance(&self, other: &Self) -> f64 {
let dot_product: f64 = self
.vec
.iter()
.zip(other.vec.iter())
.map(|(x, y)| x * y)
.sum();
let product_of_lengths = (self.vec.len() * other.vec.len()) as f64;

dot_product / product_of_lengths
}
}

#[cfg(test)]
Expand All @@ -80,7 +107,7 @@ mod tests {
.unwrap()
.unwrap();

assert!(!embeddings.data.first().unwrap().vec.is_empty())
assert!(!embeddings.data.first().unwrap().vec.is_empty());
}

#[tokio::test]
Expand All @@ -96,6 +123,50 @@ mod tests {
.unwrap()
.unwrap();

assert!(!embedding.vec.is_empty())
assert!(!embedding.vec.is_empty());
}

#[test]
fn right_angle() {
let embeddings = Embeddings {
data: vec![
Embedding {
vec: vec![1.0, 0.0, 0.0],
},
Embedding {
vec: vec![0.0, 1.0, 0.0],
},
],
model: ModelID::TextEmbeddingAda002,
usage: Usage {
prompt_tokens: 0,
completion_tokens: Some(0),
total_tokens: 0,
},
};

assert_eq!(embeddings.distances()[0], 0.0);
}

#[test]
fn non_right_angle() {
let embeddings = Embeddings {
data: vec![
Embedding {
vec: vec![1.0, 1.0, 0.0],
},
Embedding {
vec: vec![0.0, 1.0, 0.0],
},
],
model: ModelID::TextEmbeddingAda002,
usage: Usage {
prompt_tokens: 0,
completion_tokens: Some(0),
total_tokens: 0,
},
};

assert_ne!(embeddings.distances()[0], 0.0);
}
}
9 changes: 6 additions & 3 deletions src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
use super::{handle_api, ModifiedApiResponse};
use openai_proc_macros::generate_model_id_enum;
use openai_utils::{authorization, BASE_URL};
use openai_bootstrap::{ BASE_URL, authorization };
use reqwest::Client;
use serde::Deserialize;

Expand Down Expand Up @@ -99,7 +99,10 @@ mod tests {

let model = Model::new(ModelID::TextDavinci003).await.unwrap().unwrap();

assert_eq!(model.id, ModelID::TextDavinci003,)
assert_eq!(
model.id,
ModelID::TextDavinci003,
);
}

#[tokio::test]
Expand All @@ -116,6 +119,6 @@ mod tests {
assert_eq!(
model.id,
ModelID::Custom("davinci:ft-personal-2022-12-12-04-49-51".to_string()),
)
);
}
}

0 comments on commit 02c9816

Please sign in to comment.