from __future__ import annotations import asyncio import contextlib from collections.abc import Awaitable, Callable from typing import Any from card_store import persist_card from message_pipeline import chunk_tts_text, typed_message_from_gateway_event from voice_rtc import WebRTCVoiceSession PublishCallback = Callable[[Any], Awaitable[None]] class RtcSessionManager: def __init__( self, *, gateway, publish_cards_changed: PublishCallback, publish_workbench_changed: PublishCallback, publish_sessions_changed: Callable[[], Awaitable[None]], ) -> None: self._gateway = gateway self._publish_cards_changed = publish_cards_changed self._publish_workbench_changed = publish_workbench_changed self._publish_sessions_changed = publish_sessions_changed self._active_session: WebRTCVoiceSession | None = None self._active_queue: asyncio.Queue | None = None self._sender_task: asyncio.Task | None = None async def handle_offer(self, payload: dict[str, Any]) -> dict[str, Any] | None: await self._close_active_session() queue = await self._gateway.subscribe() self._active_queue = queue voice_session = WebRTCVoiceSession(gateway=self._gateway) self._active_session = voice_session self._sender_task = asyncio.create_task( self._sender_loop(queue, voice_session), name="rtc-sender", ) answer = await voice_session.handle_offer(payload) if answer is None: await self._close_active_session() return None await self._gateway.connect_nanobot() return answer async def shutdown(self) -> None: await self._close_active_session() async def _close_active_session(self) -> None: if self._sender_task is not None: self._sender_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._sender_task self._sender_task = None if self._active_session is not None: await self._active_session.close() self._active_session = None if self._active_queue is not None: await self._gateway.unsubscribe(self._active_queue) self._active_queue = None async def _sender_loop( self, queue: asyncio.Queue, voice_session: WebRTCVoiceSession, ) -> None: while True: event = await queue.get() if event.role == "nanobot-tts-partial": await voice_session.queue_output_text(event.text, partial=True) continue if event.role == "nanobot-tts-flush": await voice_session.flush_partial_output_text() continue if event.role == "nanobot-tts": for segment in chunk_tts_text(event.text): await voice_session.queue_output_text(segment) continue typed_event = typed_message_from_gateway_event(event.to_dict()) if typed_event is None: continue if typed_event.get("type") == "card": persisted = persist_card(typed_event) if persisted is None: continue payload = dict(persisted) payload["type"] = "card" await self._publish_cards_changed(payload.get("chat_id")) voice_session.send_to_datachannel(payload) continue if typed_event.get("type") == "workbench": await self._publish_workbench_changed(typed_event.get("chat_id")) voice_session.send_to_datachannel(typed_event) continue if ( typed_event.get("type") == "message" and not bool(typed_event.get("is_progress", False)) and typed_event.get("role") == "nanobot" ): await self._publish_sessions_changed() voice_session.send_to_datachannel(typed_event)