Skip to content

Instantly share code, notes, and snippets.

@jhgaylor
Created May 4, 2025 23:09
Show Gist options
  • Save jhgaylor/92a6eb31c8024b41f663f1ee112c73e1 to your computer and use it in GitHub Desktop.
Save jhgaylor/92a6eb31c8024b41f663f1ee112c73e1 to your computer and use it in GitHub Desktop.
import asyncio
import logging
from typing import AsyncIterable, List, Union
from google_a2a.common.server.task_manager import TaskManager
from google_a2a.common.server.utils import new_not_implemented_error
from google_a2a.common.types import (
Task,
TaskSendParams,
TaskStatus,
TaskState,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
GetTaskRequest,
GetTaskResponse,
TaskQueryParams,
TaskNotFoundError,
CancelTaskRequest,
CancelTaskResponse,
TaskIdParams,
TaskNotCancelableError,
SendTaskRequest,
SendTaskResponse,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
JSONRPCResponse,
JSONRPCError,
SetTaskPushNotificationRequest,
SetTaskPushNotificationResponse,
GetTaskPushNotificationRequest,
GetTaskPushNotificationResponse,
TaskPushNotificationConfig,
TaskResubscriptionRequest,
InternalError,
)
logger = logging.getLogger(__name__)
class Memory:
def __init__(self):
self.tasks: dict[str, Task] = {}
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
self.lock = asyncio.Lock()
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
self.subscriber_lock = asyncio.Lock()
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
self.push_notification_infos[task_id] = notification_config
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
return self.push_notification_infos[task_id]
async def has_push_notification_info(self, task_id: str) -> bool:
async with self.lock:
return task_id in self.push_notification_infos
async def upsert_task(self, params: TaskSendParams) -> Task:
logger.info(f"Upserting task {params.id}")
async with self.lock:
task = self.tasks.get(params.id)
if task is None:
task = Task(
id=params.id,
sessionId=params.sessionId,
messages=[params.message],
status=TaskStatus(state=TaskState.SUBMITTED),
history=[params.message],
)
self.tasks[params.id] = task
else:
task.history.append(params.message)
return task
async def update_store(self, task_id: str, status: TaskStatus, artifacts: list[Artifact]) -> Task:
async with self.lock:
try:
task = self.tasks[task_id]
except KeyError:
logger.error(f"Task {task_id} not found for updating the task")
raise ValueError(f"Task {task_id} not found")
task.status = status
if status.message is not None:
task.history.append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task
def append_task_history(self, task: Task, historyLength: int | None):
new_task = task.model_copy()
if historyLength is not None and historyLength > 0:
new_task.history = new_task.history[-historyLength:]
else:
new_task.history = []
return new_task
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
else:
self.task_sse_subscribers[task_id] = []
sse_event_queue = asyncio.Queue(maxsize=0)
self.task_sse_subscribers[task_id].append(sse_event_queue)
return sse_event_queue
async def enqueue_events_for_sse(self, task_id: str, task_update_event):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
for subscriber in list(self.task_sse_subscribers[task_id]):
await subscriber.put(task_update_event)
async def dequeue_events_for_sse(
self, request_id, task_id, sse_event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
try:
while True:
event = await sse_event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
if isinstance(event, TaskStatusUpdateEvent) and event.final:
break
finally:
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
self.task_sse_subscribers[task_id].remove(sse_event_queue)
class MemoryTaskManager(TaskManager, Memory):
def __init__(self):
Memory.__init__(self)
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
logger.info(f"Getting task {request.params.id}")
task_query_params: TaskQueryParams = request.params
async with self.lock:
task = self.tasks.get(task_query_params.id)
if task is None:
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
task_result = self.append_task_history(task, task_query_params.historyLength)
return GetTaskResponse(id=request.id, result=task_result)
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
logger.info(f"Cancelling task {request.params.id}")
task_id_params: TaskIdParams = request.params
async with self.lock:
task = self.tasks.get(task_id_params.id)
if task is None:
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
logger.info(f"Setting task push notification {request.params.id}")
task_notification_params = request.params
try:
await self.set_push_notification_info(
task_notification_params.id,
task_notification_params.pushNotificationConfig,
)
except Exception as e:
logger.error(f"Error while setting push notification info: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="An error occurred while setting push notification info"
),
)
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
logger.info(f"Getting task push notification {request.params.id}")
task_params: TaskIdParams = request.params
try:
notification_info = await self.get_push_notification_info(task_params.id)
except Exception as e:
logger.error(f"Error while getting push notification info: {e}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="An error occurred while getting push notification info"
),
)
return GetTaskPushNotificationResponse(
id=request.id,
result=TaskPushNotificationConfig(
id=task_params.id,
pushNotificationConfig=notification_info,
),
)
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
return new_not_implemented_error(request.id)
from abc import ABC, abstractmethod
from typing import Union, AsyncIterable, List
from google_a2a.common.types import Task
from google_a2a.common.types import (
JSONRPCResponse,
TaskIdParams,
TaskQueryParams,
GetTaskRequest,
TaskNotFoundError,
SendTaskRequest,
CancelTaskRequest,
TaskNotCancelableError,
SetTaskPushNotificationRequest,
GetTaskPushNotificationRequest,
GetTaskResponse,
CancelTaskResponse,
SendTaskResponse,
SetTaskPushNotificationResponse,
GetTaskPushNotificationResponse,
PushNotificationNotSupportedError,
TaskSendParams,
TaskStatus,
TaskState,
TaskResubscriptionRequest,
SendTaskStreamingRequest,
SendTaskStreamingResponse,
Artifact,
PushNotificationConfig,
TaskStatusUpdateEvent,
JSONRPCError,
TaskPushNotificationConfig,
InternalError,
)
from google_a2a.common.server.utils import new_not_implemented_error
import asyncio
import logging
logger = logging.getLogger(__name__)
class TaskManager(ABC):
@abstractmethod
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
pass
@abstractmethod
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
pass
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
@abstractmethod
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
pass
@abstractmethod
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskResponse], JSONRPCResponse]:
pass
class InMemoryTaskManager(TaskManager):
def __init__(self):
self.tasks: dict[str, Task] = {}
self.push_notification_infos: dict[str, PushNotificationConfig] = {}
self.lock = asyncio.Lock()
self.task_sse_subscribers: dict[str, List[asyncio.Queue]] = {}
self.subscriber_lock = asyncio.Lock()
async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse:
logger.info(f"Getting task {request.params.id}")
task_query_params: TaskQueryParams = request.params
async with self.lock:
task = self.tasks.get(task_query_params.id)
if task is None:
return GetTaskResponse(id=request.id, error=TaskNotFoundError())
task_result = self.append_task_history(
task, task_query_params.historyLength
)
return GetTaskResponse(id=request.id, result=task_result)
async def on_cancel_task(self, request: CancelTaskRequest) -> CancelTaskResponse:
logger.info(f"Cancelling task {request.params.id}")
task_id_params: TaskIdParams = request.params
async with self.lock:
task = self.tasks.get(task_id_params.id)
if task is None:
return CancelTaskResponse(id=request.id, error=TaskNotFoundError())
return CancelTaskResponse(id=request.id, error=TaskNotCancelableError())
@abstractmethod
async def on_send_task(self, request: SendTaskRequest) -> SendTaskResponse:
pass
@abstractmethod
async def on_send_task_subscribe(
self, request: SendTaskStreamingRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
pass
async def set_push_notification_info(self, task_id: str, notification_config: PushNotificationConfig):
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
self.push_notification_infos[task_id] = notification_config
return
async def get_push_notification_info(self, task_id: str) -> PushNotificationConfig:
async with self.lock:
task = self.tasks.get(task_id)
if task is None:
raise ValueError(f"Task not found for {task_id}")
return self.push_notification_infos[task_id]
return
async def has_push_notification_info(self, task_id: str) -> bool:
async with self.lock:
return task_id in self.push_notification_infos
async def on_set_task_push_notification(
self, request: SetTaskPushNotificationRequest
) -> SetTaskPushNotificationResponse:
logger.info(f"Setting task push notification {request.params.id}")
task_notification_params: TaskPushNotificationConfig = request.params
try:
await self.set_push_notification_info(task_notification_params.id, task_notification_params.pushNotificationConfig)
except Exception as e:
logger.error(f"Error while setting push notification info: {e}")
return JSONRPCResponse(
id=request.id,
error=InternalError(
message="An error occurred while setting push notification info"
),
)
return SetTaskPushNotificationResponse(id=request.id, result=task_notification_params)
async def on_get_task_push_notification(
self, request: GetTaskPushNotificationRequest
) -> GetTaskPushNotificationResponse:
logger.info(f"Getting task push notification {request.params.id}")
task_params: TaskIdParams = request.params
try:
notification_info = await self.get_push_notification_info(task_params.id)
except Exception as e:
logger.error(f"Error while getting push notification info: {e}")
return GetTaskPushNotificationResponse(
id=request.id,
error=InternalError(
message="An error occurred while getting push notification info"
),
)
return GetTaskPushNotificationResponse(id=request.id, result=TaskPushNotificationConfig(id=task_params.id, pushNotificationConfig=notification_info))
async def upsert_task(self, task_send_params: TaskSendParams) -> Task:
logger.info(f"Upserting task {task_send_params.id}")
async with self.lock:
task = self.tasks.get(task_send_params.id)
if task is None:
task = Task(
id=task_send_params.id,
sessionId = task_send_params.sessionId,
messages=[task_send_params.message],
status=TaskStatus(state=TaskState.SUBMITTED),
history=[task_send_params.message],
)
self.tasks[task_send_params.id] = task
else:
task.history.append(task_send_params.message)
return task
async def on_resubscribe_to_task(
self, request: TaskResubscriptionRequest
) -> Union[AsyncIterable[SendTaskStreamingResponse], JSONRPCResponse]:
return new_not_implemented_error(request.id)
async def update_store(
self, task_id: str, status: TaskStatus, artifacts: list[Artifact]
) -> Task:
async with self.lock:
try:
task = self.tasks[task_id]
except KeyError:
logger.error(f"Task {task_id} not found for updating the task")
raise ValueError(f"Task {task_id} not found")
task.status = status
if status.message is not None:
task.history.append(status.message)
if artifacts is not None:
if task.artifacts is None:
task.artifacts = []
task.artifacts.extend(artifacts)
return task
def append_task_history(self, task: Task, historyLength: int | None):
new_task = task.model_copy()
if historyLength is not None and historyLength > 0:
new_task.history = new_task.history[-historyLength:]
else:
new_task.history = []
return new_task
async def setup_sse_consumer(self, task_id: str, is_resubscribe: bool = False):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
if is_resubscribe:
raise ValueError("Task not found for resubscription")
else:
self.task_sse_subscribers[task_id] = []
sse_event_queue = asyncio.Queue(maxsize=0) # <=0 is unlimited
self.task_sse_subscribers[task_id].append(sse_event_queue)
return sse_event_queue
async def enqueue_events_for_sse(self, task_id, task_update_event):
async with self.subscriber_lock:
if task_id not in self.task_sse_subscribers:
return
current_subscribers = self.task_sse_subscribers[task_id]
for subscriber in current_subscribers:
await subscriber.put(task_update_event)
async def dequeue_events_for_sse(
self, request_id, task_id, sse_event_queue: asyncio.Queue
) -> AsyncIterable[SendTaskStreamingResponse] | JSONRPCResponse:
try:
while True:
event = await sse_event_queue.get()
if isinstance(event, JSONRPCError):
yield SendTaskStreamingResponse(id=request_id, error=event)
break
yield SendTaskStreamingResponse(id=request_id, result=event)
if isinstance(event, TaskStatusUpdateEvent) and event.final:
break
finally:
async with self.subscriber_lock:
if task_id in self.task_sse_subscribers:
self.task_sse_subscribers[task_id].remove(sse_event_queue)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment