from __future__ import annotations import asyncio import json import uuid from collections.abc import Awaitable, Callable from datetime import datetime, timezone from typing import Any SendNanobotRequest = Callable[[str, dict[str, Any]], Awaitable[Any]] class ToolJobService: def __init__( self, *, send_request: SendNanobotRequest, timeout_seconds: float, retention_seconds: float, ) -> None: self._send_request = send_request self._timeout_seconds = timeout_seconds self._retention_seconds = retention_seconds self._jobs: dict[str, dict[str, Any]] = {} self._tasks: dict[str, asyncio.Task[None]] = {} self._subscribers: dict[str, dict[str, asyncio.Queue[dict[str, Any]]]] = {} self._lock = asyncio.Lock() async def start_job(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]: job_id = uuid.uuid4().hex payload = { "job_id": job_id, "tool_name": tool_name, "status": "queued", "created_at": self._utc_now_iso(), "started_at": None, "finished_at": None, "result": None, "error": None, "error_code": None, } async with self._lock: await self._prune_locked() self._jobs[job_id] = payload self._tasks[job_id] = asyncio.create_task( self._run_job(job_id, tool_name, dict(arguments)), name=f"manual-tool-{job_id}", ) return self._serialize_job(payload) async def get_job(self, job_id: str) -> dict[str, Any] | None: async with self._lock: await self._prune_locked() payload = self._jobs.get(job_id) return self._serialize_job(payload) if payload is not None else None async def shutdown(self) -> None: async with self._lock: tasks = [task for task in self._tasks.values() if not task.done()] for task in tasks: task.cancel() if tasks: await asyncio.gather(*tasks, return_exceptions=True) async def subscribe_job(self, job_id: str) -> tuple[str, asyncio.Queue[dict[str, Any]]] | None: async with self._lock: if job_id not in self._jobs: return None subscription_id = uuid.uuid4().hex queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=16) self._subscribers.setdefault(job_id, {})[subscription_id] = queue return subscription_id, queue async def unsubscribe_job(self, job_id: str, subscription_id: str) -> None: async with self._lock: subscribers = self._subscribers.get(job_id) if not subscribers: return subscribers.pop(subscription_id, None) if not subscribers: self._subscribers.pop(job_id, None) @staticmethod def _utc_now_iso() -> str: return datetime.now(timezone.utc).isoformat() @staticmethod def _normalize_tool_result_payload(result: dict[str, Any], tool_name: str) -> dict[str, Any]: parsed = result.get("parsed") content = result.get("content") is_json = result.get("is_json") if ( isinstance(result.get("tool_name"), str) or isinstance(content, str) or parsed is not None or isinstance(is_json, bool) ): return { "tool_name": str(result.get("tool_name", tool_name)), "content": str(content or ""), "parsed": parsed, "is_json": bool(is_json), } if isinstance(content, str): try: decoded = json.loads(content) except json.JSONDecodeError: decoded = None return { "tool_name": tool_name, "content": content, "parsed": decoded, "is_json": decoded is not None, } return { "tool_name": tool_name, "content": json.dumps(result, ensure_ascii=False), "parsed": result, "is_json": True, } @staticmethod def _serialize_job(payload: dict[str, Any]) -> dict[str, Any]: result = payload.get("result") if not isinstance(result, dict): result = None return { "job_id": str(payload.get("job_id", "")), "tool_name": str(payload.get("tool_name", "")), "status": str(payload.get("status", "queued") or "queued"), "created_at": str(payload.get("created_at", "")), "started_at": payload.get("started_at"), "finished_at": payload.get("finished_at"), "result": result, "error": payload.get("error"), "error_code": payload.get("error_code"), } async def _prune_locked(self) -> None: cutoff = datetime.now(timezone.utc).timestamp() - self._retention_seconds expired_job_ids: list[str] = [] for job_id, payload in self._jobs.items(): finished_at = str(payload.get("finished_at", "") or "") if not finished_at: continue try: finished_ts = datetime.fromisoformat(finished_at).timestamp() except ValueError: finished_ts = 0.0 if finished_ts <= cutoff: expired_job_ids.append(job_id) for job_id in expired_job_ids: task = self._tasks.get(job_id) if task is not None and not task.done(): continue self._jobs.pop(job_id, None) self._tasks.pop(job_id, None) self._subscribers.pop(job_id, None) async def _publish_job_update(self, job_id: str) -> None: async with self._lock: payload = self._jobs.get(job_id) subscribers = list(self._subscribers.get(job_id, {}).values()) serialized = self._serialize_job(payload) if payload is not None else None if serialized is None: return for queue in subscribers: while queue.full(): try: queue.get_nowait() except asyncio.QueueEmpty: break queue.put_nowait(serialized) async def _run_job(self, job_id: str, tool_name: str, arguments: dict[str, Any]) -> None: async with self._lock: payload = self._jobs.get(job_id) if payload is None: return payload["status"] = "running" payload["started_at"] = self._utc_now_iso() await self._publish_job_update(job_id) try: result = await self._send_request( "tool.call", {"name": tool_name, "arguments": arguments}, ) if not isinstance(result, dict): raise RuntimeError("Nanobot API returned an invalid tool response") async with self._lock: payload = self._jobs.get(job_id) if payload is None: return payload["status"] = "completed" payload["result"] = self._normalize_tool_result_payload(result, tool_name) payload["finished_at"] = self._utc_now_iso() await self._publish_job_update(job_id) except asyncio.CancelledError: async with self._lock: payload = self._jobs.get(job_id) if payload is not None: payload["status"] = "failed" payload["error"] = "tool job cancelled" payload["finished_at"] = self._utc_now_iso() await self._publish_job_update(job_id) raise except Exception as exc: async with self._lock: payload = self._jobs.get(job_id) if payload is not None: payload["status"] = "failed" payload["error"] = str(exc) payload["error_code"] = getattr(exc, "code", None) payload["finished_at"] = self._utc_now_iso() await self._publish_job_update(job_id) finally: async with self._lock: self._tasks.pop(job_id, None) await self._prune_locked()