diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 14b4fe11..24c493e0 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -9,8 +9,10 @@ AServer AServers AService AStarlette +AUser EUR GBP +GVsb INR JPY JSONRPCt @@ -29,6 +31,7 @@ coro datamodel dunders euo +excinfo genai getkwargs gle @@ -39,9 +42,11 @@ lifecycles linting lstrips mockurl +notif oauthoidc oidc opensource +otherurl protoc pyi pyversions diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt deleted file mode 100644 index ade6eb7b..00000000 --- a/.github/actions/spelling/expect.txt +++ /dev/null @@ -1,5 +0,0 @@ -AUser -excinfo -GVsb -notif -otherurl diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml index a01ea487..971bbebc 100644 --- a/.github/workflows/update-a2a-types.yml +++ b/.github/workflows/update-a2a-types.yml @@ -4,6 +4,16 @@ on: repository_dispatch: types: [a2a_json_update] workflow_dispatch: + pull_request: + branches: + - main + paths: + - "scripts/generate_types.sh" + - "src/a2a/pydantic_base.py" + types: + - opened + - synchronize + - reopened jobs: generate_and_pr: @@ -15,6 +25,9 @@ jobs: steps: - name: Checkout code uses: actions/checkout@v4 + with: + ref: ${{ github.event.pull_request.head.ref }} + repository: ${{ github.event.pull_request.head.repo.full_name }} - name: Set up Python uses: actions/setup-python@v5 @@ -53,7 +66,19 @@ jobs: uv run scripts/grpc_gen_post_processor.py echo "Buf generate finished." - - name: Create Pull Request with Updates + - name: Commit changes to current PR + if: github.event_name == 'pull_request' # Only run this step for pull_request events + run: | + git config user.name "a2a-bot" + git config user.email "a2a-bot@google.com" + git add ${{ steps.vars.outputs.GENERATED_FILE }} src/a2a/grpc/ + git diff --cached --exit-code || git commit -m "feat: Update A2A types from specification 🤖" + git push + env: + GITHUB_TOKEN: ${{ secrets.A2A_BOT_PAT }} + + - name: Create Pull Request with Updates (for repository_dispatch/workflow_dispatch) + if: github.event_name != 'pull_request' uses: peter-evans/create-pull-request@v6 with: token: ${{ secrets.A2A_BOT_PAT }} diff --git a/scripts/generate_types.sh b/scripts/generate_types.sh old mode 100644 new mode 100755 index 942197fe..fb04f2bd --- a/scripts/generate_types.sh +++ b/scripts/generate_types.sh @@ -23,6 +23,7 @@ uv run datamodel-codegen \ --output "$GENERATED_FILE" \ --target-python-version 3.10 \ --output-model-type pydantic_v2.BaseModel \ + --base-class a2a.pydantic_base.A2ABaseModel \ --disable-timestamp \ --use-schema-description \ --use-union-operator \ @@ -32,6 +33,9 @@ uv run datamodel-codegen \ --use-one-literal-as-default \ --class-name A2A \ --use-standard-collections \ - --use-subclass-enum + --use-subclass-enum \ + --snake-case-field \ + --no-alias + echo "Codegen finished successfully." diff --git a/src/a2a/client/helpers.py b/src/a2a/client/helpers.py index 4eedadb8..60aabe7b 100644 --- a/src/a2a/client/helpers.py +++ b/src/a2a/client/helpers.py @@ -18,5 +18,7 @@ def create_text_message_object( A `Message` object with a new UUID messageId. """ return Message( - role=role, parts=[Part(TextPart(text=content))], messageId=str(uuid4()) + role=role, + parts=[Part(root=TextPart(text=content))], + message_id=str(uuid4()), ) diff --git a/src/a2a/pydantic_base.py b/src/a2a/pydantic_base.py new file mode 100644 index 00000000..dbad5076 --- /dev/null +++ b/src/a2a/pydantic_base.py @@ -0,0 +1,23 @@ +"""A2A Pydantic Base Model with shared configuration.""" + +from typing import Any + +from pydantic import BaseModel, ConfigDict +from pydantic.alias_generators import to_camel, to_snake + + +class A2ABaseModel(BaseModel): + """Base model for all A2A types with shared configuration.""" + + model_config = ConfigDict( + alias_generator=to_camel, + populate_by_name=True, + ) + + def __getattr__(self, name: str) -> Any: # noqa: D105 + snake = to_snake(name) + if hasattr(self, snake): + return getattr(self, snake) + raise AttributeError( + f'{type(self).__name__} object has no attribute {name!r}' + ) diff --git a/src/a2a/server/tasks/task_manager.py b/src/a2a/server/tasks/task_manager.py index ca42b69b..db5f8a37 100644 --- a/src/a2a/server/tasks/task_manager.py +++ b/src/a2a/server/tasks/task_manager.py @@ -96,7 +96,7 @@ async def save_task_event( when the TaskManager's ID is already set. """ task_id_from_event = ( - event.id if isinstance(event, Task) else event.taskId + event.id if isinstance(event, Task) else event.task_id ) # If task id is known, make sure it is matched if self.task_id and self.task_id != task_id_from_event: @@ -107,8 +107,8 @@ async def save_task_event( ) if not self.task_id: self.task_id = task_id_from_event - if not self.context_id and self.context_id != event.contextId: - self.context_id = event.contextId + if not self.context_id and self.context_id != event.context_id: + self.context_id = event.context_id logger.debug( 'Processing save of task event of type %s for task_id: %s', @@ -160,12 +160,12 @@ async def ensure_task( if not task: logger.info( 'Task not found or task_id not set. Creating new task for event (task_id: %s, context_id: %s).', - event.taskId, - event.contextId, + event.task_id, + event.context_id, ) # streaming agent did not previously stream task object. # Create a task object with the available information and persist the event - task = self._init_task_obj(event.taskId, event.contextId) + task = self._init_task_obj(event.task_id, event.context_id) await self._save_task(task) return task @@ -207,7 +207,7 @@ def _init_task_obj(self, task_id: str, context_id: str) -> Task: history = [self._initial_message] if self._initial_message else [] return Task( id=task_id, - contextId=context_id, + context_id=context_id, status=TaskStatus(state=TaskState.submitted), history=history, ) @@ -224,7 +224,7 @@ async def _save_task(self, task: Task) -> None: if not self.task_id: logger.info('New task created with id: %s', task.id) self.task_id = task.id - self.context_id = task.contextId + self.context_id = task.context_id def update_with_message(self, message: Message, task: Task) -> Task: """Updates a task object in memory by adding a new message to its history. diff --git a/src/a2a/types.py b/src/a2a/types.py index b3ad722e..b94ecafe 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -6,7 +6,9 @@ from enum import Enum from typing import Any, Literal -from pydantic import BaseModel, Field, RootModel +from pydantic import RootModel + +from a2a.pydantic_base import A2ABaseModel class A2A(RootModel[Any]): @@ -23,7 +25,7 @@ class In(str, Enum): query = 'query' -class APIKeySecurityScheme(BaseModel): +class APIKeySecurityScheme(A2ABaseModel): """ API Key security scheme. """ @@ -32,7 +34,7 @@ class APIKeySecurityScheme(BaseModel): """ Description of this security scheme. """ - in_: In = Field(..., alias='in') + in_: In """ The location of the API key. Valid values are "query", "header", or "cookie". """ @@ -43,7 +45,7 @@ class APIKeySecurityScheme(BaseModel): type: Literal['apiKey'] = 'apiKey' -class AgentExtension(BaseModel): +class AgentExtension(A2ABaseModel): """ A declaration of an extension supported by an Agent. """ @@ -66,7 +68,7 @@ class AgentExtension(BaseModel): """ -class AgentInterface(BaseModel): +class AgentInterface(A2ABaseModel): """ AgentInterface provides a declaration of a combination of the target url and the supported transport to interact with the agent. @@ -81,7 +83,7 @@ class AgentInterface(BaseModel): url: str -class AgentProvider(BaseModel): +class AgentProvider(A2ABaseModel): """ Represents the service provider of an agent. """ @@ -96,7 +98,7 @@ class AgentProvider(BaseModel): """ -class AgentSkill(BaseModel): +class AgentSkill(A2ABaseModel): """ Represents a unit of capability that an agent can perform. """ @@ -115,7 +117,7 @@ class AgentSkill(BaseModel): """ Unique identifier for the agent's skill. """ - inputModes: list[str] | None = None + input_modes: list[str] | None = None """ The set of interaction modes that the skill supports (if different than the default). @@ -125,7 +127,7 @@ class AgentSkill(BaseModel): """ Human readable name of the skill. """ - outputModes: list[str] | None = None + output_modes: list[str] | None = None """ Supported media types for output. """ @@ -135,17 +137,17 @@ class AgentSkill(BaseModel): """ -class AuthorizationCodeOAuthFlow(BaseModel): +class AuthorizationCodeOAuthFlow(A2ABaseModel): """ Configuration details for a supported OAuth Flow """ - authorizationUrl: str + authorization_url: str """ The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS """ - refreshUrl: str | None = None + refresh_url: str | None = None """ The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. @@ -155,19 +157,19 @@ class AuthorizationCodeOAuthFlow(BaseModel): The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty. """ - tokenUrl: str + token_url: str """ The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. """ -class ClientCredentialsOAuthFlow(BaseModel): +class ClientCredentialsOAuthFlow(A2ABaseModel): """ Configuration details for a supported OAuth Flow """ - refreshUrl: str | None = None + refresh_url: str | None = None """ The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. @@ -177,14 +179,14 @@ class ClientCredentialsOAuthFlow(BaseModel): The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty. """ - tokenUrl: str + token_url: str """ The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. """ -class ContentTypeNotSupportedError(BaseModel): +class ContentTypeNotSupportedError(A2ABaseModel): """ A2A specific error indicating incompatible content types between request and agent capabilities. """ @@ -204,7 +206,7 @@ class ContentTypeNotSupportedError(BaseModel): """ -class DataPart(BaseModel): +class DataPart(A2ABaseModel): """ Represents a structured data segment within a message part. """ @@ -223,7 +225,7 @@ class DataPart(BaseModel): """ -class DeleteTaskPushNotificationConfigParams(BaseModel): +class DeleteTaskPushNotificationConfigParams(A2ABaseModel): """ Parameters for removing pushNotificationConfiguration associated with a Task """ @@ -233,10 +235,10 @@ class DeleteTaskPushNotificationConfigParams(BaseModel): Task id. """ metadata: dict[str, Any] | None = None - pushNotificationConfigId: str + push_notification_config_id: str -class DeleteTaskPushNotificationConfigRequest(BaseModel): +class DeleteTaskPushNotificationConfigRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/pushNotificationConfig/delete' method. """ @@ -262,7 +264,7 @@ class DeleteTaskPushNotificationConfigRequest(BaseModel): """ -class DeleteTaskPushNotificationConfigSuccessResponse(BaseModel): +class DeleteTaskPushNotificationConfigSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'tasks/pushNotificationConfig/delete' method. """ @@ -282,12 +284,12 @@ class DeleteTaskPushNotificationConfigSuccessResponse(BaseModel): """ -class FileBase(BaseModel): +class FileBase(A2ABaseModel): """ Represents the base entity for FileParts """ - mimeType: str | None = None + mime_type: str | None = None """ Optional mimeType for the file """ @@ -297,7 +299,7 @@ class FileBase(BaseModel): """ -class FileWithBytes(BaseModel): +class FileWithBytes(A2ABaseModel): """ Define the variant where 'bytes' is present and 'uri' is absent """ @@ -306,7 +308,7 @@ class FileWithBytes(BaseModel): """ base64 encoded content of the file """ - mimeType: str | None = None + mime_type: str | None = None """ Optional mimeType for the file """ @@ -316,12 +318,12 @@ class FileWithBytes(BaseModel): """ -class FileWithUri(BaseModel): +class FileWithUri(A2ABaseModel): """ Define the variant where 'uri' is present and 'bytes' is absent """ - mimeType: str | None = None + mime_type: str | None = None """ Optional mimeType for the file """ @@ -335,7 +337,7 @@ class FileWithUri(BaseModel): """ -class GetTaskPushNotificationConfigParams(BaseModel): +class GetTaskPushNotificationConfigParams(A2ABaseModel): """ Parameters for fetching a pushNotificationConfiguration associated with a Task """ @@ -345,15 +347,15 @@ class GetTaskPushNotificationConfigParams(BaseModel): Task id. """ metadata: dict[str, Any] | None = None - pushNotificationConfigId: str | None = None + push_notification_config_id: str | None = None -class HTTPAuthSecurityScheme(BaseModel): +class HTTPAuthSecurityScheme(A2ABaseModel): """ HTTP Authentication security scheme. """ - bearerFormat: str | None = None + bearer_format: str | None = None """ A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation @@ -372,17 +374,17 @@ class HTTPAuthSecurityScheme(BaseModel): type: Literal['http'] = 'http' -class ImplicitOAuthFlow(BaseModel): +class ImplicitOAuthFlow(A2ABaseModel): """ Configuration details for a supported OAuth Flow """ - authorizationUrl: str + authorization_url: str """ The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS """ - refreshUrl: str | None = None + refresh_url: str | None = None """ The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. @@ -394,7 +396,7 @@ class ImplicitOAuthFlow(BaseModel): """ -class InternalError(BaseModel): +class InternalError(A2ABaseModel): """ JSON-RPC error indicating an internal JSON-RPC error on the server. """ @@ -414,7 +416,7 @@ class InternalError(BaseModel): """ -class InvalidAgentResponseError(BaseModel): +class InvalidAgentResponseError(A2ABaseModel): """ A2A specific error indicating agent returned invalid response for the current method """ @@ -434,7 +436,7 @@ class InvalidAgentResponseError(BaseModel): """ -class InvalidParamsError(BaseModel): +class InvalidParamsError(A2ABaseModel): """ JSON-RPC error indicating invalid method parameter(s). """ @@ -454,7 +456,7 @@ class InvalidParamsError(BaseModel): """ -class InvalidRequestError(BaseModel): +class InvalidRequestError(A2ABaseModel): """ JSON-RPC error indicating the JSON sent is not a valid Request object. """ @@ -474,7 +476,7 @@ class InvalidRequestError(BaseModel): """ -class JSONParseError(BaseModel): +class JSONParseError(A2ABaseModel): """ JSON-RPC error indicating invalid JSON was received by the server. """ @@ -494,7 +496,7 @@ class JSONParseError(BaseModel): """ -class JSONRPCError(BaseModel): +class JSONRPCError(A2ABaseModel): """ Represents a JSON-RPC 2.0 Error object. This is typically included in a JSONRPCErrorResponse when an error occurs. @@ -515,7 +517,7 @@ class JSONRPCError(BaseModel): """ -class JSONRPCMessage(BaseModel): +class JSONRPCMessage(A2ABaseModel): """ Base interface for any JSON-RPC 2.0 request or response. """ @@ -531,7 +533,7 @@ class JSONRPCMessage(BaseModel): """ -class JSONRPCRequest(BaseModel): +class JSONRPCRequest(A2ABaseModel): """ Represents a JSON-RPC 2.0 Request object. """ @@ -555,7 +557,7 @@ class JSONRPCRequest(BaseModel): """ -class JSONRPCSuccessResponse(BaseModel): +class JSONRPCSuccessResponse(A2ABaseModel): """ Represents a JSON-RPC 2.0 Success Response object. """ @@ -575,7 +577,7 @@ class JSONRPCSuccessResponse(BaseModel): """ -class ListTaskPushNotificationConfigParams(BaseModel): +class ListTaskPushNotificationConfigParams(A2ABaseModel): """ Parameters for getting list of pushNotificationConfigurations associated with a Task """ @@ -587,7 +589,7 @@ class ListTaskPushNotificationConfigParams(BaseModel): metadata: dict[str, Any] | None = None -class ListTaskPushNotificationConfigRequest(BaseModel): +class ListTaskPushNotificationConfigRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/pushNotificationConfig/list' method. """ @@ -622,7 +624,7 @@ class Role(str, Enum): user = 'user' -class MethodNotFoundError(BaseModel): +class MethodNotFoundError(A2ABaseModel): """ JSON-RPC error indicating the method does not exist or is not available. """ @@ -642,7 +644,7 @@ class MethodNotFoundError(BaseModel): """ -class OpenIdConnectSecurityScheme(BaseModel): +class OpenIdConnectSecurityScheme(A2ABaseModel): """ OpenID Connect security scheme configuration. """ @@ -651,14 +653,14 @@ class OpenIdConnectSecurityScheme(BaseModel): """ Description of this security scheme. """ - openIdConnectUrl: str + open_id_connect_url: str """ Well-known URL to discover the [[OpenID-Connect-Discovery]] provider metadata. """ type: Literal['openIdConnect'] = 'openIdConnect' -class PartBase(BaseModel): +class PartBase(A2ABaseModel): """ Base properties common to all message parts. """ @@ -669,12 +671,12 @@ class PartBase(BaseModel): """ -class PasswordOAuthFlow(BaseModel): +class PasswordOAuthFlow(A2ABaseModel): """ Configuration details for a supported OAuth Flow """ - refreshUrl: str | None = None + refresh_url: str | None = None """ The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. @@ -684,14 +686,14 @@ class PasswordOAuthFlow(BaseModel): The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty. """ - tokenUrl: str + token_url: str """ The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS. """ -class PushNotificationAuthenticationInfo(BaseModel): +class PushNotificationAuthenticationInfo(A2ABaseModel): """ Defines authentication details for push notifications. """ @@ -706,7 +708,7 @@ class PushNotificationAuthenticationInfo(BaseModel): """ -class PushNotificationConfig(BaseModel): +class PushNotificationConfig(A2ABaseModel): """ Configuration for setting up push notifications for task updates. """ @@ -726,7 +728,7 @@ class PushNotificationConfig(BaseModel): """ -class PushNotificationNotSupportedError(BaseModel): +class PushNotificationNotSupportedError(A2ABaseModel): """ A2A specific error indicating the agent does not support push notifications. """ @@ -746,7 +748,7 @@ class PushNotificationNotSupportedError(BaseModel): """ -class SecuritySchemeBase(BaseModel): +class SecuritySchemeBase(A2ABaseModel): """ Base properties shared by all security schemes. """ @@ -757,7 +759,7 @@ class SecuritySchemeBase(BaseModel): """ -class TaskIdParams(BaseModel): +class TaskIdParams(A2ABaseModel): """ Parameters containing only a task ID, used for simple task operations. """ @@ -769,7 +771,7 @@ class TaskIdParams(BaseModel): metadata: dict[str, Any] | None = None -class TaskNotCancelableError(BaseModel): +class TaskNotCancelableError(A2ABaseModel): """ A2A specific error indicating the task is in a state where it cannot be canceled. """ @@ -789,7 +791,7 @@ class TaskNotCancelableError(BaseModel): """ -class TaskNotFoundError(BaseModel): +class TaskNotFoundError(A2ABaseModel): """ A2A specific error indicating the requested task ID was not found. """ @@ -809,27 +811,27 @@ class TaskNotFoundError(BaseModel): """ -class TaskPushNotificationConfig(BaseModel): +class TaskPushNotificationConfig(A2ABaseModel): """ Parameters for setting or getting push notification configuration for a task """ - pushNotificationConfig: PushNotificationConfig + push_notification_config: PushNotificationConfig """ Push notification configuration. """ - taskId: str + task_id: str """ Task id. """ -class TaskQueryParams(BaseModel): +class TaskQueryParams(A2ABaseModel): """ Parameters for querying a task, including optional history length. """ - historyLength: int | None = None + history_length: int | None = None """ Number of recent messages to be retrieved. """ @@ -840,7 +842,7 @@ class TaskQueryParams(BaseModel): metadata: dict[str, Any] | None = None -class TaskResubscriptionRequest(BaseModel): +class TaskResubscriptionRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/resubscribe' method. """ @@ -880,7 +882,7 @@ class TaskState(str, Enum): unknown = 'unknown' -class TextPart(BaseModel): +class TextPart(A2ABaseModel): """ Represents a text segment within parts. """ @@ -899,7 +901,7 @@ class TextPart(BaseModel): """ -class UnsupportedOperationError(BaseModel): +class UnsupportedOperationError(A2ABaseModel): """ A2A specific error indicating the requested operation is not supported by the agent. """ @@ -949,7 +951,7 @@ class A2AError( ) -class AgentCapabilities(BaseModel): +class AgentCapabilities(A2ABaseModel): """ Defines optional capabilities supported by an agent. """ @@ -958,11 +960,11 @@ class AgentCapabilities(BaseModel): """ extensions supported by this agent. """ - pushNotifications: bool | None = None + push_notifications: bool | None = None """ true if the agent can notify updates to client. """ - stateTransitionHistory: bool | None = None + state_transition_history: bool | None = None """ true if the agent exposes status change history for tasks. """ @@ -972,7 +974,7 @@ class AgentCapabilities(BaseModel): """ -class CancelTaskRequest(BaseModel): +class CancelTaskRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/cancel' method. """ @@ -996,7 +998,7 @@ class CancelTaskRequest(BaseModel): """ -class FilePart(BaseModel): +class FilePart(A2ABaseModel): """ Represents a File segment within parts. """ @@ -1015,7 +1017,7 @@ class FilePart(BaseModel): """ -class GetTaskPushNotificationConfigRequest(BaseModel): +class GetTaskPushNotificationConfigRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/pushNotificationConfig/get' method. """ @@ -1042,7 +1044,7 @@ class GetTaskPushNotificationConfigRequest(BaseModel): """ -class GetTaskPushNotificationConfigSuccessResponse(BaseModel): +class GetTaskPushNotificationConfigSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'tasks/pushNotificationConfig/get' method. """ @@ -1062,7 +1064,7 @@ class GetTaskPushNotificationConfigSuccessResponse(BaseModel): """ -class GetTaskRequest(BaseModel): +class GetTaskRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/get' method. """ @@ -1086,7 +1088,7 @@ class GetTaskRequest(BaseModel): """ -class JSONRPCErrorResponse(BaseModel): +class JSONRPCErrorResponse(A2ABaseModel): """ Represents a JSON-RPC 2.0 Error Response object. """ @@ -1116,7 +1118,7 @@ class JSONRPCErrorResponse(BaseModel): """ -class ListTaskPushNotificationConfigSuccessResponse(BaseModel): +class ListTaskPushNotificationConfigSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'tasks/pushNotificationConfig/list' method. """ @@ -1136,12 +1138,12 @@ class ListTaskPushNotificationConfigSuccessResponse(BaseModel): """ -class MessageSendConfiguration(BaseModel): +class MessageSendConfiguration(A2ABaseModel): """ Configuration for the send message request. """ - acceptedOutputModes: list[str] + accepted_output_modes: list[str] """ Accepted output modalities by the client. """ @@ -1149,26 +1151,26 @@ class MessageSendConfiguration(BaseModel): """ If the server should treat the client as a blocking request. """ - historyLength: int | None = None + history_length: int | None = None """ Number of recent messages to be retrieved. """ - pushNotificationConfig: PushNotificationConfig | None = None + push_notification_config: PushNotificationConfig | None = None """ Where the server should send notifications when disconnected. """ -class OAuthFlows(BaseModel): +class OAuthFlows(A2ABaseModel): """ Allows configuration of the supported OAuth Flows """ - authorizationCode: AuthorizationCodeOAuthFlow | None = None + authorization_code: AuthorizationCodeOAuthFlow | None = None """ Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0. """ - clientCredentials: ClientCredentialsOAuthFlow | None = None + client_credentials: ClientCredentialsOAuthFlow | None = None """ Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0 """ @@ -1189,7 +1191,7 @@ class Part(RootModel[TextPart | FilePart | DataPart]): """ -class SetTaskPushNotificationConfigRequest(BaseModel): +class SetTaskPushNotificationConfigRequest(A2ABaseModel): """ JSON-RPC request model for the 'tasks/pushNotificationConfig/set' method. """ @@ -1215,7 +1217,7 @@ class SetTaskPushNotificationConfigRequest(BaseModel): """ -class SetTaskPushNotificationConfigSuccessResponse(BaseModel): +class SetTaskPushNotificationConfigSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'tasks/pushNotificationConfig/set' method. """ @@ -1235,12 +1237,12 @@ class SetTaskPushNotificationConfigSuccessResponse(BaseModel): """ -class Artifact(BaseModel): +class Artifact(A2ABaseModel): """ Represents an artifact generated for a task. """ - artifactId: str + artifact_id: str """ Unique identifier for the artifact. """ @@ -1293,12 +1295,12 @@ class ListTaskPushNotificationConfigResponse( """ -class Message(BaseModel): +class Message(A2ABaseModel): """ Represents a single message exchanged between user and agent. """ - contextId: str | None = None + context_id: str | None = None """ The context the message is associated with """ @@ -1310,7 +1312,7 @@ class Message(BaseModel): """ Event type """ - messageId: str + message_id: str """ Identifier created by the message creator """ @@ -1322,7 +1324,7 @@ class Message(BaseModel): """ Message content """ - referenceTaskIds: list[str] | None = None + reference_task_ids: list[str] | None = None """ List of tasks referenced as context by this message. """ @@ -1330,13 +1332,13 @@ class Message(BaseModel): """ Message sender's role """ - taskId: str | None = None + task_id: str | None = None """ Identifier of task the message is related to """ -class MessageSendParams(BaseModel): +class MessageSendParams(A2ABaseModel): """ Sent by the client to the agent as a request. May create, continue or restart a task. """ @@ -1355,7 +1357,7 @@ class MessageSendParams(BaseModel): """ -class OAuth2SecurityScheme(BaseModel): +class OAuth2SecurityScheme(A2ABaseModel): """ OAuth2.0 security scheme configuration. """ @@ -1391,7 +1393,7 @@ class SecurityScheme( """ -class SendMessageRequest(BaseModel): +class SendMessageRequest(A2ABaseModel): """ JSON-RPC request model for the 'message/send' method. """ @@ -1415,7 +1417,7 @@ class SendMessageRequest(BaseModel): """ -class SendStreamingMessageRequest(BaseModel): +class SendStreamingMessageRequest(A2ABaseModel): """ JSON-RPC request model for the 'message/stream' method. """ @@ -1448,7 +1450,7 @@ class SetTaskPushNotificationConfigResponse( """ -class TaskArtifactUpdateEvent(BaseModel): +class TaskArtifactUpdateEvent(A2ABaseModel): """ Sent by server during sendStream or subscribe requests """ @@ -1461,7 +1463,7 @@ class TaskArtifactUpdateEvent(BaseModel): """ Generated artifact """ - contextId: str + context_id: str """ The context the task is associated with """ @@ -1469,7 +1471,7 @@ class TaskArtifactUpdateEvent(BaseModel): """ Event type """ - lastChunk: bool | None = None + last_chunk: bool | None = None """ Indicates if this is the last chunk of the artifact """ @@ -1477,13 +1479,13 @@ class TaskArtifactUpdateEvent(BaseModel): """ Extension metadata. """ - taskId: str + task_id: str """ Task id """ -class TaskStatus(BaseModel): +class TaskStatus(A2ABaseModel): """ TaskState and accompanying message. """ @@ -1499,12 +1501,12 @@ class TaskStatus(BaseModel): """ -class TaskStatusUpdateEvent(BaseModel): +class TaskStatusUpdateEvent(A2ABaseModel): """ Sent by server during sendStream or subscribe requests """ - contextId: str + context_id: str """ The context the task is associated with """ @@ -1524,7 +1526,7 @@ class TaskStatusUpdateEvent(BaseModel): """ Current status of the task """ - taskId: str + task_id: str """ Task id """ @@ -1559,7 +1561,7 @@ class A2ARequest( """ -class AgentCard(BaseModel): +class AgentCard(A2ABaseModel): """ An AgentCard conveys key information: - Overall details (version, name, description, uses) @@ -1568,7 +1570,7 @@ class AgentCard(BaseModel): - Authentication requirements """ - additionalInterfaces: list[AgentInterface] | None = None + additional_interfaces: list[AgentInterface] | None = None """ Announcement of additional supported transports. Client can use any of the supported transports. @@ -1577,12 +1579,12 @@ class AgentCard(BaseModel): """ Optional capabilities supported by the agent. """ - defaultInputModes: list[str] + default_input_modes: list[str] """ The set of interaction modes that the agent supports across all skills. This can be overridden per-skill. Supported media types for input. """ - defaultOutputModes: list[str] + default_output_modes: list[str] """ Supported media types for output. """ @@ -1591,11 +1593,11 @@ class AgentCard(BaseModel): A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do. """ - documentationUrl: str | None = None + documentation_url: str | None = None """ A URL to documentation for the agent. """ - iconUrl: str | None = None + icon_url: str | None = None """ A URL to an icon for the agent. """ @@ -1603,7 +1605,7 @@ class AgentCard(BaseModel): """ Human readable name of the agent. """ - preferredTransport: str | None = None + preferred_transport: str | None = None """ The transport of the preferred endpoint. If empty, defaults to JSONRPC. """ @@ -1615,7 +1617,7 @@ class AgentCard(BaseModel): """ Security requirements for contacting the agent. """ - securitySchemes: dict[str, SecurityScheme] | None = None + security_schemes: dict[str, SecurityScheme] | None = None """ Security scheme details used for authenticating with this agent. """ @@ -1623,7 +1625,7 @@ class AgentCard(BaseModel): """ Skills are a unit of capability that an agent can perform. """ - supportsAuthenticatedExtendedCard: bool | None = None + supports_authenticated_extended_card: bool | None = None """ true if the agent supports providing an extended agent card when the user is authenticated. Defaults to false if not specified. @@ -1639,12 +1641,12 @@ class AgentCard(BaseModel): """ -class Task(BaseModel): +class Task(A2ABaseModel): artifacts: list[Artifact] | None = None """ Collection of artifacts created by the agent. """ - contextId: str + context_id: str """ Server-generated id for contextual alignment across interactions """ @@ -1667,7 +1669,7 @@ class Task(BaseModel): """ -class CancelTaskSuccessResponse(BaseModel): +class CancelTaskSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'tasks/cancel' method. """ @@ -1687,7 +1689,7 @@ class CancelTaskSuccessResponse(BaseModel): """ -class GetTaskSuccessResponse(BaseModel): +class GetTaskSuccessResponse(A2ABaseModel): """ JSON-RPC success response for the 'tasks/get' method. """ @@ -1707,7 +1709,7 @@ class GetTaskSuccessResponse(BaseModel): """ -class SendMessageSuccessResponse(BaseModel): +class SendMessageSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'message/send' method. """ @@ -1727,7 +1729,7 @@ class SendMessageSuccessResponse(BaseModel): """ -class SendStreamingMessageSuccessResponse(BaseModel): +class SendStreamingMessageSuccessResponse(A2ABaseModel): """ JSON-RPC success response model for the 'message/stream' method. """ diff --git a/src/a2a/utils/artifact.py b/src/a2a/utils/artifact.py index ee91a891..d4a316f6 100644 --- a/src/a2a/utils/artifact.py +++ b/src/a2a/utils/artifact.py @@ -21,7 +21,7 @@ def new_artifact( A new `Artifact` object with a generated artifactId. """ return Artifact( - artifactId=str(uuid.uuid4()), + artifact_id=str(uuid.uuid4()), parts=parts, name=name, description=description, diff --git a/src/a2a/utils/helpers.py b/src/a2a/utils/helpers.py index a1cc43ec..091268ba 100644 --- a/src/a2a/utils/helpers.py +++ b/src/a2a/utils/helpers.py @@ -36,12 +36,12 @@ def create_task_obj(message_send_params: MessageSendParams) -> Task: Returns: A new `Task` object initialized with 'submitted' status and the input message in history. """ - if not message_send_params.message.contextId: - message_send_params.message.contextId = str(uuid4()) + if not message_send_params.message.context_id: + message_send_params.message.context_id = str(uuid4()) return Task( id=str(uuid4()), - contextId=message_send_params.message.contextId, + context_id=message_send_params.message.context_id, status=TaskStatus(state=TaskState.submitted), history=[message_send_params.message], ) @@ -62,7 +62,7 @@ def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: task.artifacts = [] new_artifact_data: Artifact = event.artifact - artifact_id: str = new_artifact_data.artifactId + artifact_id: str = new_artifact_data.artifact_id append_parts: bool = event.append or False existing_artifact: Artifact | None = None @@ -70,7 +70,7 @@ def append_artifact_to_task(task: Task, event: TaskArtifactUpdateEvent) -> None: # Find existing artifact by its id for i, art in enumerate(task.artifacts): - if art.artifactId == artifact_id: + if art.artifact_id == artifact_id: existing_artifact = art existing_artifact_list_index = i break @@ -115,7 +115,7 @@ def build_text_artifact(text: str, artifact_id: str) -> Artifact: """ text_part = TextPart(text=text) part = Part(root=text_part) - return Artifact(parts=[part], artifactId=artifact_id) + return Artifact(parts=[part], artifact_id=artifact_id) def validate( diff --git a/src/a2a/utils/message.py b/src/a2a/utils/message.py index 0e08597a..831d6675 100644 --- a/src/a2a/utils/message.py +++ b/src/a2a/utils/message.py @@ -28,9 +28,9 @@ def new_agent_text_message( return Message( role=Role.agent, parts=[Part(root=TextPart(text=text))], - messageId=str(uuid.uuid4()), - taskId=task_id, - contextId=context_id, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, ) @@ -52,9 +52,9 @@ def new_agent_parts_message( return Message( role=Role.agent, parts=parts, - messageId=str(uuid.uuid4()), - taskId=task_id, - contextId=context_id, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, ) diff --git a/src/a2a/utils/task.py b/src/a2a/utils/task.py index 9cf4df43..386c9769 100644 --- a/src/a2a/utils/task.py +++ b/src/a2a/utils/task.py @@ -18,9 +18,9 @@ def new_task(request: Message) -> Task: """ return Task( status=TaskStatus(state=TaskState.submitted), - id=(request.taskId if request.taskId else str(uuid.uuid4())), - contextId=( - request.contextId if request.contextId else str(uuid.uuid4()) + id=(request.task_id if request.task_id else str(uuid.uuid4())), + context_id=( + request.context_id if request.context_id else str(uuid.uuid4()) ), history=[request], ) @@ -51,7 +51,7 @@ def completed_task( return Task( status=TaskStatus(state=TaskState.completed), id=task_id, - contextId=context_id, + context_id=context_id, artifacts=artifacts, history=history, ) diff --git a/tests/client/test_client.py b/tests/client/test_client.py index 5b6e9491..4df0812c 100644 --- a/tests/client/test_client.py +++ b/tests/client/test_client.py @@ -54,8 +54,8 @@ description='Just a hello world agent', url='http://localhost:9999/', version='1.0.0', - defaultInputModes=['text'], - defaultOutputModes=['text'], + default_input_modes=['text'], + default_output_modes=['text'], capabilities=AgentCapabilities(), skills=[ AgentSkill( @@ -86,7 +86,7 @@ ) AGENT_CARD_SUPPORTS_EXTENDED = AGENT_CARD.model_copy( - update={'supportsAuthenticatedExtendedCard': True} + update={'supports_authenticated_extended_card': True} ) AGENT_CARD_NO_URL_SUPPORTS_EXTENDED = AGENT_CARD_SUPPORTS_EXTENDED.model_copy( update={'url': ''} @@ -174,7 +174,9 @@ async def test_get_agent_card_success_public_only( ): mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 - mock_response.json.return_value = AGENT_CARD.model_dump(mode='json') + mock_response.json.return_value = AGENT_CARD.model_dump( + mode='json', by_alias=True + ) mock_httpx_client.get.return_value = mock_response resolver = A2ACardResolver( @@ -200,7 +202,7 @@ async def test_get_agent_card_success_with_specified_path_for_extended_card( extended_card_response = AsyncMock(spec=httpx.Response) extended_card_response.status_code = 200 extended_card_response.json.return_value = ( - AGENT_CARD_EXTENDED.model_dump(mode='json') + AGENT_CARD_EXTENDED.model_dump(mode='json', by_alias=True) ) # Mock the single call for the extended card @@ -452,7 +454,7 @@ async def test_send_message_success_use_request( success_response = create_text_message_object( role=Role.agent, content='Hi there!' - ).model_dump(exclude_none=True) + ).model_dump(by_alias=True, exclude_none=True) rpc_response: dict[str, Any] = { 'id': 123, @@ -482,13 +484,15 @@ async def test_send_message_success_use_request( assert isinstance(a2a_request_arg.root.params, MessageSendParams) assert a2a_request_arg.root.params.model_dump( - exclude_none=True - ) == params.model_dump(exclude_none=True) + by_alias=True, exclude_none=True + ) == params.model_dump(by_alias=True, exclude_none=True) assert isinstance(response, SendMessageResponse) assert isinstance(response.root, SendMessageSuccessResponse) assert ( - response.root.result.model_dump(exclude_none=True) + response.root.result.model_dump( + by_alias=True, exclude_none=True + ) == success_response ) @@ -511,7 +515,9 @@ async def test_send_message_error_response( rpc_response: dict[str, Any] = { 'id': 123, 'jsonrpc': '2.0', - 'error': error_response.model_dump(exclude_none=True), + 'error': error_response.model_dump( + by_alias=True, exclude_none=True + ), } with patch.object( @@ -523,8 +529,10 @@ async def test_send_message_error_response( assert isinstance(response, SendMessageResponse) assert isinstance(response.root, JSONRPCErrorResponse) assert response.root.error.model_dump( - exclude_none=True - ) == InvalidParamsError().model_dump(exclude_none=True) + by_alias=True, exclude_none=True + ) == InvalidParamsError().model_dump( + by_alias=True, exclude_none=True + ) @pytest.mark.asyncio @patch('a2a.client.client.aconnect_sse') @@ -548,14 +556,14 @@ async def test_send_message_streaming_success_request( 'jsonrpc': '2.0', 'result': create_text_message_object( content='First part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), + ).model_dump(mode='json', by_alias=True, exclude_none=True), } mock_stream_response_2_dict: dict[str, Any] = { 'id': 'stream_id_123', 'jsonrpc': '2.0', 'result': create_text_message_object( content='second part ', role=Role.agent - ).model_dump(mode='json', exclude_none=True), + ).model_dump(mode='json', by_alias=True, exclude_none=True), } sse_event_1 = ServerSentEvent( @@ -586,7 +594,7 @@ async def test_send_message_streaming_success_request( assert results[0].root.id == 'stream_id_123' assert ( results[0].root.result.model_dump( # type: ignore - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) == mock_stream_response_1_dict['result'] ) @@ -595,7 +603,7 @@ async def test_send_message_streaming_success_request( assert results[1].root.id == 'stream_id_123' assert ( results[1].root.result.model_dump( # type: ignore - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) == mock_stream_response_2_dict['result'] ) @@ -609,7 +617,7 @@ async def test_send_message_streaming_success_request( sent_json_payload = call_kwargs['json'] assert sent_json_payload['method'] == 'message/stream' assert sent_json_payload['params'] == params.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) assert ( call_kwargs['timeout'] is None @@ -837,7 +845,7 @@ async def test_set_task_callback_success( ) # Correctly create the TaskPushNotificationConfig (outer model) params_model = TaskPushNotificationConfig( - taskId=task_id_val, pushNotificationConfig=push_config_payload + task_id=task_id_val, push_notification_config=push_config_payload ) # request.id will be generated by the client method if not provided @@ -849,7 +857,9 @@ async def test_set_task_callback_success( rpc_response_payload: dict[str, Any] = { 'id': ANY, # Will be checked against generated ID 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json', exclude_none=True), + 'result': params_model.model_dump( + mode='json', by_alias=True, exclude_none=True + ), } with ( @@ -880,7 +890,7 @@ async def test_set_task_callback_success( == 'tasks/pushNotificationConfig/set' ) assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) assert isinstance(response, SetTaskPushNotificationConfigResponse) @@ -889,8 +899,10 @@ async def test_set_task_callback_success( ) assert response.root.id == generated_id assert response.root.result.model_dump( - mode='json', exclude_none=True - ) == params_model.model_dump(mode='json', exclude_none=True) + mode='json', by_alias=True, exclude_none=True + ) == params_model.model_dump( + mode='json', by_alias=True, exclude_none=True + ) @pytest.mark.asyncio async def test_set_task_callback_error_response( @@ -902,7 +914,7 @@ async def test_set_task_callback_error_response( req_id = 'set_cb_err_req' push_config_payload = PushNotificationConfig(url='https://errors.com') params_model = TaskPushNotificationConfig( - taskId='task_err_cb', pushNotificationConfig=push_config_payload + task_id='task_err_cb', push_notification_config=push_config_payload ) request = SetTaskPushNotificationConfigRequest( id=req_id, params=params_model @@ -912,7 +924,9 @@ async def test_set_task_callback_error_response( rpc_response_payload: dict[str, Any] = { 'id': req_id, 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'error': error_details.model_dump( + mode='json', by_alias=True, exclude_none=True + ), } with patch.object( @@ -924,8 +938,8 @@ async def test_set_task_callback_error_response( assert isinstance(response, SetTaskPushNotificationConfigResponse) assert isinstance(response.root, JSONRPCErrorResponse) assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) + mode='json', by_alias=True, exclude_none=True + ) == error_details.model_dump(by_alias=True, exclude_none=True) assert response.root.id == req_id @pytest.mark.asyncio @@ -937,7 +951,8 @@ async def test_set_task_callback_http_kwargs_passed( ) push_config_payload = PushNotificationConfig(url='https://kwargs.com') params_model = TaskPushNotificationConfig( - taskId='task_cb_kwargs', pushNotificationConfig=push_config_payload + task_id='task_cb_kwargs', + push_notification_config=push_config_payload, ) request = SetTaskPushNotificationConfigRequest( id='cb_kwargs_req', params=params_model @@ -948,7 +963,7 @@ async def test_set_task_callback_http_kwargs_passed( rpc_response_payload: dict[str, Any] = { 'id': 'cb_kwargs_req', 'jsonrpc': '2.0', - 'result': params_model.model_dump(mode='json'), + 'result': params_model.model_dump(mode='json', by_alias=True), } with patch.object( @@ -986,13 +1001,13 @@ async def test_get_task_callback_success( url='https://callback.example.com/taskupdate' ) expected_callback_config = TaskPushNotificationConfig( - taskId=task_id_val, pushNotificationConfig=push_config_payload + task_id=task_id_val, push_notification_config=push_config_payload ) rpc_response_payload: dict[str, Any] = { 'id': ANY, 'jsonrpc': '2.0', 'result': expected_callback_config.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ), } @@ -1021,7 +1036,7 @@ async def test_get_task_callback_success( == 'tasks/pushNotificationConfig/get' ) assert sent_json_payload['params'] == params_model.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) assert isinstance(response, GetTaskPushNotificationConfigResponse) @@ -1030,9 +1045,9 @@ async def test_get_task_callback_success( ) assert response.root.id == generated_id assert response.root.result.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) == expected_callback_config.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) @pytest.mark.asyncio @@ -1054,7 +1069,9 @@ async def test_get_task_callback_error_response( rpc_response_payload: dict[str, Any] = { 'id': req_id, 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'error': error_details.model_dump( + mode='json', by_alias=True, exclude_none=True + ), } with patch.object( @@ -1066,8 +1083,8 @@ async def test_get_task_callback_error_response( assert isinstance(response, GetTaskPushNotificationConfigResponse) assert isinstance(response.root, JSONRPCErrorResponse) assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(mode='json', exclude_none=True) + mode='json', by_alias=True, exclude_none=True + ) == error_details.model_dump(by_alias=True, exclude_none=True) assert response.root.id == req_id @pytest.mark.asyncio @@ -1088,13 +1105,15 @@ async def test_get_task_callback_http_kwargs_passed( url='https://getkwargs.com' ) expected_callback_config = TaskPushNotificationConfig( - taskId='task_get_cb_kwargs', - pushNotificationConfig=push_config_payload_for_expected, + task_id='task_get_cb_kwargs', + push_notification_config=push_config_payload_for_expected, ) rpc_response_payload: dict[str, Any] = { 'id': 'get_cb_kwargs_req', 'jsonrpc': '2.0', - 'result': expected_callback_config.model_dump(mode='json'), + 'result': expected_callback_config.model_dump( + mode='json', by_alias=True + ), } with patch.object( @@ -1148,13 +1167,15 @@ async def test_get_task_success_use_request( assert json_rpc_request_sent['method'] == 'tasks/get' assert json_rpc_request_sent['id'] == request_obj_id assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) assert isinstance(response, GetTaskResponse) assert hasattr(response.root, 'result') assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore + response.root.result.model_dump( + mode='json', by_alias=True, exclude_none=True + ) # type: ignore == MINIMAL_TASK ) assert response.root.id == request_obj_id @@ -1173,7 +1194,9 @@ async def test_get_task_error_response( rpc_response_payload: dict[str, Any] = { 'id': 'err_req_id', 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'error': error_details.model_dump( + mode='json', by_alias=True, exclude_none=True + ), } with patch.object( @@ -1185,8 +1208,8 @@ async def test_get_task_error_response( assert isinstance(response, GetTaskResponse) assert isinstance(response.root, JSONRPCErrorResponse) assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) + mode='json', by_alias=True, exclude_none=True + ) == error_details.model_dump(by_alias=True, exclude_none=True) assert response.root.id == 'err_req_id' @pytest.mark.asyncio @@ -1226,13 +1249,15 @@ async def test_cancel_task_success_use_request( assert json_rpc_request_sent['method'] == 'tasks/cancel' assert json_rpc_request_sent['id'] == request_obj_id assert json_rpc_request_sent['params'] == params_model.model_dump( - mode='json', exclude_none=True + mode='json', by_alias=True, exclude_none=True ) assert isinstance(response, CancelTaskResponse) assert isinstance(response.root, CancelTaskSuccessResponse) assert ( - response.root.result.model_dump(mode='json', exclude_none=True) # type: ignore + response.root.result.model_dump( + mode='json', by_alias=True, exclude_none=True + ) # type: ignore == MINIMAL_CANCELLED_TASK ) assert response.root.id == request_obj_id @@ -1251,7 +1276,9 @@ async def test_cancel_task_error_response( rpc_response_payload: dict[str, Any] = { 'id': 'err_cancel_req', 'jsonrpc': '2.0', - 'error': error_details.model_dump(mode='json', exclude_none=True), + 'error': error_details.model_dump( + mode='json', by_alias=True, exclude_none=True + ), } with patch.object( @@ -1263,6 +1290,6 @@ async def test_cancel_task_error_response( assert isinstance(response, CancelTaskResponse) assert isinstance(response.root, JSONRPCErrorResponse) assert response.root.error.model_dump( - mode='json', exclude_none=True - ) == error_details.model_dump(exclude_none=True) + mode='json', by_alias=True, exclude_none=True + ) == error_details.model_dump(by_alias=True, exclude_none=True) assert response.root.id == 'err_cancel_req' diff --git a/tests/server/agent_execution/test_context.py b/tests/server/agent_execution/test_context.py index 92d09707..1e0e776c 100644 --- a/tests/server/agent_execution/test_context.py +++ b/tests/server/agent_execution/test_context.py @@ -18,7 +18,7 @@ class TestRequestContext: @pytest.fixture def mock_message(self): """Fixture for a mock Message.""" - return Mock(spec=Message, taskId=None, contextId=None) + return Mock(spec=Message, task_id=None, context_id=None) @pytest.fixture def mock_params(self, mock_message): @@ -28,7 +28,7 @@ def mock_params(self, mock_message): @pytest.fixture def mock_task(self): """Fixture for a mock Task.""" - return Mock(spec=Task, id='task-123', contextId='context-456') + return Mock(spec=Task, id='task-123', context_id='context-456') def test_init_without_params(self): """Test initialization without parameters.""" @@ -53,11 +53,12 @@ def test_init_with_params_no_ids(self, mock_params): assert context.message == mock_params.message assert context.task_id == '00000000-0000-0000-0000-000000000001' assert ( - mock_params.message.taskId == '00000000-0000-0000-0000-000000000001' + mock_params.message.task_id + == '00000000-0000-0000-0000-000000000001' ) assert context.context_id == '00000000-0000-0000-0000-000000000002' assert ( - mock_params.message.contextId + mock_params.message.context_id == '00000000-0000-0000-0000-000000000002' ) @@ -67,7 +68,7 @@ def test_init_with_task_id(self, mock_params): context = RequestContext(request=mock_params, task_id=task_id) assert context.task_id == task_id - assert mock_params.message.taskId == task_id + assert mock_params.message.task_id == task_id def test_init_with_context_id(self, mock_params): """Test initialization with context ID provided.""" @@ -75,7 +76,7 @@ def test_init_with_context_id(self, mock_params): context = RequestContext(request=mock_params, context_id=context_id) assert context.context_id == context_id - assert mock_params.message.contextId == context_id + assert mock_params.message.context_id == context_id def test_init_with_both_ids(self, mock_params): """Test initialization with both task and context IDs provided.""" @@ -86,9 +87,9 @@ def test_init_with_both_ids(self, mock_params): ) assert context.task_id == task_id - assert mock_params.message.taskId == task_id + assert mock_params.message.task_id == task_id assert context.context_id == context_id - assert mock_params.message.contextId == context_id + assert mock_params.message.context_id == context_id def test_init_with_task(self, mock_params, mock_task): """Test initialization with a task object.""" @@ -138,13 +139,13 @@ def test_check_or_generate_task_id_no_params(self): def test_check_or_generate_task_id_with_existing_task_id(self, mock_params): """Test _check_or_generate_task_id with existing task ID.""" existing_id = 'existing-task-id' - mock_params.message.taskId = existing_id + mock_params.message.task_id = existing_id context = RequestContext(request=mock_params) # The method is called during initialization assert context.task_id == existing_id - assert mock_params.message.taskId == existing_id + assert mock_params.message.task_id == existing_id def test_check_or_generate_context_id_no_params(self): """Test _check_or_generate_context_id with no params does nothing.""" @@ -157,13 +158,13 @@ def test_check_or_generate_context_id_with_existing_context_id( ): """Test _check_or_generate_context_id with existing context ID.""" existing_id = 'existing-context-id' - mock_params.message.contextId = existing_id + mock_params.message.context_id = existing_id context = RequestContext(request=mock_params) # The method is called during initialization assert context.context_id == existing_id - assert mock_params.message.contextId == existing_id + assert mock_params.message.context_id == existing_id def test_with_related_tasks_provided(self, mock_task): """Test initialization with related tasks provided.""" @@ -185,8 +186,8 @@ def test_message_property_with_params(self, mock_params): def test_init_with_existing_ids_in_message(self, mock_message, mock_params): """Test initialization with existing IDs in the message.""" - mock_message.taskId = 'existing-task-id' - mock_message.contextId = 'existing-context-id' + mock_message.task_id = 'existing-task-id' + mock_message.context_id = 'existing-context-id' context = RequestContext(request=mock_params) @@ -198,7 +199,7 @@ def test_init_with_task_id_and_existing_task_id_match( self, mock_params, mock_task ): """Test initialization succeeds when task_id matches task.id.""" - mock_params.message.taskId = mock_task.id + mock_params.message.task_id = mock_task.id context = RequestContext( request=mock_params, task_id=mock_task.id, task=mock_task @@ -210,16 +211,16 @@ def test_init_with_task_id_and_existing_task_id_match( def test_init_with_context_id_and_existing_context_id_match( self, mock_params, mock_task ): - """Test initialization succeeds when context_id matches task.contextId.""" - mock_params.message.taskId = mock_task.id # Set matching task ID - mock_params.message.contextId = mock_task.contextId + """Test initialization succeeds when context_id matches task.context_id.""" + mock_params.message.task_id = mock_task.id # Set matching task ID + mock_params.message.context_id = mock_task.context_id context = RequestContext( request=mock_params, task_id=mock_task.id, - context_id=mock_task.contextId, + context_id=mock_task.context_id, task=mock_task, ) - assert context.context_id == mock_task.contextId + assert context.context_id == mock_task.context_id assert context.current_task == mock_task diff --git a/tests/test_types.py b/tests/test_types.py index 2c0843e7..777b3460 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -199,17 +199,17 @@ def test_security_scheme_invalid(): def test_agent_capabilities(): caps = AgentCapabilities( - streaming=None, stateTransitionHistory=None, pushNotifications=None + streaming=None, state_transition_history=None, push_notifications=None ) # All optional - assert caps.pushNotifications is None - assert caps.stateTransitionHistory is None + assert caps.push_notifications is None + assert caps.state_transition_history is None assert caps.streaming is None caps_full = AgentCapabilities( - pushNotifications=True, stateTransitionHistory=False, streaming=True + push_notifications=True, state_transition_history=False, streaming=True ) - assert caps_full.pushNotifications is True - assert caps_full.stateTransitionHistory is False + assert caps_full.push_notifications is True + assert caps_full.state_transition_history is False assert caps_full.streaming is True @@ -232,7 +232,7 @@ def test_agent_skill_valid(): skill_full = AgentSkill(**FULL_AGENT_SKILL) assert skill_full.examples == ['Find me a pasta recipe'] - assert skill_full.inputModes == ['text/plain'] + assert skill_full.input_modes == ['text/plain'] def test_agent_skill_invalid(): @@ -263,6 +263,60 @@ def test_agent_card_invalid(): AgentCard(**bad_card_data) # Missing name +def test_agent_card_with_camel_case(): + card = AgentCard( + capabilities={}, + defaultInputModes=['text/plain'], + defaultOutputModes=['application/json'], + description='TestAgent', + name='TestAgent', + skills=[ + AgentSkill( + id='skill-123', + name='Recipe Finder', + description='Finds recipes', + tags=['cooking'], + ) + ], + url='http://example.com/agent', + version='1.0', + ) + assert card.name == 'TestAgent' + assert card.version == '1.0' + assert len(card.skills) == 1 + assert card.skills[0].id == 'skill-123' + assert card.provider is None # Optional + assert card.defaultInputModes[0] == 'text/plain' + assert card.defaultOutputModes[0] == 'application/json' + + +def test_agent_card_with_snake_case(): + card = AgentCard( + capabilities={}, + default_input_modes=['text/plain'], + default_output_modes=['application/json'], + description='TestAgent', + name='TestAgent', + skills=[ + AgentSkill( + id='skill-123', + name='Recipe Finder', + description='Finds recipes', + tags=['cooking'], + ) + ], + url='http://example.com/agent', + version='1.0', + ) + assert card.name == 'TestAgent' + assert card.version == '1.0' + assert len(card.skills) == 1 + assert card.skills[0].id == 'skill-123' + assert card.provider is None # Optional + assert card.default_input_modes[0] == 'text/plain' + assert card.default_output_modes[0] == 'application/json' + + # --- Test Parts --- @@ -284,12 +338,12 @@ def test_text_part(): def test_file_part_variants(): # URI variant file_uri = FileWithUri( - uri='file:///path/to/file.txt', mimeType='text/plain' + uri='file:///path/to/file.txt', mime_type='text/plain' ) part_uri = FilePart(kind='file', file=file_uri) assert isinstance(part_uri.file, FileWithUri) assert part_uri.file.uri == 'file:///path/to/file.txt' - assert part_uri.file.mimeType == 'text/plain' + assert part_uri.file.mime_type == 'text/plain' assert not hasattr(part_uri.file, 'bytes') # Bytes variant @@ -340,9 +394,16 @@ def test_part_root_model(): assert part_data.root.data == {'key': 'value'} # Test serialization - assert part_text.model_dump(exclude_none=True) == TEXT_PART_DATA - assert part_file.model_dump(exclude_none=True) == FILE_URI_PART_DATA - assert part_data.model_dump(exclude_none=True) == DATA_PART_DATA + assert ( + part_text.model_dump(by_alias=True, exclude_none=True) == TEXT_PART_DATA + ) + assert ( + part_file.model_dump(by_alias=True, exclude_none=True) + == FILE_URI_PART_DATA + ) + assert ( + part_data.model_dump(by_alias=True, exclude_none=True) == DATA_PART_DATA + ) # --- Test Message and Task --- @@ -390,7 +451,7 @@ def test_task_status(): def test_task(): task = Task(**MINIMAL_TASK) assert task.id == 'task-abc' - assert task.contextId == 'session-xyz' + assert task.context_id == 'session-xyz' assert task.status.state == TaskState.submitted assert task.history is None assert task.artifacts is None @@ -484,7 +545,7 @@ def test_jsonrpc_response_root_model() -> None: assert isinstance(resp_error.root, JSONRPCErrorResponse) assert resp_error.root.error.code == -32600 # Note: .model_dump() might serialize the nested error model - assert resp_error.model_dump(exclude_none=True) == error_data + assert resp_error.model_dump(by_alias=True, exclude_none=True) == error_data # Invalid case (neither success nor error structure) with pytest.raises(ValidationError): @@ -499,7 +560,7 @@ def test_send_message_request() -> None: req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/send', - 'params': params.model_dump(), + 'params': params.model_dump(by_alias=True), 'id': 5, } req = SendMessageRequest.model_validate(req_data) @@ -518,7 +579,7 @@ def test_send_subscribe_request() -> None: req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/stream', - 'params': params.model_dump(), + 'params': params.model_dump(by_alias=True), 'id': 5, } req = SendStreamingMessageRequest.model_validate(req_data) @@ -533,18 +594,18 @@ def test_send_subscribe_request() -> None: def test_get_task_request() -> None: - params = TaskQueryParams(id='task-1', historyLength=2) + params = TaskQueryParams(id='task-1', history_length=2) req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/get', - 'params': params.model_dump(), + 'params': params.model_dump(by_alias=True), 'id': 5, } req = GetTaskRequest.model_validate(req_data) assert req.method == 'tasks/get' assert isinstance(req.params, TaskQueryParams) assert req.params.id == 'task-1' - assert req.params.historyLength == 2 + assert req.params.history_length == 2 with pytest.raises(ValidationError): # Wrong method literal GetTaskRequest.model_validate({**req_data, 'method': 'wrong/method'}) @@ -555,7 +616,7 @@ def test_cancel_task_request() -> None: req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/cancel', - 'params': params.model_dump(), + 'params': params.model_dump(by_alias=True), 'id': 5, } req = CancelTaskRequest.model_validate(req_data) @@ -668,7 +729,7 @@ def test_send_message_streaming_status_update_response() -> None: assert isinstance(response.root, SendStreamingMessageSuccessResponse) assert isinstance(response.root.result, TaskStatusUpdateEvent) assert response.root.result.status.state == TaskState.submitted - assert response.root.result.taskId == '1' + assert response.root.result.task_id == '1' assert not response.root.result.final with pytest.raises( @@ -705,12 +766,12 @@ def test_send_message_streaming_artifact_update_response() -> None: text_part = TextPart(**TEXT_PART_DATA) data_part = DataPart(**DATA_PART_DATA) artifact = Artifact( - artifactId='artifact-123', + artifact_id='artifact-123', name='result_data', parts=[Part(root=text_part), Part(root=data_part)], ) task_artifact_update_event_data: dict[str, Any] = { - 'artifact': artifact, + 'artifact': artifact.model_dump(by_alias=True), 'taskId': 'task_id', 'contextId': '2', 'append': False, @@ -726,11 +787,11 @@ def test_send_message_streaming_artifact_update_response() -> None: assert response.root.id == 1 assert isinstance(response.root, SendStreamingMessageSuccessResponse) assert isinstance(response.root.result, TaskArtifactUpdateEvent) - assert response.root.result.artifact.artifactId == 'artifact-123' + assert response.root.result.artifact.artifact_id == 'artifact-123' assert response.root.result.artifact.name == 'result_data' - assert response.root.result.taskId == 'task_id' + assert response.root.result.task_id == 'task_id' assert not response.root.result.append - assert response.root.result.lastChunk + assert response.root.result.last_chunk assert len(response.root.result.artifact.parts) == 2 assert isinstance(response.root.result.artifact.parts[0].root, TextPart) assert isinstance(response.root.result.artifact.parts[1].root, DataPart) @@ -738,46 +799,48 @@ def test_send_message_streaming_artifact_update_response() -> None: def test_set_task_push_notification_response() -> None: task_push_config = TaskPushNotificationConfig( - taskId='t2', - pushNotificationConfig=PushNotificationConfig( + task_id='t2', + push_notification_config=PushNotificationConfig( url='https://example.com', token='token' ), ) resp_data: dict[str, Any] = { 'jsonrpc': '2.0', - 'result': task_push_config.model_dump(), + 'result': task_push_config.model_dump(by_alias=True), 'id': 1, } resp = SetTaskPushNotificationConfigResponse.model_validate(resp_data) assert resp.root.id == 1 assert isinstance(resp.root, SetTaskPushNotificationConfigSuccessResponse) assert isinstance(resp.root.result, TaskPushNotificationConfig) - assert resp.root.result.taskId == 't2' - assert resp.root.result.pushNotificationConfig.url == 'https://example.com' - assert resp.root.result.pushNotificationConfig.token == 'token' - assert resp.root.result.pushNotificationConfig.authentication is None + assert resp.root.result.task_id == 't2' + assert ( + resp.root.result.push_notification_config.url == 'https://example.com' + ) + assert resp.root.result.push_notification_config.token == 'token' + assert resp.root.result.push_notification_config.authentication is None auth_info_dict: dict[str, Any] = { 'schemes': ['Bearer', 'Basic'], 'credentials': 'user:pass', } - task_push_config.pushNotificationConfig.authentication = ( + task_push_config.push_notification_config.authentication = ( PushNotificationAuthenticationInfo(**auth_info_dict) ) resp_data = { 'jsonrpc': '2.0', - 'result': task_push_config.model_dump(), + 'result': task_push_config.model_dump(by_alias=True), 'id': 1, } resp = SetTaskPushNotificationConfigResponse.model_validate(resp_data) assert isinstance(resp.root, SetTaskPushNotificationConfigSuccessResponse) - assert resp.root.result.pushNotificationConfig.authentication is not None - assert resp.root.result.pushNotificationConfig.authentication.schemes == [ + assert resp.root.result.push_notification_config.authentication is not None + assert resp.root.result.push_notification_config.authentication.schemes == [ 'Bearer', 'Basic', ] assert ( - resp.root.result.pushNotificationConfig.authentication.credentials + resp.root.result.push_notification_config.authentication.credentials == 'user:pass' ) @@ -797,46 +860,48 @@ def test_set_task_push_notification_response() -> None: def test_get_task_push_notification_response() -> None: task_push_config = TaskPushNotificationConfig( - taskId='t2', - pushNotificationConfig=PushNotificationConfig( + task_id='t2', + push_notification_config=PushNotificationConfig( url='https://example.com', token='token' ), ) resp_data: dict[str, Any] = { 'jsonrpc': '2.0', - 'result': task_push_config.model_dump(), + 'result': task_push_config.model_dump(by_alias=True), 'id': 1, } resp = GetTaskPushNotificationConfigResponse.model_validate(resp_data) assert resp.root.id == 1 assert isinstance(resp.root, GetTaskPushNotificationConfigSuccessResponse) assert isinstance(resp.root.result, TaskPushNotificationConfig) - assert resp.root.result.taskId == 't2' - assert resp.root.result.pushNotificationConfig.url == 'https://example.com' - assert resp.root.result.pushNotificationConfig.token == 'token' - assert resp.root.result.pushNotificationConfig.authentication is None + assert resp.root.result.task_id == 't2' + assert ( + resp.root.result.push_notification_config.url == 'https://example.com' + ) + assert resp.root.result.push_notification_config.token == 'token' + assert resp.root.result.push_notification_config.authentication is None auth_info_dict: dict[str, Any] = { 'schemes': ['Bearer', 'Basic'], 'credentials': 'user:pass', } - task_push_config.pushNotificationConfig.authentication = ( + task_push_config.push_notification_config.authentication = ( PushNotificationAuthenticationInfo(**auth_info_dict) ) resp_data = { 'jsonrpc': '2.0', - 'result': task_push_config.model_dump(), + 'result': task_push_config.model_dump(by_alias=True), 'id': 1, } resp = GetTaskPushNotificationConfigResponse.model_validate(resp_data) assert isinstance(resp.root, GetTaskPushNotificationConfigSuccessResponse) - assert resp.root.result.pushNotificationConfig.authentication is not None - assert resp.root.result.pushNotificationConfig.authentication.schemes == [ + assert resp.root.result.push_notification_config.authentication is not None + assert resp.root.result.push_notification_config.authentication.schemes == [ 'Bearer', 'Basic', ] assert ( - resp.root.result.pushNotificationConfig.authentication.credentials + resp.root.result.push_notification_config.authentication.credentials == 'user:pass' ) @@ -863,7 +928,7 @@ def test_a2a_request_root_model() -> None: send_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/send', - 'params': send_params.model_dump(), + 'params': send_params.model_dump(by_alias=True), 'id': 1, } a2a_req_send = A2ARequest.model_validate(send_req_data) @@ -874,7 +939,7 @@ def test_a2a_request_root_model() -> None: send_subs_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/stream', - 'params': send_params.model_dump(), + 'params': send_params.model_dump(by_alias=True), 'id': 1, } a2a_req_send_subs = A2ARequest.model_validate(send_subs_req_data) @@ -886,7 +951,7 @@ def test_a2a_request_root_model() -> None: get_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/get', - 'params': get_params.model_dump(), + 'params': get_params.model_dump(by_alias=True), 'id': 2, } a2a_req_get = A2ARequest.model_validate(get_req_data) @@ -898,7 +963,7 @@ def test_a2a_request_root_model() -> None: cancel_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/cancel', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), 'id': 2, } a2a_req_cancel = A2ARequest.model_validate(cancel_req_data) @@ -907,8 +972,8 @@ def test_a2a_request_root_model() -> None: # SetTaskPushNotificationConfigRequest task_push_config = TaskPushNotificationConfig( - taskId='t2', - pushNotificationConfig=PushNotificationConfig( + task_id='t2', + push_notification_config=PushNotificationConfig( url='https://example.com', token='token' ), ) @@ -916,7 +981,7 @@ def test_a2a_request_root_model() -> None: 'id': 1, 'jsonrpc': '2.0', 'method': 'tasks/pushNotificationConfig/set', - 'params': task_push_config.model_dump(), + 'params': task_push_config.model_dump(by_alias=True), 'taskId': 2, } a2a_req_set_push_req = A2ARequest.model_validate(set_push_notif_req_data) @@ -936,7 +1001,7 @@ def test_a2a_request_root_model() -> None: 'id': 1, 'jsonrpc': '2.0', 'method': 'tasks/pushNotificationConfig/get', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), 'taskId': 2, } a2a_req_get_push_req = A2ARequest.model_validate(get_push_notif_req_data) @@ -952,7 +1017,7 @@ def test_a2a_request_root_model() -> None: task_resubscribe_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/resubscribe', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), 'id': 2, } a2a_req_task_resubscribe_req = A2ARequest.model_validate( @@ -981,7 +1046,7 @@ def test_a2a_request_root_model_id_validation() -> None: send_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/send', - 'params': send_params.model_dump(), + 'params': send_params.model_dump(by_alias=True), } with pytest.raises(ValidationError): A2ARequest.model_validate(send_req_data) # missing id @@ -990,7 +1055,7 @@ def test_a2a_request_root_model_id_validation() -> None: send_subs_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'message/stream', - 'params': send_params.model_dump(), + 'params': send_params.model_dump(by_alias=True), } with pytest.raises(ValidationError): A2ARequest.model_validate(send_subs_req_data) # missing id @@ -1000,7 +1065,7 @@ def test_a2a_request_root_model_id_validation() -> None: get_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/get', - 'params': get_params.model_dump(), + 'params': get_params.model_dump(by_alias=True), } with pytest.raises(ValidationError): A2ARequest.model_validate(get_req_data) # missing id @@ -1010,22 +1075,22 @@ def test_a2a_request_root_model_id_validation() -> None: cancel_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/cancel', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), } with pytest.raises(ValidationError): A2ARequest.model_validate(cancel_req_data) # missing id # SetTaskPushNotificationConfigRequest task_push_config = TaskPushNotificationConfig( - taskId='t2', - pushNotificationConfig=PushNotificationConfig( + task_id='t2', + push_notification_config=PushNotificationConfig( url='https://example.com', token='token' ), ) set_push_notif_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/pushNotificationConfig/set', - 'params': task_push_config.model_dump(), + 'params': task_push_config.model_dump(by_alias=True), 'taskId': 2, } with pytest.raises(ValidationError): @@ -1036,7 +1101,7 @@ def test_a2a_request_root_model_id_validation() -> None: get_push_notif_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/pushNotificationConfig/get', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), 'taskId': 2, } with pytest.raises(ValidationError): @@ -1046,7 +1111,7 @@ def test_a2a_request_root_model_id_validation() -> None: task_resubscribe_req_data: dict[str, Any] = { 'jsonrpc': '2.0', 'method': 'tasks/resubscribe', - 'params': id_params.model_dump(), + 'params': id_params.model_dump(by_alias=True), } with pytest.raises(ValidationError): A2ARequest.model_validate(task_resubscribe_req_data) @@ -1300,14 +1365,16 @@ def test_task_push_notification_config() -> None: assert push_notification_config.authentication == auth_info task_push_notification_config = TaskPushNotificationConfig( - taskId='task-123', pushNotificationConfig=push_notification_config + task_id='task-123', push_notification_config=push_notification_config ) - assert task_push_notification_config.taskId == 'task-123' + assert task_push_notification_config.task_id == 'task-123' assert ( - task_push_notification_config.pushNotificationConfig + task_push_notification_config.push_notification_config == push_notification_config ) - assert task_push_notification_config.model_dump(exclude_none=True) == { + assert task_push_notification_config.model_dump( + by_alias=True, exclude_none=True + ) == { 'taskId': 'task-123', 'pushNotificationConfig': { 'url': 'https://example.com', @@ -1356,22 +1423,22 @@ def test_file_base_valid(): """Tests successful validation of FileBase.""" # No optional fields base1 = FileBase() - assert base1.mimeType is None + assert base1.mime_type is None assert base1.name is None # With mimeType only - base2 = FileBase(mimeType='image/png') - assert base2.mimeType == 'image/png' + base2 = FileBase(mime_type='image/png') + assert base2.mime_type == 'image/png' assert base2.name is None # With name only base3 = FileBase(name='document.pdf') - assert base3.mimeType is None + assert base3.mime_type is None assert base3.name == 'document.pdf' # With both fields - base4 = FileBase(mimeType='application/json', name='data.json') - assert base4.mimeType == 'application/json' + base4 = FileBase(mime_type='application/json', name='data.json') + assert base4.mime_type == 'application/json' assert base4.name == 'data.json' @@ -1381,8 +1448,8 @@ def test_file_base_invalid(): # Incorrect type for mimeType with pytest.raises(ValidationError) as excinfo_type_mime: - FileBase(mimeType=123) # type: ignore - assert 'mimeType' in str(excinfo_type_mime.value) + FileBase(mime_type=123) # type: ignore + assert 'mime_type' in str(excinfo_type_mime.value) # Incorrect type for name with pytest.raises(ValidationError) as excinfo_type_name: @@ -1417,46 +1484,56 @@ def test_a2a_error_validation_and_serialization() -> None: # 1. Test JSONParseError json_parse_instance = JSONParseError() - json_parse_data = json_parse_instance.model_dump(exclude_none=True) + json_parse_data = json_parse_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_parse = A2AError.model_validate(json_parse_data) assert isinstance(a2a_err_parse.root, JSONParseError) # 2. Test InvalidRequestError invalid_req_instance = InvalidRequestError() - invalid_req_data = invalid_req_instance.model_dump(exclude_none=True) + invalid_req_data = invalid_req_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_invalid_req = A2AError.model_validate(invalid_req_data) assert isinstance(a2a_err_invalid_req.root, InvalidRequestError) # 3. Test MethodNotFoundError method_not_found_instance = MethodNotFoundError() method_not_found_data = method_not_found_instance.model_dump( - exclude_none=True + by_alias=True, exclude_none=True ) a2a_err_method = A2AError.model_validate(method_not_found_data) assert isinstance(a2a_err_method.root, MethodNotFoundError) # 4. Test InvalidParamsError invalid_params_instance = InvalidParamsError() - invalid_params_data = invalid_params_instance.model_dump(exclude_none=True) + invalid_params_data = invalid_params_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_params = A2AError.model_validate(invalid_params_data) assert isinstance(a2a_err_params.root, InvalidParamsError) # 5. Test InternalError internal_err_instance = InternalError() - internal_err_data = internal_err_instance.model_dump(exclude_none=True) + internal_err_data = internal_err_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_internal = A2AError.model_validate(internal_err_data) assert isinstance(a2a_err_internal.root, InternalError) # 6. Test TaskNotFoundError task_not_found_instance = TaskNotFoundError(data={'taskId': 't1'}) - task_not_found_data = task_not_found_instance.model_dump(exclude_none=True) + task_not_found_data = task_not_found_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_task_nf = A2AError.model_validate(task_not_found_data) assert isinstance(a2a_err_task_nf.root, TaskNotFoundError) # 7. Test TaskNotCancelableError task_not_cancelable_instance = TaskNotCancelableError() task_not_cancelable_data = task_not_cancelable_instance.model_dump( - exclude_none=True + by_alias=True, exclude_none=True ) a2a_err_task_nc = A2AError.model_validate(task_not_cancelable_data) assert isinstance(a2a_err_task_nc.root, TaskNotCancelableError) @@ -1464,21 +1541,23 @@ def test_a2a_error_validation_and_serialization() -> None: # 8. Test PushNotificationNotSupportedError push_not_supported_instance = PushNotificationNotSupportedError() push_not_supported_data = push_not_supported_instance.model_dump( - exclude_none=True + by_alias=True, exclude_none=True ) a2a_err_push_ns = A2AError.model_validate(push_not_supported_data) assert isinstance(a2a_err_push_ns.root, PushNotificationNotSupportedError) # 9. Test UnsupportedOperationError unsupported_op_instance = UnsupportedOperationError() - unsupported_op_data = unsupported_op_instance.model_dump(exclude_none=True) + unsupported_op_data = unsupported_op_instance.model_dump( + by_alias=True, exclude_none=True + ) a2a_err_unsupported = A2AError.model_validate(unsupported_op_data) assert isinstance(a2a_err_unsupported.root, UnsupportedOperationError) # 10. Test ContentTypeNotSupportedError content_type_err_instance = ContentTypeNotSupportedError() content_type_err_data = content_type_err_instance.model_dump( - exclude_none=True + by_alias=True, exclude_none=True ) a2a_err_content = A2AError.model_validate(content_type_err_data) assert isinstance(a2a_err_content.root, ContentTypeNotSupportedError) diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 16469a02..5fb524a8 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -53,7 +53,7 @@ def test_create_task_obj(): task = create_task_obj(send_params) assert task.id is not None - assert task.contextId == message.contextId + assert task.context_id == message.context_id assert task.status.state == TaskState.submitted assert len(task.history) == 1 assert task.history[0] == message @@ -65,13 +65,13 @@ def test_create_task_obj_generates_context_id(): message_no_context_id = Message( role=Role.user, parts=[Part(root=TextPart(text='test'))], - messageId='msg-no-ctx', - taskId='task-from-msg', # Provide a taskId to differentiate from generated task.id + message_id='msg-no-ctx', + task_id='task-from-msg', # Provide a taskId to differentiate from generated task.id ) send_params = MessageSendParams(message=message_no_context_id) # Ensure message.contextId is None initially - assert send_params.message.contextId is None + assert send_params.message.context_id is None known_task_uuid = uuid.UUID('11111111-1111-1111-1111-111111111111') known_context_uuid = uuid.UUID('22222222-2222-2222-2222-222222222222') @@ -88,17 +88,17 @@ def test_create_task_obj_generates_context_id(): assert mock_uuid4.call_count == 2 # Assert that message.contextId was set to the first generated UUID - assert send_params.message.contextId == str(known_context_uuid) + assert send_params.message.context_id == str(known_context_uuid) # Assert that task.contextId is the same generated UUID - assert task.contextId == str(known_context_uuid) + assert task.context_id == str(known_context_uuid) # Assert that task.id is the second generated UUID assert task.id == str(known_task_uuid) # Ensure the original message in history also has the updated contextId assert len(task.history) == 1 - assert task.history[0].contextId == str(known_context_uuid) + assert task.history[0].context_id == str(known_context_uuid) # Test append_artifact_to_task @@ -106,7 +106,7 @@ def test_append_artifact_to_task(): # Prepare base task task = Task(**MINIMAL_TASK) assert task.id == 'task-abc' - assert task.contextId == 'session-xyz' + assert task.context_id == 'session-xyz' assert task.status.state == TaskState.submitted assert task.history is None assert task.artifacts is None @@ -114,42 +114,45 @@ def test_append_artifact_to_task(): # Prepare appending artifact and event artifact_1 = Artifact( - artifactId='artifact-123', parts=[Part(root=TextPart(text='Hello'))] + artifact_id='artifact-123', parts=[Part(root=TextPart(text='Hello'))] ) append_event_1 = TaskArtifactUpdateEvent( - artifact=artifact_1, append=False, taskId='123', contextId='123' + artifact=artifact_1, append=False, task_id='123', context_id='123' ) # Test adding a new artifact (not appending) append_artifact_to_task(task, append_event_1) assert len(task.artifacts) == 1 - assert task.artifacts[0].artifactId == 'artifact-123' + assert task.artifacts[0].artifact_id == 'artifact-123' assert task.artifacts[0].name is None assert len(task.artifacts[0].parts) == 1 assert task.artifacts[0].parts[0].root.text == 'Hello' # Test replacing the artifact artifact_2 = Artifact( - artifactId='artifact-123', + artifact_id='artifact-123', name='updated name', parts=[Part(root=TextPart(text='Updated'))], ) append_event_2 = TaskArtifactUpdateEvent( - artifact=artifact_2, append=False, taskId='123', contextId='123' + artifact=artifact_2, append=False, task_id='123', context_id='123' ) append_artifact_to_task(task, append_event_2) assert len(task.artifacts) == 1 # Should still have one artifact - assert task.artifacts[0].artifactId == 'artifact-123' + assert task.artifacts[0].artifact_id == 'artifact-123' assert task.artifacts[0].name == 'updated name' assert len(task.artifacts[0].parts) == 1 assert task.artifacts[0].parts[0].root.text == 'Updated' # Test appending parts to an existing artifact artifact_with_parts = Artifact( - artifactId='artifact-123', parts=[Part(root=TextPart(text='Part 2'))] + artifact_id='artifact-123', parts=[Part(root=TextPart(text='Part 2'))] ) append_event_3 = TaskArtifactUpdateEvent( - artifact=artifact_with_parts, append=True, taskId='123', contextId='123' + artifact=artifact_with_parts, + append=True, + task_id='123', + context_id='123', ) append_artifact_to_task(task, append_event_3) assert len(task.artifacts[0].parts) == 2 @@ -158,31 +161,31 @@ def test_append_artifact_to_task(): # Test adding another new artifact another_artifact_with_parts = Artifact( - artifactId='new_artifact', + artifact_id='new_artifact', parts=[Part(root=TextPart(text='new artifact Part 1'))], ) append_event_4 = TaskArtifactUpdateEvent( artifact=another_artifact_with_parts, append=False, - taskId='123', - contextId='123', + task_id='123', + context_id='123', ) append_artifact_to_task(task, append_event_4) assert len(task.artifacts) == 2 - assert task.artifacts[0].artifactId == 'artifact-123' - assert task.artifacts[1].artifactId == 'new_artifact' + assert task.artifacts[0].artifact_id == 'artifact-123' + assert task.artifacts[1].artifact_id == 'new_artifact' assert len(task.artifacts[0].parts) == 2 assert len(task.artifacts[1].parts) == 1 # Test appending part to a task that does not have a matching artifact non_existing_artifact_with_parts = Artifact( - artifactId='artifact-456', parts=[Part(root=TextPart(text='Part 1'))] + artifact_id='artifact-456', parts=[Part(root=TextPart(text='Part 1'))] ) append_event_5 = TaskArtifactUpdateEvent( artifact=non_existing_artifact_with_parts, append=True, - taskId='123', - contextId='123', + task_id='123', + context_id='123', ) append_artifact_to_task(task, append_event_5) assert len(task.artifacts) == 2 @@ -196,7 +199,7 @@ def test_build_text_artifact(): text = 'This is a sample text' artifact = build_text_artifact(text, artifact_id) - assert artifact.artifactId == artifact_id + assert artifact.artifact_id == artifact_id assert len(artifact.parts) == 1 assert artifact.parts[0].root.text == text diff --git a/tests/utils/test_message.py b/tests/utils/test_message.py index 6851a3ca..66acdffa 100644 --- a/tests/utils/test_message.py +++ b/tests/utils/test_message.py @@ -27,9 +27,9 @@ def test_new_agent_text_message_basic(self): assert message.role == Role.agent assert len(message.parts) == 1 assert message.parts[0].root.text == text - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.taskId is None - assert message.contextId is None + assert message.message_id == '12345678-1234-5678-1234-567812345678' + assert message.task_id is None + assert message.context_id is None def test_new_agent_text_message_with_context_id(self): # Setup @@ -46,9 +46,9 @@ def test_new_agent_text_message_with_context_id(self): # Verify assert message.role == Role.agent assert message.parts[0].root.text == text - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.contextId == context_id - assert message.taskId is None + assert message.message_id == '12345678-1234-5678-1234-567812345678' + assert message.context_id == context_id + assert message.task_id is None def test_new_agent_text_message_with_task_id(self): # Setup @@ -65,9 +65,9 @@ def test_new_agent_text_message_with_task_id(self): # Verify assert message.role == Role.agent assert message.parts[0].root.text == text - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.taskId == task_id - assert message.contextId is None + assert message.message_id == '12345678-1234-5678-1234-567812345678' + assert message.task_id == task_id + assert message.context_id is None def test_new_agent_text_message_with_both_ids(self): # Setup @@ -87,9 +87,9 @@ def test_new_agent_text_message_with_both_ids(self): # Verify assert message.role == Role.agent assert message.parts[0].root.text == text - assert message.messageId == '12345678-1234-5678-1234-567812345678' - assert message.contextId == context_id - assert message.taskId == task_id + assert message.message_id == '12345678-1234-5678-1234-567812345678' + assert message.context_id == context_id + assert message.task_id == task_id def test_new_agent_text_message_empty_text(self): # Setup @@ -105,7 +105,7 @@ def test_new_agent_text_message_empty_text(self): # Verify assert message.role == Role.agent assert message.parts[0].root.text == '' - assert message.messageId == '12345678-1234-5678-1234-567812345678' + assert message.message_id == '12345678-1234-5678-1234-567812345678' class TestGetTextParts: @@ -150,7 +150,7 @@ def test_get_message_text_single_part(self): message = Message( role=Role.agent, parts=[Part(root=TextPart(text='Hello world'))], - messageId='test-message-id', + message_id='test-message-id', ) # Exercise @@ -168,7 +168,7 @@ def test_get_message_text_multiple_parts(self): Part(root=TextPart(text='Second line')), Part(root=TextPart(text='Third line')), ], - messageId='test-message-id', + message_id='test-message-id', ) # Exercise @@ -186,7 +186,7 @@ def test_get_message_text_custom_delimiter(self): Part(root=TextPart(text='Second part')), Part(root=TextPart(text='Third part')), ], - messageId='test-message-id', + message_id='test-message-id', ) # Exercise @@ -200,7 +200,7 @@ def test_get_message_text_empty_parts(self): message = Message( role=Role.agent, parts=[], - messageId='test-message-id', + message_id='test-message-id', ) # Exercise diff --git a/tests/utils/test_task.py b/tests/utils/test_task.py index 796a7ad8..1b818cf4 100644 --- a/tests/utils/test_task.py +++ b/tests/utils/test_task.py @@ -12,7 +12,7 @@ def test_new_task_status(self): message = Message( role=Role.user, parts=[Part(root=TextPart(text='test message'))], - messageId=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), ) task = new_task(message) self.assertEqual(task.status.state.value, 'submitted') @@ -24,11 +24,11 @@ def test_new_task_generates_ids(self, mock_uuid4): message = Message( role=Role.user, parts=[Part(root=TextPart(text='test message'))], - messageId=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), ) task = new_task(message) self.assertEqual(task.id, str(mock_uuid)) - self.assertEqual(task.contextId, str(mock_uuid)) + self.assertEqual(task.context_id, str(mock_uuid)) def test_new_task_uses_provided_ids(self): task_id = str(uuid.uuid4()) @@ -36,19 +36,19 @@ def test_new_task_uses_provided_ids(self): message = Message( role=Role.user, parts=[Part(root=TextPart(text='test message'))], - messageId=str(uuid.uuid4()), - taskId=task_id, - contextId=context_id, + message_id=str(uuid.uuid4()), + task_id=task_id, + context_id=context_id, ) task = new_task(message) self.assertEqual(task.id, task_id) - self.assertEqual(task.contextId, context_id) + self.assertEqual(task.context_id, context_id) def test_new_task_initial_message_in_history(self): message = Message( role=Role.user, parts=[Part(root=TextPart(text='test message'))], - messageId=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), ) task = new_task(message) self.assertEqual(len(task.history), 1) @@ -77,7 +77,7 @@ def test_completed_task_assigns_ids_and_artifacts(self): history=[], ) self.assertEqual(task.id, task_id) - self.assertEqual(task.contextId, context_id) + self.assertEqual(task.context_id, context_id) self.assertEqual(task.artifacts, artifacts) def test_completed_task_empty_history_if_not_provided(self): @@ -97,12 +97,12 @@ def test_completed_task_uses_provided_history(self): Message( role=Role.user, parts=[Part(root=TextPart(text='Hello'))], - messageId=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), ), Message( role=Role.agent, parts=[Part(root=TextPart(text='Hi there'))], - messageId=str(uuid.uuid4()), + message_id=str(uuid.uuid4()), ), ] task = completed_task(