This commit is contained in:
kacper 2026-03-04 08:20:42 -05:00
parent 133b557512
commit ed629ff60e
7 changed files with 948 additions and 525 deletions

View file

@ -68,8 +68,12 @@ SPEECH_FILTER_RE = re.compile(
r"^(spawned nanobot tui|stopped nanobot tui|nanobot tui exited|websocket)",
re.IGNORECASE,
)
THINKING_STATUS_RE = re.compile(r"\bnanobot is thinking\b", re.IGNORECASE)
THINKING_STATUS_RE = re.compile(
r"\b(?:agent|nanobot|napbot)\b(?:\s+is)?\s+thinking\b",
re.IGNORECASE,
)
USER_PREFIX_RE = re.compile(r"^(?:you|user)\s*:\s*", re.IGNORECASE)
AGENT_PREFIX_RE = re.compile(r"^(?:nanobot|napbot)\b\s*[:>\-]?\s*", re.IGNORECASE)
VOICE_TRANSCRIPT_RE = re.compile(
r"^(?:wisper\s*:\s*)?voice\s+transcript\s*:\s*",
re.IGNORECASE,
@ -77,7 +81,6 @@ VOICE_TRANSCRIPT_RE = re.compile(
ANSI_ESCAPE_RE = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])")
CONTROL_CHAR_RE = re.compile(r"[\x00-\x1f\x7f]")
BRAILLE_SPINNER_RE = re.compile(r"[\u2800-\u28ff]")
SENTENCE_END_RE = re.compile(r"[.!?]\s*$")
TTS_ALLOWED_ASCII = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ"
"abcdefghijklmnopqrstuvwxyz"
@ -141,6 +144,15 @@ if AIORTC_AVAILABLE:
self._closed = False
self._frame_duration_s = frame_ms / 1000.0
self._last_recv_at = 0.0
self._playing = False
self._idle_frames = 0
# Number of consecutive silent frames before signalling idle.
# At 20ms per frame, 15 frames = 300ms grace period to avoid
# flickering between TTS synthesis chunks.
self._idle_grace_frames = max(
1, int(os.getenv("HOST_RTC_IDLE_GRACE_MS", "300")) // max(1, frame_ms)
)
self._on_playing_changed: Callable[[bool], None] | None = None
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
@ -211,8 +223,24 @@ if AIORTC_AVAILABLE:
try:
payload = self._queue.get_nowait()
has_audio = True
except asyncio.QueueEmpty:
payload = b"\x00" * self._bytes_per_frame
has_audio = False
# Notify when playback state changes.
if has_audio:
self._idle_frames = 0
if not self._playing:
self._playing = True
if self._on_playing_changed:
self._on_playing_changed(True)
elif self._playing:
self._idle_frames += 1
if self._idle_frames >= self._idle_grace_frames:
self._playing = False
if self._on_playing_changed:
self._on_playing_changed(False)
self._last_recv_at = loop.time()
@ -233,6 +261,8 @@ if AIORTC_AVAILABLE:
else:
class QueueAudioTrack: # pragma: no cover - used only when aiortc is unavailable
_on_playing_changed: Callable[[bool], None] | None = None
async def enqueue_pcm(
self, pcm: bytes, sample_rate: int, channels: int = 1
) -> None:
@ -343,6 +373,15 @@ class FasterWhisperSpeechToText:
).strip()
or None
)
self._repetition_penalty = float(
os.getenv("HOST_STT_REPETITION_PENALTY", "1.0")
)
raw_hallucination_threshold = os.getenv(
"HOST_STT_HALLUCINATION_SILENCE_THRESHOLD", ""
).strip()
self._hallucination_silence_threshold: float | None = (
float(raw_hallucination_threshold) if raw_hallucination_threshold else None
)
self._model: Any = None
self._init_error: str | None = None
@ -429,6 +468,8 @@ class FasterWhisperSpeechToText:
log_prob_threshold=self._log_prob_threshold,
no_speech_threshold=self._no_speech_threshold,
compression_ratio_threshold=self._compression_ratio_threshold,
repetition_penalty=self._repetition_penalty,
hallucination_silence_threshold=self._hallucination_silence_threshold,
)
transcript_parts: list[str] = []
for segment in segments:
@ -456,6 +497,8 @@ class FasterWhisperSpeechToText:
log_prob_threshold=self._log_prob_threshold,
no_speech_threshold=self._no_speech_threshold,
compression_ratio_threshold=self._compression_ratio_threshold,
repetition_penalty=self._repetition_penalty,
hallucination_silence_threshold=self._hallucination_silence_threshold,
)
transcript_parts: list[str] = []
for segment in segments:
@ -541,11 +584,11 @@ class SupertonicTextToSpeech:
os.getenv("SUPERTONIC_MODEL", "supertonic-2").strip() or "supertonic-2"
)
self._voice_style_name = (
os.getenv("SUPERTONIC_VOICE_STYLE", "M1").strip() or "M1"
os.getenv("SUPERTONIC_VOICE_STYLE", "F1").strip() or "F1"
)
self._lang = os.getenv("SUPERTONIC_LANG", "en").strip() or "en"
self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "5"))
self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.05"))
self._total_steps = int(os.getenv("SUPERTONIC_TOTAL_STEPS", "8"))
self._speed = float(os.getenv("SUPERTONIC_SPEED", "1.5"))
self._intra_op_num_threads = _optional_int_env("SUPERTONIC_INTRA_OP_THREADS")
self._inter_op_num_threads = _optional_int_env("SUPERTONIC_INTER_OP_THREADS")
self._auto_download = os.getenv(
@ -605,7 +648,7 @@ class SupertonicTextToSpeech:
message = str(exc)
if "unsupported character" not in message.lower():
raise
fallback_text = self._sanitize_text_for_supertonic(text)
fallback_text = _sanitize_tts_text(text)
if not fallback_text or fallback_text == text:
raise
wav, _duration = self._engine.synthesize(
@ -655,9 +698,6 @@ class SupertonicTextToSpeech:
channels=max(1, channels),
)
def _sanitize_text_for_supertonic(self, text: str) -> str:
return _sanitize_tts_text(text)
def _initialize_blocking(self) -> None:
if self._engine is not None and self._voice_style is not None:
return
@ -853,23 +893,15 @@ class WebRTCVoiceSession:
maxsize=self._stt_segment_queue_size
)
self._tts_buffer = ""
self._tts_chunks: list[str] = []
self._tts_flush_handle: asyncio.TimerHandle | None = None
self._tts_flush_lock = asyncio.Lock()
self._tts_buffer_lock = asyncio.Lock()
self._tts_flush_delay_s = max(
0.08, float(os.getenv("HOST_TTS_FLUSH_DELAY_S", "0.45"))
# How long to wait after the last incoming chunk before flushing the
# entire accumulated response to TTS in one go.
self._tts_response_end_delay_s = max(
0.1, float(os.getenv("HOST_TTS_RESPONSE_END_DELAY_S", "1.5"))
)
self._tts_sentence_flush_delay_s = max(
0.06,
float(os.getenv("HOST_TTS_SENTENCE_FLUSH_DELAY_S", "0.15")),
)
self._tts_min_chars = max(1, int(os.getenv("HOST_TTS_MIN_CHARS", "10")))
self._tts_max_wait_ms = max(300, int(os.getenv("HOST_TTS_MAX_WAIT_MS", "1800")))
self._tts_max_chunk_chars = max(
60, int(os.getenv("HOST_TTS_MAX_CHUNK_CHARS", "140"))
)
self._tts_buffer_started_at = 0.0
self._closed = False
self._stt_unavailable_notice_sent = False
@ -887,14 +919,7 @@ class WebRTCVoiceSession:
)
),
)
self._stt_max_ptt_ms = max(
self._stt_min_ptt_ms,
int(
os.getenv(
"HOST_STT_MAX_PTT_MS", os.getenv("HOST_STT_MAX_SEGMENT_MS", "12000")
)
),
)
self._stt_suppress_during_tts = os.getenv(
"HOST_STT_SUPPRESS_DURING_TTS", "1"
).strip() not in {
@ -924,22 +949,16 @@ class WebRTCVoiceSession:
normalized_chunk = chunk.strip()
if not normalized_chunk:
return
flush_delay = (
self._tts_sentence_flush_delay_s
if SENTENCE_END_RE.search(normalized_chunk)
else self._tts_flush_delay_s
)
loop = asyncio.get_running_loop()
async with self._tts_buffer_lock:
if not self._pc or not self._outbound_track:
return
if not self._tts_buffer.strip():
self._tts_buffer_started_at = loop.time()
self._tts_buffer = normalized_chunk
else:
# Keep line boundaries between streamed chunks so line-based filters stay accurate.
self._tts_buffer = f"{self._tts_buffer}\n{normalized_chunk}"
self._schedule_tts_flush_after(flush_delay)
# Keep line boundaries between streamed chunks so line-based filters
# stay accurate while avoiding repeated full-string copies.
self._tts_chunks.append(normalized_chunk)
# Reset the flush timer on every incoming chunk so the entire
# response is accumulated before synthesis begins. The timer
# fires once no new chunks arrive for the configured delay.
self._schedule_tts_flush_after(self._tts_response_end_delay_s)
async def handle_offer(self, payload: dict[str, Any]) -> None:
if not AIORTC_AVAILABLE or not RTCPeerConnection or not RTCSessionDescription:
@ -965,6 +984,7 @@ class WebRTCVoiceSession:
peer_connection = RTCPeerConnection()
self._pc = peer_connection
self._outbound_track = QueueAudioTrack()
self._outbound_track._on_playing_changed = self._on_track_playing_changed
peer_connection.addTrack(self._outbound_track)
@peer_connection.on("connectionstatechange")
@ -1096,8 +1116,7 @@ class WebRTCVoiceSession:
if self._tts_flush_handle:
self._tts_flush_handle.cancel()
self._tts_flush_handle = None
self._tts_buffer = ""
self._tts_buffer_started_at = 0.0
self._tts_chunks.clear()
if self._incoming_audio_task:
self._incoming_audio_task.cancel()
@ -1136,61 +1155,45 @@ class WebRTCVoiceSession:
async with self._tts_flush_lock:
async with self._tts_buffer_lock:
self._tts_flush_handle = None
raw_text = self._tts_buffer
buffer_started_at = self._tts_buffer_started_at
self._tts_buffer = ""
self._tts_buffer_started_at = 0.0
raw_text = "\n".join(self._tts_chunks)
self._tts_chunks.clear()
clean_text = self._clean_tts_text(raw_text)
if not clean_text:
return
loop = asyncio.get_running_loop()
now = loop.time()
buffer_age_ms = int(max(0.0, now - buffer_started_at) * 1000)
should_wait_for_more = (
len(clean_text) < self._tts_min_chars
and not SENTENCE_END_RE.search(clean_text)
and buffer_age_ms < self._tts_max_wait_ms
)
if should_wait_for_more:
async with self._tts_buffer_lock:
if self._tts_buffer.strip():
self._tts_buffer = f"{clean_text}\n{self._tts_buffer}".strip()
else:
self._tts_buffer = clean_text
if self._tts_buffer_started_at <= 0.0:
self._tts_buffer_started_at = (
buffer_started_at if buffer_started_at > 0.0 else now
)
self._schedule_tts_flush_after(self._tts_flush_delay_s)
return
if not self._outbound_track:
return
for part in self._chunk_tts_text(clean_text):
try:
audio = await self._tts.synthesize(part)
except Exception as exc:
await self._publish_system(f"Host TTS failed: {exc}")
return
try:
audio = await self._tts.synthesize(clean_text)
except asyncio.CancelledError:
raise
except Exception as exc:
import traceback # noqa: local import in exception handler
if not audio:
if not self._tts_unavailable_notice_sent:
self._tts_unavailable_notice_sent = True
await self._publish_system(
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
)
return
traceback.print_exc()
# Restore the lost text so a future flush can retry it.
async with self._tts_buffer_lock:
self._tts_chunks.insert(0, clean_text)
await self._publish_system(f"TTS synthesis error: {exc}")
return
if not self._outbound_track:
return
self._extend_stt_suppression(audio)
await self._outbound_track.enqueue_pcm(
pcm=audio.pcm,
sample_rate=audio.sample_rate,
channels=audio.channels,
)
if not audio:
if not self._tts_unavailable_notice_sent:
self._tts_unavailable_notice_sent = True
await self._publish_system(
f"Host TTS backend is unavailable. {self._tts.unavailable_reason()}"
)
return
if not self._outbound_track:
return
self._extend_stt_suppression(audio)
await self._outbound_track.enqueue_pcm(
pcm=audio.pcm,
sample_rate=audio.sample_rate,
channels=audio.channels,
)
def _extend_stt_suppression(self, audio: PCMChunk) -> None:
if not self._stt_suppress_during_tts:
@ -1221,7 +1224,6 @@ class WebRTCVoiceSession:
resample_state = None
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
@ -1253,7 +1255,6 @@ class WebRTCVoiceSession:
):
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
continue
@ -1262,20 +1263,11 @@ class WebRTCVoiceSession:
if not recording:
recording = True
recording_started_at = asyncio.get_running_loop().time()
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
if not recording_truncated:
next_segment_ms = segment_ms + frame_ms
if next_segment_ms <= self._stt_max_ptt_ms:
segment_buffer.extend(pcm16)
segment_ms = next_segment_ms
else:
recording_truncated = True
await self._publish_system(
"PTT max length reached; extra audio will be ignored until release."
)
segment_buffer.extend(pcm16)
segment_ms += frame_ms
continue
if recording:
@ -1291,7 +1283,6 @@ class WebRTCVoiceSession:
)
recording = False
recording_started_at = 0.0
recording_truncated = False
segment_ms = 0.0
segment_buffer = bytearray()
except asyncio.CancelledError:
@ -1456,10 +1447,21 @@ class WebRTCVoiceSession:
async def _publish_system(self, text: str) -> None:
await self._gateway.bus.publish(WisperEvent(role="system", text=text))
async def _publish_agent_state(self, state: str) -> None:
await self._gateway.bus.publish(WisperEvent(role="agent-state", text=state))
def _on_track_playing_changed(self, playing: bool) -> None:
"""Called from QueueAudioTrack.recv() when audio playback starts or stops."""
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
loop.create_task(self._publish_agent_state("speaking" if playing else "idle"))
def _clean_tts_text(self, raw_text: str) -> str:
lines = [line.strip() for line in raw_text.splitlines() if line.strip()]
useful_lines = [
line
AGENT_PREFIX_RE.sub("", line)
for line in lines
if not SPEECH_FILTER_RE.match(line)
and not THINKING_STATUS_RE.search(line)
@ -1468,71 +1470,6 @@ class WebRTCVoiceSession:
]
return _sanitize_tts_text(" ".join(useful_lines))
def _chunk_tts_text(self, text: str) -> list[str]:
clean_text = " ".join(text.split())
if not clean_text:
return []
max_chars = max(60, int(self._tts_max_chunk_chars))
if len(clean_text) <= max_chars:
return [clean_text]
sentence_parts = [
part.strip()
for part in re.split(r"(?<=[.!?])\s+", clean_text)
if part.strip()
]
if not sentence_parts:
sentence_parts = [clean_text]
chunks: list[str] = []
for part in sentence_parts:
if len(part) <= max_chars:
chunks.append(part)
else:
chunks.extend(self._chunk_tts_words(part, max_chars))
return chunks
def _chunk_tts_words(self, text: str, max_chars: int) -> list[str]:
words = [word for word in text.split() if word]
if not words:
return []
chunks: list[str] = []
current_words: list[str] = []
current_len = 0
for word in words:
if len(word) > max_chars:
if current_words:
chunks.append(" ".join(current_words))
current_words = []
current_len = 0
start = 0
while start < len(word):
end = min(start + max_chars, len(word))
piece = word[start:end]
if len(piece) == max_chars:
chunks.append(piece)
else:
current_words = [piece]
current_len = len(piece)
start = end
continue
extra = len(word) if not current_words else (1 + len(word))
if current_words and (current_len + extra) > max_chars:
chunks.append(" ".join(current_words))
current_words = [word]
current_len = len(word)
else:
current_words.append(word)
current_len += extra
if current_words:
chunks.append(" ".join(current_words))
return chunks
def _frame_to_pcm16k_mono(
self, frame: AudioFrame, resample_state: tuple[Any, ...] | None
) -> tuple[bytes, float, tuple[Any, ...] | None]: