feat: polish life os cards and voice stack

This commit is contained in:
kacper 2026-03-24 08:54:47 -04:00
parent 66362c7176
commit 0edf8c3fef
21 changed files with 3681 additions and 502 deletions

View file

@ -1,5 +1,6 @@
import asyncio
import audioop
import base64
import contextlib
import io
import json
@ -7,11 +8,15 @@ import os
import re
import shlex
import shutil
import socket
import subprocess
import sys
import tempfile
import time
import wave
from dataclasses import dataclass
from fractions import Fraction
from pathlib import Path
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from wisper import WisperEvent
@ -83,6 +88,9 @@ BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]")
TTS_ALLOWED_ASCII = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 .,!?;:'\"()[]{}@#%&*+-_/<>|"
)
TTS_WORD_RE = re.compile(r"[A-Za-z0-9]")
TTS_RETRY_BREAK_RE = re.compile(r"(?<=[.!?,;:])\s+")
TTS_PARTIAL_COMMIT_RE = re.compile(r"[.!?]\s*$|[,;:]\s+$")
LOCAL_ICE_GATHER_TIMEOUT_S = 0.35
@ -94,9 +102,39 @@ def _sanitize_tts_text(text: str) -> str:
cleaned = CONTROL_CHAR_RE.sub(" ", cleaned)
cleaned = "".join(ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
if not TTS_WORD_RE.search(cleaned):
return ""
return cleaned
def _split_tts_retry_segments(text: str, max_chars: int = 120) -> list[str]:
clean = _sanitize_tts_text(text)
if not clean:
return []
parts = [part.strip() for part in TTS_RETRY_BREAK_RE.split(clean) if part.strip()]
if len(parts) <= 1:
words = clean.split()
if len(words) <= 1:
return []
parts = []
current = words[0]
for word in words[1:]:
candidate = f"{current} {word}"
if len(candidate) <= max_chars:
current = candidate
continue
parts.append(current)
current = word
parts.append(current)
compact_parts = [_sanitize_tts_text(part) for part in parts]
compact_parts = [part for part in compact_parts if part]
if len(compact_parts) <= 1:
return []
return compact_parts
def _coerce_message_metadata(raw: Any) -> dict[str, Any]:
def _coerce_jsonish(value: Any, depth: int = 0) -> Any:
if depth > 6:
@ -143,6 +181,12 @@ class PCMChunk:
channels: int = 1
@dataclass(slots=True)
class STTSegment:
pcm: bytes
metadata: dict[str, Any]
if AIORTC_AVAILABLE:
class QueueAudioTrack(MediaStreamTrack):
@ -275,6 +319,19 @@ if AIORTC_AVAILABLE:
self._closed = True
super().stop()
def clear(self) -> None:
while True:
try:
self._queue.get_nowait()
except asyncio.QueueEmpty:
break
self._last_enqueue_at = 0.0
self._idle_frames = 0
if self._playing:
self._playing = False
if self._on_playing_changed:
self._on_playing_changed(False)
else:
class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable
@ -286,6 +343,9 @@ else:
def stop(self) -> None:
return
def clear(self) -> None:
return
def _write_temp_wav(pcm: bytes, sample_rate: int, channels: int) -> str:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
@ -706,13 +766,138 @@ class SupertonicTextToSpeech:
self._init_error = None
class MeloTTSTextToSpeech:
def __init__(self) -> None:
self._root_dir = Path(__file__).resolve().parent
self._workspace_dir = Path(
os.getenv("NANOBOT_WORKSPACE", str(Path.home() / ".nanobot"))
).expanduser()
self._socket_path = Path(
os.getenv("MELO_TTS_SOCKET", str(self._workspace_dir / "melotts.sock"))
).expanduser()
self._server_script = self._root_dir / "scripts" / "melotts_server.py"
self._server_log_path = self._workspace_dir / "logs" / "melotts-server.log"
self._startup_timeout_s = max(
5.0, float(os.getenv("MELO_TTS_SERVER_STARTUP_TIMEOUT_S", "120"))
)
self._init_error: str | None = None
self._lock = asyncio.Lock()
@property
def enabled(self) -> bool:
return self._server_script.exists()
@property
def init_error(self) -> str | None:
return self._init_error
async def synthesize(self, text: str) -> PCMChunk | None:
if not self.enabled:
return None
clean_text = " ".join(text.split())
if not clean_text:
return None
async with self._lock:
return await asyncio.to_thread(self._synthesize_blocking, clean_text)
def _synthesize_blocking(self, text: str) -> PCMChunk | None:
self._ensure_server_blocking()
response = self._rpc(
{
"action": "synthesize_pcm",
"text": text,
},
timeout_s=max(30.0, self._startup_timeout_s),
)
if not response.get("ok"):
raise RuntimeError(str(response.get("error", "MeloTTS synthesis failed")))
encoded_pcm = str(response.get("pcm", "")).strip()
if not encoded_pcm:
return None
pcm = base64.b64decode(encoded_pcm)
sample_rate = max(1, int(response.get("sample_rate", 44100)))
channels = max(1, int(response.get("channels", 1)))
return PCMChunk(pcm=pcm, sample_rate=sample_rate, channels=channels)
def _ensure_server_blocking(self) -> None:
if self._ping():
self._init_error = None
return
with contextlib.suppress(FileNotFoundError):
self._socket_path.unlink()
self._server_log_path.parent.mkdir(parents=True, exist_ok=True)
with self._server_log_path.open("a", encoding="utf-8") as log_handle:
proc = subprocess.Popen(
[sys.executable, str(self._server_script), "--socket-path", str(self._socket_path)],
cwd=str(self._root_dir),
stdin=subprocess.DEVNULL,
stdout=log_handle,
stderr=subprocess.STDOUT,
start_new_session=True,
)
deadline = time.time() + self._startup_timeout_s
while time.time() < deadline:
if self._ping():
self._init_error = None
return
exit_code = proc.poll()
if exit_code is not None:
self._init_error = (
f"MeloTTS server exited during startup with code {exit_code}. "
f"See {self._server_log_path}"
)
raise RuntimeError(self._init_error)
time.sleep(0.25)
self._init_error = (
f"MeloTTS server did not become ready within {self._startup_timeout_s:.0f}s."
)
raise RuntimeError(self._init_error)
def _ping(self) -> bool:
try:
response = self._rpc({"action": "ping"}, timeout_s=2.0)
except Exception:
return False
return bool(response.get("ok"))
def _rpc(self, payload: dict[str, Any], timeout_s: float) -> dict[str, Any]:
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
sock.settimeout(timeout_s)
try:
sock.connect(str(self._socket_path))
sock.sendall((json.dumps(payload) + "\n").encode("utf-8"))
chunks: list[bytes] = []
while True:
data = sock.recv(8192)
if not data:
break
chunks.append(data)
if b"\n" in data:
break
finally:
sock.close()
response = b"".join(chunks).decode("utf-8", errors="replace").strip()
if not response:
raise RuntimeError("empty response from MeloTTS server")
return json.loads(response)
class HostTextToSpeech:
def __init__(self) -> None:
provider = (os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic").lower()
if provider not in {"supertonic", "command", "espeak", "auto"}:
if provider not in {"supertonic", "melotts", "command", "espeak", "auto"}:
provider = "auto"
self._provider = provider
self._supertonic = SupertonicTextToSpeech()
self._melotts = MeloTTSTextToSpeech()
self._command_template = os.getenv("HOST_TTS_COMMAND", "").strip()
self._espeak = shutil.which("espeak")
@ -720,11 +905,17 @@ class HostTextToSpeech:
def enabled(self) -> bool:
if self._provider == "supertonic":
return self._supertonic.enabled
if self._provider == "melotts":
return self._melotts.enabled
if self._provider == "command":
return bool(self._command_template)
if self._provider == "espeak":
return bool(self._espeak)
return self._supertonic.enabled or bool(self._command_template or self._espeak)
return (
self._supertonic.enabled
or self._melotts.enabled
or bool(self._command_template or self._espeak)
)
async def synthesize(self, text: str) -> PCMChunk | None:
clean_text = " ".join(text.split())
@ -738,6 +929,13 @@ class HostTextToSpeech:
if self._provider == "supertonic":
return None
if self._provider in {"melotts", "auto"}:
audio = await self._melotts.synthesize(clean_text)
if audio:
return audio
if self._provider == "melotts":
return None
if self._provider in {"command", "auto"} and self._command_template:
return await asyncio.to_thread(self._synthesize_with_command, clean_text)
if self._provider == "command":
@ -755,6 +953,12 @@ class HostTextToSpeech:
if self._supertonic.init_error:
return f"supertonic initialization failed: {self._supertonic.init_error}"
return "supertonic did not return audio."
if self._provider == "melotts":
if not self._melotts.enabled:
return "MeloTTS server script is not available."
if self._melotts.init_error:
return f"MeloTTS initialization failed: {self._melotts.init_error}"
return "MeloTTS did not return audio."
if self._provider == "command":
return "HOST_TTS_COMMAND is not configured."
if self._provider == "espeak":
@ -762,6 +966,8 @@ class HostTextToSpeech:
if self._supertonic.init_error:
return f"supertonic initialization failed: {self._supertonic.init_error}"
if self._melotts.init_error:
return f"MeloTTS initialization failed: {self._melotts.init_error}"
if self._command_template:
return "HOST_TTS_COMMAND failed to produce audio."
if self._espeak:
@ -862,11 +1068,12 @@ class WebRTCVoiceSession:
self._stt = HostSpeechToText()
self._tts = HostTextToSpeech()
self._stt_segment_queue_size = max(1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2")))
self._stt_segments: asyncio.Queue[bytes] = asyncio.Queue(
self._stt_segments: asyncio.Queue[STTSegment] = asyncio.Queue(
maxsize=self._stt_segment_queue_size
)
self._tts_chunks: list[str] = []
self._tts_partial_buffer = ""
self._tts_flush_handle: asyncio.TimerHandle | None = None
self._tts_flush_lock = asyncio.Lock()
self._tts_buffer_lock = asyncio.Lock()
@ -875,8 +1082,18 @@ class WebRTCVoiceSession:
self._tts_response_end_delay_s = max(
0.1, float(os.getenv("HOST_TTS_RESPONSE_END_DELAY_S", "0.5"))
)
self._tts_partial_commit_chars = max(
24, int(os.getenv("HOST_TTS_PARTIAL_COMMIT_CHARS", "72"))
)
self._closed = False
self._audio_debug = os.getenv("HOST_AUDIO_DEBUG", "0").strip() not in {
"0",
"false",
"False",
"no",
"off",
}
self._stt_unavailable_notice_sent = False
self._tts_unavailable_notice_sent = False
self._audio_seen_notice_sent = False
@ -925,20 +1142,65 @@ class WebRTCVoiceSession:
except Exception:
pass
async def queue_output_text(self, chunk: str) -> None:
normalized_chunk = chunk.strip()
if not normalized_chunk:
def _should_commit_partial_buffer(self) -> bool:
stripped = self._tts_partial_buffer.strip()
if not stripped:
return False
if len(stripped) >= self._tts_partial_commit_chars:
return True
return bool(TTS_PARTIAL_COMMIT_RE.search(self._tts_partial_buffer))
def _commit_partial_buffer_locked(self) -> None:
partial = self._tts_partial_buffer.strip()
self._tts_partial_buffer = ""
if partial:
self._tts_chunks.append(partial)
async def queue_output_text(self, chunk: str, *, partial: bool = False) -> None:
if not chunk:
return
async with self._tts_buffer_lock:
if not self._pc or not self._outbound_track:
return
if partial:
self._tts_partial_buffer += chunk
if self._should_commit_partial_buffer():
self._commit_partial_buffer_locked()
self._schedule_tts_flush_after(0.05, reset=True)
else:
self._schedule_tts_flush_after(self._tts_response_end_delay_s, reset=True)
return
normalized_chunk = chunk.strip()
if not normalized_chunk:
return
if self._tts_partial_buffer.strip():
self._commit_partial_buffer_locked()
# Keep line boundaries between streamed chunks so line-based filters
# stay accurate while avoiding repeated full-string copies.
self._tts_chunks.append(normalized_chunk)
# Reset the flush timer on every incoming chunk so the entire
# response is accumulated before synthesis begins. The timer
# fires once no new chunks arrive for the configured delay.
self._schedule_tts_flush_after(self._tts_response_end_delay_s)
# Flush in short rolling windows instead of waiting for the whole
# response so streamed Nanobot output can start speaking sooner.
self._schedule_tts_flush_after(self._tts_response_end_delay_s, reset=False)
async def flush_partial_output_text(self) -> None:
async with self._tts_buffer_lock:
if not self._pc or not self._outbound_track:
return
if not self._tts_partial_buffer.strip():
return
self._commit_partial_buffer_locked()
self._schedule_tts_flush_after(0.05, reset=True)
def interrupt_output(self) -> None:
if self._tts_flush_handle:
self._tts_flush_handle.cancel()
self._tts_flush_handle = None
self._tts_chunks.clear()
self._tts_partial_buffer = ""
self._stt_suppress_until = 0.0
if self._outbound_track:
self._outbound_track.clear()
async def handle_offer(self, payload: dict[str, Any]) -> dict[str, Any] | None:
if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription:
@ -980,7 +1242,10 @@ class WebRTCVoiceSession:
self._active_message_metadata = _coerce_message_metadata(msg.get("metadata", {}))
self.set_push_to_talk_pressed(bool(msg.get("pressed", False)))
elif msg_type == "command":
asyncio.create_task(self._gateway.send_command(str(msg.get("command", ""))))
command = str(msg.get("command", "")).strip()
if command == "reset":
self.interrupt_output()
asyncio.create_task(self._gateway.send_command(command))
elif msg_type == "card-response":
asyncio.create_task(
self._gateway.send_card_response(
@ -1037,6 +1302,7 @@ class WebRTCVoiceSession:
self._tts_flush_handle.cancel()
self._tts_flush_handle = None
self._tts_chunks.clear()
self._tts_partial_buffer = ""
if self._incoming_audio_task:
self._incoming_audio_task.cancel()
@ -1063,55 +1329,64 @@ class WebRTCVoiceSession:
return
asyncio.create_task(self._flush_tts(), name="voice-tts-flush")
def _schedule_tts_flush_after(self, delay_s: float) -> None:
def _schedule_tts_flush_after(self, delay_s: float, *, reset: bool = True) -> None:
if self._tts_flush_handle:
if not reset:
return
self._tts_flush_handle.cancel()
loop = asyncio.get_running_loop()
self._tts_flush_handle = loop.call_later(max(0.05, delay_s), self._schedule_tts_flush)
async def _flush_tts(self) -> None:
async with self._tts_flush_lock:
async with self._tts_buffer_lock:
self._tts_flush_handle = None
raw_text = "\n".join(self._tts_chunks)
self._tts_chunks.clear()
clean_text = self._clean_tts_text(raw_text)
if not clean_text:
return
if not self._outbound_track:
return
try:
audio = await self._tts.synthesize(clean_text)
except asyncio.CancelledError:
raise
except Exception as exc:
import traceback # noqa: local import in exception handler
traceback.print_exc()
# Restore the lost text so a future flush can retry it.
while True:
async with self._tts_buffer_lock:
self._tts_chunks.insert(0, clean_text)
await self._publish_system(f"TTS synthesis error: {exc}")
return
self._tts_flush_handle = None
if not self._tts_chunks and self._tts_partial_buffer.strip():
self._commit_partial_buffer_locked()
if not self._tts_chunks:
return
raw_text = self._tts_chunks.pop(0)
if not audio:
if not self._tts_unavailable_notice_sent:
self._tts_unavailable_notice_sent = True
await self._publish_system(
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
)
return
clean_text = self._clean_tts_text(raw_text)
if not clean_text:
continue
if not self._outbound_track:
return
self._extend_stt_suppression(audio)
await self._outbound_track.enqueue_pcm(
pcm=audio.pcm,
sample_rate=audio.sample_rate,
channels=audio.channels,
)
if not self._outbound_track:
return
try:
audio = await self._tts.synthesize(clean_text)
except asyncio.CancelledError:
raise
except Exception as exc:
import traceback # noqa: local import in exception handler
traceback.print_exc()
retry_segments = _split_tts_retry_segments(clean_text)
if retry_segments:
async with self._tts_buffer_lock:
self._tts_chunks[0:0] = retry_segments
continue
await self._publish_system(f"TTS synthesis error: {exc}")
return
if not audio:
if not self._tts_unavailable_notice_sent:
self._tts_unavailable_notice_sent = True
await self._publish_system(
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
)
return
if not self._outbound_track:
return
self._extend_stt_suppression(audio)
await self._outbound_track.enqueue_pcm(
pcm=audio.pcm,
sample_rate=audio.sample_rate,
channels=audio.channels,
)
def _extend_stt_suppression(self, audio: PCMChunk) -> None:
if not self._stt_suppress_during_tts:
@ -1152,13 +1427,13 @@ class WebRTCVoiceSession:
if not pcm16:
continue
if not self._audio_seen_notice_sent:
if self._audio_debug and not self._audio_seen_notice_sent:
self._audio_seen_notice_sent = True
await self._publish_system("Receiving microphone audio on host.")
await self._publish_debug("Receiving microphone audio on host.")
if not self._audio_format_notice_sent:
if self._audio_debug and not self._audio_format_notice_sent:
self._audio_format_notice_sent = True
await self._publish_system(
await self._publish_debug(
"Inbound audio frame stats: "
f"sample_rate={int(getattr(frame, 'sample_rate', 0) or 0)}, "
f"samples={int(getattr(frame, 'samples', 0) or 0)}, "
@ -1261,16 +1536,25 @@ class WebRTCVoiceSession:
None,
)
normalized_duration_ms = (len(normalized_pcm) / 2 / 16_000) * 1000.0
if not self._ptt_timing_correction_notice_sent:
if self._audio_debug and not self._ptt_timing_correction_notice_sent:
self._ptt_timing_correction_notice_sent = True
await self._publish_system(
await self._publish_debug(
"Corrected PTT timing mismatch "
f"(estimated source={nearest_source_rate}Hz)."
)
await self._enqueue_stt_segment(pcm16=normalized_pcm, duration_ms=normalized_duration_ms)
await self._enqueue_stt_segment(
pcm16=normalized_pcm,
duration_ms=normalized_duration_ms,
metadata=dict(self._active_message_metadata),
)
async def _enqueue_stt_segment(self, pcm16: bytes, duration_ms: float) -> None:
async def _enqueue_stt_segment(
self,
pcm16: bytes,
duration_ms: float,
metadata: dict[str, Any],
) -> None:
if duration_ms < self._stt_min_ptt_ms:
return
@ -1284,17 +1568,17 @@ class WebRTCVoiceSession:
await self._publish_system("Voice input backlog detected; dropping stale segment.")
with contextlib.suppress(asyncio.QueueFull):
self._stt_segments.put_nowait(pcm16)
self._stt_segments.put_nowait(STTSegment(pcm=pcm16, metadata=dict(metadata)))
async def _stt_worker(self) -> None:
while True:
pcm16 = await self._stt_segments.get()
if not self._stt_first_segment_notice_sent:
segment = await self._stt_segments.get()
if self._audio_debug and not self._stt_first_segment_notice_sent:
self._stt_first_segment_notice_sent = True
await self._publish_system("Push-to-talk audio captured. Running host STT...")
await self._publish_debug("Push-to-talk audio captured. Running host STT...")
try:
transcript = await self._stt.transcribe_pcm(
pcm=pcm16,
pcm=segment.pcm,
sample_rate=16_000,
channels=1,
)
@ -1317,7 +1601,7 @@ class WebRTCVoiceSession:
try:
await self._gateway.send_user_message(
transcript,
metadata=dict(self._active_message_metadata),
metadata=dict(segment.metadata),
)
except RuntimeError as exc:
if self._closed:
@ -1360,6 +1644,11 @@ class WebRTCVoiceSession:
async def _publish_system(self, text: str) -> None:
await self._gateway.bus.publish(WisperEvent(role="system", text=text))
async def _publish_debug(self, text: str) -> None:
if not self._audio_debug:
return
await self._publish_system(text)
async def _publish_agent_state(self, state: str) -> None:
await self._gateway.bus.publish(WisperEvent(role="agent-state", text=state))