nanobot-voice-interface/voice_rtc.py

1657 lines
59 KiB
Python
Raw Normal View History

2026-02-28 22:12:04 -05:00
import asyncio
import audioop
import contextlib
import io
import os
import re
import shlex
import shutil
import subprocess
import tempfile
import wave
from dataclasses import dataclass
from fractions import Fraction
from typing import TYPE_CHECKING, Any, Awaitable, Callable
from wisper import WisperEvent
if TYPE_CHECKING:
from supertonic_gateway import SuperTonicGateway
try:
import numpy as np
NUMPY_AVAILABLE = True
except Exception: # pragma: no cover - runtime fallback when numpy is unavailable
np = None # type: ignore[assignment]
NUMPY_AVAILABLE = False
try:
from supertonic import TTS as SupertonicTTS
SUPERTONIC_TTS_AVAILABLE = True
except Exception: # pragma: no cover - runtime fallback when supertonic is unavailable
SupertonicTTS = None # type: ignore[assignment]
SUPERTONIC_TTS_AVAILABLE = False
try:
from faster_whisper import WhisperModel
FASTER_WHISPER_AVAILABLE = True
except (
Exception
): # pragma: no cover - runtime fallback when faster-whisper is unavailable
WhisperModel = None # type: ignore[assignment]
FASTER_WHISPER_AVAILABLE = False
try:
from aiortc import RTCPeerConnection, RTCSessionDescription
from aiortc.mediastreams import MediaStreamTrack
from aiortc.sdp import candidate_from_sdp
from av import AudioFrame
AIORTC_AVAILABLE = True
except Exception: # pragma: no cover - runtime fallback when aiortc is unavailable
RTCPeerConnection = None # type: ignore[assignment]
RTCSessionDescription = None # type: ignore[assignment]
MediaStreamTrack = object # type: ignore[assignment,misc]
candidate_from_sdp = None # type: ignore[assignment]
AudioFrame = None # type: ignore[assignment]
AIORTC_AVAILABLE = False
SPEECH_FILTER_RE = re.compile(
r"^(spawned nanobot tui|stopped nanobot tui|nanobot tui exited|websocket)",
re.IGNORECASE,
)
THINKING_STATUS_RE = re.compile(r"\bnanobot is thinking\b", re.IGNORECASE)
USER_PREFIX_RE = re.compile(r"^(?:you|user)\s*:\s*", re.IGNORECASE)
VOICE_TRANSCRIPT_RE = re.compile(
r"^(?:wisper\s*:\s*)?voice\s+transcript\s*:\s*",
re.IGNORECASE,
)
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]")
BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]")
SENTENCE_END_RE = re.compile(r"[.!?]\s*$")
TTS_ALLOWED_ASCII = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
" .,!?;:'\"()[]{}@#%&*+-_/<>|"
)
def _sanitize_tts_text(text: str) -> str:
cleaned = ANSI_ESCAPE_RE.sub(" ", text)
cleaned = BRAILLE_SPINNER_RE.sub(" ", cleaned)
cleaned = cleaned.replace("\u00a0", " ")
cleaned = cleaned.replace("", " ")
cleaned = CONTROL_CHAR_RE.sub(" ", cleaned)
cleaned = "".join(
ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned
)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
return cleaned
def _optional_int_env(name: str) -> int | None:
raw_value = os.getenv(name, "").strip()
if not raw_value:
return None
return int(raw_value)
@dataclass(slots=True)
class PCMChunk:
pcm: bytes
sample_rate: int
channels: int = 1
if AIORTC_AVAILABLE:
class QueueAudioTrack(MediaStreamTrack):
kind = "audio"
def __init__(self, sample_rate: int = 48_000, frame_ms: int = 20) -> None:
super().__init__()
self._sample_rate = sample_rate
self._frame_ms = max(1, frame_ms)
self._samples_per_frame = max(1, (sample_rate * frame_ms) // 1000)
self._bytes_per_frame = self._samples_per_frame * 2
self._queue: asyncio.Queue[bytes] = asyncio.Queue()
self._timestamp = 0
self._resample_state = None
self._resample_source_rate: int | None = None
self._lead_in_ms = max(
0, int(os.getenv("HOST_RTC_OUTBOUND_LEAD_IN_MS", "120"))
)
self._lead_in_frames = (
self._lead_in_ms + self._frame_ms - 1
) // self._frame_ms
self._lead_in_idle_s = max(
0.1, float(os.getenv("HOST_RTC_OUTBOUND_IDLE_S", "0.6"))
)
self._last_enqueue_at = 0.0
self._closed = False
self._frame_duration_s = frame_ms / 1000.0
self._last_recv_at = 0.0
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
) -> None:
if self._closed or not pcm:
return
now = asyncio.get_running_loop().time()
should_add_lead_in = (
self._lead_in_frames > 0
and self._queue.empty()
and (
self._last_enqueue_at <= 0.0
or (now - self._last_enqueue_at) >= self._lead_in_idle_s
)
)
if should_add_lead_in:
silence = b"\x00" * self._bytes_per_frame
for _index in range(self._lead_in_frames):
await self._queue.put(silence)
mono = pcm
if channels > 1:
mono = audioop.tomono(mono, 2, 0.5, 0.5)
if sample_rate != self._sample_rate:
# audioop rate conversion state is only valid when source/destination rates stay the same.
if self._resample_source_rate != sample_rate:
self._resample_state = None
self._resample_source_rate = sample_rate
mono, self._resample_state = audioop.ratecv(
mono,
2,
1,
sample_rate,
self._sample_rate,
self._resample_state,
)
else:
self._resample_state = None
self._resample_source_rate = None
if not mono:
return
for start in range(0, len(mono), self._bytes_per_frame):
chunk = mono[start : start + self._bytes_per_frame]
if len(chunk) < self._bytes_per_frame:
chunk += b"\x00" * (self._bytes_per_frame - len(chunk))
await self._queue.put(chunk)
self._last_enqueue_at = now
async def recv(self) -> AudioFrame:
if self._closed:
raise asyncio.CancelledError
# Pace frame delivery to real-time to prevent RTP burst sends.
# Without pacing, when TTS enqueues audio faster than real-time,
# aiortc sends RTP packets in a burst and the browser's jitter
# buffer skips ahead, causing the user to only hear the tail end.
loop = asyncio.get_running_loop()
now = loop.time()
if self._last_recv_at > 0.0:
elapsed = now - self._last_recv_at
remaining = self._frame_duration_s - elapsed
if remaining > 0.001:
await asyncio.sleep(remaining)
try:
payload = self._queue.get_nowait()
except asyncio.QueueEmpty:
payload = b"\x00" * self._bytes_per_frame
self._last_recv_at = loop.time()
frame = AudioFrame(
format="s16", layout="mono", samples=self._samples_per_frame
)
frame.planes[0].update(payload)
frame.sample_rate = self._sample_rate
frame.time_base = Fraction(1, self._sample_rate)
frame.pts = self._timestamp
self._timestamp += self._samples_per_frame
return frame
def stop(self) -> None:
self._closed = True
super().stop()
else:
class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
) -> None:
return
def stop(self) -> None:
return
def _write_temp_wav(pcm: bytes, sample_rate: int, channels: int) -> str:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
tmp_path = tmp_file.name
with wave.open(tmp_path, "wb") as wav_file:
wav_file.setnchannels(max(1, channels))
wav_file.setsampwidth(2)
wav_file.setframerate(sample_rate)
wav_file.writeframes(pcm)
return tmp_path
class CommandSpeechToText:
def __init__(self) -> None:
self._command_template = os.getenv("HOST_STT_COMMAND", "").strip()
@property
def enabled(self) -> bool:
return bool(self._command_template)
async def transcribe_pcm(
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
) -> str | None:
if not self.enabled or not pcm:
return None
return await asyncio.to_thread(
self._transcribe_blocking, pcm, sample_rate, channels
)
def unavailable_reason(self) -> str:
if not self._command_template:
return "HOST_STT_COMMAND is not configured."
return "HOST_STT_COMMAND failed to produce transcript."
def _transcribe_blocking(
self, pcm: bytes, sample_rate: int, channels: int
) -> str | None:
tmp_path: str | None = None
try:
tmp_path = _write_temp_wav(
pcm=pcm, sample_rate=sample_rate, channels=channels
)
command = self._command_template
if "{input_wav}" in command:
command = command.replace("{input_wav}", shlex.quote(tmp_path))
else:
command = f"{command} {shlex.quote(tmp_path)}"
result = subprocess.run(
command,
shell=True,
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
stderr = result.stderr.strip() or "unknown error"
raise RuntimeError(f"STT command failed: {stderr}")
transcript = result.stdout.strip()
return transcript or None
finally:
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(OSError):
os.unlink(tmp_path)
class FasterWhisperSpeechToText:
def __init__(self) -> None:
self._model_name = os.getenv("HOST_STT_MODEL", "base.en").strip() or "base.en"
self._device = os.getenv("HOST_STT_DEVICE", "auto").strip() or "auto"
self._compute_type = (
os.getenv("HOST_STT_COMPUTE_TYPE", "int8").strip() or "int8"
)
self._language = os.getenv("HOST_STT_LANGUAGE", "en").strip()
self._beam_size = max(1, int(os.getenv("HOST_STT_BEAM_SIZE", "2")))
self._best_of = max(1, int(os.getenv("HOST_STT_BEST_OF", "2")))
self._vad_filter = os.getenv("HOST_STT_VAD_FILTER", "0").strip() not in {
"0",
"false",
"False",
"no",
"off",
}
self._temperature = float(os.getenv("HOST_STT_TEMPERATURE", "0.0"))
self._log_prob_threshold = float(
os.getenv("HOST_STT_LOG_PROB_THRESHOLD", "-1.0")
)
self._no_speech_threshold = float(
os.getenv("HOST_STT_NO_SPEECH_THRESHOLD", "0.6")
)
self._compression_ratio_threshold = float(
os.getenv("HOST_STT_COMPRESSION_RATIO_THRESHOLD", "2.4")
)
self._initial_prompt = (
os.getenv(
"HOST_STT_INITIAL_PROMPT",
"Transcribe brief spoken English precisely. Prefer common words over sound effects.",
).strip()
or None
)
self._model: Any = None
self._init_error: str | None = None
self._lock = asyncio.Lock()
@property
def enabled(self) -> bool:
return FASTER_WHISPER_AVAILABLE and WhisperModel is not None
@property
def init_error(self) -> str | None:
return self._init_error
async def transcribe_pcm(
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
) -> str | None:
if not self.enabled or not pcm:
return None
async with self._lock:
return await asyncio.to_thread(
self._transcribe_blocking, pcm, sample_rate, channels
)
async def warmup(self) -> None:
if not self.enabled:
return
async with self._lock:
await asyncio.to_thread(self._initialize_blocking)
def _initialize_blocking(self) -> None:
if self._model is not None:
return
if not self.enabled or WhisperModel is None:
return
try:
self._model = WhisperModel(
self._model_name,
device=self._device,
compute_type=self._compute_type,
)
self._init_error = None
except Exception as exc:
self._init_error = str(exc)
self._model = None
def _transcribe_blocking(
self, pcm: bytes, sample_rate: int, channels: int
) -> str | None:
self._initialize_blocking()
if self._model is None:
if self._init_error:
raise RuntimeError(
f"faster-whisper initialization failed: {self._init_error}"
)
return None
if NUMPY_AVAILABLE and np is not None:
mono = pcm
if channels > 1:
mono = audioop.tomono(mono, 2, 0.5, 0.5)
if sample_rate != 16_000:
mono, _ = audioop.ratecv(
mono,
2,
1,
sample_rate,
16_000,
None,
)
audio = np.frombuffer(mono, dtype=np.int16).astype(np.float32) / 32768.0
if audio.size == 0:
return None
segments, _info = self._model.transcribe(
audio,
language=self._language or None,
beam_size=self._beam_size,
best_of=self._best_of,
vad_filter=self._vad_filter,
condition_on_previous_text=False,
without_timestamps=True,
initial_prompt=self._initial_prompt,
temperature=self._temperature,
log_prob_threshold=self._log_prob_threshold,
no_speech_threshold=self._no_speech_threshold,
compression_ratio_threshold=self._compression_ratio_threshold,
)
transcript_parts: list[str] = []
for segment in segments:
text = str(getattr(segment, "text", "")).strip()
if text:
transcript_parts.append(text)
transcript = " ".join(transcript_parts).strip()
return transcript or None
tmp_path: str | None = None
try:
tmp_path = _write_temp_wav(
pcm=pcm, sample_rate=sample_rate, channels=channels
)
segments, _info = self._model.transcribe(
tmp_path,
language=self._language or None,
beam_size=self._beam_size,
best_of=self._best_of,
vad_filter=self._vad_filter,
condition_on_previous_text=False,
without_timestamps=True,
initial_prompt=self._initial_prompt,
temperature=self._temperature,
log_prob_threshold=self._log_prob_threshold,
no_speech_threshold=self._no_speech_threshold,
compression_ratio_threshold=self._compression_ratio_threshold,
)
transcript_parts: list[str] = []
for segment in segments:
text = str(getattr(segment, "text", "")).strip()
if text:
transcript_parts.append(text)
transcript = " ".join(transcript_parts).strip()
return transcript or None
finally:
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(OSError):
os.unlink(tmp_path)
class HostSpeechToText:
def __init__(self) -> None:
provider = (
os.getenv("HOST_STT_PROVIDER", "faster-whisper").strip() or "faster-whisper"
).lower()
if provider not in {"faster-whisper", "command", "auto"}:
provider = "auto"
self._provider = provider
self._faster_whisper = FasterWhisperSpeechToText()
self._command = CommandSpeechToText()
@property
def enabled(self) -> bool:
if self._provider == "faster-whisper":
return self._faster_whisper.enabled
if self._provider == "command":
return self._command.enabled
return self._faster_whisper.enabled or self._command.enabled
async def transcribe_pcm(
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
) -> str | None:
if self._provider in {"faster-whisper", "auto"}:
transcript = await self._faster_whisper.transcribe_pcm(
pcm=pcm,
sample_rate=sample_rate,
channels=channels,
)
if transcript:
return transcript
if self._provider == "faster-whisper":
return None
if self._provider in {"command", "auto"}:
return await self._command.transcribe_pcm(
pcm=pcm,
sample_rate=sample_rate,
channels=channels,
)
return None
async def warmup(self) -> None:
if self._provider in {"faster-whisper", "auto"}:
await self._faster_whisper.warmup()
def unavailable_reason(self) -> str:
if self._provider == "faster-whisper":
if not self._faster_whisper.enabled:
return "faster-whisper package is not available."
if self._faster_whisper.init_error:
return f"faster-whisper initialization failed: {self._faster_whisper.init_error}"
return "faster-whisper did not return transcript."
if self._provider == "command":
return self._command.unavailable_reason()
if self._faster_whisper.init_error:
return f"faster-whisper initialization failed: {self._faster_whisper.init_error}"
if self._command.enabled:
return "HOST_STT_COMMAND failed to produce transcript."
if not self._faster_whisper.enabled:
return "faster-whisper package is not available."
return "No STT provider is configured."
class SupertonicTextToSpeech:
def __init__(self) -> None:
self._model = (
os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2"
)
self._voice_style_name = (
os.getenv("SUPERTONIC_VOICE_STYLE", "M1").strip() or "M1"
)
self._lang = os.getenv("SUPERTONIC_LANG", "en").strip() or "en"
self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "5"))
self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.05"))
self._intra_op_num_threads = _optional_int_env("SUPERTONIC_INTRA_OP_THREADS")
self._inter_op_num_threads = _optional_int_env("SUPERTONIC_INTER_OP_THREADS")
self._auto_download = os.getenv(
"SUPERTONIC_AUTO_DOWNLOAD", "1"
).strip() not in {
"0",
"false",
"False",
"no",
"off",
}
self._engine: Any = None
self._voice_style: Any = None
self._init_error: str | None = None
self._lock = asyncio.Lock()
@property
def enabled(self) -> bool:
return (
SUPERTONIC_TTS_AVAILABLE and SupertonicTTS is not None and NUMPY_AVAILABLE
)
@property
def init_error(self) -> str | None:
return self._init_error
async def synthesize(self, text: str) -> PCMChunk | None:
if not self.enabled:
return None
clean_text = " ".join(text.split())
if not clean_text:
return None
async with self._lock:
return await asyncio.to_thread(self._synthesize_blocking, clean_text)
def _synthesize_blocking(self, text: str) -> PCMChunk | None:
self._initialize_blocking()
if self._engine is None or self._voice_style is None or np is None:
return None
text = _sanitize_tts_text(text)
if not text:
return None
try:
wav, _duration = self._engine.synthesize(
text,
voice_style=self._voice_style,
lang=self._lang,
total_steps=self._total_steps,
speed=self._speed,
)
except ValueError as exc:
message = str(exc)
if "unsupported character" not in message.lower():
raise
fallback_text = self._sanitize_text_for_supertonic(text)
if not fallback_text or fallback_text == text:
raise
wav, _duration = self._engine.synthesize(
fallback_text,
voice_style=self._voice_style,
lang=self._lang,
total_steps=self._total_steps,
speed=self._speed,
)
samples = np.asarray(wav)
if samples.size == 0:
return None
channels = 1
if samples.ndim == 0:
samples = samples.reshape(1)
elif samples.ndim == 1:
channels = 1
elif samples.ndim == 2:
# Normalize to frames x channels so PCM bytes are correctly interleaved.
dim0, dim1 = int(samples.shape[0]), int(samples.shape[1])
if dim0 <= 2 and dim1 > dim0:
channels = dim0
samples = samples.T
elif dim1 <= 2 and dim0 > dim1:
channels = dim1
else:
channels = 1
samples = samples.reshape(-1)
else:
channels = 1
samples = samples.reshape(-1)
if np.issubdtype(samples.dtype, np.floating):
samples = np.clip(samples, -1.0, 1.0)
samples = (samples * 32767.0).astype(np.int16)
else:
if samples.dtype != np.int16:
samples = samples.astype(np.int16)
pcm = samples.tobytes()
return PCMChunk(
pcm=pcm,
sample_rate=int(getattr(self._engine, "sample_rate", 24_000)),
channels=max(1, channels),
)
def _sanitize_text_for_supertonic(self, text: str) -> str:
return _sanitize_tts_text(text)
def _initialize_blocking(self) -> None:
if self._engine is not None and self._voice_style is not None:
return
if not self.enabled or SupertonicTTS is None:
return
try:
engine = SupertonicTTS(
model=self._model,
auto_download=self._auto_download,
intra_op_num_threads=self._intra_op_num_threads,
inter_op_num_threads=self._inter_op_num_threads,
)
voice_style = engine.get_voice_style(self._voice_style_name)
except Exception as exc:
self._init_error = str(exc)
return
self._engine = engine
self._voice_style = voice_style
self._init_error = None
class HostTextToSpeech:
def __init__(self) -> None:
provider = (
os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic"
).lower()
if provider not in {"supertonic", "command", "espeak", "auto"}:
provider = "auto"
self._provider = provider
self._supertonic = SupertonicTextToSpeech()
self._command_template = os.getenv("HOST_TTS_COMMAND", "").strip()
self._espeak = shutil.which("espeak")
@property
def enabled(self) -> bool:
if self._provider == "supertonic":
return self._supertonic.enabled
if self._provider == "command":
return bool(self._command_template)
if self._provider == "espeak":
return bool(self._espeak)
return self._supertonic.enabled or bool(self._command_template or self._espeak)
async def synthesize(self, text: str) -> PCMChunk | None:
clean_text = " ".join(text.split())
if not clean_text:
return None
if self._provider in {"supertonic", "auto"}:
audio = await self._supertonic.synthesize(clean_text)
if audio:
return audio
if self._provider == "supertonic":
return None
if self._provider in {"command", "auto"} and self._command_template:
return await asyncio.to_thread(self._synthesize_with_command, clean_text)
if self._provider == "command":
return None
if self._provider in {"espeak", "auto"} and self._espeak:
return await asyncio.to_thread(self._synthesize_with_espeak, clean_text)
return None
def unavailable_reason(self) -> str:
if self._provider == "supertonic":
if not self._supertonic.enabled:
return "supertonic package is not available."
if self._supertonic.init_error:
return (
f"supertonic initialization failed: {self._supertonic.init_error}"
)
return "supertonic did not return audio."
if self._provider == "command":
return "HOST_TTS_COMMAND is not configured."
if self._provider == "espeak":
return "espeak binary is not available."
if self._supertonic.init_error:
return f"supertonic initialization failed: {self._supertonic.init_error}"
if self._command_template:
return "HOST_TTS_COMMAND failed to produce audio."
if self._espeak:
return "espeak failed to produce audio."
return "No TTS provider is configured."
def _synthesize_with_command(self, text: str) -> PCMChunk | None:
command = self._command_template
if "{text}" in command:
command = command.replace("{text}", shlex.quote(text))
else:
command = f"{command} {shlex.quote(text)}"
if "{output_wav}" in command:
tmp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as tmp_file:
tmp_path = tmp_file.name
command_with_output = command.replace(
"{output_wav}", shlex.quote(tmp_path)
)
result = subprocess.run(
command_with_output,
shell=True,
capture_output=True,
text=True,
check=False,
)
if result.returncode != 0:
stderr = result.stderr.strip() or "unknown error"
raise RuntimeError(f"TTS command failed: {stderr}")
return self._read_wav_file(tmp_path)
finally:
if tmp_path and os.path.exists(tmp_path):
with contextlib.suppress(OSError):
os.unlink(tmp_path)
result = subprocess.run(
command,
shell=True,
capture_output=True,
check=False,
)
if result.returncode != 0:
stderr = result.stderr.decode(errors="ignore").strip() or "unknown error"
raise RuntimeError(f"TTS command failed: {stderr}")
return self._decode_wav_bytes(result.stdout)
def _synthesize_with_espeak(self, text: str) -> PCMChunk | None:
if not self._espeak:
return None
result = subprocess.run(
[self._espeak, "--stdout", text],
capture_output=True,
check=False,
)
if result.returncode != 0:
stderr = result.stderr.decode(errors="ignore").strip() or "unknown error"
raise RuntimeError(f"espeak failed: {stderr}")
return self._decode_wav_bytes(result.stdout)
def _read_wav_file(self, path: str) -> PCMChunk | None:
try:
with open(path, "rb") as wav_file:
return self._decode_wav_bytes(wav_file.read())
except OSError:
return None
def _decode_wav_bytes(self, payload: bytes) -> PCMChunk | None:
if not payload:
return None
with wave.open(io.BytesIO(payload), "rb") as wav_file:
channels = wav_file.getnchannels()
sample_width = wav_file.getsampwidth()
sample_rate = wav_file.getframerate()
pcm = wav_file.readframes(wav_file.getnframes())
if sample_width != 2:
pcm = audioop.lin2lin(pcm, sample_width, 2)
return PCMChunk(pcm=pcm, sample_rate=sample_rate, channels=max(1, channels))
SendJsonCallable = Callable[[dict[str, Any]], Awaitable[None]]
class WebRTCVoiceSession:
def __init__(
self, gateway: "SuperTonicGateway", send_json: SendJsonCallable
) -> None:
self._gateway = gateway
self._send_json = send_json
self._pc: RTCPeerConnection | None = None
self._outbound_track: QueueAudioTrack | None = None
self._incoming_audio_task: asyncio.Task[None] | None = None
self._stt_worker_task: asyncio.Task[None] | None = None
self._stt_warmup_task: asyncio.Task[None] | None = None
self._stt = HostSpeechToText()
self._tts = HostTextToSpeech()
self._stt_segment_queue_size = max(
1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2"))
)
self._stt_segments: asyncio.Queue[bytes] = asyncio.Queue(
maxsize=self._stt_segment_queue_size
)
self._tts_buffer = ""
self._tts_flush_handle: asyncio.TimerHandle | None = None
self._tts_flush_lock = asyncio.Lock()
self._tts_buffer_lock = asyncio.Lock()
self._tts_flush_delay_s = max(
0.08, float(os.getenv("HOST_TTS_FLUSH_DELAY_S", "0.45"))
)
self._tts_sentence_flush_delay_s = max(
0.06,
float(os.getenv("HOST_TTS_SENTENCE_FLUSH_DELAY_S", "0.15")),
)
self._tts_min_chars = max(1, int(os.getenv("HOST_TTS_MIN_CHARS", "10")))
self._tts_max_wait_ms = max(300, int(os.getenv("HOST_TTS_MAX_WAIT_MS", "1800")))
self._tts_max_chunk_chars = max(
60, int(os.getenv("HOST_TTS_MAX_CHUNK_CHARS", "140"))
)
self._tts_buffer_started_at = 0.0
self._closed = False
self._stt_unavailable_notice_sent = False
self._tts_unavailable_notice_sent = False
self._audio_seen_notice_sent = False
self._audio_format_notice_sent = False
self._stt_first_segment_notice_sent = False
self._ptt_timing_correction_notice_sent = False
self._stt_min_ptt_ms = max(
120,
int(
os.getenv(
"HOST_STT_MIN_PTT_MS", os.getenv("HOST_STT_MIN_SEGMENT_MS", "220")
)
),
)
self._stt_max_ptt_ms = max(
self._stt_min_ptt_ms,
int(
os.getenv(
"HOST_STT_MAX_PTT_MS", os.getenv("HOST_STT_MAX_SEGMENT_MS", "12000")
)
),
)
self._stt_suppress_during_tts = os.getenv(
"HOST_STT_SUPPRESS_DURING_TTS", "1"
).strip() not in {
"0",
"false",
"False",
"no",
"off",
}
self._stt_suppress_ms_after_tts = max(
0,
int(os.getenv("HOST_STT_SUPPRESS_MS_AFTER_TTS", "300")),
)
self._stt_suppress_until = 0.0
self._stt_backlog_notice_interval_s = max(
2.0,
float(os.getenv("HOST_STT_BACKLOG_NOTICE_INTERVAL_S", "6.0")),
)
self._last_stt_backlog_notice_at = 0.0
self._pending_ice_candidates: list[dict[str, Any] | None] = []
self._ptt_pressed = False
def set_push_to_talk_pressed(self, pressed: bool) -> None:
self._ptt_pressed = bool(pressed)
async def queue_output_text(self, chunk: str) -> None:
normalized_chunk = chunk.strip()
if not normalized_chunk:
return
flush_delay = (
self._tts_sentence_flush_delay_s
if SENTENCE_END_RE.search(normalized_chunk)
else self._tts_flush_delay_s
)
loop = asyncio.get_running_loop()
async with self._tts_buffer_lock:
if not self._pc or not self._outbound_track:
return
if not self._tts_buffer.strip():
self._tts_buffer_started_at = loop.time()
self._tts_buffer = normalized_chunk
else:
# Keep line boundaries between streamed chunks so line-based filters stay accurate.
self._tts_buffer = f"{self._tts_buffer}\n{normalized_chunk}"
self._schedule_tts_flush_after(flush_delay)
async def handle_offer(self, payload: dict[str, Any]) -> None:
if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription:
await self._send_json(
{
"type": "rtc-error",
"message": "WebRTC backend unavailable on host (aiortc is not installed).",
}
)
return
sdp = str(payload.get("sdp", "")).strip()
rtc_type = str(payload.get("rtcType", "offer")).strip() or "offer"
if not sdp:
await self._send_json(
{"type": "rtc-error", "message": "Missing SDP offer payload."}
)
return
await self._close_peer_connection()
self._ptt_pressed = False
peer_connection = RTCPeerConnection()
self._pc = peer_connection
self._outbound_track = QueueAudioTrack()
peer_connection.addTrack(self._outbound_track)
@peer_connection.on("connectionstatechange")
def on_connectionstatechange() -> None:
asyncio.create_task(
self._send_json(
{
"type": "rtc-state",
"state": peer_connection.connectionState,
}
)
)
@peer_connection.on("track")
def on_track(track: MediaStreamTrack) -> None:
if track.kind != "audio":
return
if self._incoming_audio_task:
self._incoming_audio_task.cancel()
self._incoming_audio_task = asyncio.create_task(
self._consume_audio_track(track),
name="voice-inbound-track",
)
await peer_connection.setRemoteDescription(
RTCSessionDescription(sdp=sdp, type=rtc_type)
)
await self._drain_pending_ice_candidates(peer_connection)
answer = await peer_connection.createAnswer()
await peer_connection.setLocalDescription(answer)
await self._wait_for_ice_gathering(peer_connection)
local_description = peer_connection.localDescription
sdp_answer = str(local_description.sdp or "")
if sdp_answer:
sdp_answer = (
sdp_answer.replace("\r\n", "\n")
.replace("\r", "\n")
.strip()
.replace("\n", "\r\n")
+ "\r\n"
)
await self._send_json(
{
"type": "rtc-answer",
"sdp": sdp_answer,
"rtcType": local_description.type,
}
)
if self._stt.enabled and not self._stt_worker_task:
self._stt_worker_task = asyncio.create_task(
self._stt_worker(), name="voice-stt-worker"
)
if self._stt.enabled and (
self._stt_warmup_task is None or self._stt_warmup_task.done()
):
self._stt_warmup_task = asyncio.create_task(
self._warmup_stt(), name="voice-stt-warmup"
)
elif not self._stt.enabled and not self._stt_unavailable_notice_sent:
self._stt_unavailable_notice_sent = True
await self._publish_system(
f"Voice input backend unavailable. {self._stt.unavailable_reason()}"
)
async def handle_ice_candidate(self, payload: dict[str, Any]) -> None:
if not AIORTC_AVAILABLE:
return
raw_candidate = payload.get("candidate")
candidate_payload: dict[str, Any] | None
if raw_candidate in (None, ""):
candidate_payload = None
elif isinstance(raw_candidate, dict):
candidate_payload = raw_candidate
else:
return
if not self._pc or self._pc.remoteDescription is None:
self._pending_ice_candidates.append(candidate_payload)
return
await self._apply_ice_candidate(self._pc, candidate_payload)
async def _drain_pending_ice_candidates(
self,
peer_connection: RTCPeerConnection,
) -> None:
if not self._pending_ice_candidates:
return
pending = list(self._pending_ice_candidates)
self._pending_ice_candidates.clear()
for pending_candidate in pending:
await self._apply_ice_candidate(peer_connection, pending_candidate)
async def _apply_ice_candidate(
self,
peer_connection: RTCPeerConnection,
raw_candidate: dict[str, Any] | None,
) -> None:
if raw_candidate is None:
with contextlib.suppress(Exception):
await peer_connection.addIceCandidate(None)
return
candidate_sdp = str(raw_candidate.get("candidate", "")).strip()
if not candidate_sdp or not candidate_from_sdp:
return
if candidate_sdp.startswith("candidate:"):
candidate_sdp = candidate_sdp[len("candidate:") :]
try:
candidate = candidate_from_sdp(candidate_sdp)
candidate.sdpMid = raw_candidate.get("sdpMid")
line_index = raw_candidate.get("sdpMLineIndex")
candidate.sdpMLineIndex = (
int(line_index) if line_index is not None else None
)
await peer_connection.addIceCandidate(candidate)
except Exception as exc:
await self._publish_system(f"Failed to add ICE candidate: {exc}")
async def close(self) -> None:
self._closed = True
self._ptt_pressed = False
if self._tts_flush_handle:
self._tts_flush_handle.cancel()
self._tts_flush_handle = None
self._tts_buffer = ""
self._tts_buffer_started_at = 0.0
if self._incoming_audio_task:
self._incoming_audio_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._incoming_audio_task
self._incoming_audio_task = None
if self._stt_worker_task:
self._stt_worker_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._stt_worker_task
self._stt_worker_task = None
if self._stt_warmup_task:
self._stt_warmup_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self._stt_warmup_task
self._stt_warmup_task = None
await self._close_peer_connection()
def _schedule_tts_flush(self) -> None:
if self._closed:
return
asyncio.create_task(self._flush_tts(), name="voice-tts-flush")
def _schedule_tts_flush_after(self, delay_s: float) -> None:
if self._tts_flush_handle:
self._tts_flush_handle.cancel()
loop = asyncio.get_running_loop()
self._tts_flush_handle = loop.call_later(
max(0.05, delay_s), self._schedule_tts_flush
)
async def _flush_tts(self) -> None:
async with self._tts_flush_lock:
async with self._tts_buffer_lock:
self._tts_flush_handle = None
raw_text = self._tts_buffer
buffer_started_at = self._tts_buffer_started_at
self._tts_buffer = ""
self._tts_buffer_started_at = 0.0
clean_text = self._clean_tts_text(raw_text)
if not clean_text:
return
loop = asyncio.get_running_loop()
now = loop.time()
buffer_age_ms = int(max(0.0, now - buffer_started_at) * 1000)
should_wait_for_more = (
len(clean_text) < self._tts_min_chars
and not SENTENCE_END_RE.search(clean_text)
and buffer_age_ms < self._tts_max_wait_ms
)
if should_wait_for_more:
async with self._tts_buffer_lock:
if self._tts_buffer.strip():
self._tts_buffer = f"{clean_text}\n{self._tts_buffer}".strip()
else:
self._tts_buffer = clean_text
if self._tts_buffer_started_at <= 0.0:
self._tts_buffer_started_at = (
buffer_started_at if buffer_started_at > 0.0 else now
)
self._schedule_tts_flush_after(self._tts_flush_delay_s)
return
if not self._outbound_track:
return
for part in self._chunk_tts_text(clean_text):
try:
audio = await self._tts.synthesize(part)
except Exception as exc:
await self._publish_system(f"Host TTS failed: {exc}")
return
if not audio:
if not self._tts_unavailable_notice_sent:
self._tts_unavailable_notice_sent = True
await self._publish_system(
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
)
return
if not self._outbound_track:
return
self._extend_stt_suppression(audio)
await self._outbound_track.enqueue_pcm(
pcm=audio.pcm,
sample_rate=audio.sample_rate,
channels=audio.channels,
)
def _extend_stt_suppression(self, audio: PCMChunk) -> None:
if not self._stt_suppress_during_tts:
return
channels = max(1, int(audio.channels))
sample_rate = max(1, int(audio.sample_rate))
sample_count = len(audio.pcm) // (2 * channels)
if sample_count <= 0:
return
duration_s = sample_count / float(sample_rate)
cooldown_s = float(self._stt_suppress_ms_after_tts) / 1000.0
now = asyncio.get_running_loop().time()
base = max(now, self._stt_suppress_until)
self._stt_suppress_until = base + duration_s + cooldown_s
async def _consume_audio_track(self, track: MediaStreamTrack) -> None:
if not self._stt.enabled:
try:
while True:
await track.recv()
except asyncio.CancelledError:
raise
except Exception:
return
resample_state = None
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
try:
while True:
frame = await track.recv()
pcm16, frame_ms, resample_state = self._frame_to_pcm16k_mono(
frame, resample_state
)
if not pcm16:
continue
if not self._audio_seen_notice_sent:
self._audio_seen_notice_sent = True
await self._publish_system("Receiving microphone audio on host.")
if not self._audio_format_notice_sent:
self._audio_format_notice_sent = True
await self._publish_system(
"Inbound audio frame stats: "
f"sample_rate={int(getattr(frame, 'sample_rate', 0) or 0)}, "
f"samples={int(getattr(frame, 'samples', 0) or 0)}, "
f"time_base={getattr(frame, 'time_base', None)}."
)
if (
self._stt_suppress_during_tts
and asyncio.get_running_loop().time() < self._stt_suppress_until
):
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
continue
if self._ptt_pressed:
if not recording:
recording = True
recording_started_at = asyncio.get_running_loop().time()
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
if not recording_truncated:
next_segment_ms = segment_ms + frame_ms
if next_segment_ms <= self._stt_max_ptt_ms:
segment_buffer.extend(pcm16)
segment_ms = next_segment_ms
else:
recording_truncated = True
await self._publish_system(
"PTT max length reached; extra audio will be ignored until release."
)
continue
if recording:
observed_duration_ms = max(
1.0,
(asyncio.get_running_loop().time() - recording_started_at)
* 1000.0,
)
await self._finalize_ptt_segment(
bytes(segment_buffer),
segment_ms,
observed_duration_ms=observed_duration_ms,
)
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
except asyncio.CancelledError:
raise
except Exception as exc:
details = str(exc).strip()
if details:
await self._publish_system(
f"Voice input stream ended ({exc.__class__.__name__}): {details}"
)
else:
await self._publish_system(
f"Voice input stream ended ({exc.__class__.__name__})."
)
finally:
if recording and segment_ms >= self._stt_min_ptt_ms:
observed_duration_ms = max(
1.0,
(asyncio.get_running_loop().time() - recording_started_at) * 1000.0,
)
await self._finalize_ptt_segment(
bytes(segment_buffer),
segment_ms,
observed_duration_ms=observed_duration_ms,
)
async def _finalize_ptt_segment(
self,
pcm16: bytes,
duration_ms: float,
observed_duration_ms: float | None = None,
) -> None:
if not pcm16 or duration_ms <= 0.0:
return
normalized_pcm = pcm16
normalized_duration_ms = duration_ms
if observed_duration_ms is not None and observed_duration_ms > 0.0:
duration_ratio = duration_ms / observed_duration_ms
if duration_ratio < 0.70 or duration_ratio > 1.40:
estimated_source_rate = int(round(16_000 * duration_ratio))
estimated_source_rate = max(8_000, min(96_000, estimated_source_rate))
candidate_rates = [
8_000,
12_000,
16_000,
24_000,
32_000,
44_100,
48_000,
]
nearest_source_rate = min(
candidate_rates,
key=lambda candidate: abs(candidate - estimated_source_rate),
)
if nearest_source_rate != 16_000:
normalized_pcm, _state = audioop.ratecv(
pcm16,
2,
1,
nearest_source_rate,
16_000,
None,
)
normalized_duration_ms = (len(normalized_pcm) / 2 / 16_000) * 1000.0
if not self._ptt_timing_correction_notice_sent:
self._ptt_timing_correction_notice_sent = True
await self._publish_system(
"Corrected PTT timing mismatch "
f"(estimated source={nearest_source_rate}Hz)."
)
await self._enqueue_stt_segment(
pcm16=normalized_pcm, duration_ms=normalized_duration_ms
)
async def _enqueue_stt_segment(self, pcm16: bytes, duration_ms: float) -> None:
if duration_ms < self._stt_min_ptt_ms:
return
if self._stt_segments.full():
with contextlib.suppress(asyncio.QueueEmpty):
self._stt_segments.get_nowait()
now = asyncio.get_running_loop().time()
if (
now - self._last_stt_backlog_notice_at
) >= self._stt_backlog_notice_interval_s:
self._last_stt_backlog_notice_at = now
await self._publish_system(
"Voice input backlog detected; dropping stale segment."
)
with contextlib.suppress(asyncio.QueueFull):
self._stt_segments.put_nowait(pcm16)
async def _stt_worker(self) -> None:
while True:
pcm16 = await self._stt_segments.get()
if not self._stt_first_segment_notice_sent:
self._stt_first_segment_notice_sent = True
await self._publish_system(
"Push-to-talk audio captured. Running host STT..."
)
try:
transcript = await self._stt.transcribe_pcm(
pcm=pcm16,
sample_rate=16_000,
channels=1,
)
except asyncio.CancelledError:
raise
except Exception as exc:
await self._publish_system(f"Host STT failed: {exc}")
continue
if not transcript:
continue
transcript = transcript.strip()
if not transcript:
continue
await self._gateway.bus.publish(
WisperEvent(role="wisper", text=f"voice transcript: {transcript}")
)
await self._gateway.send_user_message(transcript)
async def _close_peer_connection(self) -> None:
if self._outbound_track:
self._outbound_track.stop()
self._outbound_track = None
if self._pc:
await self._pc.close()
self._pc = None
self._pending_ice_candidates.clear()
async def _wait_for_ice_gathering(self, peer_connection: RTCPeerConnection) -> None:
if peer_connection.iceGatheringState == "complete":
return
completed = asyncio.Event()
@peer_connection.on("icegatheringstatechange")
def on_icegatheringstatechange() -> None:
if peer_connection.iceGatheringState == "complete":
completed.set()
with contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(completed.wait(), timeout=3)
async def _warmup_stt(self) -> None:
try:
await self._stt.warmup()
except asyncio.CancelledError:
raise
except Exception:
return
async def _publish_system(self, text: str) -> None:
await self._gateway.bus.publish(WisperEvent(role="system", text=text))
def _clean_tts_text(self, raw_text: str) -> str:
lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
useful_lines = [
line
for line in lines
if not SPEECH_FILTER_RE.match(line)
and not THINKING_STATUS_RE.search(line)
and not USER_PREFIX_RE.match(line)
and not VOICE_TRANSCRIPT_RE.match(line)
]
return _sanitize_tts_text(" ".join(useful_lines))
def _chunk_tts_text(self, text: str) -> list[str]:
clean_text = " ".join(text.split())
if not clean_text:
return []
max_chars = max(60, int(self._tts_max_chunk_chars))
if len(clean_text) <= max_chars:
return [clean_text]
sentence_parts = [
part.strip()
for part in re.split(r"(?<=[.!?])\s+", clean_text)
if part.strip()
]
if not sentence_parts:
sentence_parts = [clean_text]
chunks: list[str] = []
for part in sentence_parts:
if len(part) <= max_chars:
chunks.append(part)
else:
chunks.extend(self._chunk_tts_words(part, max_chars))
return chunks
def _chunk_tts_words(self, text: str, max_chars: int) -> list[str]:
words = [word for word in text.split() if word]
if not words:
return []
chunks: list[str] = []
current_words: list[str] = []
current_len = 0
for word in words:
if len(word) > max_chars:
if current_words:
chunks.append(" ".join(current_words))
current_words = []
current_len = 0
start = 0
while start < len(word):
end = min(start + max_chars, len(word))
piece = word[start:end]
if len(piece) == max_chars:
chunks.append(piece)
else:
current_words = [piece]
current_len = len(piece)
start = end
continue
extra = len(word) if not current_words else (1 + len(word))
if current_words and (current_len + extra) > max_chars:
chunks.append(" ".join(current_words))
current_words = [word]
current_len = len(word)
else:
current_words.append(word)
current_len += extra
if current_words:
chunks.append(" ".join(current_words))
return chunks
def _frame_to_pcm16k_mono(
self, frame: AudioFrame, resample_state: tuple[Any, ...] | None
) -> tuple[bytes, float, tuple[Any, ...] | None]:
try:
pcm = frame.to_ndarray(format="s16")
except TypeError:
pcm = frame.to_ndarray()
if (
NUMPY_AVAILABLE
and np is not None
and getattr(pcm, "dtype", None) is not None
):
if pcm.dtype != np.int16:
if np.issubdtype(pcm.dtype, np.floating):
pcm = np.clip(pcm, -1.0, 1.0)
pcm = (pcm * 32767.0).astype(np.int16)
else:
pcm = pcm.astype(np.int16)
if pcm.ndim == 1:
mono = pcm.tobytes()
elif pcm.ndim == 2:
expected_channels = 0
if getattr(frame, "layout", None) is not None:
with contextlib.suppress(Exception):
expected_channels = len(frame.layout.channels)
rows = int(pcm.shape[0])
cols = int(pcm.shape[1])
# Normalize to [frames, channels] to avoid accidental channel mis-detection.
if expected_channels > 0:
if rows == expected_channels:
frames_channels = pcm.T
elif cols == expected_channels:
frames_channels = pcm
else:
frames_channels = pcm.reshape(-1, 1)
else:
if rows == 1:
frames_channels = pcm.T
elif cols == 1:
frames_channels = pcm
elif rows <= 8 and cols > rows:
frames_channels = pcm.T
elif cols <= 8 and rows > cols:
frames_channels = pcm
else:
frames_channels = pcm.reshape(-1, 1)
channel_count = (
int(frames_channels.shape[1]) if frames_channels.ndim == 2 else 1
)
if channel_count <= 1:
mono = frames_channels.reshape(-1).tobytes()
elif NUMPY_AVAILABLE and np is not None:
mixed = frames_channels.astype(np.int32).mean(axis=1)
mono = np.clip(mixed, -32768, 32767).astype(np.int16).tobytes()
elif channel_count == 2:
interleaved = frames_channels.reshape(-1).tobytes()
mono = audioop.tomono(interleaved, 2, 0.5, 0.5)
else:
mono = frames_channels[:, 0].reshape(-1).tobytes()
else:
return b"", 0.0, resample_state
source_rate = int(
getattr(frame, "sample_rate", 0) or getattr(frame, "rate", 0) or 0
)
time_base = getattr(frame, "time_base", None)
tb_rate = 0
if time_base is not None:
with contextlib.suppress(Exception):
numerator = int(getattr(time_base, "numerator", 0))
denominator = int(getattr(time_base, "denominator", 0))
if numerator == 1 and denominator > 0:
tb_rate = denominator
samples_per_channel = int(getattr(frame, "samples", 0) or 0)
if samples_per_channel > 0:
candidate_rates = [8_000, 16_000, 24_000, 32_000, 44_100, 48_000]
inferred_rate = min(
candidate_rates,
key=lambda rate: abs((samples_per_channel / float(rate)) - 0.020),
)
inferred_frame_ms = (samples_per_channel / float(inferred_rate)) * 1000.0
# If metadata suggests implausibly long frames, trust the inferred rate instead.
if (
source_rate <= 0
or (samples_per_channel / float(max(1, source_rate))) * 1000.0 > 40.0
):
source_rate = inferred_rate
elif abs(inferred_frame_ms - 20.0) <= 2.5 and source_rate not in {
inferred_rate,
tb_rate,
}:
source_rate = inferred_rate
if tb_rate > 0 and (source_rate <= 0 or abs(tb_rate - source_rate) > 2_000):
source_rate = tb_rate
if source_rate <= 0:
source_rate = 48_000
if source_rate != 16_000:
mono, resample_state = audioop.ratecv(
mono,
2,
1,
source_rate,
16_000,
resample_state,
)
if not mono:
return b"", 0.0, resample_state
sample_count = len(mono) // 2
duration_ms = (sample_count / 16_000) * 1000
return mono, duration_ms, resample_state