Skip to content

Commit 2a86d8b

Browse files
Update embedding_model module and Cargo.toml
1 parent f46a112 commit 2a86d8b

12 files changed

+160
-140
lines changed

Cargo.lock

+83-70
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+8-4
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@ name = "embed_anything"
33
version = "0.1.5"
44
edition = "2021"
55

6-
6+
[lib]
7+
name = "embed_anything"
8+
crate-type = ["cdylib"]
79

810
[dependencies]
9-
pyo3 = { version = "0.20", features = ["extension-module"] }
10-
pyo3-asyncio = { version = "0.20", features = ["tokio-runtime"] }
11-
tokio = "1.9"
11+
1212
serde_json = "1.0.112"
1313
reqwest = { version = "0.12.2", features = ["json"] }
1414
futures = "0.3.30"
@@ -25,4 +25,8 @@ candle-transformers = { git = "https://github.com/huggingface/candle.git", versi
2525
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0", features = ["mkl"] }
2626
anyhow = "1.0.81"
2727
intel-mkl-src = "0.8.1"
28+
candle-pyo3 = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
29+
tokio = {version = "1.37.0", features=["rt-multi-thread"]}
30+
pyo3 = { version = "0.21" }
31+
2832

embed_anything-0.1.1.tar.gz

-16.6 KB
Binary file not shown.

src/embedding_model/bert.rs

+16-13
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#[cfg(feature = "mkl")]
12
extern crate intel_mkl_src;
23

34
use anyhow::Error as E;
@@ -12,25 +13,26 @@ pub struct BertEmbeder {
1213
pub model: BertModel,
1314
pub tokenizer: Tokenizer,
1415
}
15-
impl BertEmbeder {
16-
pub fn default() -> anyhow::Result<Self> {
16+
17+
impl Default for BertEmbeder {
18+
fn default() -> Self {
1719
let device = Device::Cpu;
1820
let default_model = "sentence-transformers/all-MiniLM-L12-v2".to_string();
1921
let default_revision = "refs/pr/21".to_string();
2022
let (model_id, _revision) = (default_model, default_revision);
2123
let repo = Repo::model(model_id);
2224
let (config_filename, tokenizer_filename, weights_filename) = {
23-
let api = Api::new()?;
25+
let api = Api::new().unwrap();
2426
let api = api.repo(repo);
25-
let config = api.get("config.json")?;
26-
let tokenizer = api.get("tokenizer.json")?;
27-
let weights = api.get("model.safetensors")?;
27+
let config = api.get("config.json").unwrap();
28+
let tokenizer = api.get("tokenizer.json").unwrap();
29+
let weights = api.get("model.safetensors").unwrap();
2830

2931
(config, tokenizer, weights)
3032
};
31-
let config = std::fs::read_to_string(config_filename)?;
32-
let mut config: Config = serde_json::from_str(&config)?;
33-
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
33+
let config = std::fs::read_to_string(config_filename).unwrap();
34+
let mut config: Config = serde_json::from_str(&config).unwrap();
35+
let mut tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg).unwrap();
3436

3537
let pp = PaddingParams {
3638
strategy: tokenizers::PaddingStrategy::BatchLongest,
@@ -39,14 +41,15 @@ impl BertEmbeder {
3941
tokenizer.with_padding(Some(pp));
4042

4143
let vb =
42-
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
44+
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device).unwrap() };
4345

4446
config.hidden_act = HiddenAct::GeluApproximate;
4547

46-
let model = BertModel::load(vb, &config)?;
47-
Ok(BertEmbeder { model, tokenizer })
48+
let model = BertModel::load(vb, &config).unwrap();
49+
BertEmbeder { model, tokenizer }
4850
}
49-
51+
}
52+
impl BertEmbeder {
5053
pub fn tokenize_batch(&self, text_batch: &[String], device: &Device) -> anyhow::Result<Tensor> {
5154
let tokens = self
5255
.tokenizer

src/embedding_model/clip.rs

+11-12
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1+
#[cfg(feature = "mkl")]
12
extern crate intel_mkl_src;
23

3-
use std::any;
4-
54
use anyhow::Error as E;
65

76

@@ -18,23 +17,23 @@ pub struct ClipEmbeder {
1817
pub tokenizer: Tokenizer,
1918
}
2019

21-
impl ClipEmbeder {
22-
pub fn default() -> anyhow::Result<Self> {
23-
let api = hf_hub::api::sync::Api::new()?;
20+
impl Default for ClipEmbeder {
21+
fn default() -> Self {
22+
let api = hf_hub::api::sync::Api::new().unwrap();
2423
let api = api.repo(hf_hub::Repo::with_revision(
2524
"openai/clip-vit-base-patch32".to_string(),
2625
hf_hub::RepoType::Model,
2726
"refs/pr/15".to_string(),
2827
));
29-
let model_file = api.get("model.safetensors")?;
28+
let model_file = api.get("model.safetensors").unwrap();
3029
let config = clip::ClipConfig::vit_base_patch32();
3130
let device = Device::Cpu;
3231
let vb = unsafe {
33-
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)?
32+
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device).unwrap()
3433
};
35-
let model = clip::ClipModel::new(vb, &config)?;
36-
let tokenizer = Self::get_tokenizer(None)?;
37-
Ok(ClipEmbeder { model, tokenizer })
34+
let model = clip::ClipModel::new(vb, &config).unwrap();
35+
let tokenizer = Self::get_tokenizer(None).unwrap();
36+
ClipEmbeder { model, tokenizer }
3837
}
3938
}
4039

@@ -123,7 +122,7 @@ impl ClipEmbeder {
123122

124123
fn load_images<T: AsRef<std::path::Path>>(
125124
&self,
126-
paths: &Vec<T>,
125+
paths: &[T],
127126
image_size: usize,
128127
) -> anyhow::Result<Tensor> {
129128
let mut images = vec![];
@@ -144,7 +143,7 @@ impl ClipEmbeder {
144143
}
145144

146145
impl EmbedImage for ClipEmbeder{
147-
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&Vec<T>) -> anyhow::Result<Vec<EmbedData>> {
146+
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>> {
148147
let config = clip::ClipConfig::vit_base_patch32();
149148

150149
let images = self.load_images(image_paths, config.image_size).unwrap();

src/embedding_model/embed.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
use pyo3::prelude::*;
22
use serde::Deserialize;
3-
use std::any;
43
use std::collections::HashMap;
54

65

@@ -65,5 +64,5 @@ pub trait Embed {
6564
}
6665

6766
pub trait EmbedImage {
68-
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&Vec<T>) -> anyhow::Result<Vec<EmbedData>>;
67+
fn embed_image_batch<T: AsRef<std::path::Path>>(&self, image_paths:&[T]) -> anyhow::Result<Vec<EmbedData>>;
6968
}

src/embedding_model/jina.rs

+13-11
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
use anyhow::Error as E;
22
use candle_core::{DType, Device, Tensor};
33
use candle_nn::{Module, VarBuilder};
4-
// use rust_bert::pipelines::sentence_embeddings::{
5-
// SentenceEmbeddingsBuilder, SentenceEmbeddingsModel, SentenceEmbeddingsModelType,
6-
// };
74
use super::embed::{Embed, EmbedData};
85
use candle_transformers::models::jina_bert::{BertModel, Config};
96
use hf_hub::{Repo, RepoType};
@@ -12,30 +9,35 @@ pub struct JinaEmbeder {
129
pub model: BertModel,
1310
pub tokenizer: Tokenizer,
1411
}
15-
impl JinaEmbeder {
16-
pub fn default() -> anyhow::Result<Self> {
17-
let api = hf_hub::api::sync::Api::new()?;
12+
13+
impl Default for JinaEmbeder {
14+
fn default() -> Self {
15+
let api = hf_hub::api::sync::Api::new().unwrap();
1816
let model_file = api
1917
.repo(Repo::new(
2018
"jinaai/jina-embeddings-v2-base-en".to_string(),
2119
RepoType::Model,
2220
))
23-
.get("model.safetensors")?;
21+
.get("model.safetensors")
22+
.unwrap();
2423
let config = Config::v2_base();
2524

2625
let device = Device::Cpu;
2726
let vb = unsafe {
28-
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device)?
27+
VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F32, &device).unwrap()
2928
};
30-
let model = BertModel::new(vb, &config)?;
31-
let mut tokenizer = Self::get_tokenizer(None)?;
29+
let model = BertModel::new(vb, &config).unwrap();
30+
let mut tokenizer = Self::get_tokenizer(None).unwrap();
3231
let pp = tokenizers::PaddingParams {
3332
strategy: tokenizers::PaddingStrategy::BatchLongest,
3433
..Default::default()
3534
};
3635
tokenizer.with_padding(Some(pp));
37-
Ok(JinaEmbeder { model, tokenizer })
36+
JinaEmbeder { model, tokenizer }
3837
}
38+
}
39+
40+
impl JinaEmbeder {
3941

4042
pub fn get_tokenizer(tokenizer: Option<String>) -> anyhow::Result<Tokenizer> {
4143
let tokenizer = match tokenizer {

src/lib.rs

+21-22
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,36 @@
1+
#[cfg(feature = "mkl")]
2+
extern crate intel_mkl_src;
3+
14
pub mod embedding_model;
25
pub mod file_embed;
36
pub mod parser;
47
pub mod pdf_processor;
58

69
use std::path::PathBuf;
710

8-
use embedding_model::{
9-
clip::ClipEmbeder,
10-
embed::{Embed, EmbedData, EmbedImage, Embeder},
11-
};
11+
use embedding_model::embed::{EmbedData, EmbedImage, Embeder};
1212
use file_embed::FileEmbeder;
1313
use parser::FileParser;
1414
use pyo3::{exceptions::PyValueError, prelude::*};
1515
use rayon::prelude::*;
16+
use tokio::runtime::Builder;
1617

1718
#[pyfunction]
1819
pub fn embed_query(query: Vec<String>, embeder: &str) -> PyResult<Vec<EmbedData>> {
1920
let embedding_model = match embeder {
2021
"OpenAI" => Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
21-
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
22-
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default().unwrap()),
23-
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
22+
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
23+
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default()),
24+
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
2425
_ => {
2526
return Err(PyValueError::new_err(
2627
"Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.",
2728
))
2829
}
2930
};
31+
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
3032

31-
let embeddings = tokio::runtime::Runtime::new()
32-
.unwrap()
33-
.block_on(embedding_model.embed(&query))
34-
.unwrap();
33+
let embeddings = runtime.block_on(embedding_model.embed(&query)).unwrap();
3534
Ok(embeddings)
3635
}
3736

@@ -40,9 +39,9 @@ pub fn embed_query(query: Vec<String>, embeder: &str) -> PyResult<Vec<EmbedData>
4039
pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
4140
let embedding_model = match embeder {
4241
"OpenAI" => Embeder::OpenAI(embedding_model::openai::OpenAIEmbeder::default()),
43-
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
44-
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default().unwrap()),
45-
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
42+
"Jina" => Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
43+
"Clip" => Embeder::Clip(embedding_model::clip::ClipEmbeder::default()),
44+
"Bert" => Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
4645
_ => {
4746
return Err(PyValueError::new_err(
4847
"Invalid embedding model. Choose between OpenAI and AllMiniLmL12V2.",
@@ -53,8 +52,8 @@ pub fn embed_file(file_name: &str, embeder: &str) -> PyResult<Vec<EmbedData>> {
5352
let mut file_embeder = FileEmbeder::new(file_name.to_string());
5453
let text = file_embeder.extract_text().unwrap();
5554
file_embeder.split_into_chunks(&text, 100);
56-
tokio::runtime::Runtime::new()
57-
.unwrap()
55+
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
56+
runtime
5857
.block_on(file_embeder.embed(&embedding_model))
5958
.unwrap();
6059
Ok(file_embeder.embeddings)
@@ -70,17 +69,17 @@ pub fn embed_directory(directory: PathBuf, embeder: &str) -> PyResult<Vec<EmbedD
7069
.unwrap(),
7170
"Jina" => emb(
7271
directory,
73-
Embeder::Jina(embedding_model::jina::JinaEmbeder::default().unwrap()),
72+
Embeder::Jina(embedding_model::jina::JinaEmbeder::default()),
7473
)
7574
.unwrap(),
7675
"Bert" => emb(
7776
directory,
78-
Embeder::Bert(embedding_model::bert::BertEmbeder::default().unwrap()),
77+
Embeder::Bert(embedding_model::bert::BertEmbeder::default()),
7978
)
8079
.unwrap(),
8180
"Clip" => emb_image(
8281
directory,
83-
embedding_model::clip::ClipEmbeder::default().unwrap(),
82+
embedding_model::clip::ClipEmbeder::default(),
8483
)
8584
.unwrap(),
8685

@@ -96,7 +95,7 @@ pub fn embed_directory(directory: PathBuf, embeder: &str) -> PyResult<Vec<EmbedD
9695

9796
/// A Python module implemented in Rust.
9897
#[pymodule]
99-
fn embed_anything(_py: Python, m: &PyModule) -> PyResult<()> {
98+
fn embed_anything(m: &Bound<'_,PyModule>) -> PyResult<()> {
10099
m.add_function(wrap_pyfunction!(embed_file, m)?)?;
101100
m.add_function(wrap_pyfunction!(embed_directory, m)?)?;
102101
m.add_function(wrap_pyfunction!(embed_query, m)?)?;
@@ -115,8 +114,8 @@ fn emb(directory: PathBuf, embedding_model: Embeder) -> PyResult<Vec<EmbedData>>
115114
let mut file_embeder = FileEmbeder::new(file.to_string());
116115
let text = file_embeder.extract_text().unwrap();
117116
file_embeder.split_into_chunks(&text, 100);
118-
tokio::runtime::Runtime::new()
119-
.unwrap()
117+
let runtime = Builder::new_multi_thread().enable_all().build().unwrap();
118+
runtime
120119
.block_on(file_embeder.embed(&embedding_model))
121120
.unwrap();
122121
file_embeder.embeddings

test.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,20 @@
11
import os
2-
import time
32
import numpy as np
4-
os.add_dll_directory(r'D:\libtorch\lib')
3+
# os.add_dll_directory(r'D:\test')
54
from embed_anything import EmbedData
65
import embed_anything
76
from PIL import Image
8-
# start = time.time()
7+
import time
8+
99
# data:list[EmbedData] = embed_anything.embed_file("test_files/TUe_SOP_AI_2.pdf", embeder= "Bert")
1010

1111
# embeddings = np.array([data.embedding for data in data])
1212

13-
# end = time.time()
1413

1514
# print(embeddings)
1615
# print("Time taken: ", end-start)
1716

18-
17+
start = time.time()
1918
data:list[EmbedData] = embed_anything.embed_directory("test_files", embeder= "Clip")
2019

2120
embeddings = np.array([data.embedding for data in data])
@@ -29,5 +28,7 @@
2928

3029
max_index = np.argmax(similarities)
3130

32-
Image.open(data[max_index].text).show()
31+
# Image.open(data[max_index].text).show()
3332
print(data[max_index].text)
33+
end = time.time()
34+
print("Time taken: ", end-start)
Binary file not shown.

test_files/TUe_SOP_AI_2.pdf

-149 KB
Binary file not shown.
-162 KB
Binary file not shown.

0 commit comments

Comments
 (0)