This commit is contained in:
kacper 2026-03-06 22:51:19 -05:00
parent 6acf267d48
commit b7614eb3f8
4794 changed files with 1280808 additions and 1546 deletions

View file

@ -2,6 +2,7 @@ import asyncio
import audioop
import contextlib
import io
import json
import os
import re
import shlex
@ -815,11 +816,11 @@ SendJsonCallable = Callable[[dict[str, Any]], Awaitable[None]]
class WebRTCVoiceSession:
def __init__(self, gateway: "SuperTonicGateway", send_json: SendJsonCallable) -> None:
def __init__(self, gateway: "SuperTonicGateway") -> None:
self._gateway = gateway
self._send_json = send_json
self._pc: RTCPeerConnection | None = None
self._dc: Any | None = None # RTCDataChannel (aiortc)
self._outbound_track: QueueAudioTrack | None = None
self._incoming_audio_task: asyncio.Task[None] | None = None
self._stt_worker_task: asyncio.Task[None] | None = None
@ -874,12 +875,22 @@ class WebRTCVoiceSession:
float(os.getenv("HOST_STT_BACKLOG_NOTICE_INTERVAL_S", "6.0")),
)
self._last_stt_backlog_notice_at = 0.0
self._pending_ice_candidates: list[dict[str, Any] | None] = []
self._ptt_pressed = False
def set_push_to_talk_pressed(self, pressed: bool) -> None:
self._ptt_pressed = bool(pressed)
def send_to_datachannel(self, payload: dict[str, Any]) -> None:
"""Send a JSON message over the DataChannel if it is open."""
dc = self._dc
if dc is None:
return
try:
if dc.readyState == "open":
dc.send(json.dumps(payload))
except Exception:
pass
async def queue_output_text(self, chunk: str) -> None:
normalized_chunk = chunk.strip()
if not normalized_chunk:
@ -895,21 +906,14 @@ class WebRTCVoiceSession:
# fires once no new chunks arrive for the configured delay.
self._schedule_tts_flush_after(self._tts_response_end_delay_s)
async def handle_offer(self, payload: dict[str, Any]) -> None:
async def handle_offer(self, payload: dict[str, Any]) -> dict[str, Any] | None:
if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription:
await self._send_json(
{
"type": "rtc-error",
"message": "WebRTC backend unavailable on host (aiortc is not installed).",
}
)
return
return None
sdp = str(payload.get("sdp", "")).strip()
rtc_type = str(payload.get("rtcType", "offer")).strip() or "offer"
if not sdp:
await self._send_json({"type": "rtc-error", "message": "Missing SDP offer payload."})
return
return None
await self._close_peer_connection()
self._ptt_pressed = False
@ -920,16 +924,32 @@ class WebRTCVoiceSession:
self._outbound_track._on_playing_changed = self._on_track_playing_changed
peer_connection.addTrack(self._outbound_track)
@peer_connection.on("connectionstatechange")
def on_connectionstatechange() -> None:
asyncio.create_task(
self._send_json(
{
"type": "rtc-state",
"state": peer_connection.connectionState,
}
)
)
@peer_connection.on("datachannel")
def on_datachannel(channel: Any) -> None:
if channel.label != "app":
return
self._dc = channel
@channel.on("message")
def on_message(raw: str) -> None:
try:
msg = json.loads(raw)
except Exception:
return
msg_type = str(msg.get("type", "")).strip()
if msg_type == "voice-ptt":
self.set_push_to_talk_pressed(bool(msg.get("pressed", False)))
elif msg_type == "command":
asyncio.create_task(self._gateway.send_command(str(msg.get("command", ""))))
elif msg_type == "ui-response":
asyncio.create_task(
self._gateway.send_ui_response(
str(msg.get("request_id", "")),
str(msg.get("value", "")),
)
)
elif msg_type == "ping":
self.send_to_datachannel({"type": "pong"})
@peer_connection.on("track")
def on_track(track: MediaStreamTrack) -> None:
@ -943,7 +963,6 @@ class WebRTCVoiceSession:
)
await peer_connection.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=rtc_type))
await self._drain_pending_ice_candidates(peer_connection)
answer = await peer_connection.createAnswer()
await peer_connection.setLocalDescription(answer)
await self._wait_for_ice_gathering(peer_connection)
@ -955,13 +974,6 @@ class WebRTCVoiceSession:
sdp_answer.replace("\r\n", "\n").replace("\r", "\n").strip().replace("\n", "\r\n")
+ "\r\n"
)
await self._send_json(
{
"type": "rtc-answer",
"sdp": sdp_answer,
"rtcType": local_description.type,
}
)
if self._stt.enabled and not self._stt_worker_task:
self._stt_worker_task = asyncio.create_task(self._stt_worker(), name="voice-stt-worker")
@ -973,62 +985,10 @@ class WebRTCVoiceSession:
f"Voice input backend unavailable. {self._stt.unavailable_reason()}"
)
async def handle_ice_candidate(self, payload: dict[str, Any]) -> None:
if not AIORTC_AVAILABLE:
return
raw_candidate = payload.get("candidate")
candidate_payload: dict[str, Any] | None
if raw_candidate in (None, ""):
candidate_payload = None
elif isinstance(raw_candidate, dict):
candidate_payload = raw_candidate
else:
return
if not self._pc or self._pc.remoteDescription is None:
self._pending_ice_candidates.append(candidate_payload)
return
await self._apply_ice_candidate(self._pc, candidate_payload)
async def _drain_pending_ice_candidates(
self,
peer_connection: RTCPeerConnection,
) -> None:
if not self._pending_ice_candidates:
return
pending = list(self._pending_ice_candidates)
self._pending_ice_candidates.clear()
for pending_candidate in pending:
await self._apply_ice_candidate(peer_connection, pending_candidate)
async def _apply_ice_candidate(
self,
peer_connection: RTCPeerConnection,
raw_candidate: dict[str, Any] | None,
) -> None:
if raw_candidate is None:
with contextlib.suppress(Exception):
await peer_connection.addIceCandidate(None)
return
candidate_sdp = str(raw_candidate.get("candidate", "")).strip()
if not candidate_sdp or not candidate_from_sdp:
return
if candidate_sdp.startswith("candidate:"):
candidate_sdp = candidate_sdp[len("candidate:") :]
try:
candidate = candidate_from_sdp(candidate_sdp)
candidate.sdpMid = raw_candidate.get("sdpMid")
line_index = raw_candidate.get("sdpMLineIndex")
candidate.sdpMLineIndex = int(line_index) if line_index is not None else None
await peer_connection.addIceCandidate(candidate)
except Exception as exc:
await self._publish_system(f"Failed to add ICE candidate: {exc}")
return {
"sdp": sdp_answer,
"rtcType": local_description.type,
}
async def close(self) -> None:
self._closed = True
@ -1317,6 +1277,8 @@ class WebRTCVoiceSession:
await self._gateway.send_user_message(transcript)
async def _close_peer_connection(self) -> None:
self._dc = None
if self._outbound_track:
self._outbound_track.stop()
self._outbound_track = None
@ -1325,8 +1287,6 @@ class WebRTCVoiceSession:
await self._pc.close()
self._pc = None
self._pending_ice_candidates.clear()
async def _wait_for_ice_gathering(self, peer_connection: RTCPeerConnection) -> None:
if peer_connection.iceGatheringState == "complete":
return