Skip to content

Commit 0972d07

Browse files
committed
feat: repo license check
1 parent 09946ce commit 0972d07

File tree

2 files changed

+79
-4
lines changed

2 files changed

+79
-4
lines changed

src/github/mod.rs

+71
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::{
55
};
66
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
77
use serde::{Deserialize, Serialize};
8+
use serde_json::Value;
89
use std::io::Read;
910

1011
#[derive(Debug, Default, Serialize)]
@@ -150,6 +151,43 @@ const IGNORED_DIRECTORIES: &[&str] = &[
150151
"debug",
151152
];
152153

154+
const ALLOWED_LICENCES: &[&str] = &[
155+
"0bsd",
156+
"apache-2.0",
157+
"bsd-2-clause",
158+
"bsd-3-clause",
159+
"bsd-3-clause-clear",
160+
"bsd-4-clause",
161+
"isc",
162+
"mit",
163+
"mpl-2.0",
164+
"unlicense",
165+
"wtfpl",
166+
"zlib",
167+
];
168+
169+
pub async fn is_indexing_allowed(repository: &Repository) -> Result<bool> {
170+
let Repository { owner, name, .. } = repository;
171+
let url = format!("https://api.github.com/repos/{owner}/{name}/license");
172+
173+
//User-agent reference: https://docs.github.com/en/rest/overview/resources-in-the-rest-api?apiVersion=2022-11-28#user-agent-required
174+
let client = reqwest::Client::builder()
175+
.user_agent("open-sauced")
176+
.build()
177+
.unwrap();
178+
179+
let response = client.get(url).send().await?;
180+
match response.error_for_status() {
181+
Ok(response) => {
182+
let response_json = response.json::<Value>().await?;
183+
let license_key = response_json["license"]["key"].as_str().unwrap_or_default();
184+
let is_allowed: bool = ALLOWED_LICENCES.iter().any(|k| k.eq(&license_key));
185+
Ok(is_allowed)
186+
}
187+
Err(_) => Err(anyhow::anyhow!("Unable to fetch repository license")),
188+
}
189+
}
190+
153191
pub fn should_index(path: &str) -> bool {
154192
!(IGNORED_EXTENSIONS.iter().any(|ext| path.ends_with(ext))
155193
|| IGNORED_DIRECTORIES.iter().any(|dir| path.contains(dir)))
@@ -217,4 +255,37 @@ mod tests {
217255
let path = "path/to/file.tsx";
218256
assert!(should_index(path));
219257
}
258+
259+
#[tokio::test]
260+
async fn test_is_indexing_allowed() {
261+
// Permissible
262+
let repository = Repository {
263+
owner: "open-sauced".to_string(),
264+
name: "ai".to_string(),
265+
branch: "beta".to_string(),
266+
};
267+
268+
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
269+
assert_eq!(is_allowed, true);
270+
271+
//Permissible
272+
let repository = Repository {
273+
owner: "facebook".to_string(),
274+
name: "react".to_string(),
275+
branch: "main".to_string(),
276+
};
277+
278+
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
279+
assert_eq!(is_allowed, true);
280+
281+
//Imermissible
282+
let repository = Repository {
283+
owner: "open-sauced".to_string(),
284+
name: "guestbook".to_string(),
285+
branch: "main".to_string(),
286+
};
287+
288+
let is_allowed = is_indexing_allowed(&repository).await.unwrap_or_default();
289+
assert_eq!(is_allowed, false);
290+
}
220291
}

src/routes/mod.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@ pub mod events;
33

44
use crate::constants::SSE_CHANNEL_BUFFER_SIZE;
55
use crate::conversation::{Conversation, Query};
6-
use crate::github::fetch_repo_files;
6+
use crate::github::{fetch_repo_files, is_indexing_allowed};
77
use crate::routes::events::QueryEvent;
88
use crate::{db::RepositoryEmbeddingsDB, github::Repository};
99
use actix_web::web::Query as ActixQuery;
1010
use actix_web::HttpResponse;
1111
use actix_web::{
12-
error::ErrorNotFound,
12+
error::{ErrorBadRequest, ErrorForbidden, ErrorNotFound},
1313
get, post,
1414
web::{self, Json},
1515
Responder, Result,
@@ -26,7 +26,11 @@ async fn embeddings(
2626
data: Json<Repository>,
2727
db: web::Data<Arc<QdrantDB>>,
2828
model: web::Data<Arc<Onnx>>,
29-
) -> impl Responder {
29+
) -> Result<impl Responder> {
30+
if !is_indexing_allowed(&data).await.map_err(ErrorBadRequest)? {
31+
return Err(ErrorForbidden("Impermissible repository license"));
32+
}
33+
3034
let (sender, rx) = sse::channel(SSE_CHANNEL_BUFFER_SIZE);
3135

3236
actix_rt::spawn(async move {
@@ -58,7 +62,7 @@ async fn embeddings(
5862
}
5963
});
6064

61-
rx
65+
Ok(rx)
6266
}
6367

6468
#[post("/query")]

0 commit comments

Comments
 (0)