api channel and tools

This commit is contained in:
kacper 2026-03-05 15:10:14 -05:00
parent 9222c59f03
commit 3816a9627e
4 changed files with 684 additions and 582 deletions

View file

@ -41,9 +41,7 @@ try:
from faster_whisper import WhisperModel
FASTER_WHISPER_AVAILABLE = True
except (
Exception
): # pragma: no cover - runtime fallback when faster-whisper is unavailable
except Exception: # pragma: no cover - runtime fallback when faster-whisper is unavailable
WhisperModel = None # type: ignore[assignment]
FASTER_WHISPER_AVAILABLE = False
@ -82,10 +80,7 @@ ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]")
BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]")
TTS_ALLOWED_ASCII = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
"0123456789"
" .,!?;:'\"()[]{}@#%&*+-_/<>|"
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789 .,!?;:'\"()[]{}@#%&*+-_/<>|"
)
@ -95,9 +90,7 @@ def _sanitize_tts_text(text: str) -> str:
cleaned = cleaned.replace("\u00a0", " ")
cleaned = cleaned.replace("", " ")
cleaned = CONTROL_CHAR_RE.sub(" ", cleaned)
cleaned = "".join(
ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned
)
cleaned = "".join(ch if (ch in TTS_ALLOWED_ASCII or ch.isspace()) else " " for ch in cleaned)
cleaned = re.sub(r"\s+", " ", cleaned).strip()
return cleaned
@ -131,15 +124,9 @@ if AIORTC_AVAILABLE:
self._timestamp = 0
self._resample_state = None
self._resample_source_rate: int | None = None
self._lead_in_ms = max(
0, int(os.getenv("HOST_RTC_OUTBOUND_LEAD_IN_MS", "120"))
)
self._lead_in_frames = (
self._lead_in_ms + self._frame_ms - 1
) // self._frame_ms
self._lead_in_idle_s = max(
0.1, float(os.getenv("HOST_RTC_OUTBOUND_IDLE_S", "0.6"))
)
self._lead_in_ms = max(0, int(os.getenv("HOST_RTC_OUTBOUND_LEAD_IN_MS", "120")))
self._lead_in_frames = (self._lead_in_ms + self._frame_ms - 1) // self._frame_ms
self._lead_in_idle_s = max(0.1, float(os.getenv("HOST_RTC_OUTBOUND_IDLE_S", "0.6")))
self._last_enqueue_at = 0.0
self._closed = False
self._frame_duration_s = frame_ms / 1000.0
@ -154,9 +141,7 @@ if AIORTC_AVAILABLE:
)
self._on_playing_changed: Callable[[bool], None] | None = None
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
) -> None:
async def enqueue_pcm(self, pcm: bytes, sample_rate: int, channels: int = 1) -> None:
if self._closed or not pcm:
return
@ -244,9 +229,7 @@ if AIORTC_AVAILABLE:
self._last_recv_at = loop.time()
frame = AudioFrame(
format="s16", layout="mono", samples=self._samples_per_frame
)
frame = AudioFrame(format="s16", layout="mono", samples=self._samples_per_frame)
frame.planes[0].update(payload)
frame.sample_rate = self._sample_rate
frame.time_base = Fraction(1, self._sample_rate)
@ -263,9 +246,7 @@ else:
class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable
_on_playing_changed: Callable[[bool], None] | None = None
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
) -> None:
async def enqueue_pcm(self, pcm: bytes, sample_rate: int, channels: int = 1) -> None:
return
def stop(self) -> None:
@ -296,23 +277,17 @@ class CommandSpeechToText:
) -> str | None:
if not self.enabled or not pcm:
return None
return await asyncio.to_thread(
self._transcribe_blocking, pcm, sample_rate, channels
)
return await asyncio.to_thread(self._transcribe_blocking, pcm, sample_rate, channels)
def unavailable_reason(self) -> str:
if not self._command_template:
return "HOST_STT_COMMAND is not configured."
return "HOST_STT_COMMAND failed to produce transcript."
def _transcribe_blocking(
self, pcm: bytes, sample_rate: int, channels: int
) -> str | None:
def _transcribe_blocking(self, pcm: bytes, sample_rate: int, channels: int) -> str | None:
tmp_path: str | None = None
try:
tmp_path = _write_temp_wav(
pcm=pcm, sample_rate=sample_rate, channels=channels
)
tmp_path = _write_temp_wav(pcm=pcm, sample_rate=sample_rate, channels=channels)
command = self._command_template
if "{input_wav}" in command:
@ -343,9 +318,7 @@ class FasterWhisperSpeechToText:
def __init__(self) -> None:
self._model_name = os.getenv("HOST_STT_MODEL", "tiny.en").strip() or "tiny.en"
self._device = os.getenv("HOST_STT_DEVICE", "auto").strip() or "auto"
self._compute_type = (
os.getenv("HOST_STT_COMPUTE_TYPE", "int8").strip() or "int8"
)
self._compute_type = os.getenv("HOST_STT_COMPUTE_TYPE", "int8").strip() or "int8"
self._language = os.getenv("HOST_STT_LANGUAGE", "en").strip()
self._beam_size = max(1, int(os.getenv("HOST_STT_BEAM_SIZE", "1")))
self._best_of = max(1, int(os.getenv("HOST_STT_BEST_OF", "1")))
@ -357,12 +330,8 @@ class FasterWhisperSpeechToText:
"off",
}
self._temperature = float(os.getenv("HOST_STT_TEMPERATURE", "0.0"))
self._log_prob_threshold = float(
os.getenv("HOST_STT_LOG_PROB_THRESHOLD", "-1.0")
)
self._no_speech_threshold = float(
os.getenv("HOST_STT_NO_SPEECH_THRESHOLD", "0.6")
)
self._log_prob_threshold = float(os.getenv("HOST_STT_LOG_PROB_THRESHOLD", "-1.0"))
self._no_speech_threshold = float(os.getenv("HOST_STT_NO_SPEECH_THRESHOLD", "0.6"))
self._compression_ratio_threshold = float(
os.getenv("HOST_STT_COMPRESSION_RATIO_THRESHOLD", "2.4")
)
@ -373,9 +342,7 @@ class FasterWhisperSpeechToText:
).strip()
or None
)
self._repetition_penalty = float(
os.getenv("HOST_STT_REPETITION_PENALTY", "1.0")
)
self._repetition_penalty = float(os.getenv("HOST_STT_REPETITION_PENALTY", "1.0"))
raw_hallucination_threshold = os.getenv(
"HOST_STT_HALLUCINATION_SILENCE_THRESHOLD", ""
).strip()
@ -401,9 +368,7 @@ class FasterWhisperSpeechToText:
if not self.enabled or not pcm:
return None
async with self._lock:
return await asyncio.to_thread(
self._transcribe_blocking, pcm, sample_rate, channels
)
return await asyncio.to_thread(self._transcribe_blocking, pcm, sample_rate, channels)
async def warmup(self) -> None:
if not self.enabled:
@ -428,15 +393,11 @@ class FasterWhisperSpeechToText:
self._init_error = str(exc)
self._model = None
def _transcribe_blocking(
self, pcm: bytes, sample_rate: int, channels: int
) -> str | None:
def _transcribe_blocking(self, pcm: bytes, sample_rate: int, channels: int) -> str | None:
self._initialize_blocking()
if self._model is None:
if self._init_error:
raise RuntimeError(
f"faster-whisper initialization failed: {self._init_error}"
)
raise RuntimeError(f"faster-whisper initialization failed: {self._init_error}")
return None
if NUMPY_AVAILABLE and np is not None:
@ -481,9 +442,7 @@ class FasterWhisperSpeechToText:
tmp_path: str | None = None
try:
tmp_path = _write_temp_wav(
pcm=pcm, sample_rate=sample_rate, channels=channels
)
tmp_path = _write_temp_wav(pcm=pcm, sample_rate=sample_rate, channels=channels)
segments, _info = self._model.transcribe(
tmp_path,
language=self._language or None,
@ -580,20 +539,14 @@ class HostSpeechToText:
class SupertonicTextToSpeech:
def __init__(self) -> None:
self._model = (
os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2"
)
self._voice_style_name = (
os.getenv("SUPERTONIC_VOICE_STYLE", "F1").strip() or "F1"
)
self._model = os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2"
self._voice_style_name = os.getenv("SUPERTONIC_VOICE_STYLE", "F1").strip() or "F1"
self._lang = os.getenv("SUPERTONIC_LANG", "en").strip() or "en"
self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "4"))
self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.5"))
self._intra_op_num_threads = _optional_int_env("SUPERTONIC_INTRA_OP_THREADS")
self._inter_op_num_threads = _optional_int_env("SUPERTONIC_INTER_OP_THREADS")
self._auto_download = os.getenv(
"SUPERTONIC_AUTO_DOWNLOAD", "1"
).strip() not in {
self._auto_download = os.getenv("SUPERTONIC_AUTO_DOWNLOAD", "1").strip() not in {
"0",
"false",
"False",
@ -608,9 +561,7 @@ class SupertonicTextToSpeech:
@property
def enabled(self) -> bool:
return (
SUPERTONIC_TTS_AVAILABLE and SupertonicTTS is not None and NUMPY_AVAILABLE
)
return SUPERTONIC_TTS_AVAILABLE and SupertonicTTS is not None and NUMPY_AVAILABLE
@property
def init_error(self) -> str | None:
@ -723,9 +674,7 @@ class SupertonicTextToSpeech:
class HostTextToSpeech:
def __init__(self) -> None:
provider = (
os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic"
).lower()
provider = (os.getenv("HOST_TTS_PROVIDER", "supertonic").strip() or "supertonic").lower()
if provider not in {"supertonic", "command", "espeak", "auto"}:
provider = "auto"
self._provider = provider
@ -770,9 +719,7 @@ class HostTextToSpeech:
if not self._supertonic.enabled:
return "supertonic package is not available."
if self._supertonic.init_error:
return (
f"supertonic initialization failed: {self._supertonic.init_error}"
)
return f"supertonic initialization failed: {self._supertonic.init_error}"
return "supertonic did not return audio."
if self._provider == "command":
return "HOST_TTS_COMMAND is not configured."
@ -797,13 +744,9 @@ class HostTextToSpeech:
if "{output_wav}" in command:
tmp_path: str | None = None
try:
with tempfile.NamedTemporaryFile(
suffix=".wav", delete=False
) as tmp_file:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp_file:
tmp_path = tmp_file.name
command_with_output = command.replace(
"{output_wav}", shlex.quote(tmp_path)
)
command_with_output = command.replace("{output_wav}", shlex.quote(tmp_path))
result = subprocess.run(
command_with_output,
shell=True,
@ -872,9 +815,7 @@ SendJsonCallable = Callable[[dict[str, Any]], Awaitable[None]]
class WebRTCVoiceSession:
def __init__(
self, gateway: "SuperTonicGateway", send_json: SendJsonCallable
) -> None:
def __init__(self, gateway: "SuperTonicGateway", send_json: SendJsonCallable) -> None:
self._gateway = gateway
self._send_json = send_json
@ -886,9 +827,7 @@ class WebRTCVoiceSession:
self._stt = HostSpeechToText()
self._tts = HostTextToSpeech()
self._stt_segment_queue_size = max(
1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2"))
)
self._stt_segment_queue_size = max(1, int(os.getenv("HOST_STT_SEGMENT_QUEUE_SIZE", "2")))
self._stt_segments: asyncio.Queue[bytes] = asyncio.Queue(
maxsize=self._stt_segment_queue_size
)
@ -913,11 +852,7 @@ class WebRTCVoiceSession:
self._stt_min_ptt_ms = max(
120,
int(
os.getenv(
"HOST_STT_MIN_PTT_MS", os.getenv("HOST_STT_MIN_SEGMENT_MS", "220")
)
),
int(os.getenv("HOST_STT_MIN_PTT_MS", os.getenv("HOST_STT_MIN_SEGMENT_MS", "220"))),
)
self._stt_suppress_during_tts = os.getenv(
@ -973,9 +908,7 @@ class WebRTCVoiceSession:
sdp = str(payload.get("sdp", "")).strip()
rtc_type = str(payload.get("rtcType", "offer")).strip() or "offer"
if not sdp:
await self._send_json(
{"type": "rtc-error", "message": "Missing SDP offer payload."}
)
await self._send_json({"type": "rtc-error", "message": "Missing SDP offer payload."})
return
await self._close_peer_connection()
@ -1009,9 +942,7 @@ class WebRTCVoiceSession:
name="voice-inbound-track",
)
await peer_connection.setRemoteDescription(
RTCSessionDescription(sdp=sdp, type=rtc_type)
)
await peer_connection.setRemoteDescription(RTCSessionDescription(sdp=sdp, type=rtc_type))
await self._drain_pending_ice_candidates(peer_connection)
answer = await peer_connection.createAnswer()
await peer_connection.setLocalDescription(answer)
@ -1021,10 +952,7 @@ class WebRTCVoiceSession:
sdp_answer = str(local_description.sdp or "")
if sdp_answer:
sdp_answer = (
sdp_answer.replace("\r\n", "\n")
.replace("\r", "\n")
.strip()
.replace("\n", "\r\n")
sdp_answer.replace("\r\n", "\n").replace("\r", "\n").strip().replace("\n", "\r\n")
+ "\r\n"
)
await self._send_json(
@ -1036,15 +964,9 @@ class WebRTCVoiceSession:
)
if self._stt.enabled and not self._stt_worker_task:
self._stt_worker_task = asyncio.create_task(
self._stt_worker(), name="voice-stt-worker"
)
if self._stt.enabled and (
self._stt_warmup_task is None or self._stt_warmup_task.done()
):
self._stt_warmup_task = asyncio.create_task(
self._warmup_stt(), name="voice-stt-warmup"
)
self._stt_worker_task = asyncio.create_task(self._stt_worker(), name="voice-stt-worker")
if self._stt.enabled and (self._stt_warmup_task is None or self._stt_warmup_task.done()):
self._stt_warmup_task = asyncio.create_task(self._warmup_stt(), name="voice-stt-warmup")
elif not self._stt.enabled and not self._stt_unavailable_notice_sent:
self._stt_unavailable_notice_sent = True
await self._publish_system(
@ -1103,9 +1025,7 @@ class WebRTCVoiceSession:
candidate = candidate_from_sdp(candidate_sdp)
candidate.sdpMid = raw_candidate.get("sdpMid")
line_index = raw_candidate.get("sdpMLineIndex")
candidate.sdpMLineIndex = (
int(line_index) if line_index is not None else None
)
candidate.sdpMLineIndex = int(line_index) if line_index is not None else None
await peer_connection.addIceCandidate(candidate)
except Exception as exc:
await self._publish_system(f"Failed to add ICE candidate: {exc}")
@ -1147,9 +1067,7 @@ class WebRTCVoiceSession:
if self._tts_flush_handle:
self._tts_flush_handle.cancel()
loop = asyncio.get_running_loop()
self._tts_flush_handle = loop.call_later(
max(0.05, delay_s), self._schedule_tts_flush
)
self._tts_flush_handle = loop.call_later(max(0.05, delay_s), self._schedule_tts_flush)
async def _flush_tts(self) -> None:
async with self._tts_flush_lock:
@ -1230,9 +1148,7 @@ class WebRTCVoiceSession:
try:
while True:
frame = await track.recv()
pcm16, frame_ms, resample_state = self._frame_to_pcm16k_mono(
frame, resample_state
)
pcm16, frame_ms, resample_state = self._frame_to_pcm16k_mono(frame, resample_state)
if not pcm16:
continue
@ -1249,10 +1165,9 @@ class WebRTCVoiceSession:
f"time_base={getattr(frame, 'time_base', None)}."
)
if (
self._stt_suppress_during_tts
and asyncio.get_running_loop().time() < self._stt_suppress_until
):
loop = asyncio.get_running_loop()
if self._stt_suppress_during_tts and loop.time() < self._stt_suppress_until:
recording = False
recording_started_at = 0.0
segment_ms = 0.0
@ -1262,7 +1177,7 @@ class WebRTCVoiceSession:
if self._ptt_pressed:
if not recording:
recording = True
recording_started_at = asyncio.get_running_loop().time()
recording_started_at = loop.time()
segment_ms = 0.0
segment_buffer = bytearray()
@ -1273,8 +1188,7 @@ class WebRTCVoiceSession:
if recording:
observed_duration_ms = max(
1.0,
(asyncio.get_running_loop().time() - recording_started_at)
* 1000.0,
(loop.time() - recording_started_at) * 1000.0,
)
await self._finalize_ptt_segment(
bytes(segment_buffer),
@ -1285,6 +1199,7 @@ class WebRTCVoiceSession:
recording_started_at = 0.0
segment_ms = 0.0
segment_buffer = bytearray()
except asyncio.CancelledError:
raise
except Exception as exc:
@ -1294,9 +1209,7 @@ class WebRTCVoiceSession:
f"Voice input stream ended ({exc.__class__.__name__}): {details}"
)
else:
await self._publish_system(
f"Voice input stream ended ({exc.__class__.__name__})."
)
await self._publish_system(f"Voice input stream ended ({exc.__class__.__name__}).")
finally:
if recording and segment_ms >= self._stt_min_ptt_ms:
observed_duration_ms = max(
@ -1355,9 +1268,7 @@ class WebRTCVoiceSession:
f"(estimated source={nearest_source_rate}Hz)."
)
await self._enqueue_stt_segment(
pcm16=normalized_pcm, duration_ms=normalized_duration_ms
)
await self._enqueue_stt_segment(pcm16=normalized_pcm, duration_ms=normalized_duration_ms)
async def _enqueue_stt_segment(self, pcm16: bytes, duration_ms: float) -> None:
if duration_ms < self._stt_min_ptt_ms:
@ -1368,13 +1279,9 @@ class WebRTCVoiceSession:
self._stt_segments.get_nowait()
now = asyncio.get_running_loop().time()
if (
now - self._last_stt_backlog_notice_at
) >= self._stt_backlog_notice_interval_s:
if (now - self._last_stt_backlog_notice_at) >= self._stt_backlog_notice_interval_s:
self._last_stt_backlog_notice_at = now
await self._publish_system(
"Voice input backlog detected; dropping stale segment."
)
await self._publish_system("Voice input backlog detected; dropping stale segment.")
with contextlib.suppress(asyncio.QueueFull):
self._stt_segments.put_nowait(pcm16)
@ -1384,9 +1291,7 @@ class WebRTCVoiceSession:
pcm16 = await self._stt_segments.get()
if not self._stt_first_segment_notice_sent:
self._stt_first_segment_notice_sent = True
await self._publish_system(
"Push-to-talk audio captured. Running host STT..."
)
await self._publish_system("Push-to-talk audio captured. Running host STT...")
try:
transcript = await self._stt.transcribe_pcm(
pcm=pcm16,
@ -1478,11 +1383,7 @@ class WebRTCVoiceSession:
except TypeError:
pcm = frame.to_ndarray()
if (
NUMPY_AVAILABLE
and np is not None
and getattr(pcm, "dtype", None) is not None
):
if NUMPY_AVAILABLE and np is not None and getattr(pcm, "dtype", None) is not None:
if pcm.dtype != np.int16:
if np.issubdtype(pcm.dtype, np.floating):
pcm = np.clip(pcm, -1.0, 1.0)
@ -1521,9 +1422,7 @@ class WebRTCVoiceSession:
else:
frames_channels = pcm.reshape(-1, 1)
channel_count = (
int(frames_channels.shape[1]) if frames_channels.ndim == 2 else 1
)
channel_count = int(frames_channels.shape[1]) if frames_channels.ndim == 2 else 1
if channel_count <= 1:
mono = frames_channels.reshape(-1).tobytes()
elif NUMPY_AVAILABLE and np is not None:
@ -1537,9 +1436,7 @@ class WebRTCVoiceSession:
else:
return b"", 0.0, resample_state
source_rate = int(
getattr(frame, "sample_rate", 0) or getattr(frame, "rate", 0) or 0
)
source_rate = int(getattr(frame, "sample_rate", 0) or getattr(frame, "rate", 0) or 0)
time_base = getattr(frame, "time_base", None)
tb_rate = 0