Skip to content

Commit c5e4021

Browse files
authored
refactor(framework): Improve ContextVar handling
1 parent b2d6c7e commit c5e4021

File tree

5 files changed

+43
-36
lines changed

5 files changed

+43
-36
lines changed

framework/py/flwr/superlink/servicer/control/control_account_auth_interceptor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,19 @@
4646
)
4747

4848

49-
shared_account_info: contextvars.ContextVar[AccountInfo] = contextvars.ContextVar(
50-
"account_info", default=AccountInfo(flwr_aid=None, account_name=None)
49+
shared_account_info: contextvars.ContextVar[AccountInfo | None] = (
50+
contextvars.ContextVar("account_info", default=None)
5151
)
5252

5353

54+
def get_current_account_info() -> AccountInfo:
55+
"""Get the current account info from context, or return a default if not set."""
56+
account_info = shared_account_info.get()
57+
if account_info is None:
58+
return AccountInfo(flwr_aid=None, account_name=None)
59+
return account_info
60+
61+
5462
class ControlAccountAuthInterceptor(grpc.ServerInterceptor): # type: ignore
5563
"""Control API interceptor for account authentication."""
5664

framework/py/flwr/superlink/servicer/control/control_account_auth_interceptor_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242

4343
from .control_account_auth_interceptor import (
4444
ControlAccountAuthInterceptor,
45+
get_current_account_info,
4546
shared_account_info,
4647
)
4748

@@ -99,7 +100,7 @@ def test_unary_unary_authentication_successful(self, request: GrpcMessage) -> No
99100
# Assert response is as expected
100101
self.assertEqual(response, "dummy_response")
101102
# Assert `shared_account_info` is not set
102-
account_info_from_context = shared_account_info.get()
103+
account_info_from_context = get_current_account_info()
103104
self.assertIsNone(account_info_from_context.flwr_aid)
104105
self.assertIsNone(account_info_from_context.account_name)
105106

@@ -206,7 +207,7 @@ def test_unary_validate_tokens_successful(self, request: GrpcMessage) -> None:
206207
self.assertEqual(response, "dummy_response")
207208

208209
# Assert `shared_account_info` is set
209-
account_info_from_context = shared_account_info.get()
210+
account_info_from_context = get_current_account_info()
210211
self.assertEqual(
211212
account_info_from_context.flwr_aid, self.expected_account_info.flwr_aid
212213
)

framework/py/flwr/superlink/servicer/control/control_event_log_interceptor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from flwr.common.event_log_plugin.event_log_plugin import EventLogWriterPlugin
2525
from flwr.common.typing import LogEntry
2626

27-
from .control_account_auth_interceptor import shared_account_info
27+
from .control_account_auth_interceptor import get_current_account_info
2828

2929

3030
class ControlEventLogInterceptor(grpc.ServerInterceptor): # type: ignore
@@ -66,7 +66,7 @@ def _generic_method_handler(
6666
log_entry = self.log_plugin.compose_log_before_event(
6767
request=request,
6868
context=context,
69-
account_info=shared_account_info.get(),
69+
account_info=get_current_account_info(),
7070
method_name=method_name,
7171
)
7272
self.log_plugin.write_log(log_entry)
@@ -85,7 +85,7 @@ def _generic_method_handler(
8585
log_entry = self.log_plugin.compose_log_after_event(
8686
request=request,
8787
context=context,
88-
account_info=shared_account_info.get(),
88+
account_info=get_current_account_info(),
8989
method_name=method_name,
9090
response=unary_response or error,
9191
)
@@ -115,7 +115,7 @@ def response_wrapper() -> Iterator[GrpcMessage]:
115115
log_entry = self.log_plugin.compose_log_after_event(
116116
request=request,
117117
context=context,
118-
account_info=shared_account_info.get(),
118+
account_info=get_current_account_info(),
119119
method_name=method_name,
120120
response=stream_response or error,
121121
)

framework/py/flwr/superlink/servicer/control/control_servicer.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
from flwr.superlink.artifact_provider import ArtifactProvider
8181
from flwr.superlink.auth_plugin import ControlAuthnPlugin
8282

83-
from .control_account_auth_interceptor import shared_account_info
83+
from .control_account_auth_interceptor import get_current_account_info
8484

8585

8686
class ControlServicer(control_pb2_grpc.ControlServicer):
@@ -118,7 +118,7 @@ def StartRun( # pylint: disable=too-many-locals
118118
)
119119
return StartRunResponse()
120120

121-
flwr_aid = shared_account_info.get().flwr_aid
121+
flwr_aid = get_current_account_info().flwr_aid
122122
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
123123
override_config = user_config_from_proto(request.override_config)
124124
federation_options = config_record_from_proto(request.federation_options)
@@ -215,7 +215,7 @@ def StreamLogs( # pylint: disable=C0103
215215
context.abort(grpc.StatusCode.NOT_FOUND, RUN_ID_NOT_FOUND_MESSAGE)
216216

217217
# Check if `flwr_aid` matches the run's `flwr_aid`
218-
flwr_aid = shared_account_info.get().flwr_aid
218+
flwr_aid = get_current_account_info().flwr_aid
219219
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=cast(Run, run), context=context)
220220

221221
after_timestamp = request.after_timestamp + 1e-6
@@ -254,7 +254,7 @@ def ListRuns(
254254
if not request.HasField("run_id"):
255255
# If no `run_id` is specified and account auth is enabled,
256256
# return run IDs for the authenticated account
257-
flwr_aid = shared_account_info.get().flwr_aid
257+
flwr_aid = get_current_account_info().flwr_aid
258258
_check_flwr_aid_exists(flwr_aid, context)
259259
run_ids = state.get_run_ids(flwr_aid=flwr_aid)
260260
# Build a set of run IDs for `flwr ls --run-id <run_id>`
@@ -269,7 +269,7 @@ def ListRuns(
269269
raise grpc.RpcError() # This line is unreachable
270270

271271
# Check if `flwr_aid` matches the run's `flwr_aid`
272-
flwr_aid = shared_account_info.get().flwr_aid
272+
flwr_aid = get_current_account_info().flwr_aid
273273
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
274274

275275
run_ids = {run_id}
@@ -295,7 +295,7 @@ def StopRun(
295295
raise grpc.RpcError() # This line is unreachable
296296

297297
# Check if `flwr_aid` matches the run's `flwr_aid`
298-
flwr_aid = shared_account_info.get().flwr_aid
298+
flwr_aid = get_current_account_info().flwr_aid
299299
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
300300

301301
run_status = state.get_run_status({run_id})[run_id]
@@ -405,7 +405,7 @@ def PullArtifacts(
405405
)
406406

407407
# Check if `flwr_aid` matches the run's `flwr_aid`
408-
flwr_aid = shared_account_info.get().flwr_aid
408+
flwr_aid = get_current_account_info().flwr_aid
409409
_check_flwr_aid_in_run(flwr_aid=flwr_aid, run=run, context=context)
410410

411411
# Call artifact provider
@@ -435,10 +435,10 @@ def RegisterNode(
435435
state = self.linkstate_factory.state()
436436
node_id = 0
437437

438-
flwr_aid = shared_account_info.get().flwr_aid
438+
flwr_aid = get_current_account_info().flwr_aid
439439
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
440440
# Account name exists if `flwr_aid` exists
441-
account_name = cast(str, shared_account_info.get().account_name)
441+
account_name = cast(str, get_current_account_info().account_name)
442442
try:
443443
node_id = state.create_node(
444444
owner_aid=flwr_aid,
@@ -466,7 +466,7 @@ def UnregisterNode(
466466
# Init link state
467467
state = self.linkstate_factory.state()
468468

469-
flwr_aid = shared_account_info.get().flwr_aid
469+
flwr_aid = get_current_account_info().flwr_aid
470470
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
471471
try:
472472
state.delete_node(owner_aid=flwr_aid, node_id=request.node_id)
@@ -494,7 +494,7 @@ def ListNodes(
494494
# Init link state
495495
state = self.linkstate_factory.state()
496496

497-
flwr_aid = shared_account_info.get().flwr_aid
497+
flwr_aid = get_current_account_info().flwr_aid
498498
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
499499
# Retrieve all nodes for the account
500500
nodes_info = state.get_node_info(owner_aids=[flwr_aid])
@@ -510,7 +510,7 @@ def ListFederations(
510510
# Init link state
511511
state = self.linkstate_factory.state()
512512

513-
flwr_aid = shared_account_info.get().flwr_aid
513+
flwr_aid = get_current_account_info().flwr_aid
514514
flwr_aid = _check_flwr_aid_exists(flwr_aid, context)
515515

516516
# Get federations the account is a member of

framework/py/flwr/superlink/servicer/control/control_servicer_test.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,8 @@ def test_list_nodes_cli(self, flwr_aid_retrieving: str, expected: bool) -> None:
283283

284284
# Execute
285285
with patch(
286-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
287-
new=SimpleNamespace(
288-
get=lambda: SimpleNamespace(flwr_aid=flwr_aid_retrieving)
289-
),
286+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
287+
return_value=SimpleNamespace(flwr_aid=flwr_aid_retrieving),
290288
):
291289
res: ListNodesResponse = self.servicer.ListNodes(ListNodesRequest(), Mock())
292290

@@ -349,8 +347,8 @@ def test_streamlogs_auth_unsucessful(
349347

350348
# Execute & Assert
351349
with patch(
352-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
353-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid=context_flwr_aid)),
350+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
351+
return_value=SimpleNamespace(flwr_aid=context_flwr_aid),
354352
):
355353
gen = self.servicer.StreamLogs(request, ctx)
356354
with self.assertRaises(RuntimeError) as cm:
@@ -378,8 +376,8 @@ def test_streamlogs_auth_successful(self) -> None:
378376
},
379377
),
380378
patch(
381-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
382-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid="user-123")),
379+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
380+
return_value=SimpleNamespace(flwr_aid="user-123"),
383381
),
384382
):
385383
msgs = list(self.servicer.StreamLogs(request, ctx))
@@ -403,8 +401,8 @@ def test_stoprun_auth_unsuccessful(
403401

404402
# Execute & Assert
405403
with patch(
406-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
407-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid=context_flwr_aid)),
404+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
405+
return_value=SimpleNamespace(flwr_aid=context_flwr_aid),
408406
):
409407
with self.assertRaises(RuntimeError) as cm:
410408
self.servicer.StopRun(request, ctx)
@@ -419,8 +417,8 @@ def test_stoprun_auth_successful(self) -> None:
419417

420418
# Execute & Assert
421419
with patch(
422-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
423-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid="user-123")),
420+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
421+
return_value=SimpleNamespace(flwr_aid="user-123"),
424422
):
425423
response = self.servicer.StopRun(request, ctx)
426424
self.assertTrue(response.success)
@@ -441,8 +439,8 @@ def test_listruns_auth_unsuccessful(
441439

442440
# Execute & Assert
443441
with patch(
444-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
445-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid=context_flwr_aid)),
442+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
443+
return_value=SimpleNamespace(flwr_aid=context_flwr_aid),
446444
):
447445
with self.assertRaises(RuntimeError) as cm:
448446
self.servicer.ListRuns(request, ctx)
@@ -457,8 +455,8 @@ def test_listruns_auth_run_success(self) -> None:
457455

458456
# Execute & Assert
459457
with patch(
460-
"flwr.superlink.servicer.control.control_servicer.shared_account_info",
461-
new=SimpleNamespace(get=lambda: SimpleNamespace(flwr_aid="user-123")),
458+
"flwr.superlink.servicer.control.control_servicer.get_current_account_info",
459+
return_value=SimpleNamespace(flwr_aid="user-123"),
462460
):
463461
response = self.servicer.ListRuns(request, ctx)
464462
self.assertEqual(set(response.run_dict.keys()), {run_id})

0 commit comments

Comments
 (0)