138 lines
5.1 KiB
Python
138 lines
5.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
from fastapi.responses import JSONResponse, StreamingResponse
|
|
|
|
from app_dependencies import get_runtime
|
|
from message_pipeline import encode_sse_data
|
|
from nanobot_api_client import NanobotApiError, send_nanobot_api_request
|
|
from route_helpers import read_json_request
|
|
from web_runtime import WebAppRuntime
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.get("/tools")
|
|
async def list_tools() -> JSONResponse:
|
|
try:
|
|
result = await send_nanobot_api_request("tool.list", {}, timeout_seconds=20.0)
|
|
except NanobotApiError as exc:
|
|
status_code = 503 if exc.code == -32000 else 502
|
|
return JSONResponse({"error": str(exc)}, status_code=status_code)
|
|
except RuntimeError as exc:
|
|
return JSONResponse({"error": str(exc)}, status_code=503)
|
|
|
|
if not isinstance(result, dict):
|
|
return JSONResponse({"error": "Nanobot API returned an invalid tool list"}, status_code=502)
|
|
|
|
tools = result.get("tools", [])
|
|
if not isinstance(tools, list):
|
|
return JSONResponse({"error": "Nanobot API returned an invalid tool list"}, status_code=502)
|
|
return JSONResponse({"tools": tools})
|
|
|
|
|
|
@router.post("/tools/call")
|
|
async def call_tool(
|
|
request: Request, runtime: WebAppRuntime = Depends(get_runtime)
|
|
) -> JSONResponse:
|
|
try:
|
|
payload = await read_json_request(request)
|
|
except ValueError as exc:
|
|
return JSONResponse({"error": str(exc)}, status_code=400)
|
|
|
|
tool_name = str(payload.get("tool_name", payload.get("name", ""))).strip()
|
|
if not tool_name:
|
|
return JSONResponse({"error": "tool_name is required"}, status_code=400)
|
|
|
|
arguments = payload.get("arguments", payload.get("params", {}))
|
|
if arguments is None:
|
|
arguments = {}
|
|
if not isinstance(arguments, dict):
|
|
return JSONResponse({"error": "arguments must be a JSON object"}, status_code=400)
|
|
async_requested = payload.get("async") is True
|
|
|
|
if async_requested:
|
|
job_payload = await runtime.tool_job_service.start_job(tool_name, arguments)
|
|
return JSONResponse(job_payload, status_code=202)
|
|
|
|
try:
|
|
result = await send_nanobot_api_request(
|
|
"tool.call",
|
|
{"name": tool_name, "arguments": arguments},
|
|
timeout_seconds=60.0,
|
|
)
|
|
except NanobotApiError as exc:
|
|
status_code = 400 if exc.code == -32602 else 503 if exc.code == -32000 else 502
|
|
return JSONResponse({"error": str(exc)}, status_code=status_code)
|
|
except RuntimeError as exc:
|
|
return JSONResponse({"error": str(exc)}, status_code=503)
|
|
|
|
if not isinstance(result, dict):
|
|
return JSONResponse(
|
|
{"error": "Nanobot API returned an invalid tool response"}, status_code=502
|
|
)
|
|
return JSONResponse(result)
|
|
|
|
|
|
@router.get("/tools/jobs/{job_id}")
|
|
async def get_tool_job(job_id: str, runtime: WebAppRuntime = Depends(get_runtime)) -> JSONResponse:
|
|
safe_job_id = job_id.strip()
|
|
if not safe_job_id:
|
|
return JSONResponse({"error": "job id is required"}, status_code=400)
|
|
|
|
payload = await runtime.tool_job_service.get_job(safe_job_id)
|
|
if payload is None:
|
|
return JSONResponse({"error": "tool job not found"}, status_code=404)
|
|
return JSONResponse(payload)
|
|
|
|
|
|
@router.get("/tools/jobs/{job_id}/stream")
|
|
async def stream_tool_job(
|
|
job_id: str,
|
|
request: Request,
|
|
runtime: WebAppRuntime = Depends(get_runtime),
|
|
):
|
|
safe_job_id = job_id.strip()
|
|
if not safe_job_id:
|
|
return JSONResponse({"error": "job id is required"}, status_code=400)
|
|
|
|
subscription = await runtime.tool_job_service.subscribe_job(safe_job_id)
|
|
if subscription is None:
|
|
return JSONResponse({"error": "tool job not found"}, status_code=404)
|
|
subscription_id, queue = subscription
|
|
current = await runtime.tool_job_service.get_job(safe_job_id)
|
|
if current is None:
|
|
await runtime.tool_job_service.unsubscribe_job(safe_job_id, subscription_id)
|
|
return JSONResponse({"error": "tool job not found"}, status_code=404)
|
|
|
|
async def stream_events():
|
|
try:
|
|
yield ": stream-open\n\n"
|
|
yield encode_sse_data({"type": "tool.job", "job": current})
|
|
if current.get("status") in {"completed", "failed"}:
|
|
return
|
|
while True:
|
|
if await request.is_disconnected():
|
|
break
|
|
try:
|
|
payload = await asyncio.wait_for(queue.get(), timeout=20.0)
|
|
except asyncio.TimeoutError:
|
|
yield ": keepalive\n\n"
|
|
continue
|
|
yield encode_sse_data({"type": "tool.job", "job": payload})
|
|
if payload.get("status") in {"completed", "failed"}:
|
|
break
|
|
finally:
|
|
await runtime.tool_job_service.unsubscribe_job(safe_job_id, subscription_id)
|
|
|
|
return StreamingResponse(
|
|
stream_events(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no",
|
|
},
|
|
)
|