226 lines
8.3 KiB
Python
226 lines
8.3 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import uuid
|
||
|
|
from collections.abc import Awaitable, Callable
|
||
|
|
from datetime import datetime, timezone
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
SendNanobotRequest = Callable[[str, dict[str, Any]], Awaitable[Any]]
|
||
|
|
|
||
|
|
|
||
|
|
class ToolJobService:
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
*,
|
||
|
|
send_request: SendNanobotRequest,
|
||
|
|
timeout_seconds: float,
|
||
|
|
retention_seconds: float,
|
||
|
|
) -> None:
|
||
|
|
self._send_request = send_request
|
||
|
|
self._timeout_seconds = timeout_seconds
|
||
|
|
self._retention_seconds = retention_seconds
|
||
|
|
self._jobs: dict[str, dict[str, Any]] = {}
|
||
|
|
self._tasks: dict[str, asyncio.Task[None]] = {}
|
||
|
|
self._subscribers: dict[str, dict[str, asyncio.Queue[dict[str, Any]]]] = {}
|
||
|
|
self._lock = asyncio.Lock()
|
||
|
|
|
||
|
|
async def start_job(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
job_id = uuid.uuid4().hex
|
||
|
|
payload = {
|
||
|
|
"job_id": job_id,
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"status": "queued",
|
||
|
|
"created_at": self._utc_now_iso(),
|
||
|
|
"started_at": None,
|
||
|
|
"finished_at": None,
|
||
|
|
"result": None,
|
||
|
|
"error": None,
|
||
|
|
"error_code": None,
|
||
|
|
}
|
||
|
|
|
||
|
|
async with self._lock:
|
||
|
|
await self._prune_locked()
|
||
|
|
self._jobs[job_id] = payload
|
||
|
|
self._tasks[job_id] = asyncio.create_task(
|
||
|
|
self._run_job(job_id, tool_name, dict(arguments)),
|
||
|
|
name=f"manual-tool-{job_id}",
|
||
|
|
)
|
||
|
|
return self._serialize_job(payload)
|
||
|
|
|
||
|
|
async def get_job(self, job_id: str) -> dict[str, Any] | None:
|
||
|
|
async with self._lock:
|
||
|
|
await self._prune_locked()
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
return self._serialize_job(payload) if payload is not None else None
|
||
|
|
|
||
|
|
async def shutdown(self) -> None:
|
||
|
|
async with self._lock:
|
||
|
|
tasks = [task for task in self._tasks.values() if not task.done()]
|
||
|
|
for task in tasks:
|
||
|
|
task.cancel()
|
||
|
|
if tasks:
|
||
|
|
await asyncio.gather(*tasks, return_exceptions=True)
|
||
|
|
|
||
|
|
async def subscribe_job(self, job_id: str) -> tuple[str, asyncio.Queue[dict[str, Any]]] | None:
|
||
|
|
async with self._lock:
|
||
|
|
if job_id not in self._jobs:
|
||
|
|
return None
|
||
|
|
subscription_id = uuid.uuid4().hex
|
||
|
|
queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue(maxsize=16)
|
||
|
|
self._subscribers.setdefault(job_id, {})[subscription_id] = queue
|
||
|
|
return subscription_id, queue
|
||
|
|
|
||
|
|
async def unsubscribe_job(self, job_id: str, subscription_id: str) -> None:
|
||
|
|
async with self._lock:
|
||
|
|
subscribers = self._subscribers.get(job_id)
|
||
|
|
if not subscribers:
|
||
|
|
return
|
||
|
|
subscribers.pop(subscription_id, None)
|
||
|
|
if not subscribers:
|
||
|
|
self._subscribers.pop(job_id, None)
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _utc_now_iso() -> str:
|
||
|
|
return datetime.now(timezone.utc).isoformat()
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _normalize_tool_result_payload(result: dict[str, Any], tool_name: str) -> dict[str, Any]:
|
||
|
|
parsed = result.get("parsed")
|
||
|
|
content = result.get("content")
|
||
|
|
is_json = result.get("is_json")
|
||
|
|
if (
|
||
|
|
isinstance(result.get("tool_name"), str)
|
||
|
|
or isinstance(content, str)
|
||
|
|
or parsed is not None
|
||
|
|
or isinstance(is_json, bool)
|
||
|
|
):
|
||
|
|
return {
|
||
|
|
"tool_name": str(result.get("tool_name", tool_name)),
|
||
|
|
"content": str(content or ""),
|
||
|
|
"parsed": parsed,
|
||
|
|
"is_json": bool(is_json),
|
||
|
|
}
|
||
|
|
|
||
|
|
if isinstance(content, str):
|
||
|
|
try:
|
||
|
|
decoded = json.loads(content)
|
||
|
|
except json.JSONDecodeError:
|
||
|
|
decoded = None
|
||
|
|
return {
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"content": content,
|
||
|
|
"parsed": decoded,
|
||
|
|
"is_json": decoded is not None,
|
||
|
|
}
|
||
|
|
|
||
|
|
return {
|
||
|
|
"tool_name": tool_name,
|
||
|
|
"content": json.dumps(result, ensure_ascii=False),
|
||
|
|
"parsed": result,
|
||
|
|
"is_json": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _serialize_job(payload: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
result = payload.get("result")
|
||
|
|
if not isinstance(result, dict):
|
||
|
|
result = None
|
||
|
|
return {
|
||
|
|
"job_id": str(payload.get("job_id", "")),
|
||
|
|
"tool_name": str(payload.get("tool_name", "")),
|
||
|
|
"status": str(payload.get("status", "queued") or "queued"),
|
||
|
|
"created_at": str(payload.get("created_at", "")),
|
||
|
|
"started_at": payload.get("started_at"),
|
||
|
|
"finished_at": payload.get("finished_at"),
|
||
|
|
"result": result,
|
||
|
|
"error": payload.get("error"),
|
||
|
|
"error_code": payload.get("error_code"),
|
||
|
|
}
|
||
|
|
|
||
|
|
async def _prune_locked(self) -> None:
|
||
|
|
cutoff = datetime.now(timezone.utc).timestamp() - self._retention_seconds
|
||
|
|
expired_job_ids: list[str] = []
|
||
|
|
|
||
|
|
for job_id, payload in self._jobs.items():
|
||
|
|
finished_at = str(payload.get("finished_at", "") or "")
|
||
|
|
if not finished_at:
|
||
|
|
continue
|
||
|
|
try:
|
||
|
|
finished_ts = datetime.fromisoformat(finished_at).timestamp()
|
||
|
|
except ValueError:
|
||
|
|
finished_ts = 0.0
|
||
|
|
if finished_ts <= cutoff:
|
||
|
|
expired_job_ids.append(job_id)
|
||
|
|
|
||
|
|
for job_id in expired_job_ids:
|
||
|
|
task = self._tasks.get(job_id)
|
||
|
|
if task is not None and not task.done():
|
||
|
|
continue
|
||
|
|
self._jobs.pop(job_id, None)
|
||
|
|
self._tasks.pop(job_id, None)
|
||
|
|
self._subscribers.pop(job_id, None)
|
||
|
|
|
||
|
|
async def _publish_job_update(self, job_id: str) -> None:
|
||
|
|
async with self._lock:
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
subscribers = list(self._subscribers.get(job_id, {}).values())
|
||
|
|
serialized = self._serialize_job(payload) if payload is not None else None
|
||
|
|
if serialized is None:
|
||
|
|
return
|
||
|
|
for queue in subscribers:
|
||
|
|
while queue.full():
|
||
|
|
try:
|
||
|
|
queue.get_nowait()
|
||
|
|
except asyncio.QueueEmpty:
|
||
|
|
break
|
||
|
|
queue.put_nowait(serialized)
|
||
|
|
|
||
|
|
async def _run_job(self, job_id: str, tool_name: str, arguments: dict[str, Any]) -> None:
|
||
|
|
async with self._lock:
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
if payload is None:
|
||
|
|
return
|
||
|
|
payload["status"] = "running"
|
||
|
|
payload["started_at"] = self._utc_now_iso()
|
||
|
|
await self._publish_job_update(job_id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = await self._send_request(
|
||
|
|
"tool.call",
|
||
|
|
{"name": tool_name, "arguments": arguments},
|
||
|
|
)
|
||
|
|
if not isinstance(result, dict):
|
||
|
|
raise RuntimeError("Nanobot API returned an invalid tool response")
|
||
|
|
async with self._lock:
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
if payload is None:
|
||
|
|
return
|
||
|
|
payload["status"] = "completed"
|
||
|
|
payload["result"] = self._normalize_tool_result_payload(result, tool_name)
|
||
|
|
payload["finished_at"] = self._utc_now_iso()
|
||
|
|
await self._publish_job_update(job_id)
|
||
|
|
except asyncio.CancelledError:
|
||
|
|
async with self._lock:
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
if payload is not None:
|
||
|
|
payload["status"] = "failed"
|
||
|
|
payload["error"] = "tool job cancelled"
|
||
|
|
payload["finished_at"] = self._utc_now_iso()
|
||
|
|
await self._publish_job_update(job_id)
|
||
|
|
raise
|
||
|
|
except Exception as exc:
|
||
|
|
async with self._lock:
|
||
|
|
payload = self._jobs.get(job_id)
|
||
|
|
if payload is not None:
|
||
|
|
payload["status"] = "failed"
|
||
|
|
payload["error"] = str(exc)
|
||
|
|
payload["error_code"] = getattr(exc, "code", None)
|
||
|
|
payload["finished_at"] = self._utc_now_iso()
|
||
|
|
await self._publish_job_update(job_id)
|
||
|
|
finally:
|
||
|
|
async with self._lock:
|
||
|
|
self._tasks.pop(job_id, None)
|
||
|
|
await self._prune_locked()
|