nanobot-voice-interface/rtc_manager.py

110 lines
4 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import asyncio
import contextlib
from collections.abc import Awaitable, Callable
from typing import Any
from card_store import persist_card
from message_pipeline import chunk_tts_text, typed_message_from_gateway_event
from voice_rtc import WebRTCVoiceSession
PublishCallback = Callable[[Any], Awaitable[None]]
class RtcSessionManager:
def __init__(
self,
*,
gateway,
publish_cards_changed: PublishCallback,
publish_workbench_changed: PublishCallback,
publish_sessions_changed: Callable[[], Awaitable[None]],
) -> None:
self._gateway = gateway
self._publish_cards_changed = publish_cards_changed
self._publish_workbench_changed = publish_workbench_changed
self._publish_sessions_changed = publish_sessions_changed
self._active_session: WebRTCVoiceSession | None = None
self._active_queue: asyncio.Queue | None = None
self._sender_task: asyncio.Task | None = None
async def handle_offer(self, payload: dict[str, Any]) -> dict[str, Any] | None:
await self._close_active_session()
queue = await self._gateway.subscribe()
self._active_queue = queue
voice_session = WebRTCVoiceSession(gateway=self._gateway)
self._active_session = voice_session
self._sender_task = asyncio.create_task(
self._sender_loop(queue, voice_session),
name="rtc-sender",
)
answer = await voice_session.handle_offer(payload)
if answer is None:
await self._close_active_session()
return None
await self._gateway.connect_nanobot()
return answer
async def shutdown(self) -> None:
await self._close_active_session()
async def _close_active_session(self) -> None:
if self._sender_task is not None:
self._sender_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._sender_task
self._sender_task = None
if self._active_session is not None:
await self._active_session.close()
self._active_session = None
if self._active_queue is not None:
await self._gateway.unsubscribe(self._active_queue)
self._active_queue = None
async def _sender_loop(
self,
queue: asyncio.Queue,
voice_session: WebRTCVoiceSession,
) -> None:
while True:
event = await queue.get()
if event.role == "nanobot-tts-partial":
await voice_session.queue_output_text(event.text, partial=True)
continue
if event.role == "nanobot-tts-flush":
await voice_session.flush_partial_output_text()
continue
if event.role == "nanobot-tts":
for segment in chunk_tts_text(event.text):
await voice_session.queue_output_text(segment)
continue
typed_event = typed_message_from_gateway_event(event.to_dict())
if typed_event is None:
continue
if typed_event.get("type") == "card":
persisted = persist_card(typed_event)
if persisted is None:
continue
payload = dict(persisted)
payload["type"] = "card"
await self._publish_cards_changed(payload.get("chat_id"))
voice_session.send_to_datachannel(payload)
continue
if typed_event.get("type") == "workbench":
await self._publish_workbench_changed(typed_event.get("chat_id"))
voice_session.send_to_datachannel(typed_event)
continue
if (
typed_event.get("type") == "message"
and not bool(typed_event.get("is_progress", False))
and typed_event.get("role") == "nanobot"
):
await self._publish_sessions_changed()
voice_session.send_to_datachannel(typed_event)