@@ -95,8 +95,15 @@ def decrypt(encrypted_info):
95
95
no_padding_decrypted_info = decrypted_info .rstrip (b' ' )
96
96
return no_padding_decrypted_info
97
97
98
+ def log_auth_request (ip , user_id ):
99
+ db_cursor .execute (
100
+ "INSERT INTO auth_logs (request_ip, user_id) VALUES (?, ?)" ,
101
+ (ip , user_id )
102
+ )
103
+ db_connection .commit ()
104
+
98
105
def int_to_base64 (value ):
99
- """Convert an integer to a Base64URL- encoded string."""
106
+ # converts int to a base64 encoded string.
100
107
value_hex = format (value , 'x' )
101
108
if len (value_hex ) % 2 == 1 :
102
109
value_hex = '0' + value_hex
@@ -105,7 +112,7 @@ def int_to_base64(value):
105
112
return encoded .decode ('utf-8' )
106
113
107
114
def jwks_response ():
108
- """Generate JWKS JSON from unexpired keys."""
115
+ # generate JWKS JSON from unexpired keys
109
116
print ("jwks_response called" , flush = True )
110
117
all_keys = get_valid_keys_for_jwks ()
111
118
keys = [
@@ -123,7 +130,7 @@ def jwks_response():
123
130
return json .dumps ({"keys" : keys })
124
131
125
132
def save_key_to_db (key , expiry , fixed_kid = None ):
126
- """ Saves key to db, encrypted"""
133
+ # Saves key to db, encrypted
127
134
pem_key = key .private_bytes (
128
135
encoding = serialization .Encoding .PEM ,
129
136
format = serialization .PrivateFormat .PKCS8 ,
@@ -137,7 +144,7 @@ def save_key_to_db(key, expiry, fixed_kid=None):
137
144
db_connection .commit ()
138
145
139
146
def get_key_from_db (expired = False ):
140
- """ Gets key from db then decrypts it"""
147
+ # Gets key from db then decrypts it
141
148
current_time = int (datetime .datetime .utcnow ().timestamp ())
142
149
db_cursor .execute (
143
150
"SELECT kid, key FROM keys WHERE exp {} ? ORDER BY exp {} LIMIT 1" .format (
@@ -153,15 +160,15 @@ def get_key_from_db(expired=False):
153
160
return None , None
154
161
155
162
def get_valid_keys_for_jwks ():
156
- """Retrieve all unexpired keys for JWKS."""
163
+ # gets unexpired keys
157
164
current_time = int (datetime .datetime .utcnow ().timestamp ())
158
165
db_cursor .execute ("SELECT kid, key FROM keys WHERE exp > ?" , (current_time ,))
159
166
result = db_cursor .fetchall ()
160
167
print ("All valid keys for JWKS:" , result , flush = True )
161
168
return result
162
169
163
170
def initialize_starter_keys ():
164
- """ Initialize one expired and one valid key in the database."""
171
+ # Initialize 1 valid and 1 expired key in db
165
172
print ("Initializing starter keys..." , flush = True )
166
173
current_time = int (datetime .datetime .utcnow ().timestamp ())
167
174
expired_time = current_time - 3600 # Expired an hour ago
@@ -178,7 +185,7 @@ def initialize_starter_keys():
178
185
print ("Valid key inserted" , flush = True )
179
186
180
187
def reset_database ():
181
- """ Reset the database by dropping and recreating the keys table."""
188
+ # Reset the database by dropping and recreating the keys table
182
189
print ("Resetting database..." , flush = True )
183
190
db_cursor .execute ("DROP TABLE IF EXISTS keys" )
184
191
db_cursor .execute ('''
@@ -201,33 +208,61 @@ def do_POST(self):
201
208
202
209
# AUTH
203
210
if parsed_path .path == "/auth" :
204
- expired = 'expired' in params
205
- kid , pem_key = get_key_from_db ( expired )
211
+ content_length = int ( self . headers [ 'Content-Length' ])
212
+ body = self . rfile . read ( content_length ). decode ( 'utf-8' )
206
213
207
- if pem_key is None :
208
- self .send_response (404 )
209
- self .end_headers ()
210
- self .wfile .write (b"Key not found" )
211
- return
212
214
213
- private_key = serialization .load_pem_private_key (pem_key , password = None )
214
- expiry_time = datetime .datetime .utcnow () + (datetime .timedelta (hours = 1 ) if not expired else datetime .timedelta (hours = - 1 ))
215
+ try :
216
+ data = json .loads (body )
217
+ username = data ["username" ]
218
+
219
+ user_id = self .get_user_id_by_username (username )
220
+ if user_id is None :
221
+ self .send_response (404 )
222
+ self .end_headers ()
223
+ self .wfile .write (b"User not found" )
224
+ return
225
+
226
+
227
+
228
+ expired = 'expired' in params
229
+ kid , pem_key = get_key_from_db (expired )
230
+
231
+ if pem_key is None :
232
+ self .send_response (404 )
233
+ self .end_headers ()
234
+ self .wfile .write (b"Key not found" )
235
+ return
236
+
237
+ private_key = serialization .load_pem_private_key (pem_key , password = None )
238
+ expiry_time = datetime .datetime .utcnow () + (datetime .timedelta (hours = 1 ) if not expired else datetime .timedelta (hours = - 1 ))
215
239
216
- headers = {"kid" : str (kid )}
217
- token_payload = {
218
- "user" : " username" ,
219
- "exp" : expiry_time .timestamp ()
220
- }
221
- encoded_jwt = jwt .encode (token_payload , private_key , algorithm = "RS256" , headers = headers )
240
+ headers = {"kid" : str (kid )}
241
+ token_payload = {
242
+ "user" : username ,
243
+ "exp" : expiry_time .timestamp ()
244
+ }
245
+ encoded_jwt = jwt .encode (token_payload , private_key , algorithm = "RS256" , headers = headers )
222
246
223
- self .send_response (200 )
224
- self .send_header ("Content-Type" , "application/json" )
225
- self .end_headers ()
226
- self .wfile .write (json .dumps ({"token" : encoded_jwt }).encode ("utf-8" ))
247
+ # Adding logging here
248
+ client_ip = self .client_address [0 ]
249
+ log_auth_request (client_ip , user_id )
250
+
251
+ self .send_response (200 )
252
+ self .send_header ("Content-Type" , "application/json" )
253
+ self .end_headers ()
254
+ self .wfile .write (json .dumps ({"token" : encoded_jwt }).encode ("utf-8" ))
255
+
256
+ except KeyError :
257
+ self .send_response (400 ) # username
258
+ self .end_headers ()
259
+ self .wfile .write (b"Missing 'username'" )
260
+ except json .JSONDecodeError :
261
+ self .send_response (400 ) # json
262
+ self .end_headers ()
263
+ self .wfile .write (b"Invalid JSON format" )
227
264
return
228
-
229
- self .send_response (405 )
230
- self .end_headers ()
265
+
231
266
232
267
# REGISTER
233
268
if parsed_path .path == "/register" :
@@ -239,12 +274,12 @@ def do_POST(self):
239
274
email = data ["email" ]
240
275
241
276
password = str (uuid .uuid4 ())
242
- hashed_password = PassHasher .hash (password )
277
+ password_hash = PassHasher .hash (password )
243
278
244
279
try :
245
280
db_cursor .execute (
246
- "INSERT INTO users (username, hashed_password , email) VALUES (?, ?, ?)" ,
247
- (username , hashed_password , email )
281
+ "INSERT INTO users (username, password_hash , email) VALUES (?, ?, ?)" ,
282
+ (username , password_hash , email )
248
283
)
249
284
db_connection .commit ()
250
285
@@ -253,7 +288,7 @@ def do_POST(self):
253
288
self .send_header ("Content-Type" , "application/json" )
254
289
self .end_headers ()
255
290
self .wfile .write (json .dumps ({"password" : password }).encode ("utf-8" ))
256
- except :
291
+ except sqlite3 . IntegrityError :
257
292
self .send_response (400 ) # error, throw 400
258
293
self .end_headers ()
259
294
self .wfile .write (b"Username and/or email address already exists" )
@@ -276,6 +311,11 @@ def do_GET(self):
276
311
277
312
self .send_response (405 )
278
313
self .end_headers ()
314
+
315
+ def get_user_id_by_username (self , username ):
316
+ db_cursor .execute ("SELECT id FROM users WHERE username = ?" , (username ,))
317
+ result = db_cursor .fetchone ()
318
+ return result [0 ] if result else None
279
319
280
320
# Initialize database and start server
281
321
reset_database ()
0 commit comments