103 lines
3.2 KiB
Python
103 lines
3.2 KiB
Python
import asyncio
|
|
import contextlib
|
|
import json
|
|
from pathlib import Path
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
from fastapi.responses import FileResponse, JSONResponse
|
|
|
|
from supertonic_gateway import SuperTonicGateway
|
|
from voice_rtc import WebRTCVoiceSession
|
|
|
|
|
|
BASE_DIR = Path(__file__).resolve().parent
|
|
INDEX_PATH = BASE_DIR / "static" / "index.html"
|
|
|
|
app = FastAPI(title="Nanobot SuperTonic Wisper Web")
|
|
gateway = SuperTonicGateway()
|
|
|
|
|
|
@app.get("/health")
|
|
async def health() -> JSONResponse:
|
|
return JSONResponse({"status": "ok"})
|
|
|
|
|
|
@app.get("/")
|
|
async def index() -> FileResponse:
|
|
return FileResponse(INDEX_PATH)
|
|
|
|
|
|
@app.websocket("/ws/chat")
|
|
async def websocket_chat(websocket: WebSocket) -> None:
|
|
await websocket.accept()
|
|
send_lock = asyncio.Lock()
|
|
|
|
async def safe_send_json(payload: dict[str, Any]) -> None:
|
|
async with send_lock:
|
|
await websocket.send_json(payload)
|
|
|
|
queue = await gateway.subscribe()
|
|
voice_session = WebRTCVoiceSession(gateway=gateway, send_json=safe_send_json)
|
|
sender = asyncio.create_task(_sender_loop(safe_send_json, queue, voice_session))
|
|
try:
|
|
while True:
|
|
raw_message = await websocket.receive_text()
|
|
try:
|
|
message = json.loads(raw_message)
|
|
except json.JSONDecodeError:
|
|
await safe_send_json(
|
|
{"role": "system", "text": "Invalid JSON message.", "timestamp": ""}
|
|
)
|
|
continue
|
|
|
|
msg_type = str(message.get("type", "")).strip()
|
|
if msg_type == "spawn":
|
|
await gateway.spawn_tui()
|
|
elif msg_type == "stop":
|
|
await gateway.stop_tui()
|
|
elif msg_type == "rtc-offer":
|
|
await voice_session.handle_offer(message)
|
|
elif msg_type == "rtc-ice-candidate":
|
|
await voice_session.handle_ice_candidate(message)
|
|
elif msg_type == "voice-ptt":
|
|
voice_session.set_push_to_talk_pressed(
|
|
bool(message.get("pressed", False))
|
|
)
|
|
else:
|
|
await safe_send_json(
|
|
{
|
|
"role": "system",
|
|
"text": (
|
|
"Unknown message type. Use spawn, stop, rtc-offer, "
|
|
"rtc-ice-candidate, or voice-ptt."
|
|
),
|
|
"timestamp": "",
|
|
}
|
|
)
|
|
except WebSocketDisconnect:
|
|
pass
|
|
finally:
|
|
sender.cancel()
|
|
with contextlib.suppress(asyncio.CancelledError):
|
|
await sender
|
|
await voice_session.close()
|
|
await gateway.unsubscribe(queue)
|
|
|
|
|
|
@app.on_event("shutdown")
|
|
async def on_shutdown() -> None:
|
|
await gateway.shutdown()
|
|
|
|
|
|
async def _sender_loop(
|
|
send_json: Callable[[dict[str, Any]], Awaitable[None]],
|
|
queue: asyncio.Queue,
|
|
voice_session: WebRTCVoiceSession,
|
|
) -> None:
|
|
while True:
|
|
event = await queue.get()
|
|
if event.role == "nanobot-tts":
|
|
await voice_session.queue_output_text(event.text)
|
|
continue
|
|
await send_json(event.to_dict())
|