1657 lines
59 KiB
Python
1657 lines
59 KiB
Python
|
|
import asyncio
|
||
|
|
import audioop
|
||
|
|
import contextlib
|
||
|
|
import io
|
||
|
|
import os
|
||
|
|
import re
|
||
|
|
import shlex
|
||
|
|
import shutil
|
||
|
|
import subprocess
|
||
|
|
import tempfile
|
||
|
|
import wave
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from fractions import Fraction
|
||
|
|
from typing import TYPE_CHECKING, Any, Awaitable, Callable
|
||
|
|
|
||
|
|
from wisper import WisperEvent
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from supertonic_gateway import SuperTonicGateway
|
||
|
|
|
||
|
|
|
||
|
|
try:
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
NUMPY_AVAILABLE = True
|
||
|
|
except Exception: # pragma: no cover - runtime fallback when numpy is unavailable
|
||
|
|
np = None # type: ignore[assignment]
|
||
|
|
NUMPY_AVAILABLE = False
|
||
|
|
|
||
|
|
|
||
|
|
try:
|
||
|
|
from supertonic import TTS as SupertonicTTS
|
||
|
|
|
||
|
|
SUPERTONIC_TTS_AVAILABLE = True
|
||
|
|
except Exception: # pragma: no cover - runtime fallback when supertonic is unavailable
|
||
|
|
SupertonicTTS = None # type: ignore[assignment]
|
||
|
|
SUPERTONIC_TTS_AVAILABLE = False
|
||
|
|
|
||
|
|
|
||
|
|
try:
|
||
|
|
from faster_whisper import WhisperModel
|
||
|
|
|
||
|
|
FASTER_WHISPER_AVAILABLE = True
|
||
|
|
except (
|
||
|
|
Exception
|
||
|
|
): # pragma: no cover - runtime fallback when faster-whisper is unavailable
|
||
|
|
WhisperModel = None # type: ignore[assignment]
|
||
|
|
FASTER_WHISPER_AVAILABLE = False
|
||
|
|
|
||
|
|
|
||
|
|
try:
|
||
|
|
from aiortc import RTCPeerConnection, RTCSessionDescription
|
||
|
|
from aiortc.mediastreams import MediaStreamTrack
|
||
|
|
from aiortc.sdp import candidate_from_sdp
|
||
|
|
from av import AudioFrame
|
||
|
|
|
||
|
|
AIORTC_AVAILABLE = True
|
||
|
|
except Exception: # pragma: no cover - runtime fallback when aiortc is unavailable
|
||
|
|
RTCPeerConnection = None # type: ignore[assignment]
|
||
|
|
RTCSessionDescription = None # type: ignore[assignment]
|
||
|
|
MediaStreamTrack = object # type: ignore[assignment,misc]
|
||
|
|
candidate_from_sdp = None # type: ignore[assignment]
|
||
|
|
AudioFrame = None # type: ignore[assignment]
|
||
|
|
AIORTC_AVAILABLE = False
|
||
|
|
|
||
|
|
|
||
|
|
SPEECH_FILTER_RE = re.compile(
|
||
|
|
r"^(spawned nanobot tui|stopped nanobot tui|nanobot tui exited|websocket)",
|
||
|
|
re.IGNORECASE,
|
||
|
|
)
|
||
|
|
THINKING_STATUS_RE = re.compile(r"\bnanobot is thinking\b", re.IGNORECASE)
|
||
|
|
USER_PREFIX_RE = re.compile(r"^(?:you|user)\s*:\s*", re.IGNORECASE)
|
||
|
|
VOICE_TRANSCRIPT_RE = re.compile(
|
||
|
|
r"^(?:wisper\s*:\s*)?voice\s+transcript\s*:\s*",
|
||
|
|
re.IGNORECASE,
|
||
|
|
)
|
||
|
|
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
|
||
|
|
CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]")
|
||
|
|
BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]")
|
||
|
|
SENTENCE_END_RE = re.compile(r"[.!?]\s*$")
|
||
|
|
TTS_ALLOWED_ASCII = set(
|
||
|
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||
|
|
"abcdefghijklmnopqrstuvwxyz"
|
||
|
|
"0123456789"
|
||
|
|
" .,!?;:'\"()[]{}@#%&*+-_/<>|"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def _sanitize_tts_text(text: str) -> str:
|
||
|
|
cleaned = ANSI_ESCAPE_RE.sub(" ", text)
|
||
|
|
cleaned = BRAILLE_SPINNER_RE.sub(" ", cleaned)
|
||
|
|
cleaned = cleaned.replace("\u00a0", " ")
|
||
|
|
cleaned = cleaned.replace("•", " ")
|
||
|
|
cleaned = CONTROL_CHAR_RE.sub(" ", cleaned)
|
||
|
|
cleaned = "".join(
|
||
|
|
ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned
|
||
|
|
)
|
||
|
|
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||
|
|
return cleaned
|
||
|
|
|
||
|
|
|
||
|
|
def _optional_int_env(name: str) -> int | None:
|
||
|
|
raw_value = os.getenv(name, "").strip()
|
||
|
|
if not raw_value:
|
||
|
|
return None
|
||
|
|
return int(raw_value)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass(slots=True)
|
||
|
|
class PCMChunk:
|
||
|
|
pcm: bytes
|
||
|
|
sample_rate: int
|
||
|
|
channels: int = 1
|
||
|
|
|
||
|
|
|
||
|
|
if AIORTC_AVAILABLE:
|
||
|
|
|
||
|
|
class QueueAudioTrack(MediaStreamTrack):
|
||
|
|
kind = "audio"
|
||
|
|
|
||
|
|
def __init__(self, sample_rate: int = 48_000, frame_ms: int = 20) -> None:
|
||
|
|
super().__init__()
|
||
|
|
self._sample_rate = sample_rate
|
||
|
|
self._frame_ms = max(1, frame_ms)
|
||
|
|
self._samples_per_frame = max(1, (sample_rate * frame_ms) // 1000)
|
||
|
|
self._bytes_per_frame = self._samples_per_frame * 2
|
||
|
|
self._queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||
|
|
self._timestamp = 0
|
||
|
|
self._resample_state = None
|
||
|
|
self._resample_source_rate: int | None = None
|
||
|
|
self._lead_in_ms = max(
|
||
|
|
0, int(os.getenv("HOST_RTC_OUTBOUND_LEAD_IN_MS", "120"))
|
||
|
|
)
|
||
|
|
self._lead_in_frames = (
|
||
|
|
self._lead_in_ms + self._frame_ms - 1
|
||
|
|
) // self._frame_ms
|
||
|
|
self._lead_in_idle_s = max(
|
||
|
|
0.1, float(os.getenv("HOST_RTC_OUTBOUND_IDLE_S", "0.6"))
|
||
|
|
)
|
||
|
|
self._last_enqueue_at = 0.0
|
||
|
|
self._closed = False
|
||
|
|
self._frame_duration_s = frame_ms / 1000.0
|
||
|
|
self._last_recv_at = 0.0
|
||
|
|
|
||
|
|
async def enqueue_pcm(
|
||
|
|
self, pcm: bytes, sample_rate: int, channels: int = 1
|
||
|
|
) -> None:
|
||
|
|
if self._closed or not pcm:
|
||
|
|
return
|
||
|
|
|
||
|
|
now = asyncio.get_running_loop().time()
|
||
|
|
should_add_lead_in = (
|
||
|
|
self._lead_in_frames > 0
|
||
|
|
and self._queue.empty()
|
||
|
|
and (
|
||
|
|
self._last_enqueue_at <= 0.0
|
||
|
|
or (now - self._last_enqueue_at) >= self._lead_in_idle_s
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if should_add_lead_in:
|
||
|
|
silence = b"\x00" * self._bytes_per_frame
|
||
|
|
for _index in range(self._lead_in_frames):
|
||
|
|
await self._queue.put(silence)
|
||
|
|
|
||
|
|
mono = pcm
|
||
|
|
if channels > 1:
|
||
|
|
mono = audioop.tomono(mono, 2, 0.5, 0.5)
|
||
|
|
|
||
|
|
if sample_rate != self._sample_rate:
|
||
|
|
# audioop rate conversion state is only valid when source/destination rates stay the same.
|
||
|
|
if self._resample_source_rate != sample_rate:
|
||
|
|
self._resample_state = None
|
||
|
|
self._resample_source_rate = sample_rate
|
||
|
|
mono, self._resample_state = audioop.ratecv(
|
||
|
|
mono,
|
||
|
|
2,
|
||
|
|
1,
|
||
|
|
sample_rate,
|
||
|
|
self._sample_rate,
|
||
|
|
self._resample_state,
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._resample_state = None
|
||
|
|
self._resample_source_rate = None
|
||
|
|
|
||
|
|
if not mono:
|
||
|
|
return
|
||
|
|
|
||
|
|
for start in range(0, len(mono), self._bytes_per_frame):
|
||
|
|
chunk = mono[start : start + self._bytes_per_frame]
|
||
|
|
if len(chunk) < self._bytes_per_frame:
|
||
|
|
chunk += b"\x00" * (self._bytes_per_frame - len(chunk))
|
||
|
|
await self._queue.put(chunk)
|
||
|
|
self._last_enqueue_at = now
|
||
|
|
|
||
|
|
async def recv(self) -> AudioFrame:
|
||
|
|
if self._closed:
|
||
|
|
raise asyncio.CancelledError
|
||
|
|
|
||
|
|
# Pace frame delivery to real-time to prevent RTP burst sends.
|
||
|
|
# Without pacing, when TTS enqueues audio faster than real-time,
|
||
|
|
# aiortc sends RTP packets in a burst and the browser's jitter
|
||
|
|
# buffer skips ahead, causing the user to only hear the tail end.
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
now = loop.time()
|
||
|
|
if self._last_recv_at > 0.0:
|
||
|
|
elapsed = now - self._last_recv_at
|
||
|
|
remaining = self._frame_duration_s - elapsed
|
||
|
|
if remaining > 0.001:
|
||
|
|
await asyncio.sleep(remaining)
|
||
|
|
|
||
|
|
try:
|
||
|
|
payload = self._queue.get_nowait()
|
||
|
|
except asyncio.QueueEmpty:
|
||
|
|
payload = b"\x00" * self._bytes_per_frame
|
||
|
|
|
||
|
|
self._last_recv_at = loop.time()
|
||
|
|
|
||
|
|
frame = AudioFrame(
|
||
|
|
format="s16", layout="mono", samples=self._samples_per_frame
|
||
|
|
)
|
||
|
|
frame.planes[0].update(payload)
|
||
|
|
frame.sample_rate = self._sample_rate
|
||
|
|
frame.time_base = Fraction(1, self._sample_rate)
|
||
|
|
frame.pts = self._timestamp
|
||
|
|
self._timestamp += self._samples_per_frame
|
||
|
|
return frame
|
||
|
|
|
||
|
|
def stop(self) -> None:
|
||
|
|
self._closed = True
|
||
|
|
super().stop()
|
||
|
|
|
||
|
|
else:
|
||
|
|
|
||
|
|
class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable
|
||
|
|
async def enqueue_pcm(
|
||
|
|
self, pcm: bytes, sample_rate: int, channels: int = 1
|
||
|
|
) -> None:
|
||
|
|
return
|
||
|
|
|
||
|
|
def stop(self) -> None:
|
||
|
|
return
|
||
|
|
|
||
|
|
|
||
|
|
def _write_temp_wav(pcm: bytes, sample_rate: int, channels: int) -> str:
|
||
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
|
||
|
|
tmp_path = tmp_file.name
|
||
|
|
with wave.open(tmp_path, "wb") as wav_file:
|
||
|
|
wav_file.setnchannels(max(1, channels))
|
||
|
|
wav_file.setsampwidth(2)
|
||
|
|
wav_file.setframerate(sample_rate)
|
||
|
|
wav_file.writeframes(pcm)
|
||
|
|
return tmp_path
|
||
|
|
|
||
|
|
|
||
|
|
class CommandSpeechToText:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._command_template = os.getenv("HOST_STT_COMMAND", "").strip()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
return bool(self._command_template)
|
||
|
|
|
||
|
|
async def transcribe_pcm(
|
||
|
|
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
|
||
|
|
) -> str | None:
|
||
|
|
if not self.enabled or not pcm:
|
||
|
|
return None
|
||
|
|
return await asyncio.to_thread(
|
||
|
|
self._transcribe_blocking, pcm, sample_rate, channels
|
||
|
|
)
|
||
|
|
|
||
|
|
def unavailable_reason(self) -> str:
|
||
|
|
if not self._command_template:
|
||
|
|
return "HOST_STT_COMMAND is not configured."
|
||
|
|
return "HOST_STT_COMMAND failed to produce transcript."
|
||
|
|
|
||
|
|
def _transcribe_blocking(
|
||
|
|
self, pcm: bytes, sample_rate: int, channels: int
|
||
|
|
) -> str | None:
|
||
|
|
tmp_path: str | None = None
|
||
|
|
try:
|
||
|
|
tmp_path = _write_temp_wav(
|
||
|
|
pcm=pcm, sample_rate=sample_rate, channels=channels
|
||
|
|
)
|
||
|
|
|
||
|
|
command = self._command_template
|
||
|
|
if "{input_wav}" in command:
|
||
|
|
command = command.replace("{input_wav}", shlex.quote(tmp_path))
|
||
|
|
else:
|
||
|
|
command = f"{command} {shlex.quote(tmp_path)}"
|
||
|
|
|
||
|
|
result = subprocess.run(
|
||
|
|
command,
|
||
|
|
shell=True,
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
check=False,
|
||
|
|
)
|
||
|
|
if result.returncode != 0:
|
||
|
|
stderr = result.stderr.strip() or "unknown error"
|
||
|
|
raise RuntimeError(f"STT command failed: {stderr}")
|
||
|
|
|
||
|
|
transcript = result.stdout.strip()
|
||
|
|
return transcript or None
|
||
|
|
finally:
|
||
|
|
if tmp_path and os.path.exists(tmp_path):
|
||
|
|
with contextlib.suppress(OSError):
|
||
|
|
os.unlink(tmp_path)
|
||
|
|
|
||
|
|
|
||
|
|
class FasterWhisperSpeechToText:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._model_name = os.getenv("HOST_STT_MODEL", "base.en").strip() or "base.en"
|
||
|
|
self._device = os.getenv("HOST_STT_DEVICE", "auto").strip() or "auto"
|
||
|
|
self._compute_type = (
|
||
|
|
os.getenv("HOST_STT_COMPUTE_TYPE", "int8").strip() or "int8"
|
||
|
|
)
|
||
|
|
self._language = os.getenv("HOST_STT_LANGUAGE", "en").strip()
|
||
|
|
self._beam_size = max(1, int(os.getenv("HOST_STT_BEAM_SIZE", "2")))
|
||
|
|
self._best_of = max(1, int(os.getenv("HOST_STT_BEST_OF", "2")))
|
||
|
|
self._vad_filter = os.getenv("HOST_STT_VAD_FILTER", "0").strip() not in {
|
||
|
|
"0",
|
||
|
|
"false",
|
||
|
|
"False",
|
||
|
|
"no",
|
||
|
|
"off",
|
||
|
|
}
|
||
|
|
self._temperature = float(os.getenv("HOST_STT_TEMPERATURE", "0.0"))
|
||
|
|
self._log_prob_threshold = float(
|
||
|
|
os.getenv("HOST_STT_LOG_PROB_THRESHOLD", "-1.0")
|
||
|
|
)
|
||
|
|
self._no_speech_threshold = float(
|
||
|
|
os.getenv("HOST_STT_NO_SPEECH_THRESHOLD", "0.6")
|
||
|
|
)
|
||
|
|
self._compression_ratio_threshold = float(
|
||
|
|
os.getenv("HOST_STT_COMPRESSION_RATIO_THRESHOLD", "2.4")
|
||
|
|
)
|
||
|
|
self._initial_prompt = (
|
||
|
|
os.getenv(
|
||
|
|
"HOST_STT_INITIAL_PROMPT",
|
||
|
|
"Transcribe brief spoken English precisely. Prefer common words over sound effects.",
|
||
|
|
).strip()
|
||
|
|
or None
|
||
|
|
)
|
||
|
|
|
||
|
|
self._model: Any = None
|
||
|
|
self._init_error: str | None = None
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
return FASTER_WHISPER_AVAILABLE and WhisperModel is not None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def init_error(self) -> str | None:
|
||
|
|
return self._init_error
|
||
|
|
|
||
|
|
async def transcribe_pcm(
|
||
|
|
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
|
||
|
|
) -> str | None:
|
||
|
|
if not self.enabled or not pcm:
|
||
|
|
return None
|
||
|
|
async with self._lock:
|
||
|
|
return await asyncio.to_thread(
|
||
|
|
self._transcribe_blocking, pcm, sample_rate, channels
|
||
|
|
)
|
||
|
|
|
||
|
|
async def warmup(self) -> None:
|
||
|
|
if not self.enabled:
|
||
|
|
return
|
||
|
|
async with self._lock:
|
||
|
|
await asyncio.to_thread(self._initialize_blocking)
|
||
|
|
|
||
|
|
def _initialize_blocking(self) -> None:
|
||
|
|
if self._model is not None:
|
||
|
|
return
|
||
|
|
if not self.enabled or WhisperModel is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
self._model = WhisperModel(
|
||
|
|
self._model_name,
|
||
|
|
device=self._device,
|
||
|
|
compute_type=self._compute_type,
|
||
|
|
)
|
||
|
|
self._init_error = None
|
||
|
|
except Exception as exc:
|
||
|
|
self._init_error = str(exc)
|
||
|
|
self._model = None
|
||
|
|
|
||
|
|
def _transcribe_blocking(
|
||
|
|
self, pcm: bytes, sample_rate: int, channels: int
|
||
|
|
) -> str | None:
|
||
|
|
self._initialize_blocking()
|
||
|
|
if self._model is None:
|
||
|
|
if self._init_error:
|
||
|
|
raise RuntimeError(
|
||
|
|
f"faster-whisper initialization failed: {self._init_error}"
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
if NUMPY_AVAILABLE and np is not None:
|
||
|
|
mono = pcm
|
||
|
|
if channels > 1:
|
||
|
|
mono = audioop.tomono(mono, 2, 0.5, 0.5)
|
||
|
|
if sample_rate != 16_000:
|
||
|
|
mono, _ = audioop.ratecv(
|
||
|
|
mono,
|
||
|
|
2,
|
||
|
|
1,
|
||
|
|
sample_rate,
|
||
|
|
16_000,
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
audio = np.frombuffer(mono, dtype=np.int16).astype(np.float32) / 32768.0
|
||
|
|
if audio.size == 0:
|
||
|
|
return None
|
||
|
|
segments, _info = self._model.transcribe(
|
||
|
|
audio,
|
||
|
|
language=self._language or None,
|
||
|
|
beam_size=self._beam_size,
|
||
|
|
best_of=self._best_of,
|
||
|
|
vad_filter=self._vad_filter,
|
||
|
|
condition_on_previous_text=False,
|
||
|
|
without_timestamps=True,
|
||
|
|
initial_prompt=self._initial_prompt,
|
||
|
|
temperature=self._temperature,
|
||
|
|
log_prob_threshold=self._log_prob_threshold,
|
||
|
|
no_speech_threshold=self._no_speech_threshold,
|
||
|
|
compression_ratio_threshold=self._compression_ratio_threshold,
|
||
|
|
)
|
||
|
|
transcript_parts: list[str] = []
|
||
|
|
for segment in segments:
|
||
|
|
text = str(getattr(segment, "text", "")).strip()
|
||
|
|
if text:
|
||
|
|
transcript_parts.append(text)
|
||
|
|
transcript = " ".join(transcript_parts).strip()
|
||
|
|
return transcript or None
|
||
|
|
|
||
|
|
tmp_path: str | None = None
|
||
|
|
try:
|
||
|
|
tmp_path = _write_temp_wav(
|
||
|
|
pcm=pcm, sample_rate=sample_rate, channels=channels
|
||
|
|
)
|
||
|
|
segments, _info = self._model.transcribe(
|
||
|
|
tmp_path,
|
||
|
|
language=self._language or None,
|
||
|
|
beam_size=self._beam_size,
|
||
|
|
best_of=self._best_of,
|
||
|
|
vad_filter=self._vad_filter,
|
||
|
|
condition_on_previous_text=False,
|
||
|
|
without_timestamps=True,
|
||
|
|
initial_prompt=self._initial_prompt,
|
||
|
|
temperature=self._temperature,
|
||
|
|
log_prob_threshold=self._log_prob_threshold,
|
||
|
|
no_speech_threshold=self._no_speech_threshold,
|
||
|
|
compression_ratio_threshold=self._compression_ratio_threshold,
|
||
|
|
)
|
||
|
|
transcript_parts: list[str] = []
|
||
|
|
for segment in segments:
|
||
|
|
text = str(getattr(segment, "text", "")).strip()
|
||
|
|
if text:
|
||
|
|
transcript_parts.append(text)
|
||
|
|
transcript = " ".join(transcript_parts).strip()
|
||
|
|
return transcript or None
|
||
|
|
finally:
|
||
|
|
if tmp_path and os.path.exists(tmp_path):
|
||
|
|
with contextlib.suppress(OSError):
|
||
|
|
os.unlink(tmp_path)
|
||
|
|
|
||
|
|
|
||
|
|
class HostSpeechToText:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
provider = (
|
||
|
|
os.getenv("HOST_STT_PROVIDER", "faster-whisper").strip() or "faster-whisper"
|
||
|
|
).lower()
|
||
|
|
if provider not in {"faster-whisper", "command", "auto"}:
|
||
|
|
provider = "auto"
|
||
|
|
self._provider = provider
|
||
|
|
self._faster_whisper = FasterWhisperSpeechToText()
|
||
|
|
self._command = CommandSpeechToText()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
if self._provider == "faster-whisper":
|
||
|
|
return self._faster_whisper.enabled
|
||
|
|
if self._provider == "command":
|
||
|
|
return self._command.enabled
|
||
|
|
return self._faster_whisper.enabled or self._command.enabled
|
||
|
|
|
||
|
|
async def transcribe_pcm(
|
||
|
|
self, pcm: bytes, sample_rate: int = 16_000, channels: int = 1
|
||
|
|
) -> str | None:
|
||
|
|
if self._provider in {"faster-whisper", "auto"}:
|
||
|
|
transcript = await self._faster_whisper.transcribe_pcm(
|
||
|
|
pcm=pcm,
|
||
|
|
sample_rate=sample_rate,
|
||
|
|
channels=channels,
|
||
|
|
)
|
||
|
|
if transcript:
|
||
|
|
return transcript
|
||
|
|
if self._provider == "faster-whisper":
|
||
|
|
return None
|
||
|
|
|
||
|
|
if self._provider in {"command", "auto"}:
|
||
|
|
return await self._command.transcribe_pcm(
|
||
|
|
pcm=pcm,
|
||
|
|
sample_rate=sample_rate,
|
||
|
|
channels=channels,
|
||
|
|
)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
async def warmup(self) -> None:
|
||
|
|
if self._provider in {"faster-whisper", "auto"}:
|
||
|
|
await self._faster_whisper.warmup()
|
||
|
|
|
||
|
|
def unavailable_reason(self) -> str:
|
||
|
|
if self._provider == "faster-whisper":
|
||
|
|
if not self._faster_whisper.enabled:
|
||
|
|
return "faster-whisper package is not available."
|
||
|
|
if self._faster_whisper.init_error:
|
||
|
|
return f"faster-whisper initialization failed: {self._faster_whisper.init_error}"
|
||
|
|
return "faster-whisper did not return transcript."
|
||
|
|
if self._provider == "command":
|
||
|
|
return self._command.unavailable_reason()
|
||
|
|
|
||
|
|
if self._faster_whisper.init_error:
|
||
|
|
return f"faster-whisper initialization failed: {self._faster_whisper.init_error}"
|
||
|
|
if self._command.enabled:
|
||
|
|
return "HOST_STT_COMMAND failed to produce transcript."
|
||
|
|
if not self._faster_whisper.enabled:
|
||
|
|
return "faster-whisper package is not available."
|
||
|
|
return "No STT provider is configured."
|
||
|
|
|
||
|
|
|
||
|
|
class SupertonicTextToSpeech:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
self._model = (
|
||
|
|
os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2"
|
||
|
|
)
|
||
|
|
self._voice_style_name = (
|
||
|
|
os.getenv("SUPERTONIC_VOICE_STYLE", "M1").strip() or "M1"
|
||
|
|
)
|
||
|
|
self._lang = os.getenv("SUPERTONIC_LANG", "en").strip() or "en"
|
||
|
|
self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "5"))
|
||
|
|
self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.05"))
|
||
|
|
self._intra_op_num_threads = _optional_int_env("SUPERTONIC_INTRA_OP_THREADS")
|
||
|
|
self._inter_op_num_threads = _optional_int_env("SUPERTONIC_INTER_OP_THREADS")
|
||
|
|
self._auto_download = os.getenv(
|
||
|
|
"SUPERTONIC_AUTO_DOWNLOAD", "1"
|
||
|
|
).strip() not in {
|
||
|
|
"0",
|
||
|
|
"false",
|
||
|
|
"False",
|
||
|
|
"no",
|
||
|
|
"off",
|
||
|
|
}
|
||
|
|
|
||
|
|
self._engine: Any = None
|
||
|
|
self._voice_style: Any = None
|
||
|
|
self._init_error: str | None = None
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
|
||
|
|
@property
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
return (
|
||
|
|
SUPERTONIC_TTS_AVAILABLE and SupertonicTTS is not None and NUMPY_AVAILABLE
|
||
|
|
)
|
||
|
|
|
||
|
|
@property
|
||
|
|
def init_error(self) -> str | None:
|
||
|
|
return self._init_error
|
||
|
|
|
||
|
|
async def synthesize(self, text: str) -> PCMChunk | None:
|
||
|
|
if not self.enabled:
|
||
|
|
return None
|
||
|
|
|
||
|
|
clean_text = " ".join(text.split())
|
||
|
|
if not clean_text:
|
||
|
|
return None
|
||
|
|
|
||
|
|
async with self._lock:
|
||
|
|
return await asyncio.to_thread(self._synthesize_blocking, clean_text)
|
||
|
|
|
||
|
|
def _synthesize_blocking(self, text: str) -> PCMChunk | None:
|
||
|
|
self._initialize_blocking()
|
||
|
|
if self._engine is None or self._voice_style is None or np is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
text = _sanitize_tts_text(text)
|
||
|
|
if not text:
|
||
|
|
return None
|
||
|
|
|
||
|
|
try:
|
||
|
|
wav, _duration = self._engine.synthesize(
|
||
|
|
text,
|
||
|
|
voice_style=self._voice_style,
|
||
|
|
lang=self._lang,
|
||
|
|
total_steps=self._total_steps,
|
||
|
|
speed=self._speed,
|
||
|
|
)
|
||
|
|
except ValueError as exc:
|
||
|
|
message = str(exc)
|
||
|
|
if "unsupported character" not in message.lower():
|
||
|
|
raise
|
||
|
|
fallback_text = self._sanitize_text_for_supertonic(text)
|
||
|
|
if not fallback_text or fallback_text == text:
|
||
|
|
raise
|
||
|
|
wav, _duration = self._engine.synthesize(
|
||
|
|
fallback_text,
|
||
|
|
voice_style=self._voice_style,
|
||
|
|
lang=self._lang,
|
||
|
|
total_steps=self._total_steps,
|
||
|
|
speed=self._speed,
|
||
|
|
)
|
||
|
|
|
||
|
|
samples = np.asarray(wav)
|
||
|
|
if samples.size == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
channels = 1
|
||
|
|
if samples.ndim == 0:
|
||
|
|
samples = samples.reshape(1)
|
||
|
|
elif samples.ndim == 1:
|
||
|
|
channels = 1
|
||
|
|
elif samples.ndim == 2:
|
||
|
|
# Normalize to frames x channels so PCM bytes are correctly interleaved.
|
||
|
|
dim0, dim1 = int(samples.shape[0]), int(samples.shape[1])
|
||
|
|
if dim0 <= 2 and dim1 > dim0:
|
||
|
|
channels = dim0
|
||
|
|
samples = samples.T
|
||
|
|
elif dim1 <= 2 and dim0 > dim1:
|
||
|
|
channels = dim1
|
||
|
|
else:
|
||
|
|
channels = 1
|
||
|
|
samples = samples.reshape(-1)
|
||
|
|
else:
|
||
|
|
channels = 1
|
||
|
|
samples = samples.reshape(-1)
|
||
|
|
|
||
|
|
if np.issubdtype(samples.dtype, np.floating):
|
||
|
|
samples = np.clip(samples, -1.0, 1.0)
|
||
|
|
samples = (samples * 32767.0).astype(np.int16)
|
||
|
|
else:
|
||
|
|
if samples.dtype != np.int16:
|
||
|
|
samples = samples.astype(np.int16)
|
||
|
|
|
||
|
|
pcm = samples.tobytes()
|
||
|
|
|
||
|
|
return PCMChunk(
|
||
|
|
pcm=pcm,
|
||
|
|
sample_rate=int(getattr(self._engine, "sample_rate", 24_000)),
|
||
|
|
channels=max(1, channels),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _sanitize_text_for_supertonic(self, text: str) -> str:
|
||
|
|
return _sanitize_tts_text(text)
|
||
|
|
|
||
|
|
def _initialize_blocking(self) -> None:
|
||
|
|
if self._engine is not None and self._voice_style is not None:
|
||
|
|
return
|
||
|
|
if not self.enabled or SupertonicTTS is None:
|
||
|
|
return
|
||
|
|
|
||
|
|
try:
|
||
|
|
engine = SupertonicTTS(
|
||
|
|
model=self._model,
|
||
|
|
auto_download=self._auto_download,
|
||
|
|
intra_op_num_threads=self._intra_op_num_threads,
|
||
|
|
inter_op_num_threads=self._inter_op_num_threads,
|
||
|
|
)
|
||
|
|
voice_style = engine.get_voice_style(self._voice_style_name)
|
||
|
|
except Exception as exc:
|
||
|
|
self._init_error = str(exc)
|
||
|
|
return
|
||
|
|
|
||
|
|
self._engine = engine
|
||
|
|
self._voice_style = voice_style
|
||
|
|
self._init_error = None
|
||
|
|
|
||
|
|
|
||
|
|
class HostTextToSpeech:
|
||
|
|
def __init__(self) -> None:
|
||
|
|
provider = (
|
||
|
|
os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic"
|
||
|
|
).lower()
|
||
|
|
if provider not in {"supertonic", "command", "espeak", "auto"}:
|
||
|
|
provider = "auto"
|
||
|
|
self._provider = provider
|
||
|
|
self._supertonic = SupertonicTextToSpeech()
|
||
|
|
self._command_template = os.getenv("HOST_TTS_COMMAND", "").strip()
|
||
|
|
self._espeak = shutil.which("espeak")
|
||
|
|
|
||
|
|
@property
|
||
|
|
def enabled(self) -> bool:
|
||
|
|
if self._provider == "supertonic":
|
||
|
|
return self._supertonic.enabled
|
||
|
|
if self._provider == "command":
|
||
|
|
return bool(self._command_template)
|
||
|
|
if self._provider == "espeak":
|
||
|
|
return bool(self._espeak)
|
||
|
|
return self._supertonic.enabled or bool(self._command_template or self._espeak)
|
||
|
|
|
||
|
|
async def synthesize(self, text: str) -> PCMChunk | None:
|
||
|
|
clean_text = " ".join(text.split())
|
||
|
|
if not clean_text:
|
||
|
|
return None
|
||
|
|
|
||
|
|
if self._provider in {"supertonic", "auto"}:
|
||
|
|
audio = await self._supertonic.synthesize(clean_text)
|
||
|
|
if audio:
|
||
|
|
return audio
|
||
|
|
if self._provider == "supertonic":
|
||
|
|
return None
|
||
|
|
|
||
|
|
if self._provider in {"command", "auto"} and self._command_template:
|
||
|
|
return await asyncio.to_thread(self._synthesize_with_command, clean_text)
|
||
|
|
if self._provider == "command":
|
||
|
|
return None
|
||
|
|
|
||
|
|
if self._provider in {"espeak", "auto"} and self._espeak:
|
||
|
|
return await asyncio.to_thread(self._synthesize_with_espeak, clean_text)
|
||
|
|
|
||
|
|
return None
|
||
|
|
|
||
|
|
def unavailable_reason(self) -> str:
|
||
|
|
if self._provider == "supertonic":
|
||
|
|
if not self._supertonic.enabled:
|
||
|
|
return "supertonic package is not available."
|
||
|
|
if self._supertonic.init_error:
|
||
|
|
return (
|
||
|
|
f"supertonic initialization failed: {self._supertonic.init_error}"
|
||
|
|
)
|
||
|
|
return "supertonic did not return audio."
|
||
|
|
if self._provider == "command":
|
||
|
|
return "HOST_TTS_COMMAND is not configured."
|
||
|
|
if self._provider == "espeak":
|
||
|
|
return "espeak binary is not available."
|
||
|
|
|
||
|
|
if self._supertonic.init_error:
|
||
|
|
return f"supertonic initialization failed: {self._supertonic.init_error}"
|
||
|
|
if self._command_template:
|
||
|
|
return "HOST_TTS_COMMAND failed to produce audio."
|
||
|
|
if self._espeak:
|
||
|
|
return "espeak failed to produce audio."
|
||
|
|
return "No TTS provider is configured."
|
||
|
|
|
||
|
|
def _synthesize_with_command(self, text: str) -> PCMChunk | None:
|
||
|
|
command = self._command_template
|
||
|
|
if "{text}" in command:
|
||
|
|
command = command.replace("{text}", shlex.quote(text))
|
||
|
|
else:
|
||
|
|
command = f"{command} {shlex.quote(text)}"
|
||
|
|
|
||
|
|
if "{output_wav}" in command:
|
||
|
|
tmp_path: str | None = None
|
||
|
|
try:
|
||
|
|
with tempfile.NamedTemporaryFile(
|
||
|
|
suffix=".wav", delete=False
|
||
|
|
) as tmp_file:
|
||
|
|
tmp_path = tmp_file.name
|
||
|
|
command_with_output = command.replace(
|
||
|
|
"{output_wav}", shlex.quote(tmp_path)
|
||
|
|
)
|
||
|
|
result = subprocess.run(
|
||
|
|
command_with_output,
|
||
|
|
shell=True,
|
||
|
|
capture_output=True,
|
||
|
|
text=True,
|
||
|
|
check=False,
|
||
|
|
)
|
||
|
|
if result.returncode != 0:
|
||
|
|
stderr = result.stderr.strip() or "unknown error"
|
||
|
|
raise RuntimeError(f"TTS command failed: {stderr}")
|
||
|
|
return self._read_wav_file(tmp_path)
|
||
|
|
finally:
|
||
|
|
if tmp_path and os.path.exists(tmp_path):
|
||
|
|
with contextlib.suppress(OSError):
|
||
|
|
os.unlink(tmp_path)
|
||
|
|
|
||
|
|
result = subprocess.run(
|
||
|
|
command,
|
||
|
|
shell=True,
|
||
|
|
capture_output=True,
|
||
|
|
check=False,
|
||
|
|
)
|
||
|
|
if result.returncode != 0:
|
||
|
|
stderr = result.stderr.decode(errors="ignore").strip() or "unknown error"
|
||
|
|
raise RuntimeError(f"TTS command failed: {stderr}")
|
||
|
|
return self._decode_wav_bytes(result.stdout)
|
||
|
|
|
||
|
|
def _synthesize_with_espeak(self, text: str) -> PCMChunk | None:
|
||
|
|
if not self._espeak:
|
||
|
|
return None
|
||
|
|
|
||
|
|
result = subprocess.run(
|
||
|
|
[self._espeak, "--stdout", text],
|
||
|
|
capture_output=True,
|
||
|
|
check=False,
|
||
|
|
)
|
||
|
|
if result.returncode != 0:
|
||
|
|
stderr = result.stderr.decode(errors="ignore").strip() or "unknown error"
|
||
|
|
raise RuntimeError(f"espeak failed: {stderr}")
|
||
|
|
return self._decode_wav_bytes(result.stdout)
|
||
|
|
|
||
|
|
def _read_wav_file(self, path: str) -> PCMChunk | None:
|
||
|
|
try:
|
||
|
|
with open(path, "rb") as wav_file:
|
||
|
|
return self._decode_wav_bytes(wav_file.read())
|
||
|
|
except OSError:
|
||
|
|
return None
|
||
|
|
|
||
|
|
def _decode_wav_bytes(self, payload: bytes) -> PCMChunk | None:
|
||
|
|
if not payload:
|
||
|
|
return None
|
||
|
|
|
||
|
|
with wave.open(io.BytesIO(payload), "rb") as wav_file:
|
||
|
|
channels = wav_file.getnchannels()
|
||
|
|
sample_width = wav_file.getsampwidth()
|
||
|
|
sample_rate = wav_file.getframerate()
|
||
|
|
pcm = wav_file.readframes(wav_file.getnframes())
|
||
|
|
|
||
|
|
if sample_width != 2:
|
||
|
|
pcm = audioop.lin2lin(pcm, sample_width, 2)
|
||
|
|
|
||
|
|
return PCMChunk(pcm=pcm, sample_rate=sample_rate, channels=max(1, channels))
|
||
|
|
|
||
|
|
|
||
|
|
SendJsonCallable = Callable[[dict[str, Any]], Awaitable[None]]
|
||
|
|
|
||
|
|
|
||
|
|
class WebRTCVoiceSession:
|
||
|
|
def __init__(
|
||
|
|
self, gateway: "SuperTonicGateway", send_json: SendJsonCallable
|
||
|
|
) -> None:
|
||
|
|
self._gateway = gateway
|
||
|
|
self._send_json = send_json
|
||
|
|
|
||
|
|
self._pc: RTCPeerConnection | None = None
|
||
|
|
self._outbound_track: QueueAudioTrack | None = None
|
||
|
|
self._incoming_audio_task: asyncio.Task[None] | None = None
|
||
|
|
self._stt_worker_task: asyncio.Task[None] | None = None
|
||
|
|
self._stt_warmup_task: asyncio.Task[None] | None = None
|
||
|
|
|
||
|
|
self._stt = HostSpeechToText()
|
||
|
|
self._tts = HostTextToSpeech()
|
||
|
|
self._stt_segment_queue_size = max(
|
||
|
|
1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2"))
|
||
|
|
)
|
||
|
|
self._stt_segments: asyncio.Queue[bytes] = asyncio.Queue(
|
||
|
|
maxsize=self._stt_segment_queue_size
|
||
|
|
)
|
||
|
|
|
||
|
|
self._tts_buffer = ""
|
||
|
|
self._tts_flush_handle: asyncio.TimerHandle | None = None
|
||
|
|
self._tts_flush_lock = asyncio.Lock()
|
||
|
|
self._tts_buffer_lock = asyncio.Lock()
|
||
|
|
self._tts_flush_delay_s = max(
|
||
|
|
0.08, float(os.getenv("HOST_TTS_FLUSH_DELAY_S", "0.45"))
|
||
|
|
)
|
||
|
|
self._tts_sentence_flush_delay_s = max(
|
||
|
|
0.06,
|
||
|
|
float(os.getenv("HOST_TTS_SENTENCE_FLUSH_DELAY_S", "0.15")),
|
||
|
|
)
|
||
|
|
self._tts_min_chars = max(1, int(os.getenv("HOST_TTS_MIN_CHARS", "10")))
|
||
|
|
self._tts_max_wait_ms = max(300, int(os.getenv("HOST_TTS_MAX_WAIT_MS", "1800")))
|
||
|
|
self._tts_max_chunk_chars = max(
|
||
|
|
60, int(os.getenv("HOST_TTS_MAX_CHUNK_CHARS", "140"))
|
||
|
|
)
|
||
|
|
self._tts_buffer_started_at = 0.0
|
||
|
|
|
||
|
|
self._closed = False
|
||
|
|
self._stt_unavailable_notice_sent = False
|
||
|
|
self._tts_unavailable_notice_sent = False
|
||
|
|
self._audio_seen_notice_sent = False
|
||
|
|
self._audio_format_notice_sent = False
|
||
|
|
self._stt_first_segment_notice_sent = False
|
||
|
|
self._ptt_timing_correction_notice_sent = False
|
||
|
|
|
||
|
|
self._stt_min_ptt_ms = max(
|
||
|
|
120,
|
||
|
|
int(
|
||
|
|
os.getenv(
|
||
|
|
"HOST_STT_MIN_PTT_MS", os.getenv("HOST_STT_MIN_SEGMENT_MS", "220")
|
||
|
|
)
|
||
|
|
),
|
||
|
|
)
|
||
|
|
self._stt_max_ptt_ms = max(
|
||
|
|
self._stt_min_ptt_ms,
|
||
|
|
int(
|
||
|
|
os.getenv(
|
||
|
|
"HOST_STT_MAX_PTT_MS", os.getenv("HOST_STT_MAX_SEGMENT_MS", "12000")
|
||
|
|
)
|
||
|
|
),
|
||
|
|
)
|
||
|
|
self._stt_suppress_during_tts = os.getenv(
|
||
|
|
"HOST_STT_SUPPRESS_DURING_TTS", "1"
|
||
|
|
).strip() not in {
|
||
|
|
"0",
|
||
|
|
"false",
|
||
|
|
"False",
|
||
|
|
"no",
|
||
|
|
"off",
|
||
|
|
}
|
||
|
|
self._stt_suppress_ms_after_tts = max(
|
||
|
|
0,
|
||
|
|
int(os.getenv("HOST_STT_SUPPRESS_MS_AFTER_TTS", "300")),
|
||
|
|
)
|
||
|
|
self._stt_suppress_until = 0.0
|
||
|
|
self._stt_backlog_notice_interval_s = max(
|
||
|
|
2.0,
|
||
|
|
float(os.getenv("HOST_STT_BACKLOG_NOTICE_INTERVAL_S", "6.0")),
|
||
|
|
)
|
||
|
|
self._last_stt_backlog_notice_at = 0.0
|
||
|
|
self._pending_ice_candidates: list[dict[str, Any] | None] = []
|
||
|
|
self._ptt_pressed = False
|
||
|
|
|
||
|
|
def set_push_to_talk_pressed(self, pressed: bool) -> None:
|
||
|
|
self._ptt_pressed = bool(pressed)
|
||
|
|
|
||
|
|
async def queue_output_text(self, chunk: str) -> None:
|
||
|
|
normalized_chunk = chunk.strip()
|
||
|
|
if not normalized_chunk:
|
||
|
|
return
|
||
|
|
flush_delay = (
|
||
|
|
self._tts_sentence_flush_delay_s
|
||
|
|
if SENTENCE_END_RE.search(normalized_chunk)
|
||
|
|
else self._tts_flush_delay_s
|
||
|
|
)
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
async with self._tts_buffer_lock:
|
||
|
|
if not self._pc or not self._outbound_track:
|
||
|
|
return
|
||
|
|
if not self._tts_buffer.strip():
|
||
|
|
self._tts_buffer_started_at = loop.time()
|
||
|
|
self._tts_buffer = normalized_chunk
|
||
|
|
else:
|
||
|
|
# Keep line boundaries between streamed chunks so line-based filters stay accurate.
|
||
|
|
self._tts_buffer = f"{self._tts_buffer}\n{normalized_chunk}"
|
||
|
|
self._schedule_tts_flush_after(flush_delay)
|
||
|
|
|
||
|
|
async def handle_offer(self, payload: dict[str, Any]) -> None:
|
||
|
|
if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription:
|
||
|
|
await self._send_json(
|
||
|
|
{
|
||
|
|
"type": "rtc-error",
|
||
|
|
"message": "WebRTC backend unavailable on host (aiortc is not installed).",
|
||
|
|
}
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
sdp = str(payload.get("sdp", "")).strip()
|
||
|
|
rtc_type = str(payload.get("rtcType", "offer")).strip() or "offer"
|
||
|
|
if not sdp:
|
||
|
|
await self._send_json(
|
||
|
|
{"type": "rtc-error", "message": "Missing SDP offer payload."}
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
await self._close_peer_connection()
|
||
|
|
self._ptt_pressed = False
|
||
|
|
|
||
|
|
peer_connection = RTCPeerConnection()
|
||
|
|
self._pc = peer_connection
|
||
|
|
self._outbound_track = QueueAudioTrack()
|
||
|
|
peer_connection.addTrack(self._outbound_track)
|
||
|
|
|
||
|
|
@peer_connection.on("connectionstatechange")
|
||
|
|
def on_connectionstatechange() -> None:
|
||
|
|
asyncio.create_task(
|
||
|
|
self._send_json(
|
||
|
|
{
|
||
|
|
"type": "rtc-state",
|
||
|
|
"state": peer_connection.connectionState,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
@peer_connection.on("track")
|
||
|
|
def on_track(track: MediaStreamTrack) -> None:
|
||
|
|
if track.kind != "audio":
|
||
|
|
return
|
||
|
|
if self._incoming_audio_task:
|
||
|
|
self._incoming_audio_task.cancel()
|
||
|
|
self._incoming_audio_task = asyncio.create_task(
|
||
|
|
self._consume_audio_track(track),
|
||
|
|
name="voice-inbound-track",
|
||
|
|
)
|
||
|
|
|
||
|
|
await peer_connection.setRemoteDescription(
|
||
|
|
RTCSessionDescription(sdp=sdp, type=rtc_type)
|
||
|
|
)
|
||
|
|
await self._drain_pending_ice_candidates(peer_connection)
|
||
|
|
answer = await peer_connection.createAnswer()
|
||
|
|
await peer_connection.setLocalDescription(answer)
|
||
|
|
await self._wait_for_ice_gathering(peer_connection)
|
||
|
|
|
||
|
|
local_description = peer_connection.localDescription
|
||
|
|
sdp_answer = str(local_description.sdp or "")
|
||
|
|
if sdp_answer:
|
||
|
|
sdp_answer = (
|
||
|
|
sdp_answer.replace("\r\n", "\n")
|
||
|
|
.replace("\r", "\n")
|
||
|
|
.strip()
|
||
|
|
.replace("\n", "\r\n")
|
||
|
|
+ "\r\n"
|
||
|
|
)
|
||
|
|
await self._send_json(
|
||
|
|
{
|
||
|
|
"type": "rtc-answer",
|
||
|
|
"sdp": sdp_answer,
|
||
|
|
"rtcType": local_description.type,
|
||
|
|
}
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._stt.enabled and not self._stt_worker_task:
|
||
|
|
self._stt_worker_task = asyncio.create_task(
|
||
|
|
self._stt_worker(), name="voice-stt-worker"
|
||
|
|
)
|
||
|
|
if self._stt.enabled and (
|
||
|
|
self._stt_warmup_task is None or self._stt_warmup_task.done()
|
||
|
|
):
|
||
|
|
self._stt_warmup_task = asyncio.create_task(
|
||
|
|
self._warmup_stt(), name="voice-stt-warmup"
|
||
|
|
)
|
||
|
|
elif not self._stt.enabled and not self._stt_unavailable_notice_sent:
|
||
|
|
self._stt_unavailable_notice_sent = True
|
||
|
|
await self._publish_system(
|
||
|
|
f"Voice input backend unavailable. {self._stt.unavailable_reason()}"
|
||
|
|
)
|
||
|
|
|
||
|
|
async def handle_ice_candidate(self, payload: dict[str, Any]) -> None:
|
||
|
|
if not AIORTC_AVAILABLE:
|
||
|
|
return
|
||
|
|
|
||
|
|
raw_candidate = payload.get("candidate")
|
||
|
|
candidate_payload: dict[str, Any] | None
|
||
|
|
if raw_candidate in (None, ""):
|
||
|
|
candidate_payload = None
|
||
|
|
elif isinstance(raw_candidate, dict):
|
||
|
|
candidate_payload = raw_candidate
|
||
|
|
else:
|
||
|
|
return
|
||
|
|
|
||
|
|
if not self._pc or self._pc.remoteDescription is None:
|
||
|
|
self._pending_ice_candidates.append(candidate_payload)
|
||
|
|
return
|
||
|
|
|
||
|
|
await self._apply_ice_candidate(self._pc, candidate_payload)
|
||
|
|
|
||
|
|
async def _drain_pending_ice_candidates(
|
||
|
|
self,
|
||
|
|
peer_connection: RTCPeerConnection,
|
||
|
|
) -> None:
|
||
|
|
if not self._pending_ice_candidates:
|
||
|
|
return
|
||
|
|
|
||
|
|
pending = list(self._pending_ice_candidates)
|
||
|
|
self._pending_ice_candidates.clear()
|
||
|
|
for pending_candidate in pending:
|
||
|
|
await self._apply_ice_candidate(peer_connection, pending_candidate)
|
||
|
|
|
||
|
|
async def _apply_ice_candidate(
|
||
|
|
self,
|
||
|
|
peer_connection: RTCPeerConnection,
|
||
|
|
raw_candidate: dict[str, Any] | None,
|
||
|
|
) -> None:
|
||
|
|
if raw_candidate is None:
|
||
|
|
with contextlib.suppress(Exception):
|
||
|
|
await peer_connection.addIceCandidate(None)
|
||
|
|
return
|
||
|
|
|
||
|
|
candidate_sdp = str(raw_candidate.get("candidate", "")).strip()
|
||
|
|
if not candidate_sdp or not candidate_from_sdp:
|
||
|
|
return
|
||
|
|
|
||
|
|
if candidate_sdp.startswith("candidate:"):
|
||
|
|
candidate_sdp = candidate_sdp[len("candidate:") :]
|
||
|
|
|
||
|
|
try:
|
||
|
|
candidate = candidate_from_sdp(candidate_sdp)
|
||
|
|
candidate.sdpMid = raw_candidate.get("sdpMid")
|
||
|
|
line_index = raw_candidate.get("sdpMLineIndex")
|
||
|
|
candidate.sdpMLineIndex = (
|
||
|
|
int(line_index) if line_index is not None else None
|
||
|
|
)
|
||
|
|
await peer_connection.addIceCandidate(candidate)
|
||
|
|
except Exception as exc:
|
||
|
|
await self._publish_system(f"Failed to add ICE candidate: {exc}")
|
||
|
|
|
||
|
|
async def close(self) -> None:
|
||
|
|
self._closed = True
|
||
|
|
self._ptt_pressed = False
|
||
|
|
if self._tts_flush_handle:
|
||
|
|
self._tts_flush_handle.cancel()
|
||
|
|
self._tts_flush_handle = None
|
||
|
|
self._tts_buffer = ""
|
||
|
|
self._tts_buffer_started_at = 0.0
|
||
|
|
|
||
|
|
if self._incoming_audio_task:
|
||
|
|
self._incoming_audio_task.cancel()
|
||
|
|
with contextlib.suppress(asyncio.CancelledError):
|
||
|
|
await self._incoming_audio_task
|
||
|
|
self._incoming_audio_task = None
|
||
|
|
|
||
|
|
if self._stt_worker_task:
|
||
|
|
self._stt_worker_task.cancel()
|
||
|
|
with contextlib.suppress(asyncio.CancelledError):
|
||
|
|
await self._stt_worker_task
|
||
|
|
self._stt_worker_task = None
|
||
|
|
|
||
|
|
if self._stt_warmup_task:
|
||
|
|
self._stt_warmup_task.cancel()
|
||
|
|
with contextlib.suppress(asyncio.CancelledError):
|
||
|
|
await self._stt_warmup_task
|
||
|
|
self._stt_warmup_task = None
|
||
|
|
|
||
|
|
await self._close_peer_connection()
|
||
|
|
|
||
|
|
def _schedule_tts_flush(self) -> None:
|
||
|
|
if self._closed:
|
||
|
|
return
|
||
|
|
asyncio.create_task(self._flush_tts(), name="voice-tts-flush")
|
||
|
|
|
||
|
|
def _schedule_tts_flush_after(self, delay_s: float) -> None:
|
||
|
|
if self._tts_flush_handle:
|
||
|
|
self._tts_flush_handle.cancel()
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
self._tts_flush_handle = loop.call_later(
|
||
|
|
max(0.05, delay_s), self._schedule_tts_flush
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _flush_tts(self) -> None:
|
||
|
|
async with self._tts_flush_lock:
|
||
|
|
async with self._tts_buffer_lock:
|
||
|
|
self._tts_flush_handle = None
|
||
|
|
raw_text = self._tts_buffer
|
||
|
|
buffer_started_at = self._tts_buffer_started_at
|
||
|
|
self._tts_buffer = ""
|
||
|
|
self._tts_buffer_started_at = 0.0
|
||
|
|
clean_text = self._clean_tts_text(raw_text)
|
||
|
|
if not clean_text:
|
||
|
|
return
|
||
|
|
|
||
|
|
loop = asyncio.get_running_loop()
|
||
|
|
now = loop.time()
|
||
|
|
buffer_age_ms = int(max(0.0, now - buffer_started_at) * 1000)
|
||
|
|
should_wait_for_more = (
|
||
|
|
len(clean_text) < self._tts_min_chars
|
||
|
|
and not SENTENCE_END_RE.search(clean_text)
|
||
|
|
and buffer_age_ms < self._tts_max_wait_ms
|
||
|
|
)
|
||
|
|
if should_wait_for_more:
|
||
|
|
async with self._tts_buffer_lock:
|
||
|
|
if self._tts_buffer.strip():
|
||
|
|
self._tts_buffer = f"{clean_text}\n{self._tts_buffer}".strip()
|
||
|
|
else:
|
||
|
|
self._tts_buffer = clean_text
|
||
|
|
if self._tts_buffer_started_at <= 0.0:
|
||
|
|
self._tts_buffer_started_at = (
|
||
|
|
buffer_started_at if buffer_started_at > 0.0 else now
|
||
|
|
)
|
||
|
|
self._schedule_tts_flush_after(self._tts_flush_delay_s)
|
||
|
|
return
|
||
|
|
|
||
|
|
if not self._outbound_track:
|
||
|
|
return
|
||
|
|
|
||
|
|
for part in self._chunk_tts_text(clean_text):
|
||
|
|
try:
|
||
|
|
audio = await self._tts.synthesize(part)
|
||
|
|
except Exception as exc:
|
||
|
|
await self._publish_system(f"Host TTS failed: {exc}")
|
||
|
|
return
|
||
|
|
|
||
|
|
if not audio:
|
||
|
|
if not self._tts_unavailable_notice_sent:
|
||
|
|
self._tts_unavailable_notice_sent = True
|
||
|
|
await self._publish_system(
|
||
|
|
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
|
||
|
|
)
|
||
|
|
return
|
||
|
|
|
||
|
|
if not self._outbound_track:
|
||
|
|
return
|
||
|
|
self._extend_stt_suppression(audio)
|
||
|
|
await self._outbound_track.enqueue_pcm(
|
||
|
|
pcm=audio.pcm,
|
||
|
|
sample_rate=audio.sample_rate,
|
||
|
|
channels=audio.channels,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _extend_stt_suppression(self, audio: PCMChunk) -> None:
|
||
|
|
if not self._stt_suppress_during_tts:
|
||
|
|
return
|
||
|
|
|
||
|
|
channels = max(1, int(audio.channels))
|
||
|
|
sample_rate = max(1, int(audio.sample_rate))
|
||
|
|
sample_count = len(audio.pcm) // (2 * channels)
|
||
|
|
if sample_count <= 0:
|
||
|
|
return
|
||
|
|
|
||
|
|
duration_s = sample_count / float(sample_rate)
|
||
|
|
cooldown_s = float(self._stt_suppress_ms_after_tts) / 1000.0
|
||
|
|
now = asyncio.get_running_loop().time()
|
||
|
|
base = max(now, self._stt_suppress_until)
|
||
|
|
self._stt_suppress_until = base + duration_s + cooldown_s
|
||
|
|
|
||
|
|
async def _consume_audio_track(self, track: MediaStreamTrack) -> None:
|
||
|
|
if not self._stt.enabled:
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
await track.recv()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception:
|
||
|
|
return
|
||
|
|
|
||
|
|
resample_state = None
|
||
|
|
recording = False
|
||
|
|
recording_started_at = 0.0
|
||
|
|
recording_truncated = False
|
||
|
|
segment_ms = 0.0
|
||
|
|
segment_buffer = bytearray()
|
||
|
|
|
||
|
|
try:
|
||
|
|
while True:
|
||
|
|
frame = await track.recv()
|
||
|
|
pcm16, frame_ms, resample_state = self._frame_to_pcm16k_mono(
|
||
|
|
frame, resample_state
|
||
|
|
)
|
||
|
|
if not pcm16:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not self._audio_seen_notice_sent:
|
||
|
|
self._audio_seen_notice_sent = True
|
||
|
|
await self._publish_system("Receiving microphone audio on host.")
|
||
|
|
|
||
|
|
if not self._audio_format_notice_sent:
|
||
|
|
self._audio_format_notice_sent = True
|
||
|
|
await self._publish_system(
|
||
|
|
"Inbound audio frame stats: "
|
||
|
|
f"sample_rate={int(getattr(frame, 'sample_rate', 0) or 0)}, "
|
||
|
|
f"samples={int(getattr(frame, 'samples', 0) or 0)}, "
|
||
|
|
f"time_base={getattr(frame, 'time_base', None)}."
|
||
|
|
)
|
||
|
|
|
||
|
|
if (
|
||
|
|
self._stt_suppress_during_tts
|
||
|
|
and asyncio.get_running_loop().time() < self._stt_suppress_until
|
||
|
|
):
|
||
|
|
recording = False
|
||
|
|
recording_started_at = 0.0
|
||
|
|
recording_truncated = False
|
||
|
|
segment_ms = 0.0
|
||
|
|
segment_buffer = bytearray()
|
||
|
|
continue
|
||
|
|
|
||
|
|
if self._ptt_pressed:
|
||
|
|
if not recording:
|
||
|
|
recording = True
|
||
|
|
recording_started_at = asyncio.get_running_loop().time()
|
||
|
|
recording_truncated = False
|
||
|
|
segment_ms = 0.0
|
||
|
|
segment_buffer = bytearray()
|
||
|
|
|
||
|
|
if not recording_truncated:
|
||
|
|
next_segment_ms = segment_ms + frame_ms
|
||
|
|
if next_segment_ms <= self._stt_max_ptt_ms:
|
||
|
|
segment_buffer.extend(pcm16)
|
||
|
|
segment_ms = next_segment_ms
|
||
|
|
else:
|
||
|
|
recording_truncated = True
|
||
|
|
await self._publish_system(
|
||
|
|
"PTT max length reached; extra audio will be ignored until release."
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
if recording:
|
||
|
|
observed_duration_ms = max(
|
||
|
|
1.0,
|
||
|
|
(asyncio.get_running_loop().time() - recording_started_at)
|
||
|
|
* 1000.0,
|
||
|
|
)
|
||
|
|
await self._finalize_ptt_segment(
|
||
|
|
bytes(segment_buffer),
|
||
|
|
segment_ms,
|
||
|
|
observed_duration_ms=observed_duration_ms,
|
||
|
|
)
|
||
|
|
recording = False
|
||
|
|
recording_started_at = 0.0
|
||
|
|
recording_truncated = False
|
||
|
|
segment_ms = 0.0
|
||
|
|
segment_buffer = bytearray()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
details = str(exc).strip()
|
||
|
|
if details:
|
||
|
|
await self._publish_system(
|
||
|
|
f"Voice input stream ended ({exc.__class__.__name__}): {details}"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
await self._publish_system(
|
||
|
|
f"Voice input stream ended ({exc.__class__.__name__})."
|
||
|
|
)
|
||
|
|
finally:
|
||
|
|
if recording and segment_ms >= self._stt_min_ptt_ms:
|
||
|
|
observed_duration_ms = max(
|
||
|
|
1.0,
|
||
|
|
(asyncio.get_running_loop().time() - recording_started_at) * 1000.0,
|
||
|
|
)
|
||
|
|
await self._finalize_ptt_segment(
|
||
|
|
bytes(segment_buffer),
|
||
|
|
segment_ms,
|
||
|
|
observed_duration_ms=observed_duration_ms,
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _finalize_ptt_segment(
|
||
|
|
self,
|
||
|
|
pcm16: bytes,
|
||
|
|
duration_ms: float,
|
||
|
|
observed_duration_ms: float | None = None,
|
||
|
|
) -> None:
|
||
|
|
if not pcm16 or duration_ms <= 0.0:
|
||
|
|
return
|
||
|
|
|
||
|
|
normalized_pcm = pcm16
|
||
|
|
normalized_duration_ms = duration_ms
|
||
|
|
if observed_duration_ms is not None and observed_duration_ms > 0.0:
|
||
|
|
duration_ratio = duration_ms / observed_duration_ms
|
||
|
|
if duration_ratio < 0.70 or duration_ratio > 1.40:
|
||
|
|
estimated_source_rate = int(round(16_000 * duration_ratio))
|
||
|
|
estimated_source_rate = max(8_000, min(96_000, estimated_source_rate))
|
||
|
|
candidate_rates = [
|
||
|
|
8_000,
|
||
|
|
12_000,
|
||
|
|
16_000,
|
||
|
|
24_000,
|
||
|
|
32_000,
|
||
|
|
44_100,
|
||
|
|
48_000,
|
||
|
|
]
|
||
|
|
nearest_source_rate = min(
|
||
|
|
candidate_rates,
|
||
|
|
key=lambda candidate: abs(candidate - estimated_source_rate),
|
||
|
|
)
|
||
|
|
if nearest_source_rate != 16_000:
|
||
|
|
normalized_pcm, _state = audioop.ratecv(
|
||
|
|
pcm16,
|
||
|
|
2,
|
||
|
|
1,
|
||
|
|
nearest_source_rate,
|
||
|
|
16_000,
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
normalized_duration_ms = (len(normalized_pcm) / 2 / 16_000) * 1000.0
|
||
|
|
if not self._ptt_timing_correction_notice_sent:
|
||
|
|
self._ptt_timing_correction_notice_sent = True
|
||
|
|
await self._publish_system(
|
||
|
|
"Corrected PTT timing mismatch "
|
||
|
|
f"(estimated source={nearest_source_rate}Hz)."
|
||
|
|
)
|
||
|
|
|
||
|
|
await self._enqueue_stt_segment(
|
||
|
|
pcm16=normalized_pcm, duration_ms=normalized_duration_ms
|
||
|
|
)
|
||
|
|
|
||
|
|
async def _enqueue_stt_segment(self, pcm16: bytes, duration_ms: float) -> None:
|
||
|
|
if duration_ms < self._stt_min_ptt_ms:
|
||
|
|
return
|
||
|
|
|
||
|
|
if self._stt_segments.full():
|
||
|
|
with contextlib.suppress(asyncio.QueueEmpty):
|
||
|
|
self._stt_segments.get_nowait()
|
||
|
|
|
||
|
|
now = asyncio.get_running_loop().time()
|
||
|
|
if (
|
||
|
|
now - self._last_stt_backlog_notice_at
|
||
|
|
) >= self._stt_backlog_notice_interval_s:
|
||
|
|
self._last_stt_backlog_notice_at = now
|
||
|
|
await self._publish_system(
|
||
|
|
"Voice input backlog detected; dropping stale segment."
|
||
|
|
)
|
||
|
|
|
||
|
|
with contextlib.suppress(asyncio.QueueFull):
|
||
|
|
self._stt_segments.put_nowait(pcm16)
|
||
|
|
|
||
|
|
async def _stt_worker(self) -> None:
|
||
|
|
while True:
|
||
|
|
pcm16 = await self._stt_segments.get()
|
||
|
|
if not self._stt_first_segment_notice_sent:
|
||
|
|
self._stt_first_segment_notice_sent = True
|
||
|
|
await self._publish_system(
|
||
|
|
"Push-to-talk audio captured. Running host STT..."
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
transcript = await self._stt.transcribe_pcm(
|
||
|
|
pcm=pcm16,
|
||
|
|
sample_rate=16_000,
|
||
|
|
channels=1,
|
||
|
|
)
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
await self._publish_system(f"Host STT failed: {exc}")
|
||
|
|
continue
|
||
|
|
|
||
|
|
if not transcript:
|
||
|
|
continue
|
||
|
|
|
||
|
|
transcript = transcript.strip()
|
||
|
|
if not transcript:
|
||
|
|
continue
|
||
|
|
|
||
|
|
await self._gateway.bus.publish(
|
||
|
|
WisperEvent(role="wisper", text=f"voice transcript: {transcript}")
|
||
|
|
)
|
||
|
|
await self._gateway.send_user_message(transcript)
|
||
|
|
|
||
|
|
async def _close_peer_connection(self) -> None:
|
||
|
|
if self._outbound_track:
|
||
|
|
self._outbound_track.stop()
|
||
|
|
self._outbound_track = None
|
||
|
|
|
||
|
|
if self._pc:
|
||
|
|
await self._pc.close()
|
||
|
|
self._pc = None
|
||
|
|
|
||
|
|
self._pending_ice_candidates.clear()
|
||
|
|
|
||
|
|
async def _wait_for_ice_gathering(self, peer_connection: RTCPeerConnection) -> None:
|
||
|
|
if peer_connection.iceGatheringState == "complete":
|
||
|
|
return
|
||
|
|
|
||
|
|
completed = asyncio.Event()
|
||
|
|
|
||
|
|
@peer_connection.on("icegatheringstatechange")
|
||
|
|
def on_icegatheringstatechange() -> None:
|
||
|
|
if peer_connection.iceGatheringState == "complete":
|
||
|
|
completed.set()
|
||
|
|
|
||
|
|
with contextlib.suppress(asyncio.TimeoutError):
|
||
|
|
await asyncio.wait_for(completed.wait(), timeout=3)
|
||
|
|
|
||
|
|
async def _warmup_stt(self) -> None:
|
||
|
|
try:
|
||
|
|
await self._stt.warmup()
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
raise
|
||
|
|
except Exception:
|
||
|
|
return
|
||
|
|
|
||
|
|
async def _publish_system(self, text: str) -> None:
|
||
|
|
await self._gateway.bus.publish(WisperEvent(role="system", text=text))
|
||
|
|
|
||
|
|
def _clean_tts_text(self, raw_text: str) -> str:
|
||
|
|
lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
|
||
|
|
useful_lines = [
|
||
|
|
line
|
||
|
|
for line in lines
|
||
|
|
if not SPEECH_FILTER_RE.match(line)
|
||
|
|
and not THINKING_STATUS_RE.search(line)
|
||
|
|
and not USER_PREFIX_RE.match(line)
|
||
|
|
and not VOICE_TRANSCRIPT_RE.match(line)
|
||
|
|
]
|
||
|
|
return _sanitize_tts_text(" ".join(useful_lines))
|
||
|
|
|
||
|
|
def _chunk_tts_text(self, text: str) -> list[str]:
|
||
|
|
clean_text = " ".join(text.split())
|
||
|
|
if not clean_text:
|
||
|
|
return []
|
||
|
|
|
||
|
|
max_chars = max(60, int(self._tts_max_chunk_chars))
|
||
|
|
if len(clean_text) <= max_chars:
|
||
|
|
return [clean_text]
|
||
|
|
|
||
|
|
sentence_parts = [
|
||
|
|
part.strip()
|
||
|
|
for part in re.split(r"(?<=[.!?])\s+", clean_text)
|
||
|
|
if part.strip()
|
||
|
|
]
|
||
|
|
if not sentence_parts:
|
||
|
|
sentence_parts = [clean_text]
|
||
|
|
|
||
|
|
chunks: list[str] = []
|
||
|
|
for part in sentence_parts:
|
||
|
|
if len(part) <= max_chars:
|
||
|
|
chunks.append(part)
|
||
|
|
else:
|
||
|
|
chunks.extend(self._chunk_tts_words(part, max_chars))
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
def _chunk_tts_words(self, text: str, max_chars: int) -> list[str]:
|
||
|
|
words = [word for word in text.split() if word]
|
||
|
|
if not words:
|
||
|
|
return []
|
||
|
|
|
||
|
|
chunks: list[str] = []
|
||
|
|
current_words: list[str] = []
|
||
|
|
current_len = 0
|
||
|
|
|
||
|
|
for word in words:
|
||
|
|
if len(word) > max_chars:
|
||
|
|
if current_words:
|
||
|
|
chunks.append(" ".join(current_words))
|
||
|
|
current_words = []
|
||
|
|
current_len = 0
|
||
|
|
start = 0
|
||
|
|
while start < len(word):
|
||
|
|
end = min(start + max_chars, len(word))
|
||
|
|
piece = word[start:end]
|
||
|
|
if len(piece) == max_chars:
|
||
|
|
chunks.append(piece)
|
||
|
|
else:
|
||
|
|
current_words = [piece]
|
||
|
|
current_len = len(piece)
|
||
|
|
start = end
|
||
|
|
continue
|
||
|
|
|
||
|
|
extra = len(word) if not current_words else (1 + len(word))
|
||
|
|
if current_words and (current_len + extra) > max_chars:
|
||
|
|
chunks.append(" ".join(current_words))
|
||
|
|
current_words = [word]
|
||
|
|
current_len = len(word)
|
||
|
|
else:
|
||
|
|
current_words.append(word)
|
||
|
|
current_len += extra
|
||
|
|
|
||
|
|
if current_words:
|
||
|
|
chunks.append(" ".join(current_words))
|
||
|
|
return chunks
|
||
|
|
|
||
|
|
def _frame_to_pcm16k_mono(
|
||
|
|
self, frame: AudioFrame, resample_state: tuple[Any, ...] | None
|
||
|
|
) -> tuple[bytes, float, tuple[Any, ...] | None]:
|
||
|
|
try:
|
||
|
|
pcm = frame.to_ndarray(format="s16")
|
||
|
|
except TypeError:
|
||
|
|
pcm = frame.to_ndarray()
|
||
|
|
|
||
|
|
if (
|
||
|
|
NUMPY_AVAILABLE
|
||
|
|
and np is not None
|
||
|
|
and getattr(pcm, "dtype", None) is not None
|
||
|
|
):
|
||
|
|
if pcm.dtype != np.int16:
|
||
|
|
if np.issubdtype(pcm.dtype, np.floating):
|
||
|
|
pcm = np.clip(pcm, -1.0, 1.0)
|
||
|
|
pcm = (pcm * 32767.0).astype(np.int16)
|
||
|
|
else:
|
||
|
|
pcm = pcm.astype(np.int16)
|
||
|
|
|
||
|
|
if pcm.ndim == 1:
|
||
|
|
mono = pcm.tobytes()
|
||
|
|
elif pcm.ndim == 2:
|
||
|
|
expected_channels = 0
|
||
|
|
if getattr(frame, "layout", None) is not None:
|
||
|
|
with contextlib.suppress(Exception):
|
||
|
|
expected_channels = len(frame.layout.channels)
|
||
|
|
|
||
|
|
rows = int(pcm.shape[0])
|
||
|
|
cols = int(pcm.shape[1])
|
||
|
|
|
||
|
|
# Normalize to [frames, channels] to avoid accidental channel mis-detection.
|
||
|
|
if expected_channels > 0:
|
||
|
|
if rows == expected_channels:
|
||
|
|
frames_channels = pcm.T
|
||
|
|
elif cols == expected_channels:
|
||
|
|
frames_channels = pcm
|
||
|
|
else:
|
||
|
|
frames_channels = pcm.reshape(-1, 1)
|
||
|
|
else:
|
||
|
|
if rows == 1:
|
||
|
|
frames_channels = pcm.T
|
||
|
|
elif cols == 1:
|
||
|
|
frames_channels = pcm
|
||
|
|
elif rows <= 8 and cols > rows:
|
||
|
|
frames_channels = pcm.T
|
||
|
|
elif cols <= 8 and rows > cols:
|
||
|
|
frames_channels = pcm
|
||
|
|
else:
|
||
|
|
frames_channels = pcm.reshape(-1, 1)
|
||
|
|
|
||
|
|
channel_count = (
|
||
|
|
int(frames_channels.shape[1]) if frames_channels.ndim == 2 else 1
|
||
|
|
)
|
||
|
|
if channel_count <= 1:
|
||
|
|
mono = frames_channels.reshape(-1).tobytes()
|
||
|
|
elif NUMPY_AVAILABLE and np is not None:
|
||
|
|
mixed = frames_channels.astype(np.int32).mean(axis=1)
|
||
|
|
mono = np.clip(mixed, -32768, 32767).astype(np.int16).tobytes()
|
||
|
|
elif channel_count == 2:
|
||
|
|
interleaved = frames_channels.reshape(-1).tobytes()
|
||
|
|
mono = audioop.tomono(interleaved, 2, 0.5, 0.5)
|
||
|
|
else:
|
||
|
|
mono = frames_channels[:, 0].reshape(-1).tobytes()
|
||
|
|
else:
|
||
|
|
return b"", 0.0, resample_state
|
||
|
|
|
||
|
|
source_rate = int(
|
||
|
|
getattr(frame, "sample_rate", 0) or getattr(frame, "rate", 0) or 0
|
||
|
|
)
|
||
|
|
|
||
|
|
time_base = getattr(frame, "time_base", None)
|
||
|
|
tb_rate = 0
|
||
|
|
if time_base is not None:
|
||
|
|
with contextlib.suppress(Exception):
|
||
|
|
numerator = int(getattr(time_base, "numerator", 0))
|
||
|
|
denominator = int(getattr(time_base, "denominator", 0))
|
||
|
|
if numerator == 1 and denominator > 0:
|
||
|
|
tb_rate = denominator
|
||
|
|
|
||
|
|
samples_per_channel = int(getattr(frame, "samples", 0) or 0)
|
||
|
|
if samples_per_channel > 0:
|
||
|
|
candidate_rates = [8_000, 16_000, 24_000, 32_000, 44_100, 48_000]
|
||
|
|
inferred_rate = min(
|
||
|
|
candidate_rates,
|
||
|
|
key=lambda rate: abs((samples_per_channel / float(rate)) - 0.020),
|
||
|
|
)
|
||
|
|
inferred_frame_ms = (samples_per_channel / float(inferred_rate)) * 1000.0
|
||
|
|
# If metadata suggests implausibly long frames, trust the inferred rate instead.
|
||
|
|
if (
|
||
|
|
source_rate <= 0
|
||
|
|
or (samples_per_channel / float(max(1, source_rate))) * 1000.0 > 40.0
|
||
|
|
):
|
||
|
|
source_rate = inferred_rate
|
||
|
|
elif abs(inferred_frame_ms - 20.0) <= 2.5 and source_rate not in {
|
||
|
|
inferred_rate,
|
||
|
|
tb_rate,
|
||
|
|
}:
|
||
|
|
source_rate = inferred_rate
|
||
|
|
|
||
|
|
if tb_rate > 0 and (source_rate <= 0 or abs(tb_rate - source_rate) > 2_000):
|
||
|
|
source_rate = tb_rate
|
||
|
|
if source_rate <= 0:
|
||
|
|
source_rate = 48_000
|
||
|
|
|
||
|
|
if source_rate != 16_000:
|
||
|
|
mono, resample_state = audioop.ratecv(
|
||
|
|
mono,
|
||
|
|
2,
|
||
|
|
1,
|
||
|
|
source_rate,
|
||
|
|
16_000,
|
||
|
|
resample_state,
|
||
|
|
)
|
||
|
|
|
||
|
|
if not mono:
|
||
|
|
return b"", 0.0, resample_state
|
||
|
|
|
||
|
|
sample_count = len(mono) // 2
|
||
|
|
duration_ms = (sample_count / 16_000) * 1000
|
||
|
|
return mono, duration_ms, resample_state
|