Created
May 4, 2025 23:09
-
-
Save jhgaylor/92a6eb31c8024b41f663f1ee112c73e1 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import logging | |
from typing import AsyncIterable, List, Union | |
from google_a2a.common.server.task_manager import TaskManager | |
from google_a2a.common.server.utils import new_not_implemented_error | |
from google_a2a.common.types import ( | |
Task, | |
TaskSendParams, | |
TaskStatus, | |
TaskState, | |
Artifact, | |
PushNotificationConfig, | |
TaskStatusUpdateEvent, | |
GetTaskRequest, | |
GetTaskResponse, | |
TaskQueryParams, | |
TaskNotFoundError, | |
CancelTaskRequest, | |
CancelTaskResponse, | |
TaskIdParams, | |
TaskNotCancelableError, | |
SendTaskRequest, | |
SendTaskResponse, | |
SendTaskStreamingRequest, | |
SendTaskStreamingResponse, | |
JSONRPCResponse, | |
JSONRPCError, | |
SetTaskPushNotificationRequest, | |
SetTaskPushNotificationResponse, | |
GetTaskPushNotificationRequest, | |
GetTaskPushNotificationResponse, | |
TaskPushNotificationConfig, | |
TaskResubscriptionRequest, | |
InternalError, | |
) | |
logger = logging.getLogger(__name__) | |
class Memory: | |
def __init__(self): | |
self.tasks: dict[str, Task] = {} | |
self.push_notification_infos: dict[str, PushNotificationConfig] = {} | |
self.lock = asyncio.Lock() | |
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} | |
self.subscriber_lock = asyncio.Lock() | |
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig): | |
async with self.lock: | |
task = self.tasks.get(task_id) | |
if task is None: | |
raise ValueError(f"Task not found for {task_id}") | |
self.push_notification_infos[task_id] = notification_config | |
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: | |
async with self.lock: | |
task = self.tasks.get(task_id) | |
if task is None: | |
raise ValueError(f"Task not found for {task_id}") | |
return self.push_notification_infos[task_id] | |
async def has_push_notification_info(self, task_id: str) -> bool: | |
async with self.lock: | |
return task_id in self.push_notification_infos | |
async def upsert_task(self, params: TaskSendParams) -> Task: | |
logger.info(f"Upserting task {params.id}") | |
async with self.lock: | |
task = self.tasks.get(params.id) | |
if task is None: | |
task = Task( | |
id=params.id, | |
sessionId=params.sessionId, | |
messages=[params.message], | |
status=TaskStatus(state=TaskState.SUBMITTED), | |
history=[params.message], | |
) | |
self.tasks[params.id] = task | |
else: | |
task.history.append(params.message) | |
return task | |
async def update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task: | |
async with self.lock: | |
try: | |
task = self.tasks[task_id] | |
except KeyError: | |
logger.error(f"Task {task_id} not found for updating the task") | |
raise ValueError(f"Task {task_id} not found") | |
task.status = status | |
if status.message is not None: | |
task.history.append(status.message) | |
if artifacts is not None: | |
if task.artifacts is None: | |
task.artifacts = [] | |
task.artifacts.extend(artifacts) | |
return task | |
def append_task_history(self, task: Task, historyLength: int | None): | |
new_task = task.model_copy() | |
if historyLength is not None and historyLength > 0: | |
new_task.history = new_task.history[-historyLength:] | |
else: | |
new_task.history = [] | |
return new_task | |
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): | |
async with self.subscriber_lock: | |
if task_id not in self.task_sse_subscribers: | |
if is_resubscribe: | |
raise ValueError("Task not found for resubscription") | |
else: | |
self.task_sse_subscribers[task_id] = [] | |
sse_event_queue = asyncio.Queue(maxsize=0) | |
self.task_sse_subscribers[task_id].append(sse_event_queue) | |
return sse_event_queue | |
async def enqueue_events_for_sse(self, task_id: str, task_update_event): | |
async with self.subscriber_lock: | |
if task_id not in self.task_sse_subscribers: | |
return | |
for subscriber in list(self.task_sse_subscribers[task_id]): | |
await subscriber.put(task_update_event) | |
async def dequeue_events_for_sse( | |
self, request_id, task_id, sse_event_queue: asyncio.Queue | |
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: | |
try: | |
while True: | |
event = await sse_event_queue.get() | |
if isinstance(event, JSONRPCError): | |
yield SendTaskStreamingResponse(id=request_id, error=event) | |
break | |
yield SendTaskStreamingResponse(id=request_id, result=event) | |
if isinstance(event, TaskStatusUpdateEvent) and event.final: | |
break | |
finally: | |
async with self.subscriber_lock: | |
if task_id in self.task_sse_subscribers: | |
self.task_sse_subscribers[task_id].remove(sse_event_queue) | |
class MemoryTaskManager(TaskManager, Memory): | |
def __init__(self): | |
Memory.__init__(self) | |
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: | |
logger.info(f"Getting task {request.params.id}") | |
task_query_params: TaskQueryParams = request.params | |
async with self.lock: | |
task = self.tasks.get(task_query_params.id) | |
if task is None: | |
return GetTaskResponse(id=request.id, error=TaskNotFoundError()) | |
task_result = self.append_task_history(task, task_query_params.historyLength) | |
return GetTaskResponse(id=request.id, result=task_result) | |
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: | |
logger.info(f"Cancelling task {request.params.id}") | |
task_id_params: TaskIdParams = request.params | |
async with self.lock: | |
task = self.tasks.get(task_id_params.id) | |
if task is None: | |
return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) | |
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) | |
async def on_set_task_push_notification( | |
self, request: SetTaskPushNotificationRequest | |
) -> SetTaskPushNotificationResponse: | |
logger.info(f"Setting task push notification {request.params.id}") | |
task_notification_params = request.params | |
try: | |
await self.set_push_notification_info( | |
task_notification_params.id, | |
task_notification_params.pushNotificationConfig, | |
) | |
except Exception as e: | |
logger.error(f"Error while setting push notification info: {e}") | |
return JSONRPCResponse( | |
id=request.id, | |
error=InternalError( | |
message="An error occurred while setting push notification info" | |
), | |
) | |
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params) | |
async def on_get_task_push_notification( | |
self, request: GetTaskPushNotificationRequest | |
) -> GetTaskPushNotificationResponse: | |
logger.info(f"Getting task push notification {request.params.id}") | |
task_params: TaskIdParams = request.params | |
try: | |
notification_info = await self.get_push_notification_info(task_params.id) | |
except Exception as e: | |
logger.error(f"Error while getting push notification info: {e}") | |
return GetTaskPushNotificationResponse( | |
id=request.id, | |
error=InternalError( | |
message="An error occurred while getting push notification info" | |
), | |
) | |
return GetTaskPushNotificationResponse( | |
id=request.id, | |
result=TaskPushNotificationConfig( | |
id=task_params.id, | |
pushNotificationConfig=notification_info, | |
), | |
) | |
async def on_resubscribe_to_task( | |
self, request: TaskResubscriptionRequest | |
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: | |
return new_not_implemented_error(request.id) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from abc import ABC, abstractmethod | |
from typing import Union, AsyncIterable, List | |
from google_a2a.common.types import Task | |
from google_a2a.common.types import ( | |
JSONRPCResponse, | |
TaskIdParams, | |
TaskQueryParams, | |
GetTaskRequest, | |
TaskNotFoundError, | |
SendTaskRequest, | |
CancelTaskRequest, | |
TaskNotCancelableError, | |
SetTaskPushNotificationRequest, | |
GetTaskPushNotificationRequest, | |
GetTaskResponse, | |
CancelTaskResponse, | |
SendTaskResponse, | |
SetTaskPushNotificationResponse, | |
GetTaskPushNotificationResponse, | |
PushNotificationNotSupportedError, | |
TaskSendParams, | |
TaskStatus, | |
TaskState, | |
TaskResubscriptionRequest, | |
SendTaskStreamingRequest, | |
SendTaskStreamingResponse, | |
Artifact, | |
PushNotificationConfig, | |
TaskStatusUpdateEvent, | |
JSONRPCError, | |
TaskPushNotificationConfig, | |
InternalError, | |
) | |
from google_a2a.common.server.utils import new_not_implemented_error | |
import asyncio | |
import logging | |
logger = logging.getLogger(__name__) | |
class TaskManager(ABC): | |
@abstractmethod | |
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: | |
pass | |
@abstractmethod | |
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: | |
pass | |
@abstractmethod | |
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: | |
pass | |
@abstractmethod | |
async def on_send_task_subscribe( | |
self, request: SendTaskStreamingRequest | |
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: | |
pass | |
@abstractmethod | |
async def on_set_task_push_notification( | |
self, request: SetTaskPushNotificationRequest | |
) -> SetTaskPushNotificationResponse: | |
pass | |
@abstractmethod | |
async def on_get_task_push_notification( | |
self, request: GetTaskPushNotificationRequest | |
) -> GetTaskPushNotificationResponse: | |
pass | |
@abstractmethod | |
async def on_resubscribe_to_task( | |
self, request: TaskResubscriptionRequest | |
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]: | |
pass | |
class InMemoryTaskManager(TaskManager): | |
def __init__(self): | |
self.tasks: dict[str, Task] = {} | |
self.push_notification_infos: dict[str, PushNotificationConfig] = {} | |
self.lock = asyncio.Lock() | |
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {} | |
self.subscriber_lock = asyncio.Lock() | |
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: | |
logger.info(f"Getting task {request.params.id}") | |
task_query_params: TaskQueryParams = request.params | |
async with self.lock: | |
task = self.tasks.get(task_query_params.id) | |
if task is None: | |
return GetTaskResponse(id=request.id, error=TaskNotFoundError()) | |
task_result = self.append_task_history( | |
task, task_query_params.historyLength | |
) | |
return GetTaskResponse(id=request.id, result=task_result) | |
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse: | |
logger.info(f"Cancelling task {request.params.id}") | |
task_id_params: TaskIdParams = request.params | |
async with self.lock: | |
task = self.tasks.get(task_id_params.id) | |
if task is None: | |
return CancelTaskResponse(id=request.id, error=TaskNotFoundError()) | |
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError()) | |
@abstractmethod | |
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse: | |
pass | |
@abstractmethod | |
async def on_send_task_subscribe( | |
self, request: SendTaskStreamingRequest | |
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: | |
pass | |
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig): | |
async with self.lock: | |
task = self.tasks.get(task_id) | |
if task is None: | |
raise ValueError(f"Task not found for {task_id}") | |
self.push_notification_infos[task_id] = notification_config | |
return | |
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig: | |
async with self.lock: | |
task = self.tasks.get(task_id) | |
if task is None: | |
raise ValueError(f"Task not found for {task_id}") | |
return self.push_notification_infos[task_id] | |
return | |
async def has_push_notification_info(self, task_id: str) -> bool: | |
async with self.lock: | |
return task_id in self.push_notification_infos | |
async def on_set_task_push_notification( | |
self, request: SetTaskPushNotificationRequest | |
) -> SetTaskPushNotificationResponse: | |
logger.info(f"Setting task push notification {request.params.id}") | |
task_notification_params: TaskPushNotificationConfig = request.params | |
try: | |
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig) | |
except Exception as e: | |
logger.error(f"Error while setting push notification info: {e}") | |
return JSONRPCResponse( | |
id=request.id, | |
error=InternalError( | |
message="An error occurred while setting push notification info" | |
), | |
) | |
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params) | |
async def on_get_task_push_notification( | |
self, request: GetTaskPushNotificationRequest | |
) -> GetTaskPushNotificationResponse: | |
logger.info(f"Getting task push notification {request.params.id}") | |
task_params: TaskIdParams = request.params | |
try: | |
notification_info = await self.get_push_notification_info(task_params.id) | |
except Exception as e: | |
logger.error(f"Error while getting push notification info: {e}") | |
return GetTaskPushNotificationResponse( | |
id=request.id, | |
error=InternalError( | |
message="An error occurred while getting push notification info" | |
), | |
) | |
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info)) | |
async def upsert_task(self, task_send_params: TaskSendParams) -> Task: | |
logger.info(f"Upserting task {task_send_params.id}") | |
async with self.lock: | |
task = self.tasks.get(task_send_params.id) | |
if task is None: | |
task = Task( | |
id=task_send_params.id, | |
sessionId = task_send_params.sessionId, | |
messages=[task_send_params.message], | |
status=TaskStatus(state=TaskState.SUBMITTED), | |
history=[task_send_params.message], | |
) | |
self.tasks[task_send_params.id] = task | |
else: | |
task.history.append(task_send_params.message) | |
return task | |
async def on_resubscribe_to_task( | |
self, request: TaskResubscriptionRequest | |
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]: | |
return new_not_implemented_error(request.id) | |
async def update_store( | |
self, task_id: str, status: TaskStatus, artifacts: list[Artifact] | |
) -> Task: | |
async with self.lock: | |
try: | |
task = self.tasks[task_id] | |
except KeyError: | |
logger.error(f"Task {task_id} not found for updating the task") | |
raise ValueError(f"Task {task_id} not found") | |
task.status = status | |
if status.message is not None: | |
task.history.append(status.message) | |
if artifacts is not None: | |
if task.artifacts is None: | |
task.artifacts = [] | |
task.artifacts.extend(artifacts) | |
return task | |
def append_task_history(self, task: Task, historyLength: int | None): | |
new_task = task.model_copy() | |
if historyLength is not None and historyLength > 0: | |
new_task.history = new_task.history[-historyLength:] | |
else: | |
new_task.history = [] | |
return new_task | |
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False): | |
async with self.subscriber_lock: | |
if task_id not in self.task_sse_subscribers: | |
if is_resubscribe: | |
raise ValueError("Task not found for resubscription") | |
else: | |
self.task_sse_subscribers[task_id] = [] | |
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited | |
self.task_sse_subscribers[task_id].append(sse_event_queue) | |
return sse_event_queue | |
async def enqueue_events_for_sse(self, task_id, task_update_event): | |
async with self.subscriber_lock: | |
if task_id not in self.task_sse_subscribers: | |
return | |
current_subscribers = self.task_sse_subscribers[task_id] | |
for subscriber in current_subscribers: | |
await subscriber.put(task_update_event) | |
async def dequeue_events_for_sse( | |
self, request_id, task_id, sse_event_queue: asyncio.Queue | |
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse: | |
try: | |
while True: | |
event = await sse_event_queue.get() | |
if isinstance(event, JSONRPCError): | |
yield SendTaskStreamingResponse(id=request_id, error=event) | |
break | |
yield SendTaskStreamingResponse(id=request_id, result=event) | |
if isinstance(event, TaskStatusUpdateEvent) and event.final: | |
break | |
finally: | |
async with self.subscriber_lock: | |
if task_id in self.task_sse_subscribers: | |
self.task_sse_subscribers[task_id].remove(sse_event_queue) | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment