From d1df4ff6a212087d4393f65c70c3bfc7514e1f65 Mon Sep 17 00:00:00 2001 From: Kai Vogelgesang Date: Fri, 23 Sep 2022 00:31:18 +0200 Subject: [PATCH] Implement viewer_{dis,}connect events --- lua/main.tl | 48 +++++++++++++++++++++++-------------- lua/socket.tl | 7 ++++++ server/server/monitoring.py | 24 ++++++++++++++++++- 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/lua/main.tl b/lua/main.tl index f26992f..1929f5c 100644 --- a/lua/main.tl +++ b/lua/main.tl @@ -13,6 +13,7 @@ local socket = Socket.new(ENDPOINT) print("[MAIN] Setup framebuffer") +local prev_term = term.current() local orig_native = term.native local buffer = Framebuffer.wrap(orig_native()) term.native = function(): term.Redirect @@ -100,23 +101,34 @@ while shell_running do else e = table.pack(os.pullEventRaw()) end - - for pid = 1, #tasks do - local task = tasks[pid] - if task.filter == nil or task.filter == e[1] or e[1] == "terminate" then - local ok, param = coroutine.resume(task.coro, table.unpack(e as {any})) - if not ok then - term.redirect(orig_native()) - term.clear() - term.setCursorPos(1,1) - print("OMEGABIG OOF") - print(("pid %d"):format(pid)) - error(param, 0) - else - task.filter = param as string - end - if pid == 1 and coroutine.status(task.coro) == "dead" then - shell_running = false + + if e[1] == "websocket_message" and e[2] == ENDPOINT then + local payload = json.decode(e[3] as string) as table + if payload["type"] == "push_event" then + event_queue:push(payload["event"] as table) + elseif payload["type"] == "viewer_connect" then + socket:signal_viewer_connect(true) + elseif payload["type"] == "viewer_disconnect" then + socket:signal_viewer_connect(false) + end + else + for pid = 1, #tasks do + local task = tasks[pid] + if task.filter == nil or task.filter == e[1] or e[1] == "terminate" then + local ok, param = coroutine.resume(task.coro, table.unpack(e as {any})) + if not ok then + term.redirect(orig_native()) + term.clear() + term.setCursorPos(1,1) + print("OMEGABIG OOF") + print(("pid %d"):format(pid)) + error(param, 0) + else + task.filter = param as string + end + if pid == 1 and coroutine.status(task.coro) == "dead" then + shell_running = false + end end end end @@ -125,6 +137,6 @@ end socket:close() term.native = orig_native -term.redirect(term.native()) +term.redirect(prev_term) term.clear() term.setCursorPos(1,1) \ No newline at end of file diff --git a/lua/socket.tl b/lua/socket.tl index 7c86f21..7ecc127 100644 --- a/lua/socket.tl +++ b/lua/socket.tl @@ -20,6 +20,7 @@ local record Socket send: function(self: Socket, message: string) reconnect: function(self: Socket) close: function(self: Socket) + signal_viewer_connect: function(self: Socket, connected: boolean) _endpoint: string _callback: StateCallback _ws: http.Websocket @@ -73,6 +74,12 @@ impl.close = function(self: Socket) self._ws.close() end +impl.signal_viewer_connect = function(self: Socket, connected: boolean) + if self:is_bad_state() then return end --how? + local new_state: State = connected and "viewer_connected" or "ok" + self:_set_state(new_state) +end + local function new(endpoint: string): Socket return setmetatable({ state = "reset", diff --git a/server/server/monitoring.py b/server/server/monitoring.py index 6801f58..0f9ea11 100644 --- a/server/server/monitoring.py +++ b/server/server/monitoring.py @@ -31,6 +31,18 @@ class WSManager: self.viewers: dict[UUID, set[WebSocket]] = dict() self.queue: asyncio.Queue[tuple[UUID, any]] = asyncio.Queue() + async def send_connect(self, uuid: UUID): + if uuid not in self.computers: + return + + await self.computers[uuid].send_json({"type": "viewer_connect"}) + + async def send_disconnect(self, uuid: UUID): + if uuid not in self.computers: + return + + await self.computers[uuid].send_json({"type": "viewer_disconnect"}) + async def queue_task(self): print("[WS] queue task started") while True: @@ -53,6 +65,10 @@ class WSManager: print(f"[WS] Computer {uuid} connected") self.computers[uuid] = socket + + if len(self.viewers.get(uuid, [])) > 0: + await self.send_connect(uuid) + while True: try: data = await socket.receive_json() @@ -76,6 +92,9 @@ class WSManager: if uuid not in self.viewers: self.viewers[uuid] = set() + if len(self.viewers[uuid]) == 0: + await self.send_connect(uuid) + self.viewers[uuid].add(socket) while True: @@ -83,8 +102,11 @@ class WSManager: data = await socket.receive_json() except WebSocketDisconnect: break - + self.viewers[uuid].remove(socket) + if len(self.viewers[uuid]) == 0: + await self.send_disconnect(uuid) + print(f"[WS] Browser disconnected for {uuid}")