Skip to content

Commit 8f97ff8

Browse files
Remove async
1 parent 5acf1b9 commit 8f97ff8

File tree

11 files changed

+121
-74
lines changed

11 files changed

+121
-74
lines changed

Cargo.toml

+36-5
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,57 @@ version = "0.1.14"
55
edition = "2021"
66

77
[dependencies]
8+
# Data Serialization
9+
serde = { version = "1.0.196", features = ["derive"] }
810
serde_json = "1.0.112"
11+
12+
# HTTP Client
913
reqwest = { version = "0.12.2", features = ["json", "blocking"] }
10-
serde = {version = "1.0.196", features = ["derive"]}
11-
pdf-extract = "0.7.4"
14+
15+
# Filesystem
1216
walkdir = "2.4.0"
17+
18+
# Regular Expressions
1319
regex = "1.10.3"
20+
21+
# Parallelism
1422
rayon = "1.8.1"
23+
24+
# Image Processing
1525
image = "0.25.1"
16-
hf-hub = "0.3.2"
26+
27+
# Natural Language Processing
1728
tokenizers = "0.15.2"
29+
30+
# PDF Processing
31+
pdf-extract = "0.7.4"
32+
33+
# Hugging Face Libraries
34+
hf-hub = "0.3.2"
1835
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
1936
candle-transformers = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
2037
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
38+
39+
# Error Handling
2140
anyhow = "1.0.81"
22-
tokio = {version = "1.37.0", features=["rt-multi-thread", "macros"]}
41+
42+
# Asynchronous Programming
43+
tokio = { version = "1.37.0", features = ["macros", "rt-multi-thread"] }
44+
45+
# Python Interoperability
2346
pyo3 = { version = "0.21" }
24-
intel-mkl-src = {version = "0.8.1", optional = true }
47+
48+
# Optional Dependency
49+
intel-mkl-src = { version = "0.8.1", optional = true }
50+
51+
# Markdown Processing
2552
markdown-parser = "0.1.2"
2653
markdown_to_text = "1.0.0"
54+
55+
# Web Scraping
2756
scraper = "0.19.0"
57+
58+
# Text Processing
2859
text-cleaner = "0.1.0"
2960

3061
[dev-dependencies]

examples/bert.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::{path::PathBuf, time::Instant};
33

44
fn main() {
55
let now = Instant::now();
6-
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["md".to_string()])).unwrap();
6+
let out = embed_directory(PathBuf::from("test_files"), "Bert", Some(vec!["pdf".to_string()])).unwrap();
77
println!("{:?}", out);
88
let elapsed_time = now.elapsed();
99
println!("Elapsed Time: {}", elapsed_time.as_secs_f32());

examples/web_embed.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ async fn main() {
88
let website_processor = website_processor::WebsiteProcesor;
99
let webpage = website_processor.process_website(url).await.unwrap();
1010
let embeder = embed_anything::embedding_model::bert::BertEmbeder::default();
11-
let embed_data = webpage.embed_webpage(&embeder).await.unwrap();
11+
let embed_data = webpage.embed_webpage(&embeder).unwrap();
1212
let embeddings: Vec<Vec<f32>> = embed_data.iter().map(|data| data.embedding.clone()).collect();
1313

1414
let embeddings = Tensor::from_vec(
@@ -18,7 +18,7 @@ async fn main() {
1818
).unwrap();
1919

2020
let query = vec!["how to use lstm for nlp".to_string()];
21-
let query_embedding: Vec<f32> = embeder.embed(&query, None).await.unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();
21+
let query_embedding: Vec<f32> = embeder.embed(&query, None).unwrap().iter().map(|data| data.embedding.clone()).flatten().collect();
2222

2323
let query_embedding_tensor = Tensor::from_vec(
2424
query_embedding.clone(),

src/embedding_model/bert.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ impl BertEmbeder {
6868
Ok(Tensor::stack(&token_ids, 0)?)
6969
}
7070

71-
pub async fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
71+
pub fn embed(&self, text_batch: &[String],metadata:Option<HashMap<String,String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
7272
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
7373
let token_type_ids = token_ids.zeros_like().unwrap();
7474
let embeddings = self.model.forward(&token_ids, &token_type_ids).unwrap();
@@ -89,7 +89,7 @@ impl Embed for BertEmbeder {
8989
fn embed(
9090
&self,
9191
text_batch: &[String],metadata: Option<HashMap<String,String>>
92-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
92+
) -> Result<Vec<EmbedData>, anyhow::Error> {
9393
self.embed(text_batch, metadata)
9494
}
9595
}
@@ -99,7 +99,7 @@ impl TextEmbed for BertEmbeder {
9999
&self,
100100
text_batch: &[String],
101101
metadata: Option<HashMap<String,String>>
102-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
102+
) -> Result<Vec<EmbedData>, anyhow::Error> {
103103
self.embed(text_batch, metadata)
104104
}
105105
}

src/embedding_model/clip.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ impl EmbedImage for ClipEmbeder {
191191
}
192192

193193
impl Embed for ClipEmbeder {
194-
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
194+
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
195195
let (input_ids, _vec_seq) = ClipEmbeder::tokenize_sequences(
196196
Some(text_batch.to_vec()),
197197
&self.tokenizer,

src/embedding_model/embed.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ pub enum Embeder {
6060
}
6161

6262
impl Embeder {
63-
pub async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
63+
pub fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
6464
match self {
65-
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
66-
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
67-
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata).await,
68-
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata).await,
65+
Embeder::OpenAI(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
66+
Embeder::Jina(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
67+
Embeder::Clip(embeder) => Embed::embed(embeder, text_batch, metadata),
68+
Embeder::Bert(embeder) => TextEmbed::embed(embeder, text_batch, metadata),
6969
}
7070
}
7171
}
@@ -76,12 +76,12 @@ pub trait Embed {
7676
&self,
7777
text_batch: &[String],
7878
metadata: Option<HashMap<String, String>>,
79-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
79+
) ->Result<Vec<EmbedData>, anyhow::Error>;
8080

8181
}
8282

8383
pub trait TextEmbed {
84-
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>>;
84+
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error>;
8585
}
8686

8787
pub trait EmbedImage {

src/embedding_model/jina.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ impl JinaEmbeder {
7171
Ok(Tensor::stack(&token_ids, 0)?)
7272
}
7373

74-
async fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
74+
fn embed(&self, text_batch: &[String], metadata:Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, anyhow::Error> {
7575
let token_ids = self.tokenize_batch(text_batch, &self.model.device).unwrap();
7676
let embeddings = self.model.forward(&token_ids).unwrap();
7777

@@ -97,7 +97,7 @@ impl Embed for JinaEmbeder {
9797
&self,
9898
text_batch: &[String],
9999
metadata: Option<HashMap<String, String>>,
100-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
100+
) -> Result<Vec<EmbedData>, anyhow::Error> {
101101
self.embed(text_batch, metadata)
102102
}
103103
}
@@ -107,7 +107,7 @@ impl TextEmbed for JinaEmbeder {
107107
&self,
108108
text_batch: &[String],
109109
metadata: Option<HashMap<String, String>>,
110-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
110+
) -> Result<Vec<EmbedData>, anyhow::Error> {
111111
self.embed(text_batch, metadata)
112112
}
113113
}

src/embedding_model/openai.rs

+37-21
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ impl Default for OpenAIEmbeder {
2222
}
2323

2424
impl Embed for OpenAIEmbeder {
25-
fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
25+
fn embed(
26+
&self,
27+
text_batch: &[String],
28+
metadata: Option<HashMap<String, String>>,
29+
) -> Result<Vec<EmbedData>, anyhow::Error> {
2630
self.embed(text_batch, metadata)
2731
}
2832
}
@@ -32,7 +36,7 @@ impl TextEmbed for OpenAIEmbeder {
3236
&self,
3337
text_batch: &[String],
3438
metadata: Option<HashMap<String, String>>,
35-
) -> impl std::future::Future<Output = Result<Vec<EmbedData>, reqwest::Error>> {
39+
) -> Result<Vec<EmbedData>, anyhow::Error> {
3640
self.embed(text_batch, metadata)
3741
}
3842
}
@@ -47,28 +51,41 @@ impl OpenAIEmbeder {
4751
}
4852
}
4953

50-
async fn embed(&self, text_batch: &[String], metadata: Option<HashMap<String, String>>) -> Result<Vec<EmbedData>, reqwest::Error> {
54+
fn embed(
55+
&self,
56+
text_batch: &[String],
57+
metadata: Option<HashMap<String, String>>,
58+
) -> Result<Vec<EmbedData>, anyhow::Error> {
5159
let client = Client::new();
52-
53-
let response = client
54-
.post(&self.url)
55-
.header("Content-Type", "application/json")
56-
.header("Authorization", format!("Bearer {}", self.api_key))
57-
.json(&json!({
58-
"input": text_batch,
59-
"model": "text-embedding-3-small",
60-
}))
61-
.send()
62-
.await?;
63-
64-
let data = response.json::<EmbedResponse>().await?;
65-
println!("{:?}", data.usage);
60+
let runtime = tokio::runtime::Builder::new_current_thread().enable_io()
61+
.build()
62+
.unwrap();
63+
64+
let data = runtime.block_on(async move {
65+
let response = client
66+
.post(&self.url)
67+
.header("Content-Type", "application/json")
68+
.header("Authorization", format!("Bearer {}", self.api_key))
69+
.json(&json!({
70+
"input": text_batch,
71+
"model": "text-embedding-3-small",
72+
}))
73+
.send()
74+
.await
75+
.unwrap();
76+
77+
let data = response.json::<EmbedResponse>().await.unwrap();
78+
println!("{:?}", data.usage);
79+
data
80+
});
6681

6782
let emb_data = data
6883
.data
6984
.iter()
7085
.zip(text_batch)
71-
.map(move |(data, text)| EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone()))
86+
.map(move |(data, text)| {
87+
EmbedData::new(data.embedding.clone(), Some(text.clone()), metadata.clone())
88+
})
7289
.collect::<Vec<_>>();
7390

7491
Ok(emb_data)
@@ -79,15 +96,14 @@ impl OpenAIEmbeder {
7996
mod tests {
8097
use super::*;
8198

82-
#[tokio::test]
83-
async fn test_openai_embed() {
99+
fn test_openai_embed() {
84100
let openai = OpenAIEmbeder::default();
85101
let text_batch = vec![
86102
"Once upon a time".to_string(),
87103
"The quick brown fox jumps over the lazy dog".to_string(),
88104
];
89105

90-
let embeddings = openai.embed(&text_batch, None).await.unwrap();
106+
let embeddings = openai.embed(&text_batch, None).unwrap();
91107
assert_eq!(embeddings.len(), 2);
92108
}
93109
}

src/file_embed.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ impl FileEmbeder {
6060

6161
}
6262

63-
pub async fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), reqwest::Error> {
64-
self.embeddings = embeder.embed(&self.chunks, metadata).await?;
63+
pub fn embed(&mut self, embeder: &Embeder, metadata: Option<HashMap< String, String>>) -> Result<(), anyhow::Error> {
64+
self.embeddings = embeder.embed(&self.chunks, metadata)?;
6565
Ok(())
6666
}
6767

@@ -91,7 +91,7 @@ mod tests {
9191
let embeder = Embeder::Bert(BertEmbeder::default());
9292
let mut file_embeder = FileEmbeder::new(file_path.to_string_lossy().to_string());
9393
file_embeder.split_into_chunks(&text, 100);
94-
file_embeder.embed(&embeder, None).await.unwrap();
94+
file_embeder.embed(&embeder, None).unwrap();
9595
assert_eq!(file_embeder.chunks.len(), 5);
9696
assert_eq!(file_embeder.embeddings.len(), 5);
9797
}

src/file_processor/website_processor.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,22 @@ pub struct WebPage {
2222
}
2323

2424
impl WebPage {
25-
pub async fn embed_webpage<T: TextEmbed>(&self, embeder: &T) -> Result<Vec<EmbedData>, Error>{
25+
pub fn embed_webpage<T: TextEmbed>(&self, embeder: &T) -> Result<Vec<EmbedData>, Error>{
2626
let mut embed_data = Vec::new();
2727
let paragraph_embeddings = if let Some(paragraphs) = &self.paragraphs {
28-
self.embed_tag::<T>("p", paragraphs.to_vec(), &embeder).await.unwrap_or(Vec::new())
28+
self.embed_tag::<T>("p", paragraphs.to_vec(), &embeder).unwrap_or(Vec::new())
2929
} else {
3030
Vec::new()
3131
};
3232

3333
let header_embeddings = if let Some(headers) = &self.headers {
34-
self.embed_tag::<T>("h1", headers.to_vec(), &embeder).await.unwrap_or(Vec::new())
34+
self.embed_tag::<T>("h1", headers.to_vec(), &embeder).unwrap_or(Vec::new())
3535
} else {
3636
Vec::new()
3737
};
3838

3939
let code_embeddings = if let Some(codes) = &self.codes {
40-
self.embed_tag::<T>("code", codes.to_vec(), &embeder).await.unwrap_or(Vec::new())
40+
self.embed_tag::<T>("code", codes.to_vec(), &embeder).unwrap_or(Vec::new())
4141
} else {
4242
Vec::new()
4343
};
@@ -48,7 +48,7 @@ impl WebPage {
4848
Ok(embed_data)
4949
}
5050

51-
pub async fn embed_tag<T: TextEmbed>(&self,tag: &str, tag_content: Vec<String>, embeder: &T) -> Result<Vec<EmbedData>, Error> {
51+
pub fn embed_tag<T: TextEmbed>(&self,tag: &str, tag_content: Vec<String>, embeder: &T) -> Result<Vec<EmbedData>, Error> {
5252
let mut embed_data = Vec::new();
5353
for content in tag_content {
5454
let mut file_embeder = FileEmbeder::new(self.url.to_string());
@@ -84,7 +84,7 @@ impl WebPage {
8484

8585
let embeddings = embeder
8686
.embed(&chunks, Some(metadata_hashmap))
87-
.await
87+
8888
.unwrap_or(Vec::new());
8989
for embedding in embeddings {
9090
embed_data.push(embedding);

0 commit comments

Comments
 (0)