Skip to content

Commit efaabd3

Browse files
authored
feat: Add middleware to the client SDK (#171)
This PR introduces a middleware framework to the a2a-python client SDK, which allows for various authentication mechanisms. The new AuthInterceptor automatically injects authentication credentials into outgoing requests based on the agent's defined security schemes. This change simplifies the process of handling different authentication methods like Bearer tokens and API keys for developers using the SDK.
1 parent 39307f1 commit efaabd3

File tree

10 files changed

+736
-36
lines changed

10 files changed

+736
-36
lines changed

.github/actions/spelling/allow.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ opensource
4545
protoc
4646
pyi
4747
pyversions
48+
respx
4849
resub
4950
socio
5051
sse

.vscode/launch.json

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
"PYTHONPATH": "${workspaceFolder}"
1313
},
1414
"cwd": "${workspaceFolder}/examples/helloworld",
15-
"args": ["--host", "localhost", "--port", "9999"]
15+
"args": [
16+
"--host",
17+
"localhost",
18+
"--port",
19+
"9999"
20+
]
1621
},
1722
{
1823
"name": "Debug Currency Agent",
@@ -25,7 +30,24 @@
2530
"PYTHONPATH": "${workspaceFolder}"
2631
},
2732
"cwd": "${workspaceFolder}/examples/langgraph",
28-
"args": ["--host", "localhost", "--port", "10000"]
33+
"args": [
34+
"--host",
35+
"localhost",
36+
"--port",
37+
"10000"
38+
]
39+
},
40+
{
41+
"name": "Pytest All",
42+
"type": "debugpy",
43+
"request": "launch",
44+
"module": "pytest",
45+
"args": [
46+
"-v",
47+
"-s"
48+
],
49+
"console": "integratedTerminal",
50+
"justMyCode": true
2951
}
3052
]
31-
}
53+
}

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ dev = [
7676
"pytest-asyncio>=0.26.0",
7777
"pytest-cov>=6.1.1",
7878
"pytest-mock>=3.14.0",
79+
"respx>=0.20.2",
7980
"ruff>=0.11.6",
8081
"uv-dynamic-versioning>=0.8.2",
8182
"types-protobuf",

src/a2a/client/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
"""Client-side components for interacting with an A2A agent."""
22

3+
from a2a.client.auth import (
4+
AuthInterceptor,
5+
CredentialService,
6+
InMemoryContextCredentialStore,
7+
)
38
from a2a.client.client import A2ACardResolver, A2AClient
49
from a2a.client.errors import (
510
A2AClientError,
@@ -8,6 +13,7 @@
813
)
914
from a2a.client.grpc_client import A2AGrpcClient
1015
from a2a.client.helpers import create_text_message_object
16+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
1117

1218

1319
__all__ = [
@@ -17,5 +23,10 @@
1723
'A2AClientHTTPError',
1824
'A2AClientJSONError',
1925
'A2AGrpcClient',
26+
'AuthInterceptor',
27+
'ClientCallContext',
28+
'ClientCallInterceptor',
29+
'CredentialService',
30+
'InMemoryContextCredentialStore',
2031
'create_text_message_object',
2132
]

src/a2a/client/auth/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
"""Client-side authentication components for the A2A Python SDK."""
2+
3+
from a2a.client.auth.credentials import (
4+
CredentialService,
5+
InMemoryContextCredentialStore,
6+
)
7+
from a2a.client.auth.interceptor import AuthInterceptor
8+
9+
10+
__all__ = [
11+
'AuthInterceptor',
12+
'CredentialService',
13+
'InMemoryContextCredentialStore',
14+
]

src/a2a/client/auth/credentials.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from abc import ABC, abstractmethod
2+
3+
from a2a.client.middleware import ClientCallContext
4+
5+
6+
class CredentialService(ABC):
7+
"""An abstract service for retrieving credentials."""
8+
9+
@abstractmethod
10+
async def get_credentials(
11+
self,
12+
security_scheme_name: str,
13+
context: ClientCallContext | None,
14+
) -> str | None:
15+
"""
16+
Retrieves a credential (e.g., token) for a security scheme.
17+
"""
18+
19+
20+
class InMemoryContextCredentialStore(CredentialService):
21+
"""A simple in-memory store for session-keyed credentials.
22+
23+
This class uses the 'sessionId' from the ClientCallContext state to
24+
store and retrieve credentials...
25+
"""
26+
27+
def __init__(self) -> None:
28+
self._store: dict[str, dict[str, str]] = {}
29+
30+
async def get_credentials(
31+
self,
32+
security_scheme_name: str,
33+
context: ClientCallContext | None,
34+
) -> str | None:
35+
"""Retrieves credentials from the in-memory store.
36+
37+
Args:
38+
security_scheme_name: The name of the security scheme.
39+
context: The client call context.
40+
41+
Returns:
42+
The credential string, or None if not found.
43+
"""
44+
if not context or 'sessionId' not in context.state:
45+
return None
46+
session_id = context.state['sessionId']
47+
return self._store.get(session_id, {}).get(security_scheme_name)
48+
49+
async def set_credentials(
50+
self, session_id: str, security_scheme_name: str, credential: str
51+
) -> None:
52+
"""Method to populate the store."""
53+
if session_id not in self._store:
54+
self._store[session_id] = {}
55+
self._store[session_id][security_scheme_name] = credential

src/a2a/client/auth/interceptor.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import logging # noqa: I001
2+
from typing import Any
3+
4+
from a2a.client.auth.credentials import CredentialService
5+
from a2a.client.middleware import ClientCallContext, ClientCallInterceptor
6+
from a2a.types import (
7+
AgentCard,
8+
APIKeySecurityScheme,
9+
HTTPAuthSecurityScheme,
10+
In,
11+
OAuth2SecurityScheme,
12+
OpenIdConnectSecurityScheme,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class AuthInterceptor(ClientCallInterceptor):
19+
"""An interceptor that automatically adds authentication details to requests.
20+
21+
Based on the agent's security schemes.
22+
"""
23+
24+
def __init__(self, credential_service: CredentialService):
25+
self._credential_service = credential_service
26+
27+
async def intercept(
28+
self,
29+
method_name: str,
30+
request_payload: dict[str, Any],
31+
http_kwargs: dict[str, Any],
32+
agent_card: AgentCard | None,
33+
context: ClientCallContext | None,
34+
) -> tuple[dict[str, Any], dict[str, Any]]:
35+
"""Applies authentication headers to the request if credentials are available."""
36+
if (
37+
agent_card is None
38+
or agent_card.security is None
39+
or agent_card.securitySchemes is None
40+
):
41+
return request_payload, http_kwargs
42+
43+
for requirement in agent_card.security:
44+
for scheme_name in requirement:
45+
credential = await self._credential_service.get_credentials(
46+
scheme_name, context
47+
)
48+
if credential and scheme_name in agent_card.securitySchemes:
49+
scheme_def_union = agent_card.securitySchemes.get(
50+
scheme_name
51+
)
52+
if not scheme_def_union:
53+
continue
54+
scheme_def = scheme_def_union.root
55+
56+
headers = http_kwargs.get('headers', {})
57+
58+
match scheme_def:
59+
# Case 1a: HTTP Bearer scheme with an if guard
60+
case HTTPAuthSecurityScheme() if (
61+
scheme_def.scheme.lower() == 'bearer'
62+
):
63+
headers['Authorization'] = f'Bearer {credential}'
64+
logger.debug(
65+
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
66+
)
67+
http_kwargs['headers'] = headers
68+
return request_payload, http_kwargs
69+
70+
# Case 1b: OAuth2 and OIDC schemes, which are implicitly Bearer
71+
case (
72+
OAuth2SecurityScheme()
73+
| OpenIdConnectSecurityScheme()
74+
):
75+
headers['Authorization'] = f'Bearer {credential}'
76+
logger.debug(
77+
f"Added Bearer token for scheme '{scheme_name}' (type: {scheme_def.type})."
78+
)
79+
http_kwargs['headers'] = headers
80+
return request_payload, http_kwargs
81+
82+
# Case 2: API Key in Header
83+
case APIKeySecurityScheme(in_=In.header):
84+
headers[scheme_def.name] = credential
85+
logger.debug(
86+
f"Added API Key Header for scheme '{scheme_name}'."
87+
)
88+
http_kwargs['headers'] = headers
89+
return request_payload, http_kwargs
90+
91+
# Note: Other cases like API keys in query/cookie are not handled and will be skipped.
92+
93+
return request_payload, http_kwargs

0 commit comments

Comments
 (0)