-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathwebsocket.py
executable file
·225 lines (176 loc) · 7.37 KB
/
websocket.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
#!/usr/bin/env -S BACKEND_ENV=config.env python3
import json
from fastapi import Body, FastAPI, WebSocket, WebSocketDisconnect, status
from utils.jwt import decode_jwt
import asyncio
from hypercorn.config import Config
from utils.logger import changeLogFile, logging_schema
from utils.settings import settings
from hypercorn.asyncio import serve
from models.enums import Status as ResponseStatus
from fastapi.responses import JSONResponse
from fastapi.logger import logger
all_MKW = "__ALL__"
class ConnectionManager:
def __init__(self) -> None:
self.active_connections: dict[str, WebSocket] = {}
# item => channel_name: list_of_subscribed_clients
self.channels: dict[str, list[str]] = {}
async def connect(self, websocket: WebSocket, user_shortname: str):
await websocket.accept()
self.active_connections[user_shortname] = websocket
def disconnect(self, user_shortname: str):
del self.active_connections[user_shortname]
async def send_message(self, message: str, user_shortname: str):
if user_shortname in self.active_connections:
await self.active_connections[user_shortname].send_text(message)
return True
return False
async def broadcast_message(self, message: str, channel_name: str):
if channel_name not in self.channels:
return False
for user_shortname in self.channels[channel_name]:
await self.send_message(message, user_shortname)
return True
def remove_all_subscriptions(self, username: str):
updated_channels: dict[str, list[str]] = {}
for channel_name, users in self.channels.items():
if username in users:
users.remove(username)
updated_channels[channel_name] = users
self.channels = updated_channels
async def channel_unsubscribe(self, websocket: WebSocket):
connections_usernames = list(self.active_connections.keys())
connections = list(self.active_connections.values())
username = connections_usernames[connections.index(websocket)]
self.remove_all_subscriptions(username)
subscribed_message = json.dumps({
"type": "notification_unsubscribe",
"message": {
"status": "success"
}
})
await self.send_message(subscribed_message, username)
def generate_channel_name(self, msg: dict):
if not {"space_name", "subpath"}.issubset(msg):
return False
space_name = msg["space_name"]
subpath = msg["subpath"]
schema_shortname = msg.get("schema_shortname", all_MKW)
action_type = msg.get("action_type", all_MKW)
ticket_state = msg.get("ticket_state", all_MKW)
return f"{space_name}:{subpath}:{schema_shortname}:{action_type}:{ticket_state}"
async def channel_subscribe(
self,
websocket: WebSocket,
msg_json: dict
):
channel_name = self.generate_channel_name(msg_json)
if not channel_name:
return False
self.channels.setdefault(channel_name, [])
connections_usernames = list(self.active_connections.keys())
connections = list(self.active_connections.values())
username = connections_usernames[connections.index(websocket)]
self.remove_all_subscriptions(username)
self.channels[channel_name].append(username)
subscribed_message = json.dumps({
"type": "notification_subscription",
"message": {
"status": "success"
}
})
await self.send_message(subscribed_message, username)
websocket_manager = ConnectionManager()
app = FastAPI()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket, token: str):
try:
decoded_token = decode_jwt(token)
except Exception:
return status.HTTP_401_UNAUTHORIZED, [], b"Invalid token\n"
user_shortname = decoded_token["shortname"]
try:
await websocket_manager.connect(websocket, user_shortname)
except Exception as e:
return status.HTTP_500_INTERNAL_SERVER_ERROR, [], str(e.__str__()).encode()
success_connection_message = json.dumps({
"type": "connection_response",
"message": {
"status": "success"
}
})
try:
await websocket_manager.send_message(success_connection_message, user_shortname)
except Exception as e:
return status.HTTP_500_INTERNAL_SERVER_ERROR, [], str(e.__str__()).encode()
try:
while True:
try:
msg = await websocket.receive_text()
msg_json = json.loads(msg)
if "type" in msg_json and msg_json["type"] == "notification_subscription":
await websocket_manager.channel_subscribe(websocket, msg_json)
if "type" in msg_json and msg_json["type"] == "notification_unsubscribe":
await websocket_manager.channel_unsubscribe(websocket)
except Exception as e:
logger.error(f"Error while processing message: {e.__str__()}", extra={"user_shortname": user_shortname})
break
except WebSocketDisconnect:
logger.info("WebSocket connection closed", extra={"user_shortname": user_shortname})
websocket_manager.disconnect(user_shortname)
@app.api_route(path="/send-message/{user_shortname}", methods=["post"])
async def send_message(user_shortname: str, message: dict = Body(...)):
formatted_message = json.dumps({
"type": "message",
"message": message
})
is_sent = await websocket_manager.send_message(formatted_message, user_shortname)
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"status": ResponseStatus.success, "message_sent": is_sent}
)
@app.api_route(path="/broadcast-to-channels", methods=["post"])
async def broadcast(data: dict = Body(...)):
formatted_message = json.dumps({
"type": "notification_subscription",
"message": data["message"]
})
is_sent = False
for channel_name in data["channels"]:
is_sent = await websocket_manager.broadcast_message(formatted_message, channel_name) or is_sent
return JSONResponse(
status_code=status.HTTP_200_OK,
content={"status": ResponseStatus.success, "message_sent": is_sent}
)
@app.api_route(path="/info", methods=["get"])
async def service_info():
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"status": ResponseStatus.success,
"data": {
"connected_clients": str(websocket_manager.active_connections),
"channels": str(websocket_manager.channels)
}
}
)
@app.on_event("startup")
async def app_startup() -> None:
logger.info("Starting up")
print('{"stage":"starting up"}')
@app.on_event("shutdown")
async def app_shutdown() -> None:
logger.info("Application shutting down")
print('{"stage":"shutting down"}')
async def main():
config = Config()
config.bind = [f"{settings.listening_host}:{settings.websocket_port}"]
config.backlog = 200
changeLogFile(settings.ws_log_file)
config.logconfig_dict = logging_schema
config.errorlog = logger
config.accesslog = logger
await serve(app, config) # type: ignore
if __name__ == "__main__":
asyncio.run(main())