Skip to content

refactor!: Convert class fields in types.py to snake_case #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ sql = ["sqlalchemy[asyncio,postgresql-asyncpg,aiomysql,aiosqlite]>=2.0.0"]
encryption = ["cryptography>=43.0.0"]

[project.urls]
homepage = "https://a2a-protocol.org/"
homepage = "https://a2aproject.github.io/A2A/"
repository = "https://github.com/a2aproject/a2a-python"
changelog = "https://github.com/a2aproject/a2a-python/blob/main/CHANGELOG.md"
documentation = "https://a2a-protocol.org/latest/sdk/python/"
documentation = "https://a2aproject.github.io/A2A/sdk/python/"

[tool.hatch.build.targets.wheel]
packages = ["src/a2a"]
Expand Down
4 changes: 3 additions & 1 deletion scripts/generate_types.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ uv run datamodel-codegen \
--class-name A2A \
--use-standard-collections \
--use-subclass-enum \
--base-class a2a._base.A2ABaseModel
--base-class a2a._base.A2ABaseModel \
--snake-case-field \
--no-alias

echo "Formatting generated file with ruff..."
uv run ruff format "$GENERATED_FILE"
Expand Down
3 changes: 3 additions & 0 deletions src/a2a/_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel


class A2ABaseModel(BaseModel):
Expand All @@ -12,4 +13,6 @@ class A2ABaseModel(BaseModel):
# SEE: https://docs.pydantic.dev/latest/api/config/#pydantic.config.ConfigDict.populate_by_name
validate_by_name=True,
validate_by_alias=True,
serialize_by_alias=True,
alias_generator=to_camel,
)
6 changes: 3 additions & 3 deletions src/a2a/client/auth/interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def intercept(
if (
agent_card is None
or agent_card.security is None
or agent_card.securitySchemes is None
or agent_card.security_schemes is None
):
return request_payload, http_kwargs

Expand All @@ -45,8 +45,8 @@ async def intercept(
credential = await self._credential_service.get_credentials(
scheme_name, context
)
if credential and scheme_name in agent_card.securitySchemes:
scheme_def_union = agent_card.securitySchemes.get(
if credential and scheme_name in agent_card.security_schemes:
scheme_def_union = agent_card.security_schemes.get(
scheme_name
)
if not scheme_def_union:
Expand Down
4 changes: 2 additions & 2 deletions src/a2a/client/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def create_text_message_object(
content: The text content of the message. Defaults to an empty string.
Returns:
A `Message` object with a new UUID messageId.
A `Message` object with a new UUID message_id.
"""
return Message(
role=role, parts=[Part(TextPart(text=content))], messageId=str(uuid4())
role=role, parts=[Part(TextPart(text=content))], message_id=str(uuid4())
)
22 changes: 11 additions & 11 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,14 @@ def __init__( # noqa: PLR0913
# match the request. Otherwise, create them
if self._params:
if task_id:
self._params.message.taskId = task_id
self._params.message.task_id = task_id
if task and task.id != task_id:
raise ServerError(InvalidParamsError(message='bad task id'))
else:
self._check_or_generate_task_id()
if context_id:
self._params.message.contextId = context_id
if task and task.contextId != context_id:
self._params.message.context_id = context_id
if task and task.context_id != context_id:
raise ServerError(
InvalidParamsError(message='bad context id')
)
Expand Down Expand Up @@ -148,17 +148,17 @@ def _check_or_generate_task_id(self) -> None:
if not self._params:
return

if not self._task_id and not self._params.message.taskId:
self._params.message.taskId = str(uuid.uuid4())
if self._params.message.taskId:
self._task_id = self._params.message.taskId
if not self._task_id and not self._params.message.task_id:
self._params.message.task_id = str(uuid.uuid4())
if self._params.message.task_id:
self._task_id = self._params.message.task_id

def _check_or_generate_context_id(self) -> None:
"""Ensures a context ID is present, generating one if necessary."""
if not self._params:
return

if not self._context_id and not self._params.message.contextId:
self._params.message.contextId = str(uuid.uuid4())
if self._params.message.contextId:
self._context_id = self._params.message.contextId
if not self._context_id and not self._params.message.context_id:
self._params.message.context_id = str(uuid.uuid4())
if self._params.message.context_id:
self._context_id = self._params.message.context_id
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(

Args:
should_populate_referred_tasks: If True, the builder will fetch tasks
referenced in `params.message.referenceTaskIds` and populate the
referenced in `params.message.reference_task_ids` and populate the
`related_tasks` field in the RequestContext. Defaults to False.
task_store: The TaskStore instance to use for fetching referred tasks.
Required if `should_populate_referred_tasks` is True.
Expand All @@ -38,7 +38,7 @@ async def build(

This method assembles the RequestContext object. If the builder was
initialized with `should_populate_referred_tasks=True`, it fetches all tasks
referenced in `params.message.referenceTaskIds` from the `task_store`.
referenced in `params.message.reference_task_ids` from the `task_store`.

Args:
params: The parameters of the incoming message send request.
Expand All @@ -57,12 +57,12 @@ async def build(
self._task_store
and self._should_populate_referred_tasks
and params
and params.message.referenceTaskIds
and params.message.reference_task_ids
):
tasks = await asyncio.gather(
*[
self._task_store.get(task_id)
for task_id in params.message.referenceTaskIds
for task_id in params.message.reference_task_ids
]
)
related_tasks = [x for x in tasks if x is not None]
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/fastapi_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def add_routes_to_app(
)(self._handle_requests)
app.get(agent_card_url)(self._handle_get_agent_card)

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app.get(extended_agent_card_url)(
self._handle_get_authenticated_extended_agent_card
)
Expand Down
8 changes: 4 additions & 4 deletions src/a2a/server/apps/jsonrpc/jsonrpc_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ def __init__(
agent_card=agent_card, request_handler=http_handler
)
if (
self.agent_card.supportsAuthenticatedExtendedCard
self.agent_card.supports_authenticated_extended_card
and self.extended_agent_card is None
):
logger.error(
'AgentCard.supportsAuthenticatedExtendedCard is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
'AgentCard.supports_authenticated_extended_card is True, but no extended_agent_card was provided. The /agent/authenticatedExtendedCard endpoint will return 404.'
)
self._context_builder = context_builder or DefaultCallContextBuilder()

Expand Down Expand Up @@ -421,7 +421,7 @@ async def _handle_get_authenticated_extended_agent_card(
self, request: Request
) -> JSONResponse:
"""Handles GET requests for the authenticated extended agent card."""
if not self.agent_card.supportsAuthenticatedExtendedCard:
if not self.agent_card.supports_authenticated_extended_card:
return JSONResponse(
{'error': 'Extended agent card not supported or not enabled.'},
status_code=404,
Expand All @@ -435,7 +435,7 @@ async def _handle_get_authenticated_extended_agent_card(
by_alias=True,
)
)
# If supportsAuthenticatedExtendedCard is true, but no specific
# If supports_authenticated_extended_card is true, but no specific
# extended_agent_card was provided during server initialization,
# return a 404
return JSONResponse(
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/apps/jsonrpc/starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def routes(
),
]

if self.agent_card.supportsAuthenticatedExtendedCard:
if self.agent_card.supports_authenticated_extended_card:
app_routes.append(
Route(
extended_agent_card_url,
Expand Down
10 changes: 5 additions & 5 deletions src/a2a/server/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ class TaskMixin:
"""Mixin providing standard task columns with proper type handling."""

id: Mapped[str] = mapped_column(String(36), primary_key=True, index=True)
contextId: Mapped[str] = mapped_column(String(36), nullable=False) # noqa: N815
context_id: Mapped[str] = mapped_column(String(36), nullable=False)
kind: Mapped[str] = mapped_column(
String(16), nullable=False, default='task'
)
Expand All @@ -148,12 +148,12 @@ def task_metadata(cls) -> Mapped[dict[str, Any] | None]:
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = (
'<{CLS}(id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
'<{CLS}(id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
)
return repr_template.format(
CLS=self.__class__.__name__,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down Expand Up @@ -188,11 +188,11 @@ class TaskModel(TaskMixin, base):
@override
def __repr__(self) -> str:
"""Return a string representation of the task."""
repr_template = '<TaskModel[{TABLE}](id="{ID}", contextId="{CTX_ID}", status="{STATUS}")>'
repr_template = '<TaskModel[{TABLE}](id="{ID}", context_id="{CTX_ID}", status="{STATUS}")>'
return repr_template.format(
TABLE=table_name,
ID=self.id,
CTX_ID=self.contextId,
CTX_ID=self.context_id,
STATUS=self.status,
)

Expand Down
29 changes: 15 additions & 14 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def on_cancel_task(

task_manager = TaskManager(
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
)
Expand All @@ -140,7 +140,7 @@ async def on_cancel_task(
RequestContext(
None,
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task=task,
),
queue,
Expand Down Expand Up @@ -184,8 +184,8 @@ async def _setup_message_execution(
"""
# Create task manager and validate existing task
task_manager = TaskManager(
task_id=params.message.taskId,
context_id=params.message.contextId,
task_id=params.message.task_id,
context_id=params.message.context_id,
task_store=self.task_store,
initial_message=params.message,
)
Expand All @@ -205,7 +205,7 @@ async def _setup_message_execution(
request_context = await self._request_context_builder.build(
params=params,
task_id=task.id if task else None,
context_id=params.message.contextId,
context_id=params.message.context_id,
task=task,
context=context,
)
Expand All @@ -218,10 +218,10 @@ async def _setup_message_execution(
if (
self._push_config_store
and params.configuration
and params.configuration.pushNotificationConfig
and params.configuration.push_notification_config
):
await self._push_config_store.set_info(
task_id, params.configuration.pushNotificationConfig
task_id, params.configuration.push_notification_config
)

queue = await self._queue_manager.create_or_tap(task_id)
Expand Down Expand Up @@ -366,13 +366,13 @@ async def on_set_task_push_notification_config(
if not self._push_config_store:
raise ServerError(error=UnsupportedOperationError())

task: Task | None = await self.task_store.get(params.taskId)
task: Task | None = await self.task_store.get(params.task_id)
if not task:
raise ServerError(error=TaskNotFoundError())

await self._push_config_store.set_info(
params.taskId,
params.pushNotificationConfig,
params.task_id,
params.push_notification_config,
)

return params
Expand Down Expand Up @@ -404,7 +404,8 @@ async def on_get_task_push_notification_config(
)

return TaskPushNotificationConfig(
taskId=params.id, pushNotificationConfig=push_notification_config[0]
task_id=params.id,
pushNotificationConfig=push_notification_config[0],
)

async def on_resubscribe_to_task(
Expand All @@ -430,7 +431,7 @@ async def on_resubscribe_to_task(

task_manager = TaskManager(
task_id=task.id,
context_id=task.contextId,
context_id=task.context_id,
task_store=self.task_store,
initial_message=None,
)
Expand Down Expand Up @@ -470,7 +471,7 @@ async def on_list_task_push_notification_config(
for config in push_notification_config_list:
task_push_notification_config.append(
TaskPushNotificationConfig(
taskId=params.id, pushNotificationConfig=config
task_id=params.id, pushNotificationConfig=config
)
)

Expand All @@ -493,5 +494,5 @@ async def on_delete_task_push_notification_config(
raise ServerError(error=TaskNotFoundError())

await self._push_config_store.delete_info(
params.id, params.pushNotificationConfigId
params.id, params.push_notification_config_id
)
2 changes: 1 addition & 1 deletion src/a2a/server/request_handlers/grpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ async def GetTaskPushNotificationConfig(
return a2a_pb2.TaskPushNotificationConfig()

@validate(
lambda self: self.agent_card.capabilities.pushNotifications,
lambda self: self.agent_card.capabilities.push_notifications,
'Push notifications are not supported by the agent',
)
async def CreateTaskPushNotificationConfig(
Expand Down
2 changes: 1 addition & 1 deletion src/a2a/server/request_handlers/jsonrpc_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ async def get_push_notification_config(
)

@validate(
lambda self: self.agent_card.capabilities.pushNotifications,
lambda self: self.agent_card.capabilities.push_notifications,
'Push notifications are not supported by the agent',
)
async def set_push_notification_config(
Expand Down
4 changes: 2 additions & 2 deletions src/a2a/server/tasks/database_task_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _to_orm(self, task: Task) -> TaskModel:
"""Maps a Pydantic Task to a SQLAlchemy TaskModel instance."""
return self.task_model(
id=task.id,
contextId=task.contextId,
context_id=task.context_id,
kind=task.kind,
status=task.status,
artifacts=task.artifacts,
Expand All @@ -108,7 +108,7 @@ def _from_orm(self, task_model: TaskModel) -> Task:
# Map database columns to Pydantic model fields
task_data_from_db = {
'id': task_model.id,
'contextId': task_model.contextId,
'context_id': task_model.context_id,
'kind': task_model.kind,
'status': task_model.status,
'artifacts': task_model.artifacts,
Expand Down
Loading
Loading