Skip to content

Commit 9ef6aac

Browse files
committed
refactor: license info in forbidden response
1 parent 796d015 commit 9ef6aac

File tree

2 files changed

+36
-15
lines changed

2 files changed

+36
-15
lines changed

src/github/mod.rs

+32-12
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::{
55
};
66
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
77
use serde::{Deserialize, Serialize};
8-
use serde_json::Value;
8+
use serde_json::{json, Value};
99
use std::io::Read;
1010

1111
#[derive(Debug, Default, Serialize)]
@@ -151,7 +151,7 @@ const IGNORED_DIRECTORIES: &[&str] = &[
151151
"debug",
152152
];
153153

154-
const ALLOWED_LICENCES: &[&str] = &[
154+
const ALLOWED_LICENSES: &[&str] = &[
155155
"0bsd",
156156
"apache-2.0",
157157
"bsd-2-clause",
@@ -160,13 +160,18 @@ const ALLOWED_LICENCES: &[&str] = &[
160160
"bsd-4-clause",
161161
"isc",
162162
"mit",
163-
"mpl-2.0",
164163
"unlicense",
165164
"wtfpl",
166165
"zlib",
167166
];
168167

169-
pub async fn is_indexing_allowed(repository: &Repository) -> Result<bool> {
168+
#[derive(Serialize, Debug, Default)]
169+
pub struct LicenseFetchResponse {
170+
pub permissible: bool,
171+
pub error: Option<Value>,
172+
}
173+
174+
pub async fn fetch_license_info(repository: &Repository) -> Result<LicenseFetchResponse> {
170175
let Repository { owner, name, .. } = repository;
171176
let url = format!("https://api.github.com/repos/{owner}/{name}/license");
172177

@@ -181,8 +186,22 @@ pub async fn is_indexing_allowed(repository: &Repository) -> Result<bool> {
181186
Ok(response) => {
182187
let response_json = response.json::<Value>().await?;
183188
let license_key = response_json["license"]["key"].as_str().unwrap_or_default();
184-
let is_allowed: bool = ALLOWED_LICENCES.iter().any(|k| k.eq(&license_key));
185-
Ok(is_allowed)
189+
let permissible: bool = ALLOWED_LICENSES.iter().any(|k| k.eq(&license_key));
190+
191+
Ok(LicenseFetchResponse {
192+
permissible,
193+
error: if permissible {
194+
None
195+
} else {
196+
Some(json! {{
197+
"message": "Impermissible repository license",
198+
"license": {
199+
"name": response_json["license"]["name"],
200+
"url": response_json["html_url"]
201+
}
202+
}})
203+
},
204+
})
186205
}
187206
Err(_) => Err(anyhow::anyhow!("Unable to fetch repository license")),
188207
}
@@ -265,8 +284,9 @@ mod tests {
265284
branch: "beta".to_string(),
266285
};
267286

268-
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
269-
assert_eq!(is_allowed, true);
287+
let license_info = fetch_license_info(&repository).await.unwrap_or_default();
288+
dbg!(&license_info);
289+
assert_eq!(license_info.permissible, true);
270290

271291
//Permissible
272292
let repository = Repository {
@@ -275,8 +295,8 @@ mod tests {
275295
branch: "main".to_string(),
276296
};
277297

278-
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
279-
assert_eq!(is_allowed, true);
298+
let license_info = fetch_license_info(&repository).await.unwrap_or_default();
299+
assert_eq!(license_info.permissible, true);
280300

281301
//Impermissible
282302
let repository = Repository {
@@ -285,7 +305,7 @@ mod tests {
285305
branch: "main".to_string(),
286306
};
287307

288-
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
289-
assert_eq!(is_allowed, false);
308+
let license_info = fetch_license_info(&repository).await.unwrap_or_default();
309+
assert_eq!(license_info.permissible, false);
290310
}
291311
}

src/routes/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ pub mod events;
33

44
use crate::constants::SSE_CHANNEL_BUFFER_SIZE;
55
use crate::conversation::{Conversation, Query};
6-
use crate::github::{fetch_repo_files, is_indexing_allowed};
6+
use crate::github::{fetch_license_info, fetch_repo_files};
77
use crate::routes::events::QueryEvent;
88
use crate::{db::RepositoryEmbeddingsDB, github::Repository};
99
use actix_web::web::Query as ActixQuery;
@@ -27,8 +27,9 @@ async fn embeddings(
2727
db: web::Data<Arc<QdrantDB>>,
2828
model: web::Data<Arc<Onnx>>,
2929
) -> Result<impl Responder> {
30-
if !is_indexing_allowed(&data).await.map_err(ErrorBadRequest)? {
31-
return Err(ErrorForbidden("Impermissible repository license"));
30+
let license_info = fetch_license_info(&data).await.map_err(ErrorBadRequest)?;
31+
if !license_info.permissible {
32+
return Err(ErrorForbidden(license_info.error.unwrap_or_default()));
3233
}
3334

3435
let (sender, rx) = sse::channel(SSE_CHANNEL_BUFFER_SIZE);

0 commit comments

Comments
 (0)