@@ -22,7 +22,11 @@ impl Default for OpenAIEmbeder {
22
22
}
23
23
24
24
impl Embed for OpenAIEmbeder {
25
- fn embed ( & self , text_batch : & [ String ] , metadata : Option < HashMap < String , String > > ) -> impl std:: future:: Future < Output = Result < Vec < EmbedData > , reqwest:: Error > > {
25
+ fn embed (
26
+ & self ,
27
+ text_batch : & [ String ] ,
28
+ metadata : Option < HashMap < String , String > > ,
29
+ ) -> Result < Vec < EmbedData > , anyhow:: Error > {
26
30
self . embed ( text_batch, metadata)
27
31
}
28
32
}
@@ -32,7 +36,7 @@ impl TextEmbed for OpenAIEmbeder {
32
36
& self ,
33
37
text_batch : & [ String ] ,
34
38
metadata : Option < HashMap < String , String > > ,
35
- ) -> impl std :: future :: Future < Output = Result < Vec < EmbedData > , reqwest :: Error > > {
39
+ ) -> Result < Vec < EmbedData > , anyhow :: Error > {
36
40
self . embed ( text_batch, metadata)
37
41
}
38
42
}
@@ -47,28 +51,41 @@ impl OpenAIEmbeder {
47
51
}
48
52
}
49
53
50
- async fn embed ( & self , text_batch : & [ String ] , metadata : Option < HashMap < String , String > > ) -> Result < Vec < EmbedData > , reqwest:: Error > {
54
+ fn embed (
55
+ & self ,
56
+ text_batch : & [ String ] ,
57
+ metadata : Option < HashMap < String , String > > ,
58
+ ) -> Result < Vec < EmbedData > , anyhow:: Error > {
51
59
let client = Client :: new ( ) ;
52
-
53
- let response = client
54
- . post ( & self . url )
55
- . header ( "Content-Type" , "application/json" )
56
- . header ( "Authorization" , format ! ( "Bearer {}" , self . api_key) )
57
- . json ( & json ! ( {
58
- "input" : text_batch,
59
- "model" : "text-embedding-3-small" ,
60
- } ) )
61
- . send ( )
62
- . await ?;
63
-
64
- let data = response. json :: < EmbedResponse > ( ) . await ?;
65
- println ! ( "{:?}" , data. usage) ;
60
+ let runtime = tokio:: runtime:: Builder :: new_current_thread ( ) . enable_io ( )
61
+ . build ( )
62
+ . unwrap ( ) ;
63
+
64
+ let data = runtime. block_on ( async move {
65
+ let response = client
66
+ . post ( & self . url )
67
+ . header ( "Content-Type" , "application/json" )
68
+ . header ( "Authorization" , format ! ( "Bearer {}" , self . api_key) )
69
+ . json ( & json ! ( {
70
+ "input" : text_batch,
71
+ "model" : "text-embedding-3-small" ,
72
+ } ) )
73
+ . send ( )
74
+ . await
75
+ . unwrap ( ) ;
76
+
77
+ let data = response. json :: < EmbedResponse > ( ) . await . unwrap ( ) ;
78
+ println ! ( "{:?}" , data. usage) ;
79
+ data
80
+ } ) ;
66
81
67
82
let emb_data = data
68
83
. data
69
84
. iter ( )
70
85
. zip ( text_batch)
71
- . map ( move |( data, text) | EmbedData :: new ( data. embedding . clone ( ) , Some ( text. clone ( ) ) , metadata. clone ( ) ) )
86
+ . map ( move |( data, text) | {
87
+ EmbedData :: new ( data. embedding . clone ( ) , Some ( text. clone ( ) ) , metadata. clone ( ) )
88
+ } )
72
89
. collect :: < Vec < _ > > ( ) ;
73
90
74
91
Ok ( emb_data)
@@ -79,15 +96,14 @@ impl OpenAIEmbeder {
79
96
mod tests {
80
97
use super :: * ;
81
98
82
- #[ tokio:: test]
83
- async fn test_openai_embed ( ) {
99
+ fn test_openai_embed ( ) {
84
100
let openai = OpenAIEmbeder :: default ( ) ;
85
101
let text_batch = vec ! [
86
102
"Once upon a time" . to_string( ) ,
87
103
"The quick brown fox jumps over the lazy dog" . to_string( ) ,
88
104
] ;
89
105
90
- let embeddings = openai. embed ( & text_batch, None ) . await . unwrap ( ) ;
106
+ let embeddings = openai. embed ( & text_batch, None ) . unwrap ( ) ;
91
107
assert_eq ! ( embeddings. len( ) , 2 ) ;
92
108
}
93
109
}
0 commit comments