|
| 1 | +import unittest |
| 2 | +import datetime |
| 3 | +import json |
| 4 | +import jwt |
| 5 | +import requests |
| 6 | +import sqlite3 |
| 7 | +from threading import Thread |
| 8 | +from http.server import HTTPServer |
| 9 | +from cryptography.hazmat.primitives.asymmetric import rsa |
| 10 | +from project2 import MyServer, reset_database, save_key_to_db, get_key_from_db, jwks_response |
| 11 | + |
| 12 | +HOST = "http://localhost:8080" |
| 13 | +DB_PATH = "totally_not_my_privateKeys.db" |
| 14 | + |
| 15 | +class TestProject2Server(unittest.TestCase): |
| 16 | + |
| 17 | + def setUp(self): |
| 18 | + reset_database() |
| 19 | + self.server = HTTPServer(("localhost", 8080), MyServer) |
| 20 | + self.server_thread = Thread(target=self.server.serve_forever) |
| 21 | + self.server_thread.daemon = True |
| 22 | + self.server_thread.start() |
| 23 | + print("Server started for testing...") |
| 24 | + |
| 25 | + def tearDown(self): |
| 26 | + self.server.shutdown() |
| 27 | + self.server.server_close() |
| 28 | + self.server_thread.join() |
| 29 | + print("Server stopped after testing.") |
| 30 | + |
| 31 | + def test_auth_endpoint_jwt_valid(self): |
| 32 | + """Test JWT generation for a valid (non-expired) token.""" |
| 33 | + response = requests.post(f"{HOST}/auth") |
| 34 | + self.assertEqual(response.status_code, 200) |
| 35 | + token = json.loads(response.text)["token"] |
| 36 | + decoded = jwt.decode(token, options={"verify_signature": False}) |
| 37 | + self.assertEqual(decoded["user"], "username") |
| 38 | + self.assertGreater(decoded["exp"], datetime.datetime.utcnow().timestamp()) |
| 39 | + |
| 40 | + def test_auth_endpoint_jwt_expired(self): |
| 41 | + """Test JWT generation for an expired token.""" |
| 42 | + response = requests.post(f"{HOST}/auth?expired=true") |
| 43 | + self.assertEqual(response.status_code, 200) |
| 44 | + token = json.loads(response.text)["token"] |
| 45 | + decoded = jwt.decode(token, options={"verify_signature": False}) |
| 46 | + self.assertEqual(decoded["user"], "username") |
| 47 | + self.assertLess(decoded["exp"], datetime.datetime.utcnow().timestamp()) |
| 48 | + |
| 49 | + def test_jwks_json_endpoint(self): |
| 50 | + """Test that the JWKS endpoint returns valid JSON Web Key data.""" |
| 51 | + response = requests.get(f"{HOST}/.well-known/jwks.json") |
| 52 | + self.assertEqual(response.status_code, 200) |
| 53 | + keys = json.loads(response.text)["keys"] |
| 54 | + self.assertGreater(len(keys), 0, "No keys returned in JWKS") |
| 55 | + for key in keys: |
| 56 | + self.assertIn("kid", key) |
| 57 | + self.assertIn("n", key) |
| 58 | + self.assertIn("e", key) |
| 59 | + self.assertEqual(key["alg"], "RS256") |
| 60 | + self.assertEqual(key["kty"], "RSA") |
| 61 | + self.assertEqual(key["use"], "sig") |
| 62 | + |
| 63 | + def test_database_key_insertion_and_retrieval(self): |
| 64 | + """Test saving and retrieving keys from the database.""" |
| 65 | + expiry_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp()) |
| 66 | + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 67 | + kid = save_key_to_db(private_key, expiry_time, fixed_kid=99) # Example fixed_kid provided |
| 68 | + |
| 69 | + retrieved_kid, retrieved_key = get_key_from_db(expired=False) |
| 70 | + self.assertEqual(kid, retrieved_kid, "Retrieved key ID does not match saved key ID.") |
| 71 | + self.assertIsNotNone(retrieved_key, "No key was retrieved from the database.") |
| 72 | + |
| 73 | + def test_jwks_response_only_unexpired(self): |
| 74 | + """Test that only unexpired keys appear in the JWKS response.""" |
| 75 | + expired_time = int((datetime.datetime.utcnow() - datetime.timedelta(hours=1)).timestamp()) |
| 76 | + unexpired_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp()) |
| 77 | + |
| 78 | + expired_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 79 | + unexpired_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 80 | + |
| 81 | + save_key_to_db(expired_key, expired_time, fixed_kid=101) |
| 82 | + save_key_to_db(unexpired_key, unexpired_time, fixed_kid=102) |
| 83 | + |
| 84 | + jwks = json.loads(jwks_response())["keys"] |
| 85 | + self.assertTrue(any(key["kid"] == "102" for key in jwks), "No unexpired keys found in JWKS") |
| 86 | + self.assertFalse(any(key["kid"] == "101" for key in jwks), "Expired keys should not appear in JWKS") |
| 87 | + |
| 88 | + def test_unsupported_methods(self): |
| 89 | + """Ensure unsupported HTTP methods return 405 or 501 status.""" |
| 90 | + for method in [requests.put, requests.delete, requests.patch, requests.head]: |
| 91 | + response = method(f"{HOST}/auth") |
| 92 | + self.assertIn(response.status_code, [405, 501], "Unexpected status for unsupported method") |
| 93 | + |
| 94 | + def test_database_cleared_between_tests(self): |
| 95 | + """Verify that the database resets between tests.""" |
| 96 | + conn = sqlite3.connect(DB_PATH) |
| 97 | + cursor = conn.cursor() |
| 98 | + cursor.execute("SELECT COUNT(*) FROM keys") |
| 99 | + row_count = cursor.fetchone()[0] |
| 100 | + conn.close() |
| 101 | + self.assertEqual(row_count, 0, "Database should be empty between tests") |
| 102 | + |
| 103 | + def test_jwks_response_structure(self): |
| 104 | + """Verify JWKS response format and contents.""" |
| 105 | + response = requests.get(f"{HOST}/.well-known/jwks.json") |
| 106 | + self.assertEqual(response.status_code, 200) |
| 107 | + data = json.loads(response.text) |
| 108 | + self.assertIn("keys", data, "JWKS response missing 'keys' field") |
| 109 | + self.assertIsInstance(data["keys"], list, "'keys' field should be a list") |
| 110 | + if data["keys"]: |
| 111 | + first_key = data["keys"][0] |
| 112 | + self.assertIn("alg", first_key) |
| 113 | + self.assertIn("kty", first_key) |
| 114 | + self.assertIn("use", first_key) |
| 115 | + self.assertIn("kid", first_key) |
| 116 | + self.assertIn("n", first_key) |
| 117 | + self.assertIn("e", first_key) |
| 118 | + |
| 119 | + def test_save_key_with_fixed_kid(self): |
| 120 | + """Test saving a key with a specified fixed_kid.""" |
| 121 | + expiry_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp()) |
| 122 | + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) |
| 123 | + save_key_to_db(private_key, expiry_time, fixed_kid=42) |
| 124 | + |
| 125 | + # Verify that the key was saved with the specified kid |
| 126 | + cursor = sqlite3.connect(DB_PATH).cursor() |
| 127 | + cursor.execute("SELECT kid FROM keys WHERE kid = ?", (42,)) |
| 128 | + result = cursor.fetchone() |
| 129 | + self.assertIsNotNone(result, "Key with specified kid=42 was not found in the database") |
| 130 | + self.assertEqual(result[0], 42, "Retrieved kid does not match the specified fixed_kid") |
| 131 | + |
| 132 | +if __name__ == "__main__": |
| 133 | + unittest.main() |
0 commit comments