176 lines
4.6 KiB
Python
176 lines
4.6 KiB
Python
import asyncio
|
|
from typing import Any
|
|
from uuid import UUID
|
|
from datetime import datetime
|
|
from bson import ObjectId
|
|
|
|
from fastapi import (
|
|
Body,
|
|
FastAPI,
|
|
HTTPException,
|
|
WebSocket,
|
|
WebSocketDisconnect,
|
|
status,
|
|
)
|
|
from fastapi.responses import JSONResponse
|
|
from pydantic import BaseModel, Field, ValidationError
|
|
|
|
from .db import PyObjectId, events
|
|
|
|
monitoring = FastAPI()
|
|
|
|
|
|
class ScreenContent(BaseModel):
|
|
x: int
|
|
y: int
|
|
width: int
|
|
height: int
|
|
blink: bool
|
|
fg: int
|
|
text: list[str]
|
|
fg_color: list[str]
|
|
bg_color: list[str]
|
|
palette: list[int]
|
|
|
|
|
|
class Update(BaseModel):
|
|
screen: ScreenContent | None
|
|
|
|
|
|
class WSManager:
|
|
def __init__(self):
|
|
self.computers: dict[UUID, WebSocket] = dict()
|
|
self.viewers: dict[UUID, set[WebSocket]] = dict()
|
|
self.queue: asyncio.Queue[tuple[UUID, any]] = asyncio.Queue()
|
|
|
|
async def send_connect(self, uuid: UUID):
|
|
if uuid not in self.computers:
|
|
return
|
|
|
|
await self.computers[uuid].send_json({"type": "viewer_connect"})
|
|
|
|
async def send_disconnect(self, uuid: UUID):
|
|
if uuid not in self.computers:
|
|
return
|
|
|
|
await self.computers[uuid].send_json({"type": "viewer_disconnect"})
|
|
|
|
async def queue_task(self):
|
|
print("[WS] queue task started")
|
|
while True:
|
|
(uuid, message) = await self.queue.get()
|
|
|
|
if uuid not in self.viewers:
|
|
continue
|
|
|
|
viewers = self.viewers[uuid]
|
|
await asyncio.gather(*(viewer.send_json(message) for viewer in viewers))
|
|
|
|
async def broadcast(self, uuid: UUID, message):
|
|
await self.queue.put((uuid, message))
|
|
|
|
async def on_computer_connect(self, socket: WebSocket, uuid: UUID):
|
|
if uuid in self.computers:
|
|
print(f"[WS] Closing duplicate connection for {uuid}")
|
|
await socket.close()
|
|
return
|
|
|
|
print(f"[WS] Computer {uuid} connected")
|
|
self.computers[uuid] = socket
|
|
|
|
if len(self.viewers.get(uuid, [])) > 0:
|
|
await self.send_connect(uuid)
|
|
|
|
while True:
|
|
try:
|
|
data = await socket.receive_json()
|
|
data = Update.parse_obj(data)
|
|
|
|
if data.screen:
|
|
await self.broadcast(uuid, data.screen.dict())
|
|
|
|
except ValidationError as e:
|
|
print(f"[WS] Received invalid message from {uuid}:")
|
|
print(e.json)
|
|
except WebSocketDisconnect:
|
|
break
|
|
|
|
del self.computers[uuid]
|
|
print(f"[WS] Computer {uuid} disconnected")
|
|
|
|
async def on_browser_connect(self, socket: WebSocket, uuid: UUID):
|
|
print(f"[WS] Browser connected for {uuid}")
|
|
|
|
if uuid not in self.viewers:
|
|
self.viewers[uuid] = set()
|
|
|
|
if len(self.viewers[uuid]) == 0:
|
|
await self.send_connect(uuid)
|
|
|
|
self.viewers[uuid].add(socket)
|
|
|
|
while True:
|
|
try:
|
|
data = await socket.receive_json()
|
|
except WebSocketDisconnect:
|
|
break
|
|
|
|
self.viewers[uuid].remove(socket)
|
|
if len(self.viewers[uuid]) == 0:
|
|
await self.send_disconnect(uuid)
|
|
|
|
print(f"[WS] Browser disconnected for {uuid}")
|
|
|
|
|
|
ws_manager = WSManager()
|
|
|
|
|
|
@monitoring.websocket("/computer/{uuid}/ws")
|
|
async def computer_ws(socket: WebSocket, uuid: UUID):
|
|
await socket.accept()
|
|
await ws_manager.on_computer_connect(socket, uuid)
|
|
|
|
|
|
@monitoring.websocket("/browser/{uuid}/ws")
|
|
async def browser_ws(socket: WebSocket, uuid: UUID):
|
|
await socket.accept()
|
|
await ws_manager.on_browser_connect(socket, uuid)
|
|
|
|
|
|
class Event(BaseModel):
|
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
|
timestamp: datetime
|
|
value: Any
|
|
|
|
class Config:
|
|
allow_population_by_field_name = True
|
|
arbitrary_types_allowed = True
|
|
json_encoders = {ObjectId: str}
|
|
|
|
|
|
@monitoring.get("/events", response_model=list[Event])
|
|
async def get_events():
|
|
print("get /events")
|
|
return await events.find().to_list(1000)
|
|
|
|
|
|
@monitoring.get("/events/{id}", response_model=Event)
|
|
async def get_single_event(id: PyObjectId):
|
|
if (event := await events.find_one({"_id": id})) is not None:
|
|
return event
|
|
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND)
|
|
|
|
|
|
@monitoring.post(
|
|
"/push_event", response_model=Event, status_code=status.HTTP_201_CREATED
|
|
)
|
|
async def push_event(value: Any = Body(...)):
|
|
event = {
|
|
"timestamp": datetime.now(),
|
|
"value": value,
|
|
}
|
|
new_event = await events.insert_one(event)
|
|
created_event = await events.find_one({"_id": new_event.inserted_id})
|
|
return created_event
|