132 lines
4.5 KiB
Python
132 lines
4.5 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import unittest
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from tool_job_service import ToolJobService
|
||
|
|
|
||
|
|
|
||
|
|
async def wait_for_job_status(
|
||
|
|
service: ToolJobService,
|
||
|
|
job_id: str,
|
||
|
|
*,
|
||
|
|
status: str,
|
||
|
|
attempts: int = 50,
|
||
|
|
) -> dict[str, Any]:
|
||
|
|
for _ in range(attempts):
|
||
|
|
payload = await service.get_job(job_id)
|
||
|
|
if payload is not None and payload["status"] == status:
|
||
|
|
return payload
|
||
|
|
await asyncio.sleep(0.01)
|
||
|
|
raise AssertionError(f"job {job_id} never reached status {status}")
|
||
|
|
|
||
|
|
|
||
|
|
class ToolJobServiceTests(unittest.IsolatedAsyncioTestCase):
|
||
|
|
async def test_start_job_completes_and_returns_unwrapped_result(self) -> None:
|
||
|
|
async def send_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
self.assertEqual(method, "tool.call")
|
||
|
|
self.assertEqual(params["name"], "demo.tool")
|
||
|
|
return {
|
||
|
|
"tool_name": "demo.tool",
|
||
|
|
"content": '{"ok": true}',
|
||
|
|
"parsed": {"ok": True, "arguments": params["arguments"]},
|
||
|
|
"is_json": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
service = ToolJobService(
|
||
|
|
send_request=send_request,
|
||
|
|
timeout_seconds=30.0,
|
||
|
|
retention_seconds=60.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
job = await service.start_job("demo.tool", {"count": 3})
|
||
|
|
self.assertEqual(job["status"], "queued")
|
||
|
|
self.assertEqual(job["tool_name"], "demo.tool")
|
||
|
|
|
||
|
|
completed = await wait_for_job_status(service, job["job_id"], status="completed")
|
||
|
|
self.assertEqual(
|
||
|
|
completed["result"],
|
||
|
|
{
|
||
|
|
"tool_name": "demo.tool",
|
||
|
|
"content": '{"ok": true}',
|
||
|
|
"parsed": {"ok": True, "arguments": {"count": 3}},
|
||
|
|
"is_json": True,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
self.assertIsNone(completed["error"])
|
||
|
|
|
||
|
|
async def test_shutdown_cancels_running_jobs(self) -> None:
|
||
|
|
gate = asyncio.Event()
|
||
|
|
|
||
|
|
async def send_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
await gate.wait()
|
||
|
|
return {"parsed": {"done": True}}
|
||
|
|
|
||
|
|
service = ToolJobService(
|
||
|
|
send_request=send_request,
|
||
|
|
timeout_seconds=30.0,
|
||
|
|
retention_seconds=60.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
job = await service.start_job("slow.tool", {})
|
||
|
|
await asyncio.sleep(0.01)
|
||
|
|
await service.shutdown()
|
||
|
|
|
||
|
|
failed = await wait_for_job_status(service, job["job_id"], status="failed")
|
||
|
|
self.assertEqual(failed["error"], "tool job cancelled")
|
||
|
|
self.assertIsNone(failed["result"])
|
||
|
|
|
||
|
|
async def test_failed_job_exposes_error_code_when_present(self) -> None:
|
||
|
|
class ToolFailure(RuntimeError):
|
||
|
|
def __init__(self, message: str, code: int) -> None:
|
||
|
|
super().__init__(message)
|
||
|
|
self.code = code
|
||
|
|
|
||
|
|
async def send_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
raise ToolFailure("boom", 418)
|
||
|
|
|
||
|
|
service = ToolJobService(
|
||
|
|
send_request=send_request,
|
||
|
|
timeout_seconds=30.0,
|
||
|
|
retention_seconds=60.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
job = await service.start_job("broken.tool", {})
|
||
|
|
failed = await wait_for_job_status(service, job["job_id"], status="failed")
|
||
|
|
self.assertEqual(failed["error"], "boom")
|
||
|
|
self.assertEqual(failed["error_code"], 418)
|
||
|
|
|
||
|
|
async def test_subscribers_receive_running_and_completed_updates(self) -> None:
|
||
|
|
gate = asyncio.Event()
|
||
|
|
|
||
|
|
async def send_request(method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||
|
|
gate.set()
|
||
|
|
return {
|
||
|
|
"tool_name": "demo.tool",
|
||
|
|
"content": '{"ok": true}',
|
||
|
|
"parsed": {"ok": True},
|
||
|
|
"is_json": True,
|
||
|
|
}
|
||
|
|
|
||
|
|
service = ToolJobService(
|
||
|
|
send_request=send_request,
|
||
|
|
timeout_seconds=30.0,
|
||
|
|
retention_seconds=60.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
job = await service.start_job("demo.tool", {})
|
||
|
|
subscription = await service.subscribe_job(job["job_id"])
|
||
|
|
self.assertIsNotNone(subscription)
|
||
|
|
subscription_id, queue = subscription or ("", asyncio.Queue())
|
||
|
|
|
||
|
|
await gate.wait()
|
||
|
|
first_update = await asyncio.wait_for(queue.get(), timeout=0.2)
|
||
|
|
self.assertEqual(first_update["status"], "running")
|
||
|
|
|
||
|
|
second_update = await asyncio.wait_for(queue.get(), timeout=0.2)
|
||
|
|
self.assertEqual(second_update["status"], "completed")
|
||
|
|
self.assertEqual(second_update["result"]["tool_name"], "demo.tool")
|
||
|
|
|
||
|
|
await service.unsubscribe_job(job["job_id"], subscription_id)
|