nanobot-voice-interface/tool_job_service.py

226 lines
8.3 KiB
Python
Raw Permalink Normal View History

from __future__ import annotations
import asyncio
import json
import uuid
from collections.abc import Awaitable, Callable
from datetime import datetime, timezone
from typing import Any
SendNanobotRequest = Callable[[str, dict[str, Any]], Awaitable[Any]]
class ToolJobService:
def __init__(
self,
*,
send_request: SendNanobotRequest,
timeout_seconds: float,
retention_seconds: float,
) -> None:
self._send_request = send_request
self._timeout_seconds = timeout_seconds
self._retention_seconds = retention_seconds
self._jobs: dict[str, dict[str, Any]] = {}
self._tasks: dict[str, asyncio.Task[None]] = {}
self._subscribers: dict[str, dict[str, asyncio.Queue[dict[str, Any]]]] = {}
self._lock = asyncio.Lock()
async def start_job(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
job_id = uuid.uuid4().hex
payload = {
"job_id": job_id,
"tool_name": tool_name,
"status": "queued",
"created_at": self._utc_now_iso(),
"started_at": None,
"finished_at": None,
"result": None,
"error": None,
"error_code": None,
}
async with self._lock:
await self._prune_locked()
self._jobs[job_id] = payload
self._tasks[job_id] = asyncio.create_task(
self._run_job(job_id, tool_name, dict(arguments)),
name=f"manual-tool-{job_id}",
)
return self._serialize_job(payload)
async def get_job(self, job_id: str) -> dict[str, Any] | None:
async with self._lock:
await self._prune_locked()
payload = self._jobs.get(job_id)
return self._serialize_job(payload) if payload is not None else None
async def shutdown(self) -> None:
async with self._lock:
tasks = [task for task in self._tasks.values() if not task.done()]
for task in tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
async def subscribe_job(self, job_id: str) -> tuple[str, asyncio.Queue[dict[str, Any]]] | None:
async with self._lock:
if job_id not in self._jobs:
return None
subscription_id = uuid.uuid4().hex
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=16)
self._subscribers.setdefault(job_id, {})[subscription_id] = queue
return subscription_id, queue
async def unsubscribe_job(self, job_id: str, subscription_id: str) -> None:
async with self._lock:
subscribers = self._subscribers.get(job_id)
if not subscribers:
return
subscribers.pop(subscription_id, None)
if not subscribers:
self._subscribers.pop(job_id, None)
@staticmethod
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
@staticmethod
def _normalize_tool_result_payload(result: dict[str, Any], tool_name: str) -> dict[str, Any]:
parsed = result.get("parsed")
content = result.get("content")
is_json = result.get("is_json")
if (
isinstance(result.get("tool_name"), str)
or isinstance(content, str)
or parsed is not None
or isinstance(is_json, bool)
):
return {
"tool_name": str(result.get("tool_name", tool_name)),
"content": str(content or ""),
"parsed": parsed,
"is_json": bool(is_json),
}
if isinstance(content, str):
try:
decoded = json.loads(content)
except json.JSONDecodeError:
decoded = None
return {
"tool_name": tool_name,
"content": content,
"parsed": decoded,
"is_json": decoded is not None,
}
return {
"tool_name": tool_name,
"content": json.dumps(result, ensure_ascii=False),
"parsed": result,
"is_json": True,
}
@staticmethod
def _serialize_job(payload: dict[str, Any]) -> dict[str, Any]:
result = payload.get("result")
if not isinstance(result, dict):
result = None
return {
"job_id": str(payload.get("job_id", "")),
"tool_name": str(payload.get("tool_name", "")),
"status": str(payload.get("status", "queued") or "queued"),
"created_at": str(payload.get("created_at", "")),
"started_at": payload.get("started_at"),
"finished_at": payload.get("finished_at"),
"result": result,
"error": payload.get("error"),
"error_code": payload.get("error_code"),
}
async def _prune_locked(self) -> None:
cutoff = datetime.now(timezone.utc).timestamp() - self._retention_seconds
expired_job_ids: list[str] = []
for job_id, payload in self._jobs.items():
finished_at = str(payload.get("finished_at", "") or "")
if not finished_at:
continue
try:
finished_ts = datetime.fromisoformat(finished_at).timestamp()
except ValueError:
finished_ts = 0.0
if finished_ts <= cutoff:
expired_job_ids.append(job_id)
for job_id in expired_job_ids:
task = self._tasks.get(job_id)
if task is not None and not task.done():
continue
self._jobs.pop(job_id, None)
self._tasks.pop(job_id, None)
self._subscribers.pop(job_id, None)
async def _publish_job_update(self, job_id: str) -> None:
async with self._lock:
payload = self._jobs.get(job_id)
subscribers = list(self._subscribers.get(job_id, {}).values())
serialized = self._serialize_job(payload) if payload is not None else None
if serialized is None:
return
for queue in subscribers:
while queue.full():
try:
queue.get_nowait()
except asyncio.QueueEmpty:
break
queue.put_nowait(serialized)
async def _run_job(self, job_id: str, tool_name: str, arguments: dict[str, Any]) -> None:
async with self._lock:
payload = self._jobs.get(job_id)
if payload is None:
return
payload["status"] = "running"
payload["started_at"] = self._utc_now_iso()
await self._publish_job_update(job_id)
try:
result = await self._send_request(
"tool.call",
{"name": tool_name, "arguments": arguments},
)
if not isinstance(result, dict):
raise RuntimeError("Nanobot API returned an invalid tool response")
async with self._lock:
payload = self._jobs.get(job_id)
if payload is None:
return
payload["status"] = "completed"
payload["result"] = self._normalize_tool_result_payload(result, tool_name)
payload["finished_at"] = self._utc_now_iso()
await self._publish_job_update(job_id)
except asyncio.CancelledError:
async with self._lock:
payload = self._jobs.get(job_id)
if payload is not None:
payload["status"] = "failed"
payload["error"] = "tool job cancelled"
payload["finished_at"] = self._utc_now_iso()
await self._publish_job_update(job_id)
raise
except Exception as exc:
async with self._lock:
payload = self._jobs.get(job_id)
if payload is not None:
payload["status"] = "failed"
payload["error"] = str(exc)
payload["error_code"] = getattr(exc, "code", None)
payload["finished_at"] = self._utc_now_iso()
await self._publish_job_update(job_id)
finally:
async with self._lock:
self._tasks.pop(job_id, None)
await self._prune_locked()