nanobot-voice-interface/routes/tools.py

139 lines
5.1 KiB
Python
Raw Normal View History

from __future__ import annotations
import asyncio
from fastapi import APIRouter, Depends, Request
from fastapi.responses import JSONResponse, StreamingResponse
from app_dependencies import get_runtime
from message_pipeline import encode_sse_data
from nanobot_api_client import NanobotApiError, send_nanobot_api_request
from route_helpers import read_json_request
from web_runtime import WebAppRuntime
router = APIRouter()
@router.get("/tools")
async def list_tools() -> JSONResponse:
try:
result = await send_nanobot_api_request("tool.list", {}, timeout_seconds=20.0)
except NanobotApiError as exc:
status_code = 503 if exc.code == -32000 else 502
return JSONResponse({"error": str(exc)}, status_code=status_code)
except RuntimeError as exc:
return JSONResponse({"error": str(exc)}, status_code=503)
if not isinstance(result, dict):
return JSONResponse({"error": "Nanobot API returned an invalid tool list"}, status_code=502)
tools = result.get("tools", [])
if not isinstance(tools, list):
return JSONResponse({"error": "Nanobot API returned an invalid tool list"}, status_code=502)
return JSONResponse({"tools": tools})
@router.post("/tools/call")
async def call_tool(
request: Request, runtime: WebAppRuntime = Depends(get_runtime)
) -> JSONResponse:
try:
payload = await read_json_request(request)
except ValueError as exc:
return JSONResponse({"error": str(exc)}, status_code=400)
tool_name = str(payload.get("tool_name", payload.get("name", ""))).strip()
if not tool_name:
return JSONResponse({"error": "tool_name is required"}, status_code=400)
arguments = payload.get("arguments", payload.get("params", {}))
if arguments is None:
arguments = {}
if not isinstance(arguments, dict):
return JSONResponse({"error": "arguments must be a JSON object"}, status_code=400)
async_requested = payload.get("async") is True
if async_requested:
job_payload = await runtime.tool_job_service.start_job(tool_name, arguments)
return JSONResponse(job_payload, status_code=202)
try:
result = await send_nanobot_api_request(
"tool.call",
{"name": tool_name, "arguments": arguments},
timeout_seconds=60.0,
)
except NanobotApiError as exc:
status_code = 400 if exc.code == -32602 else 503 if exc.code == -32000 else 502
return JSONResponse({"error": str(exc)}, status_code=status_code)
except RuntimeError as exc:
return JSONResponse({"error": str(exc)}, status_code=503)
if not isinstance(result, dict):
return JSONResponse(
{"error": "Nanobot API returned an invalid tool response"}, status_code=502
)
return JSONResponse(result)
@router.get("/tools/jobs/{job_id}")
async def get_tool_job(job_id: str, runtime: WebAppRuntime = Depends(get_runtime)) -> JSONResponse:
safe_job_id = job_id.strip()
if not safe_job_id:
return JSONResponse({"error": "job id is required"}, status_code=400)
payload = await runtime.tool_job_service.get_job(safe_job_id)
if payload is None:
return JSONResponse({"error": "tool job not found"}, status_code=404)
return JSONResponse(payload)
@router.get("/tools/jobs/{job_id}/stream")
async def stream_tool_job(
job_id: str,
request: Request,
runtime: WebAppRuntime = Depends(get_runtime),
):
safe_job_id = job_id.strip()
if not safe_job_id:
return JSONResponse({"error": "job id is required"}, status_code=400)
subscription = await runtime.tool_job_service.subscribe_job(safe_job_id)
if subscription is None:
return JSONResponse({"error": "tool job not found"}, status_code=404)
subscription_id, queue = subscription
current = await runtime.tool_job_service.get_job(safe_job_id)
if current is None:
await runtime.tool_job_service.unsubscribe_job(safe_job_id, subscription_id)
return JSONResponse({"error": "tool job not found"}, status_code=404)
async def stream_events():
try:
yield ": stream-open\n\n"
yield encode_sse_data({"type": "tool.job", "job": current})
if current.get("status") in {"completed", "failed"}:
return
while True:
if await request.is_disconnected():
break
try:
payload = await asyncio.wait_for(queue.get(), timeout=20.0)
except asyncio.TimeoutError:
yield ": keepalive\n\n"
continue
yield encode_sse_data({"type": "tool.job", "job": payload})
if payload.get("status") in {"completed", "failed"}:
break
finally:
await runtime.tool_job_service.unsubscribe_job(safe_job_id, subscription_id)
return StreamingResponse(
stream_events(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)