Skip to content

Commit

Permalink
Add federated JWK lookup to pepper service (aptos-labs#14425)
Browse files Browse the repository at this point in the history
* Add federated JWK lookup to pepper service

* small fixes

* fix match

* xclipppy
  • Loading branch information
heliuchuan authored Aug 27, 2024
1 parent cf7e840 commit 09ca1d6
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
25 changes: 25 additions & 0 deletions keyless/pepper/service/src/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,44 @@

use crate::{metrics::JWK_FETCH_SECONDS, Issuer, KeyID};
use anyhow::{anyhow, Result};
use aptos_keyless_pepper_common::jwt::parse;
use aptos_logger::warn;
use aptos_types::jwks::rsa::RSA_JWK;
use dashmap::DashMap;
use jsonwebtoken::DecodingKey;
use once_cell::sync::Lazy;
use regex::Regex;
use serde_json::Value;
use std::{sync::Arc, time::Duration};
use tokio::time::Instant;

static AUTH_0_REGEX: Lazy<Regex> =
Lazy::new(|| Regex::new(r"^https://[a-zA-Z0-9-]+\.us\.auth0\.com/$").unwrap());

/// The JWK in-mem cache.
pub static DECODING_KEY_CACHE: Lazy<DashMap<Issuer, DashMap<KeyID, Arc<RSA_JWK>>>> =
Lazy::new(DashMap::new);

pub async fn get_federated_jwk(jwt: &str) -> Result<Arc<RSA_JWK>> {
let payload = parse(jwt)?;

if !AUTH_0_REGEX.is_match(&payload.claims.iss) {
return Err(anyhow!("not a federated iss"));
}

let jwt_kid: String = match payload.header.kid {
Some(kid) => kid,
None => return Err(anyhow!("no kid found on jwt header")),
};

let jwk_url = format!("{}.well-known/jwks.json", &payload.claims.iss);
let keys = fetch_jwks(&jwk_url).await?;
let key = keys
.get(&jwt_kid)
.ok_or_else(|| anyhow!("unknown kid: {}", jwt_kid))?;
Ok(key.clone())
}

/// Send a request to a JWK endpoint and return its JWK map.
pub async fn fetch_jwks(jwk_url: &str) -> Result<DashMap<KeyID, Arc<RSA_JWK>>> {
let response = reqwest::get(jwk_url)
Expand Down
16 changes: 13 additions & 3 deletions keyless/pepper/service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ use aptos_types::{
},
};
use firestore::{async_trait, paths, struct_path::path};
use jsonwebtoken::{Algorithm::RS256, Validation};
use jsonwebtoken::{Algorithm::RS256, DecodingKey, Validation};
use jwk::get_federated_jwk;
use rand::thread_rng;
use serde::{Deserialize, Serialize};
use std::time::{SystemTime, UNIX_EPOCH};
Expand Down Expand Up @@ -357,13 +358,22 @@ async fn process_common(
.kid
.ok_or_else(|| BadRequest("missing kid in JWT".to_string()))?;

let sig_pub_key = jwk::cached_decoding_key(&claims.claims.iss, &key_id)
let cached_key = jwk::cached_decoding_key_as_rsa(&claims.claims.iss, &key_id);

let jwk = match cached_key {
Ok(key) => key,
Err(_) => get_federated_jwk(&jwt)
.await
.map_err(|e| BadRequest(format!("JWK not found: {e}")))?,
};
let jwk_decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
.map_err(|e| BadRequest(format!("JWK not found: {e}")))?;

let mut validation_with_sig_verification = Validation::new(RS256);
validation_with_sig_verification.validate_exp = false; // Don't validate the exp time
let _claims = jsonwebtoken::decode::<Claims>(
jwt.as_str(),
&sig_pub_key,
&jwk_decoding_key,
&validation_with_sig_verification,
) // Signature verification happens here.
.map_err(|e| BadRequest(format!("JWT signature verification failed: {e}")))?;
Expand Down

0 comments on commit 09ca1d6

Please sign in to comment.