Skip to content

Commit 4719e18

Browse files
authored
Add files via upload
1 parent 2f3b06e commit 4719e18

File tree

2 files changed

+133
-0
lines changed

2 files changed

+133
-0
lines changed

project2Test.PNG

14.8 KB
Loading

test2.py

+133
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import unittest
2+
import datetime
3+
import json
4+
import jwt
5+
import requests
6+
import sqlite3
7+
from threading import Thread
8+
from http.server import HTTPServer
9+
from cryptography.hazmat.primitives.asymmetric import rsa
10+
from project2 import MyServer, reset_database, save_key_to_db, get_key_from_db, jwks_response
11+
12+
HOST = "http://localhost:8080"
13+
DB_PATH = "totally_not_my_privateKeys.db"
14+
15+
class TestProject2Server(unittest.TestCase):
16+
17+
def setUp(self):
18+
reset_database()
19+
self.server = HTTPServer(("localhost", 8080), MyServer)
20+
self.server_thread = Thread(target=self.server.serve_forever)
21+
self.server_thread.daemon = True
22+
self.server_thread.start()
23+
print("Server started for testing...")
24+
25+
def tearDown(self):
26+
self.server.shutdown()
27+
self.server.server_close()
28+
self.server_thread.join()
29+
print("Server stopped after testing.")
30+
31+
def test_auth_endpoint_jwt_valid(self):
32+
"""Test JWT generation for a valid (non-expired) token."""
33+
response = requests.post(f"{HOST}/auth")
34+
self.assertEqual(response.status_code, 200)
35+
token = json.loads(response.text)["token"]
36+
decoded = jwt.decode(token, options={"verify_signature": False})
37+
self.assertEqual(decoded["user"], "username")
38+
self.assertGreater(decoded["exp"], datetime.datetime.utcnow().timestamp())
39+
40+
def test_auth_endpoint_jwt_expired(self):
41+
"""Test JWT generation for an expired token."""
42+
response = requests.post(f"{HOST}/auth?expired=true")
43+
self.assertEqual(response.status_code, 200)
44+
token = json.loads(response.text)["token"]
45+
decoded = jwt.decode(token, options={"verify_signature": False})
46+
self.assertEqual(decoded["user"], "username")
47+
self.assertLess(decoded["exp"], datetime.datetime.utcnow().timestamp())
48+
49+
def test_jwks_json_endpoint(self):
50+
"""Test that the JWKS endpoint returns valid JSON Web Key data."""
51+
response = requests.get(f"{HOST}/.well-known/jwks.json")
52+
self.assertEqual(response.status_code, 200)
53+
keys = json.loads(response.text)["keys"]
54+
self.assertGreater(len(keys), 0, "No keys returned in JWKS")
55+
for key in keys:
56+
self.assertIn("kid", key)
57+
self.assertIn("n", key)
58+
self.assertIn("e", key)
59+
self.assertEqual(key["alg"], "RS256")
60+
self.assertEqual(key["kty"], "RSA")
61+
self.assertEqual(key["use"], "sig")
62+
63+
def test_database_key_insertion_and_retrieval(self):
64+
"""Test saving and retrieving keys from the database."""
65+
expiry_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp())
66+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
67+
kid = save_key_to_db(private_key, expiry_time, fixed_kid=99) # Example fixed_kid provided
68+
69+
retrieved_kid, retrieved_key = get_key_from_db(expired=False)
70+
self.assertEqual(kid, retrieved_kid, "Retrieved key ID does not match saved key ID.")
71+
self.assertIsNotNone(retrieved_key, "No key was retrieved from the database.")
72+
73+
def test_jwks_response_only_unexpired(self):
74+
"""Test that only unexpired keys appear in the JWKS response."""
75+
expired_time = int((datetime.datetime.utcnow() - datetime.timedelta(hours=1)).timestamp())
76+
unexpired_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp())
77+
78+
expired_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
79+
unexpired_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
80+
81+
save_key_to_db(expired_key, expired_time, fixed_kid=101)
82+
save_key_to_db(unexpired_key, unexpired_time, fixed_kid=102)
83+
84+
jwks = json.loads(jwks_response())["keys"]
85+
self.assertTrue(any(key["kid"] == "102" for key in jwks), "No unexpired keys found in JWKS")
86+
self.assertFalse(any(key["kid"] == "101" for key in jwks), "Expired keys should not appear in JWKS")
87+
88+
def test_unsupported_methods(self):
89+
"""Ensure unsupported HTTP methods return 405 or 501 status."""
90+
for method in [requests.put, requests.delete, requests.patch, requests.head]:
91+
response = method(f"{HOST}/auth")
92+
self.assertIn(response.status_code, [405, 501], "Unexpected status for unsupported method")
93+
94+
def test_database_cleared_between_tests(self):
95+
"""Verify that the database resets between tests."""
96+
conn = sqlite3.connect(DB_PATH)
97+
cursor = conn.cursor()
98+
cursor.execute("SELECT COUNT(*) FROM keys")
99+
row_count = cursor.fetchone()[0]
100+
conn.close()
101+
self.assertEqual(row_count, 0, "Database should be empty between tests")
102+
103+
def test_jwks_response_structure(self):
104+
"""Verify JWKS response format and contents."""
105+
response = requests.get(f"{HOST}/.well-known/jwks.json")
106+
self.assertEqual(response.status_code, 200)
107+
data = json.loads(response.text)
108+
self.assertIn("keys", data, "JWKS response missing 'keys' field")
109+
self.assertIsInstance(data["keys"], list, "'keys' field should be a list")
110+
if data["keys"]:
111+
first_key = data["keys"][0]
112+
self.assertIn("alg", first_key)
113+
self.assertIn("kty", first_key)
114+
self.assertIn("use", first_key)
115+
self.assertIn("kid", first_key)
116+
self.assertIn("n", first_key)
117+
self.assertIn("e", first_key)
118+
119+
def test_save_key_with_fixed_kid(self):
120+
"""Test saving a key with a specified fixed_kid."""
121+
expiry_time = int((datetime.datetime.utcnow() + datetime.timedelta(hours=1)).timestamp())
122+
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
123+
save_key_to_db(private_key, expiry_time, fixed_kid=42)
124+
125+
# Verify that the key was saved with the specified kid
126+
cursor = sqlite3.connect(DB_PATH).cursor()
127+
cursor.execute("SELECT kid FROM keys WHERE kid = ?", (42,))
128+
result = cursor.fetchone()
129+
self.assertIsNotNone(result, "Key with specified kid=42 was not found in the database")
130+
self.assertEqual(result[0], 42, "Retrieved kid does not match the specified fixed_kid")
131+
132+
if __name__ == "__main__":
133+
unittest.main()

0 commit comments

Comments
 (0)