forked from StarlightSearch/EmbedAnything
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwebsite_processor.rs
175 lines (146 loc) · 4.87 KB
/
website_processor.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
use std::collections::{HashMap, HashSet};
use anyhow::Result;
use scraper::{Html, Selector};
use serde_json::json;
use url::Url;
use crate::{
embedding_model::embed::{EmbedData, TextEmbed},
text_loader::TextLoader,
};
#[derive(Debug)]
pub struct WebPage {
pub url: String,
pub title: Option<String>,
pub headers: Option<Vec<String>>,
pub paragraphs: Option<Vec<String>>,
pub codes: Option<Vec<String>>,
pub links: Option<HashSet<String>>,
}
impl WebPage {
pub fn embed_webpage<T: TextEmbed>(&self, embeder: &T) -> Result<Vec<EmbedData>> {
let mut embed_data = Vec::new();
if let Some(paragraphs) = &self.paragraphs {
embed_data.extend(self.embed_tag("p", paragraphs, embeder)?);
}
if let Some(headers) = &self.headers {
embed_data.extend(self.embed_tag("h1", headers, embeder)?);
}
if let Some(codes) = &self.codes {
embed_data.extend(self.embed_tag("code", codes, embeder)?);
}
Ok(embed_data)
}
pub fn embed_tag<T: TextEmbed>(
&self,
tag: &str,
tag_content: &[String],
embeder: &T,
) -> Result<Vec<EmbedData>> {
let mut embed_data = Vec::new();
for content in tag_content {
let chunks = match TextLoader::split_into_chunks(content, 1000) {
Some(chunks) => chunks,
None => continue,
};
if chunks.is_empty() {
continue;
}
let tag_type = match tag {
"h1" => "header",
"h2" => "subheader",
"h3" => "subsubheader",
"p" => "paragraph",
"code" => "code",
_ => "paragraph",
};
let metadata = json!({
"url": self.url,
"type": tag_type,
"full_text": content,
});
let metadata_hashmap: HashMap<String, String> = serde_json::from_value(metadata)?;
let embeddings = embeder.embed(&chunks, Some(metadata_hashmap))?;
embed_data.extend(embeddings);
}
Ok(embed_data)
}
}
impl Default for WebPage {
fn default() -> Self {
Self {
url: "".to_string(),
title: None,
headers: None,
paragraphs: None,
codes: None,
links: None,
}
}
}
impl Default for WebsiteProcessor {
fn default() -> Self {
Self::new()
}
}
pub struct WebsiteProcessor;
impl WebsiteProcessor {
pub fn new() -> Self {
Self {}
}
pub async fn process_website(&self, website: &str) -> Result<WebPage> {
let response = reqwest::get(website).await?.text().await?;
let document = Html::parse_document(&response);
let headers = self.get_text_from_tag("h1,h2,h3", &document)?;
let paragraphs = self.get_text_from_tag("p", &document)?;
let codes = self.get_text_from_tag("code", &document)?;
let links = self.extract_links(website, &document)?;
let title = self.get_title(&document)?;
let web_page = WebPage {
url: website.to_string(),
title,
headers: Some(headers),
paragraphs: Some(paragraphs),
codes: Some(codes),
links: Some(links),
};
Ok(web_page)
}
fn get_text_from_tag(&self, tag: &str, document: &Html) -> Result<Vec<String>, anyhow::Error> {
let selector = Selector::parse(tag).unwrap();
Ok(document
.select(&selector)
.map(|element| element.text().collect::<String>().trim().to_string())
.collect())
}
fn extract_links(&self, website: &str, document: &Html) -> Result<HashSet<String>> {
let mut links = HashSet::new();
let base_url = Url::parse(website)?;
for element in document.select(&Selector::parse("a").unwrap()) {
if let Some(href) = element.value().attr("href") {
let mut link_url = base_url.join(href)?;
// Normalize URLs, remove fragments and ensure they are absolute.
link_url.set_fragment(None);
links.insert(link_url.to_string());
}
}
Ok(links)
}
fn get_title(&self, document: &Html) -> Result<Option<String>> {
if let Some(title_element) = document.select(&Selector::parse("title").unwrap()).next() {
Ok(Some(title_element.text().collect::<String>()))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_process_website() {
let website_processor = WebsiteProcessor::new();
let website = "https://www.scrapingbee.com/blog/web-scraping-rust/";
let result = website_processor.process_website(website).await;
assert!(result.is_ok());
}
}