import asyncio import audioop import contextlib import io import os import re import shlex import shutil import subprocess import tempfile import wave from dataclasses import dataclass from fractions import Fraction from typing import TYPE_CHECKING, Any, Awaitable, Callable from wisper import WisperEvent if TYPE_CHECKING: from supertonic_gateway import SuperTonicGateway try: import numpy as np NUMPY_AVAILABLE = True except Exception: # pragma: no cover - runtime fallback when numpy is unavailable np = None # type: ignore[assignment] NUMPY_AVAILABLE = False try: from supertonic import TTS as SupertonicTTS SUPERTONIC_TTS_AVAILABLE = True except Exception: # pragma: no cover - runtime fallback when supertonic is unavailable SupertonicTTS = None # type: ignore[assignment] SUPERTONIC_TTS_AVAILABLE = False try: from faster_whisper import WhisperModel FASTER_WHISPER_AVAILABLE = True except ( Exception ): # pragma: no cover - runtime fallback when faster-whisper is unavailable WhisperModel = None # type: ignore[assignment] FASTER_WHISPER_AVAILABLE = False try: from aiortc import RTCPeerConnection, RTCSessionDescription from aiortc.mediastreams import MediaStreamTrack from aiortc.sdp import candidate_from_sdp from av import AudioFrame AIORTC_AVAILABLE = True except Exception: # pragma: no cover - runtime fallback when aiortc is unavailable RTCPeerConnection = None # type: ignore[assignment] RTCSessionDescription = None # type: ignore[assignment] MediaStreamTrack = object # type: ignore[assignment,misc] candidate_from_sdp = None # type: ignore[assignment] AudioFrame = None # type: ignore[assignment] AIORTC_AVAILABLE = False SPEECH_FILTER_RE = re.compile( r"^(spawned nanobot tui|stopped nanobot tui|nanobot tui exited|websocket)", re.IGNORECASE, ) THINKING_STATUS_RE = re.compile(r"\bnanobot is thinking\b", re.IGNORECASE) USER_PREFIX_RE = re.compile(r"^(?:you|user)\s*:\s*", re.IGNORECASE) VOICE_TRANSCRIPT_RE = re.compile( r"^(?:wisper\s*:\s*)?voice\s+transcript\s*:\s*", re.IGNORECASE, ) ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]") BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]") SENTENCE_END_RE = re.compile(r"[.!?]\s*$") TTS_ALLOWED_ASCII = set( "ABCDEFGHIJKLMNOPQRSTUVWXYZ" "abcdefghijklmnopqrstuvwxyz" "0123456789" " .,!?;:'\"()[]{}@#%&*+-_/<>|" ) def _sanitize_tts_text(text: str) -> str: cleaned = ANSI_ESCAPE_RE.sub(" ", text) cleaned = BRAILLE_SPINNER_RE.sub(" ", cleaned) cleaned = cleaned.replace("\u00a0", " ") cleaned = cleaned.replace("•", " ") cleaned = CONTROL_CHAR_RE.sub(" ", cleaned) cleaned = "".join( ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned ) cleaned = re.sub(r"\s+", " ", cleaned).strip() return cleaned def _optional_int_env(name: str) -> int | None: raw_value = os.getenv(name, "").strip() if not raw_value: return None return int(raw_value) @dataclass(slots=True) class PCMChunk: pcm: bytes sample_rate: int channels: int = 1 if AIORTC_AVAILABLE: class QueueAudioTrack(MediaStreamTrack): kind = "audio" def __init__(self, sample_rate: int = 48_000, frame_ms: int = 20) -> None: super().__init__() self._sample_rate = sample_rate self._frame_ms = max(1, frame_ms) self._samples_per_frame = max(1, (sample_rate * frame_ms) // 1000) self._bytes_per_frame = self._samples_per_frame * 2 self._queue: asyncio.Queue[bytes] = asyncio.Queue() self._timestamp = 0 self._resample_state = None self._resample_source_rate: int | None = None self._lead_in_ms = max( 0, int(os.getenv("HOST_RTC_OUTBOUND_LEAD_IN_MS", "120")) ) self._lead_in_frames = ( self._lead_in_ms + self._frame_ms - 1 ) // self._frame_ms self._lead_in_idle_s = max( 0.1, float(os.getenv("HOST_RTC_OUTBOUND_IDLE_S", "0.6")) ) self._last_enqueue_at = 0.0 self._closed = False self._frame_duration_s = frame_ms / 1000.0 self._last_recv_at = 0.0 async def enqueue_pcm( self, pcm: bytes, sample_rate: int, channels: int = 1 ) -> None: if self._closed or not pcm: return now = asyncio.get_running_loop().time() should_add_lead_in = ( self._lead_in_frames > 0 and self._queue.empty() and ( self._last_enqueue_at <= 0.0 or (now - self._last_enqueue_at) >= self._lead_in_idle_s ) ) if should_add_lead_in: silence = b"\x00" * self._bytes_per_frame for _index in range(self._lead_in_frames): await self._queue.put(silence) mono = pcm if channels > 1: mono = audioop.tomono(mono, 2, 0.5, 0.5) if sample_rate != self._sample_rate: # audioop rate conversion state is only valid when source/destination rates stay the same. if self._resample_source_rate != sample_rate: self._resample_state = None self._resample_source_rate = sample_rate mono, self._resample_state = audioop.ratecv( mono, 2, 1, sample_rate, self._sample_rate, self._resample_state, ) else: self._resample_state = None self._resample_source_rate = None if not mono: return for start in range(0, len(mono), self._bytes_per_frame): chunk = mono[start : start + self._bytes_per_frame] if len(chunk) < self._bytes_per_frame: chunk += b"\x00" * (self._bytes_per_frame - len(chunk)) await self._queue.put(chunk) self._last_enqueue_at = now async def recv(self) -> AudioFrame: if self._closed: raise asyncio.CancelledError # Pace frame delivery to real-time to prevent RTP burst sends. # Without pacing, when TTS enqueues audio faster than real-time, # aiortc sends RTP packets in a burst and the browser's jitter # buffer skips ahead, causing the user to only hear the tail end. loop = asyncio.get_running_loop() now = loop.time() if self._last_recv_at > 0.0: elapsed = now - self._last_recv_at remaining = self._frame_duration_s - elapsed if remaining > 0.001: await asyncio.sleep(remaining) try: payload = self._queue.get_nowait() except asyncio.QueueEmpty: payload = b"\x00" * self._bytes_per_frame self._last_recv_at = loop.time() frame = AudioFrame( format="s16", layout="mono", samples=self._samples_per_frame ) frame.planes[0].update(payload) frame.sample_rate = self._sample_rate frame.time_base = Fraction(1, self._sample_rate) frame.pts = self._timestamp self._timestamp += self._samples_per_frame return frame def stop(self) -> None: self._closed = True super().stop() else: class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable async def enqueue_pcm( self, pcm: bytes, sample_rate: int, channels: int = 1 ) -> None: return def stop(self) -> None: return def _write_temp_wav(pcm: bytes, sample_rate: int, channels: int) -> str: with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file: tmp_path = tmp_file.name with wave.open(tmp_path, "wb") as wav_file: wav_file.setnchannels(max(1, channels)) wav_file.setsampwidth(2) wav_file.setframerate(sample_rate) wav_file.writeframes(pcm) return tmp_path class CommandSpeechToText: def __init__(self) -> None: self._command_template = os.getenv("HOST_STT_COMMAND", "").strip() @property def enabled(self) -> bool: return bool(self._command_template) async def transcribe_pcm( self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1 ) -> str | None: if not self.enabled or not pcm: return None return await asyncio.to_thread( self._transcribe_blocking, pcm, sample_rate, channels ) def unavailable_reason(self) -> str: if not self._command_template: return "HOST_STT_COMMAND is not configured." return "HOST_STT_COMMAND failed to produce transcript." def _transcribe_blocking( self, pcm: bytes, sample_rate: int, channels: int ) -> str | None: tmp_path: str | None = None try: tmp_path = _write_temp_wav( pcm=pcm, sample_rate=sample_rate, channels=channels ) command = self._command_template if "{input_wav}" in command: command = command.replace("{input_wav}", shlex.quote(tmp_path)) else: command = f"{command} {shlex.quote(tmp_path)}" result = subprocess.run( command, shell=True, capture_output=True, text=True, check=False, ) if result.returncode != 0: stderr = result.stderr.strip() or "unknown error" raise RuntimeError(f"STT command failed: {stderr}") transcript = result.stdout.strip() return transcript or None finally: if tmp_path and os.path.exists(tmp_path): with contextlib.suppress(OSError): os.unlink(tmp_path) class FasterWhisperSpeechToText: def __init__(self) -> None: self._model_name = os.getenv("HOST_STT_MODEL", "base.en").strip() or "base.en" self._device = os.getenv("HOST_STT_DEVICE", "auto").strip() or "auto" self._compute_type = ( os.getenv("HOST_STT_COMPUTE_TYPE", "int8").strip() or "int8" ) self._language = os.getenv("HOST_STT_LANGUAGE", "en").strip() self._beam_size = max(1, int(os.getenv("HOST_STT_BEAM_SIZE", "2"))) self._best_of = max(1, int(os.getenv("HOST_STT_BEST_OF", "2"))) self._vad_filter = os.getenv("HOST_STT_VAD_FILTER", "0").strip() not in { "0", "false", "False", "no", "off", } self._temperature = float(os.getenv("HOST_STT_TEMPERATURE", "0.0")) self._log_prob_threshold = float( os.getenv("HOST_STT_LOG_PROB_THRESHOLD", "-1.0") ) self._no_speech_threshold = float( os.getenv("HOST_STT_NO_SPEECH_THRESHOLD", "0.6") ) self._compression_ratio_threshold = float( os.getenv("HOST_STT_COMPRESSION_RATIO_THRESHOLD", "2.4") ) self._initial_prompt = ( os.getenv( "HOST_STT_INITIAL_PROMPT", "Transcribe brief spoken English precisely. Prefer common words over sound effects.", ).strip() or None ) self._model: Any = None self._init_error: str | None = None self._lock = asyncio.Lock() @property def enabled(self) -> bool: return FASTER_WHISPER_AVAILABLE and WhisperModel is not None @property def init_error(self) -> str | None: return self._init_error async def transcribe_pcm( self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1 ) -> str | None: if not self.enabled or not pcm: return None async with self._lock: return await asyncio.to_thread( self._transcribe_blocking, pcm, sample_rate, channels ) async def warmup(self) -> None: if not self.enabled: return async with self._lock: await asyncio.to_thread(self._initialize_blocking) def _initialize_blocking(self) -> None: if self._model is not None: return if not self.enabled or WhisperModel is None: return try: self._model = WhisperModel( self._model_name, device=self._device, compute_type=self._compute_type, ) self._init_error = None except Exception as exc: self._init_error = str(exc) self._model = None def _transcribe_blocking( self, pcm: bytes, sample_rate: int, channels: int ) -> str | None: self._initialize_blocking() if self._model is None: if self._init_error: raise RuntimeError( f"faster-whisper initialization failed: {self._init_error}" ) return None if NUMPY_AVAILABLE and np is not None: mono = pcm if channels > 1: mono = audioop.tomono(mono, 2, 0.5, 0.5) if sample_rate != 16_000: mono, _ = audioop.ratecv( mono, 2, 1, sample_rate, 16_000, None, ) audio = np.frombuffer(mono, dtype=np.int16).astype(np.float32) / 32768.0 if audio.size == 0: return None segments, _info = self._model.transcribe( audio, language=self._language or None, beam_size=self._beam_size, best_of=self._best_of, vad_filter=self._vad_filter, condition_on_previous_text=False, without_timestamps=True, initial_prompt=self._initial_prompt, temperature=self._temperature, log_prob_threshold=self._log_prob_threshold, no_speech_threshold=self._no_speech_threshold, compression_ratio_threshold=self._compression_ratio_threshold, ) transcript_parts: list[str] = [] for segment in segments: text = str(getattr(segment, "text", "")).strip() if text: transcript_parts.append(text) transcript = " ".join(transcript_parts).strip() return transcript or None tmp_path: str | None = None try: tmp_path = _write_temp_wav( pcm=pcm, sample_rate=sample_rate, channels=channels ) segments, _info = self._model.transcribe( tmp_path, language=self._language or None, beam_size=self._beam_size, best_of=self._best_of, vad_filter=self._vad_filter, condition_on_previous_text=False, without_timestamps=True, initial_prompt=self._initial_prompt, temperature=self._temperature, log_prob_threshold=self._log_prob_threshold, no_speech_threshold=self._no_speech_threshold, compression_ratio_threshold=self._compression_ratio_threshold, ) transcript_parts: list[str] = [] for segment in segments: text = str(getattr(segment, "text", "")).strip() if text: transcript_parts.append(text) transcript = " ".join(transcript_parts).strip() return transcript or None finally: if tmp_path and os.path.exists(tmp_path): with contextlib.suppress(OSError): os.unlink(tmp_path) class HostSpeechToText: def __init__(self) -> None: provider = ( os.getenv("HOST_STT_PROVIDER", "faster-whisper").strip() or "faster-whisper" ).lower() if provider not in {"faster-whisper", "command", "auto"}: provider = "auto" self._provider = provider self._faster_whisper = FasterWhisperSpeechToText() self._command = CommandSpeechToText() @property def enabled(self) -> bool: if self._provider == "faster-whisper": return self._faster_whisper.enabled if self._provider == "command": return self._command.enabled return self._faster_whisper.enabled or self._command.enabled async def transcribe_pcm( self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1 ) -> str | None: if self._provider in {"faster-whisper", "auto"}: transcript = await self._faster_whisper.transcribe_pcm( pcm=pcm, sample_rate=sample_rate, channels=channels, ) if transcript: return transcript if self._provider == "faster-whisper": return None if self._provider in {"command", "auto"}: return await self._command.transcribe_pcm( pcm=pcm, sample_rate=sample_rate, channels=channels, ) return None async def warmup(self) -> None: if self._provider in {"faster-whisper", "auto"}: await self._faster_whisper.warmup() def unavailable_reason(self) -> str: if self._provider == "faster-whisper": if not self._faster_whisper.enabled: return "faster-whisper package is not available." if self._faster_whisper.init_error: return f"faster-whisper initialization failed: {self._faster_whisper.init_error}" return "faster-whisper did not return transcript." if self._provider == "command": return self._command.unavailable_reason() if self._faster_whisper.init_error: return f"faster-whisper initialization failed: {self._faster_whisper.init_error}" if self._command.enabled: return "HOST_STT_COMMAND failed to produce transcript." if not self._faster_whisper.enabled: return "faster-whisper package is not available." return "No STT provider is configured." class SupertonicTextToSpeech: def __init__(self) -> None: self._model = ( os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2" ) self._voice_style_name = ( os.getenv("SUPERTONIC_VOICE_STYLE", "M1").strip() or "M1" ) self._lang = os.getenv("SUPERTONIC_LANG", "en").strip() or "en" self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "5")) self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.05")) self._intra_op_num_threads = _optional_int_env("SUPERTONIC_INTRA_OP_THREADS") self._inter_op_num_threads = _optional_int_env("SUPERTONIC_INTER_OP_THREADS") self._auto_download = os.getenv( "SUPERTONIC_AUTO_DOWNLOAD", "1" ).strip() not in { "0", "false", "False", "no", "off", } self._engine: Any = None self._voice_style: Any = None self._init_error: str | None = None self._lock = asyncio.Lock() @property def enabled(self) -> bool: return ( SUPERTONIC_TTS_AVAILABLE and SupertonicTTS is not None and NUMPY_AVAILABLE ) @property def init_error(self) -> str | None: return self._init_error async def synthesize(self, text: str) -> PCMChunk | None: if not self.enabled: return None clean_text = " ".join(text.split()) if not clean_text: return None async with self._lock: return await asyncio.to_thread(self._synthesize_blocking, clean_text) def _synthesize_blocking(self, text: str) -> PCMChunk | None: self._initialize_blocking() if self._engine is None or self._voice_style is None or np is None: return None text = _sanitize_tts_text(text) if not text: return None try: wav, _duration = self._engine.synthesize( text, voice_style=self._voice_style, lang=self._lang, total_steps=self._total_steps, speed=self._speed, ) except ValueError as exc: message = str(exc) if "unsupported character" not in message.lower(): raise fallback_text = self._sanitize_text_for_supertonic(text) if not fallback_text or fallback_text == text: raise wav, _duration = self._engine.synthesize( fallback_text, voice_style=self._voice_style, lang=self._lang, total_steps=self._total_steps, speed=self._speed, ) samples = np.asarray(wav) if samples.size == 0: return None channels = 1 if samples.ndim == 0: samples = samples.reshape(1) elif samples.ndim == 1: channels = 1 elif samples.ndim == 2: # Normalize to frames x channels so PCM bytes are correctly interleaved. dim0, dim1 = int(samples.shape[0]), int(samples.shape[1]) if dim0 <= 2 and dim1 > dim0: channels = dim0 samples = samples.T elif dim1 <= 2 and dim0 > dim1: channels = dim1 else: channels = 1 samples = samples.reshape(-1) else: channels = 1 samples = samples.reshape(-1) if np.issubdtype(samples.dtype, np.floating): samples = np.clip(samples, -1.0, 1.0) samples = (samples * 32767.0).astype(np.int16) else: if samples.dtype != np.int16: samples = samples.astype(np.int16) pcm = samples.tobytes() return PCMChunk( pcm=pcm, sample_rate=int(getattr(self._engine, "sample_rate", 24_000)), channels=max(1, channels), ) def _sanitize_text_for_supertonic(self, text: str) -> str: return _sanitize_tts_text(text) def _initialize_blocking(self) -> None: if self._engine is not None and self._voice_style is not None: return if not self.enabled or SupertonicTTS is None: return try: engine = SupertonicTTS( model=self._model, auto_download=self._auto_download, intra_op_num_threads=self._intra_op_num_threads, inter_op_num_threads=self._inter_op_num_threads, ) voice_style = engine.get_voice_style(self._voice_style_name) except Exception as exc: self._init_error = str(exc) return self._engine = engine self._voice_style = voice_style self._init_error = None class HostTextToSpeech: def __init__(self) -> None: provider = ( os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic" ).lower() if provider not in {"supertonic", "command", "espeak", "auto"}: provider = "auto" self._provider = provider self._supertonic = SupertonicTextToSpeech() self._command_template = os.getenv("HOST_TTS_COMMAND", "").strip() self._espeak = shutil.which("espeak") @property def enabled(self) -> bool: if self._provider == "supertonic": return self._supertonic.enabled if self._provider == "command": return bool(self._command_template) if self._provider == "espeak": return bool(self._espeak) return self._supertonic.enabled or bool(self._command_template or self._espeak) async def synthesize(self, text: str) -> PCMChunk | None: clean_text = " ".join(text.split()) if not clean_text: return None if self._provider in {"supertonic", "auto"}: audio = await self._supertonic.synthesize(clean_text) if audio: return audio if self._provider == "supertonic": return None if self._provider in {"command", "auto"} and self._command_template: return await asyncio.to_thread(self._synthesize_with_command, clean_text) if self._provider == "command": return None if self._provider in {"espeak", "auto"} and self._espeak: return await asyncio.to_thread(self._synthesize_with_espeak, clean_text) return None def unavailable_reason(self) -> str: if self._provider == "supertonic": if not self._supertonic.enabled: return "supertonic package is not available." if self._supertonic.init_error: return ( f"supertonic initialization failed: {self._supertonic.init_error}" ) return "supertonic did not return audio." if self._provider == "command": return "HOST_TTS_COMMAND is not configured." if self._provider == "espeak": return "espeak binary is not available." if self._supertonic.init_error: return f"supertonic initialization failed: {self._supertonic.init_error}" if self._command_template: return "HOST_TTS_COMMAND failed to produce audio." if self._espeak: return "espeak failed to produce audio." return "No TTS provider is configured." def _synthesize_with_command(self, text: str) -> PCMChunk | None: command = self._command_template if "{text}" in command: command = command.replace("{text}", shlex.quote(text)) else: command = f"{command} {shlex.quote(text)}" if "{output_wav}" in command: tmp_path: str | None = None try: with tempfile.NamedTemporaryFile( suffix=".wav", delete=False ) as tmp_file: tmp_path = tmp_file.name command_with_output = command.replace( "{output_wav}", shlex.quote(tmp_path) ) result = subprocess.run( command_with_output, shell=True, capture_output=True, text=True, check=False, ) if result.returncode != 0: stderr = result.stderr.strip() or "unknown error" raise RuntimeError(f"TTS command failed: {stderr}") return self._read_wav_file(tmp_path) finally: if tmp_path and os.path.exists(tmp_path): with contextlib.suppress(OSError): os.unlink(tmp_path) result = subprocess.run( command, shell=True, capture_output=True, check=False, ) if result.returncode != 0: stderr = result.stderr.decode(errors="ignore").strip() or "unknown error" raise RuntimeError(f"TTS command failed: {stderr}") return self._decode_wav_bytes(result.stdout) def _synthesize_with_espeak(self, text: str) -> PCMChunk | None: if not self._espeak: return None result = subprocess.run( [self._espeak, "--stdout", text], capture_output=True, check=False, ) if result.returncode != 0: stderr = result.stderr.decode(errors="ignore").strip() or "unknown error" raise RuntimeError(f"espeak failed: {stderr}") return self._decode_wav_bytes(result.stdout) def _read_wav_file(self, path: str) -> PCMChunk | None: try: with open(path, "rb") as wav_file: return self._decode_wav_bytes(wav_file.read()) except OSError: return None def _decode_wav_bytes(self, payload: bytes) -> PCMChunk | None: if not payload: return None with wave.open(io.BytesIO(payload), "rb") as wav_file: channels = wav_file.getnchannels() sample_width = wav_file.getsampwidth() sample_rate = wav_file.getframerate() pcm = wav_file.readframes(wav_file.getnframes()) if sample_width != 2: pcm = audioop.lin2lin(pcm, sample_width, 2) return PCMChunk(pcm=pcm, sample_rate=sample_rate, channels=max(1, channels)) SendJsonCallable = Callable[[dict[str, Any]], Awaitable[None]] class WebRTCVoiceSession: def __init__( self, gateway: "SuperTonicGateway", send_json: SendJsonCallable ) -> None: self._gateway = gateway self._send_json = send_json self._pc: RTCPeerConnection | None = None self._outbound_track: QueueAudioTrack | None = None self._incoming_audio_task: asyncio.Task[None] | None = None self._stt_worker_task: asyncio.Task[None] | None = None self._stt_warmup_task: asyncio.Task[None] | None = None self._stt = HostSpeechToText() self._tts = HostTextToSpeech() self._stt_segment_queue_size = max( 1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2")) ) self._stt_segments: asyncio.Queue[bytes] = asyncio.Queue( maxsize=self._stt_segment_queue_size ) self._tts_buffer = "" self._tts_flush_handle: asyncio.TimerHandle | None = None self._tts_flush_lock = asyncio.Lock() self._tts_buffer_lock = asyncio.Lock() self._tts_flush_delay_s = max( 0.08, float(os.getenv("HOST_TTS_FLUSH_DELAY_S", "0.45")) ) self._tts_sentence_flush_delay_s = max( 0.06, float(os.getenv("HOST_TTS_SENTENCE_FLUSH_DELAY_S", "0.15")), ) self._tts_min_chars = max(1, int(os.getenv("HOST_TTS_MIN_CHARS", "10"))) self._tts_max_wait_ms = max(300, int(os.getenv("HOST_TTS_MAX_WAIT_MS", "1800"))) self._tts_max_chunk_chars = max( 60, int(os.getenv("HOST_TTS_MAX_CHUNK_CHARS", "140")) ) self._tts_buffer_started_at = 0.0 self._closed = False self._stt_unavailable_notice_sent = False self._tts_unavailable_notice_sent = False self._audio_seen_notice_sent = False self._audio_format_notice_sent = False self._stt_first_segment_notice_sent = False self._ptt_timing_correction_notice_sent = False self._stt_min_ptt_ms = max( 120, int( os.getenv( "HOST_STT_MIN_PTT_MS", os.getenv("HOST_STT_MIN_SEGMENT_MS", "220") ) ), ) self._stt_max_ptt_ms = max( self._stt_min_ptt_ms, int( os.getenv( "HOST_STT_MAX_PTT_MS", os.getenv("HOST_STT_MAX_SEGMENT_MS", "12000") ) ), ) self._stt_suppress_during_tts = os.getenv( "HOST_STT_SUPPRESS_DURING_TTS", "1" ).strip() not in { "0", "false", "False", "no", "off", } self._stt_suppress_ms_after_tts = max( 0, int(os.getenv("HOST_STT_SUPPRESS_MS_AFTER_TTS", "300")), ) self._stt_suppress_until = 0.0 self._stt_backlog_notice_interval_s = max( 2.0, float(os.getenv("HOST_STT_BACKLOG_NOTICE_INTERVAL_S", "6.0")), ) self._last_stt_backlog_notice_at = 0.0 self._pending_ice_candidates: list[dict[str, Any] | None] = [] self._ptt_pressed = False def set_push_to_talk_pressed(self, pressed: bool) -> None: self._ptt_pressed = bool(pressed) async def queue_output_text(self, chunk: str) -> None: normalized_chunk = chunk.strip() if not normalized_chunk: return flush_delay = ( self._tts_sentence_flush_delay_s if SENTENCE_END_RE.search(normalized_chunk) else self._tts_flush_delay_s ) loop = asyncio.get_running_loop() async with self._tts_buffer_lock: if not self._pc or not self._outbound_track: return if not self._tts_buffer.strip(): self._tts_buffer_started_at = loop.time() self._tts_buffer = normalized_chunk else: # Keep line boundaries between streamed chunks so line-based filters stay accurate. self._tts_buffer = f"{self._tts_buffer}\n{normalized_chunk}" self._schedule_tts_flush_after(flush_delay) async def handle_offer(self, payload: dict[str, Any]) -> None: if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription: await self._send_json( { "type": "rtc-error", "message": "WebRTC backend unavailable on host (aiortc is not installed).", } ) return sdp = str(payload.get("sdp", "")).strip() rtc_type = str(payload.get("rtcType", "offer")).strip() or "offer" if not sdp: await self._send_json( {"type": "rtc-error", "message": "Missing SDP offer payload."} ) return await self._close_peer_connection() self._ptt_pressed = False peer_connection = RTCPeerConnection() self._pc = peer_connection self._outbound_track = QueueAudioTrack() peer_connection.addTrack(self._outbound_track) @peer_connection.on("connectionstatechange") def on_connectionstatechange() -> None: asyncio.create_task( self._send_json( { "type": "rtc-state", "state": peer_connection.connectionState, } ) ) @peer_connection.on("track") def on_track(track: MediaStreamTrack) -> None: if track.kind != "audio": return if self._incoming_audio_task: self._incoming_audio_task.cancel() self._incoming_audio_task = asyncio.create_task( self._consume_audio_track(track), name="voice-inbound-track", ) await peer_connection.setRemoteDescription( RTCSessionDescription(sdp=sdp, type=rtc_type) ) await self._drain_pending_ice_candidates(peer_connection) answer = await peer_connection.createAnswer() await peer_connection.setLocalDescription(answer) await self._wait_for_ice_gathering(peer_connection) local_description = peer_connection.localDescription sdp_answer = str(local_description.sdp or "") if sdp_answer: sdp_answer = ( sdp_answer.replace("\r\n", "\n") .replace("\r", "\n") .strip() .replace("\n", "\r\n") + "\r\n" ) await self._send_json( { "type": "rtc-answer", "sdp": sdp_answer, "rtcType": local_description.type, } ) if self._stt.enabled and not self._stt_worker_task: self._stt_worker_task = asyncio.create_task( self._stt_worker(), name="voice-stt-worker" ) if self._stt.enabled and ( self._stt_warmup_task is None or self._stt_warmup_task.done() ): self._stt_warmup_task = asyncio.create_task( self._warmup_stt(), name="voice-stt-warmup" ) elif not self._stt.enabled and not self._stt_unavailable_notice_sent: self._stt_unavailable_notice_sent = True await self._publish_system( f"Voice input backend unavailable. {self._stt.unavailable_reason()}" ) async def handle_ice_candidate(self, payload: dict[str, Any]) -> None: if not AIORTC_AVAILABLE: return raw_candidate = payload.get("candidate") candidate_payload: dict[str, Any] | None if raw_candidate in (None, ""): candidate_payload = None elif isinstance(raw_candidate, dict): candidate_payload = raw_candidate else: return if not self._pc or self._pc.remoteDescription is None: self._pending_ice_candidates.append(candidate_payload) return await self._apply_ice_candidate(self._pc, candidate_payload) async def _drain_pending_ice_candidates( self, peer_connection: RTCPeerConnection, ) -> None: if not self._pending_ice_candidates: return pending = list(self._pending_ice_candidates) self._pending_ice_candidates.clear() for pending_candidate in pending: await self._apply_ice_candidate(peer_connection, pending_candidate) async def _apply_ice_candidate( self, peer_connection: RTCPeerConnection, raw_candidate: dict[str, Any] | None, ) -> None: if raw_candidate is None: with contextlib.suppress(Exception): await peer_connection.addIceCandidate(None) return candidate_sdp = str(raw_candidate.get("candidate", "")).strip() if not candidate_sdp or not candidate_from_sdp: return if candidate_sdp.startswith("candidate:"): candidate_sdp = candidate_sdp[len("candidate:") :] try: candidate = candidate_from_sdp(candidate_sdp) candidate.sdpMid = raw_candidate.get("sdpMid") line_index = raw_candidate.get("sdpMLineIndex") candidate.sdpMLineIndex = ( int(line_index) if line_index is not None else None ) await peer_connection.addIceCandidate(candidate) except Exception as exc: await self._publish_system(f"Failed to add ICE candidate: {exc}") async def close(self) -> None: self._closed = True self._ptt_pressed = False if self._tts_flush_handle: self._tts_flush_handle.cancel() self._tts_flush_handle = None self._tts_buffer = "" self._tts_buffer_started_at = 0.0 if self._incoming_audio_task: self._incoming_audio_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._incoming_audio_task self._incoming_audio_task = None if self._stt_worker_task: self._stt_worker_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._stt_worker_task self._stt_worker_task = None if self._stt_warmup_task: self._stt_warmup_task.cancel() with contextlib.suppress(asyncio.CancelledError): await self._stt_warmup_task self._stt_warmup_task = None await self._close_peer_connection() def _schedule_tts_flush(self) -> None: if self._closed: return asyncio.create_task(self._flush_tts(), name="voice-tts-flush") def _schedule_tts_flush_after(self, delay_s: float) -> None: if self._tts_flush_handle: self._tts_flush_handle.cancel() loop = asyncio.get_running_loop() self._tts_flush_handle = loop.call_later( max(0.05, delay_s), self._schedule_tts_flush ) async def _flush_tts(self) -> None: async with self._tts_flush_lock: async with self._tts_buffer_lock: self._tts_flush_handle = None raw_text = self._tts_buffer buffer_started_at = self._tts_buffer_started_at self._tts_buffer = "" self._tts_buffer_started_at = 0.0 clean_text = self._clean_tts_text(raw_text) if not clean_text: return loop = asyncio.get_running_loop() now = loop.time() buffer_age_ms = int(max(0.0, now - buffer_started_at) * 1000) should_wait_for_more = ( len(clean_text) < self._tts_min_chars and not SENTENCE_END_RE.search(clean_text) and buffer_age_ms < self._tts_max_wait_ms ) if should_wait_for_more: async with self._tts_buffer_lock: if self._tts_buffer.strip(): self._tts_buffer = f"{clean_text}\n{self._tts_buffer}".strip() else: self._tts_buffer = clean_text if self._tts_buffer_started_at <= 0.0: self._tts_buffer_started_at = ( buffer_started_at if buffer_started_at > 0.0 else now ) self._schedule_tts_flush_after(self._tts_flush_delay_s) return if not self._outbound_track: return for part in self._chunk_tts_text(clean_text): try: audio = await self._tts.synthesize(part) except Exception as exc: await self._publish_system(f"Host TTS failed: {exc}") return if not audio: if not self._tts_unavailable_notice_sent: self._tts_unavailable_notice_sent = True await self._publish_system( f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}" ) return if not self._outbound_track: return self._extend_stt_suppression(audio) await self._outbound_track.enqueue_pcm( pcm=audio.pcm, sample_rate=audio.sample_rate, channels=audio.channels, ) def _extend_stt_suppression(self, audio: PCMChunk) -> None: if not self._stt_suppress_during_tts: return channels = max(1, int(audio.channels)) sample_rate = max(1, int(audio.sample_rate)) sample_count = len(audio.pcm) // (2 * channels) if sample_count <= 0: return duration_s = sample_count / float(sample_rate) cooldown_s = float(self._stt_suppress_ms_after_tts) / 1000.0 now = asyncio.get_running_loop().time() base = max(now, self._stt_suppress_until) self._stt_suppress_until = base + duration_s + cooldown_s async def _consume_audio_track(self, track: MediaStreamTrack) -> None: if not self._stt.enabled: try: while True: await track.recv() except asyncio.CancelledError: raise except Exception: return resample_state = None recording = False recording_started_at = 0.0 recording_truncated = False segment_ms = 0.0 segment_buffer = bytearray() try: while True: frame = await track.recv() pcm16, frame_ms, resample_state = self._frame_to_pcm16k_mono( frame, resample_state ) if not pcm16: continue if not self._audio_seen_notice_sent: self._audio_seen_notice_sent = True await self._publish_system("Receiving microphone audio on host.") if not self._audio_format_notice_sent: self._audio_format_notice_sent = True await self._publish_system( "Inbound audio frame stats: " f"sample_rate={int(getattr(frame, 'sample_rate', 0) or 0)}, " f"samples={int(getattr(frame, 'samples', 0) or 0)}, " f"time_base={getattr(frame, 'time_base', None)}." ) if ( self._stt_suppress_during_tts and asyncio.get_running_loop().time() < self._stt_suppress_until ): recording = False recording_started_at = 0.0 recording_truncated = False segment_ms = 0.0 segment_buffer = bytearray() continue if self._ptt_pressed: if not recording: recording = True recording_started_at = asyncio.get_running_loop().time() recording_truncated = False segment_ms = 0.0 segment_buffer = bytearray() if not recording_truncated: next_segment_ms = segment_ms + frame_ms if next_segment_ms <= self._stt_max_ptt_ms: segment_buffer.extend(pcm16) segment_ms = next_segment_ms else: recording_truncated = True await self._publish_system( "PTT max length reached; extra audio will be ignored until release." ) continue if recording: observed_duration_ms = max( 1.0, (asyncio.get_running_loop().time() - recording_started_at) * 1000.0, ) await self._finalize_ptt_segment( bytes(segment_buffer), segment_ms, observed_duration_ms=observed_duration_ms, ) recording = False recording_started_at = 0.0 recording_truncated = False segment_ms = 0.0 segment_buffer = bytearray() except asyncio.CancelledError: raise except Exception as exc: details = str(exc).strip() if details: await self._publish_system( f"Voice input stream ended ({exc.__class__.__name__}): {details}" ) else: await self._publish_system( f"Voice input stream ended ({exc.__class__.__name__})." ) finally: if recording and segment_ms >= self._stt_min_ptt_ms: observed_duration_ms = max( 1.0, (asyncio.get_running_loop().time() - recording_started_at) * 1000.0, ) await self._finalize_ptt_segment( bytes(segment_buffer), segment_ms, observed_duration_ms=observed_duration_ms, ) async def _finalize_ptt_segment( self, pcm16: bytes, duration_ms: float, observed_duration_ms: float | None = None, ) -> None: if not pcm16 or duration_ms <= 0.0: return normalized_pcm = pcm16 normalized_duration_ms = duration_ms if observed_duration_ms is not None and observed_duration_ms > 0.0: duration_ratio = duration_ms / observed_duration_ms if duration_ratio < 0.70 or duration_ratio > 1.40: estimated_source_rate = int(round(16_000 * duration_ratio)) estimated_source_rate = max(8_000, min(96_000, estimated_source_rate)) candidate_rates = [ 8_000, 12_000, 16_000, 24_000, 32_000, 44_100, 48_000, ] nearest_source_rate = min( candidate_rates, key=lambda candidate: abs(candidate - estimated_source_rate), ) if nearest_source_rate != 16_000: normalized_pcm, _state = audioop.ratecv( pcm16, 2, 1, nearest_source_rate, 16_000, None, ) normalized_duration_ms = (len(normalized_pcm) / 2 / 16_000) * 1000.0 if not self._ptt_timing_correction_notice_sent: self._ptt_timing_correction_notice_sent = True await self._publish_system( "Corrected PTT timing mismatch " f"(estimated source={nearest_source_rate}Hz)." ) await self._enqueue_stt_segment( pcm16=normalized_pcm, duration_ms=normalized_duration_ms ) async def _enqueue_stt_segment(self, pcm16: bytes, duration_ms: float) -> None: if duration_ms < self._stt_min_ptt_ms: return if self._stt_segments.full(): with contextlib.suppress(asyncio.QueueEmpty): self._stt_segments.get_nowait() now = asyncio.get_running_loop().time() if ( now - self._last_stt_backlog_notice_at ) >= self._stt_backlog_notice_interval_s: self._last_stt_backlog_notice_at = now await self._publish_system( "Voice input backlog detected; dropping stale segment." ) with contextlib.suppress(asyncio.QueueFull): self._stt_segments.put_nowait(pcm16) async def _stt_worker(self) -> None: while True: pcm16 = await self._stt_segments.get() if not self._stt_first_segment_notice_sent: self._stt_first_segment_notice_sent = True await self._publish_system( "Push-to-talk audio captured. Running host STT..." ) try: transcript = await self._stt.transcribe_pcm( pcm=pcm16, sample_rate=16_000, channels=1, ) except asyncio.CancelledError: raise except Exception as exc: await self._publish_system(f"Host STT failed: {exc}") continue if not transcript: continue transcript = transcript.strip() if not transcript: continue await self._gateway.bus.publish( WisperEvent(role="wisper", text=f"voice transcript: {transcript}") ) await self._gateway.send_user_message(transcript) async def _close_peer_connection(self) -> None: if self._outbound_track: self._outbound_track.stop() self._outbound_track = None if self._pc: await self._pc.close() self._pc = None self._pending_ice_candidates.clear() async def _wait_for_ice_gathering(self, peer_connection: RTCPeerConnection) -> None: if peer_connection.iceGatheringState == "complete": return completed = asyncio.Event() @peer_connection.on("icegatheringstatechange") def on_icegatheringstatechange() -> None: if peer_connection.iceGatheringState == "complete": completed.set() with contextlib.suppress(asyncio.TimeoutError): await asyncio.wait_for(completed.wait(), timeout=3) async def _warmup_stt(self) -> None: try: await self._stt.warmup() except asyncio.CancelledError: raise except Exception: return async def _publish_system(self, text: str) -> None: await self._gateway.bus.publish(WisperEvent(role="system", text=text)) def _clean_tts_text(self, raw_text: str) -> str: lines = [line.strip() for line in raw_text.splitlines() if line.strip()] useful_lines = [ line for line in lines if not SPEECH_FILTER_RE.match(line) and not THINKING_STATUS_RE.search(line) and not USER_PREFIX_RE.match(line) and not VOICE_TRANSCRIPT_RE.match(line) ] return _sanitize_tts_text(" ".join(useful_lines)) def _chunk_tts_text(self, text: str) -> list[str]: clean_text = " ".join(text.split()) if not clean_text: return [] max_chars = max(60, int(self._tts_max_chunk_chars)) if len(clean_text) <= max_chars: return [clean_text] sentence_parts = [ part.strip() for part in re.split(r"(?<=[.!?])\s+", clean_text) if part.strip() ] if not sentence_parts: sentence_parts = [clean_text] chunks: list[str] = [] for part in sentence_parts: if len(part) <= max_chars: chunks.append(part) else: chunks.extend(self._chunk_tts_words(part, max_chars)) return chunks def _chunk_tts_words(self, text: str, max_chars: int) -> list[str]: words = [word for word in text.split() if word] if not words: return [] chunks: list[str] = [] current_words: list[str] = [] current_len = 0 for word in words: if len(word) > max_chars: if current_words: chunks.append(" ".join(current_words)) current_words = [] current_len = 0 start = 0 while start < len(word): end = min(start + max_chars, len(word)) piece = word[start:end] if len(piece) == max_chars: chunks.append(piece) else: current_words = [piece] current_len = len(piece) start = end continue extra = len(word) if not current_words else (1 + len(word)) if current_words and (current_len + extra) > max_chars: chunks.append(" ".join(current_words)) current_words = [word] current_len = len(word) else: current_words.append(word) current_len += extra if current_words: chunks.append(" ".join(current_words)) return chunks def _frame_to_pcm16k_mono( self, frame: AudioFrame, resample_state: tuple[Any, ...] | None ) -> tuple[bytes, float, tuple[Any, ...] | None]: try: pcm = frame.to_ndarray(format="s16") except TypeError: pcm = frame.to_ndarray() if ( NUMPY_AVAILABLE and np is not None and getattr(pcm, "dtype", None) is not None ): if pcm.dtype != np.int16: if np.issubdtype(pcm.dtype, np.floating): pcm = np.clip(pcm, -1.0, 1.0) pcm = (pcm * 32767.0).astype(np.int16) else: pcm = pcm.astype(np.int16) if pcm.ndim == 1: mono = pcm.tobytes() elif pcm.ndim == 2: expected_channels = 0 if getattr(frame, "layout", None) is not None: with contextlib.suppress(Exception): expected_channels = len(frame.layout.channels) rows = int(pcm.shape[0]) cols = int(pcm.shape[1]) # Normalize to [frames, channels] to avoid accidental channel mis-detection. if expected_channels > 0: if rows == expected_channels: frames_channels = pcm.T elif cols == expected_channels: frames_channels = pcm else: frames_channels = pcm.reshape(-1, 1) else: if rows == 1: frames_channels = pcm.T elif cols == 1: frames_channels = pcm elif rows <= 8 and cols > rows: frames_channels = pcm.T elif cols <= 8 and rows > cols: frames_channels = pcm else: frames_channels = pcm.reshape(-1, 1) channel_count = ( int(frames_channels.shape[1]) if frames_channels.ndim == 2 else 1 ) if channel_count <= 1: mono = frames_channels.reshape(-1).tobytes() elif NUMPY_AVAILABLE and np is not None: mixed = frames_channels.astype(np.int32).mean(axis=1) mono = np.clip(mixed, -32768, 32767).astype(np.int16).tobytes() elif channel_count == 2: interleaved = frames_channels.reshape(-1).tobytes() mono = audioop.tomono(interleaved, 2, 0.5, 0.5) else: mono = frames_channels[:, 0].reshape(-1).tobytes() else: return b"", 0.0, resample_state source_rate = int( getattr(frame, "sample_rate", 0) or getattr(frame, "rate", 0) or 0 ) time_base = getattr(frame, "time_base", None) tb_rate = 0 if time_base is not None: with contextlib.suppress(Exception): numerator = int(getattr(time_base, "numerator", 0)) denominator = int(getattr(time_base, "denominator", 0)) if numerator == 1 and denominator > 0: tb_rate = denominator samples_per_channel = int(getattr(frame, "samples", 0) or 0) if samples_per_channel > 0: candidate_rates = [8_000, 16_000, 24_000, 32_000, 44_100, 48_000] inferred_rate = min( candidate_rates, key=lambda rate: abs((samples_per_channel / float(rate)) - 0.020), ) inferred_frame_ms = (samples_per_channel / float(inferred_rate)) * 1000.0 # If metadata suggests implausibly long frames, trust the inferred rate instead. if ( source_rate <= 0 or (samples_per_channel / float(max(1, source_rate))) * 1000.0 > 40.0 ): source_rate = inferred_rate elif abs(inferred_frame_ms - 20.0) <= 2.5 and source_rate not in { inferred_rate, tb_rate, }: source_rate = inferred_rate if tb_rate > 0 and (source_rate <= 0 or abs(tb_rate - source_rate) > 2_000): source_rate = tb_rate if source_rate <= 0: source_rate = 48_000 if source_rate != 16_000: mono, resample_state = audioop.ratecv( mono, 2, 1, source_rate, 16_000, resample_state, ) if not mono: return b"", 0.0, resample_state sample_count = len(mono) // 2 duration_ms = (sample_count / 16_000) * 1000 return mono, duration_ms, resample_state