110 lines
4 KiB
Python
110 lines
4 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import contextlib
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from card_store import persist_card
|
||
|
|
from message_pipeline import chunk_tts_text, typed_message_from_gateway_event
|
||
|
|
from voice_rtc import WebRTCVoiceSession
|
||
|
|
|
||
|
|
PublishCallback = Callable[[Any], Awaitable[None]]
|
||
|
|
|
||
|
|
|
||
|
|
class RtcSessionManager:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
gateway,
|
||
|
|
publish_cards_changed: PublishCallback,
|
||
|
|
publish_workbench_changed: PublishCallback,
|
||
|
|
publish_sessions_changed: Callable[[], Awaitable[None]],
|
||
|
|
) -> None:
|
||
|
|
self._gateway = gateway
|
||
|
|
self._publish_cards_changed = publish_cards_changed
|
||
|
|
self._publish_workbench_changed = publish_workbench_changed
|
||
|
|
self._publish_sessions_changed = publish_sessions_changed
|
||
|
|
self._active_session: WebRTCVoiceSession | None = None
|
||
|
|
self._active_queue: asyncio.Queue | None = None
|
||
|
|
self._sender_task: asyncio.Task | None = None
|
||
|
|
|
||
|
|
async def handle_offer(self, payload: dict[str, Any]) -> dict[str, Any] | None:
|
||
|
|
await self._close_active_session()
|
||
|
|
|
||
|
|
queue = await self._gateway.subscribe()
|
||
|
|
self._active_queue = queue
|
||
|
|
|
||
|
|
voice_session = WebRTCVoiceSession(gateway=self._gateway)
|
||
|
|
self._active_session = voice_session
|
||
|
|
self._sender_task = asyncio.create_task(
|
||
|
|
self._sender_loop(queue, voice_session),
|
||
|
|
name="rtc-sender",
|
||
|
|
)
|
||
|
|
|
||
|
|
answer = await voice_session.handle_offer(payload)
|
||
|
|
if answer is None:
|
||
|
|
await self._close_active_session()
|
||
|
|
return None
|
||
|
|
|
||
|
|
await self._gateway.connect_nanobot()
|
||
|
|
return answer
|
||
|
|
|
||
|
|
async def shutdown(self) -> None:
|
||
|
|
await self._close_active_session()
|
||
|
|
|
||
|
|
async def _close_active_session(self) -> None:
|
||
|
|
if self._sender_task is not None:
|
||
|
|
self._sender_task.cancel()
|
||
|
|
with contextlib.suppress(asyncio.CancelledError):
|
||
|
|
await self._sender_task
|
||
|
|
self._sender_task = None
|
||
|
|
if self._active_session is not None:
|
||
|
|
await self._active_session.close()
|
||
|
|
self._active_session = None
|
||
|
|
if self._active_queue is not None:
|
||
|
|
await self._gateway.unsubscribe(self._active_queue)
|
||
|
|
self._active_queue = None
|
||
|
|
|
||
|
|
async def _sender_loop(
|
||
|
|
self,
|
||
|
|
queue: asyncio.Queue,
|
||
|
|
voice_session: WebRTCVoiceSession,
|
||
|
|
) -> None:
|
||
|
|
while True:
|
||
|
|
event = await queue.get()
|
||
|
|
if event.role == "nanobot-tts-partial":
|
||
|
|
await voice_session.queue_output_text(event.text, partial=True)
|
||
|
|
continue
|
||
|
|
if event.role == "nanobot-tts-flush":
|
||
|
|
await voice_session.flush_partial_output_text()
|
||
|
|
continue
|
||
|
|
if event.role == "nanobot-tts":
|
||
|
|
for segment in chunk_tts_text(event.text):
|
||
|
|
await voice_session.queue_output_text(segment)
|
||
|
|
continue
|
||
|
|
|
||
|
|
typed_event = typed_message_from_gateway_event(event.to_dict())
|
||
|
|
if typed_event is None:
|
||
|
|
continue
|
||
|
|
if typed_event.get("type") == "card":
|
||
|
|
persisted = persist_card(typed_event)
|
||
|
|
if persisted is None:
|
||
|
|
continue
|
||
|
|
payload = dict(persisted)
|
||
|
|
payload["type"] = "card"
|
||
|
|
await self._publish_cards_changed(payload.get("chat_id"))
|
||
|
|
voice_session.send_to_datachannel(payload)
|
||
|
|
continue
|
||
|
|
if typed_event.get("type") == "workbench":
|
||
|
|
await self._publish_workbench_changed(typed_event.get("chat_id"))
|
||
|
|
voice_session.send_to_datachannel(typed_event)
|
||
|
|
continue
|
||
|
|
if (
|
||
|
|
typed_event.get("type") == "message"
|
||
|
|
and not bool(typed_event.get("is_progress", False))
|
||
|
|
and typed_event.get("role") == "nanobot"
|
||
|
|
):
|
||
|
|
await self._publish_sessions_changed()
|
||
|
|
voice_session.send_to_datachannel(typed_event)
|