Skip to content

avoid exposing asyncio.Future directly to api consumers #5765

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 7 additions & 13 deletions src/textual/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
ScreenResultCallbackType,
ScreenResultType,
SystemModalScreen,
AwaitScreen,
)
from textual.signal import Signal
from textual.theme import BUILTIN_THEMES, Theme, ThemeProvider
Expand Down Expand Up @@ -2683,14 +2684,14 @@ def push_screen(
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: Literal[True] = True,
) -> asyncio.Future[ScreenResultType]: ...
) -> AwaitScreen[ScreenResultType]: ...

def push_screen(
self,
screen: Screen[ScreenResultType] | str,
callback: ScreenResultCallbackType[ScreenResultType] | None = None,
wait_for_dismiss: bool = False,
) -> AwaitMount | asyncio.Future[ScreenResultType]:
) -> AwaitMount | AwaitScreen[ScreenResultType]:
"""Push a new [screen](/guide/screens) on the screen stack, making it the current screen.

Args:
Expand All @@ -2703,22 +2704,14 @@ def push_screen(
NoActiveWorker: If using `wait_for_dismiss` outside of a worker.

Returns:
An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future
An optional awaitable that awaits the mounting of the screen and its children, or an awaitable
to await the result of the screen.
"""
if not isinstance(screen, (Screen, str)):
raise TypeError(
f"push_screen requires a Screen instance or str; not {screen!r}"
)

try:
loop = asyncio.get_running_loop()
except RuntimeError:
# Mainly for testing, when push_screen isn't called in an async context
future: asyncio.Future[ScreenResultType] = asyncio.Future()
else:
future = loop.create_future()

if self._screen_stack:
self.screen.post_message(events.ScreenSuspend())
self.screen.refresh()
Expand All @@ -2728,7 +2721,8 @@ def push_screen(
except LookupError:
message_pump = self.app

next_screen._push_result_callback(message_pump, callback, future)
await_screen: AwaitScreen[ScreenResultType] = AwaitScreen()
next_screen._push_result_callback(message_pump, callback, await_screen)
self._load_screen_css(next_screen)
self._screen_stack.append(next_screen)
next_screen.post_message(events.ScreenResume())
Expand All @@ -2740,7 +2734,7 @@ def push_screen(
raise NoActiveWorker(
"push_screen must be run from a worker when `wait_for_dismiss` is True"
) from None
return future
return await_screen
else:
return await_mount

Expand Down
50 changes: 41 additions & 9 deletions src/textual/screen.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import enum
import asyncio
from functools import partial
from operator import attrgetter
Expand All @@ -23,6 +24,8 @@
Optional,
TypeVar,
Union,
Literal,
Generator,
)

import rich.repr
Expand Down Expand Up @@ -83,6 +86,33 @@
"""Type of a screen result callback function."""


class _Unset(enum.Enum):
UNSET = enum.auto()


class AwaitScreen(Generic[ScreenResultType]):
"""
An Awaitable object that resumes with the result of a Screen.
"""

def __init__(self) -> None:
self._event = asyncio.Event()
self._result: ScreenResultType | Literal[_Unset.UNSET] = _Unset.UNSET

async def wait(self) -> ScreenResultType:
await self._event.wait()
assert self._result is not _Unset.UNSET
return self._result

def __await__(self) -> Generator[Any, Any, ScreenResultType]:
return self.wait().__await__()

def set_result(self, result):
assert self._result is _Unset.UNSET
self._result = result
self._event.set()


@rich.repr.auto
class ResultCallback(Generic[ScreenResultType]):
"""Holds the details of a callback."""
Expand All @@ -91,21 +121,21 @@ def __init__(
self,
requester: MessagePump,
callback: ScreenResultCallbackType[ScreenResultType] | None,
future: asyncio.Future[ScreenResultType] | None = None,
await_screen: AwaitScreen[ScreenResultType] | None = None,
) -> None:
"""Initialise the result callback object.

Args:
requester: The object making a request for the callback.
callback: The callback function.
future: A Future to hold the result.
await_screen: An AwaitScreen to hold the result.
"""
self.requester = requester
"""The object in the DOM that requested the callback."""
self.callback: ScreenResultCallbackType | None = callback
"""The callback function."""
self.future = future
"""A future for the result"""
self.await_screen = await_screen
"""An AwaitScreen for the result"""

def __call__(self, result: ScreenResultType) -> None:
"""Call the callback, passing the given result.
Expand All @@ -116,8 +146,8 @@ def __call__(self, result: ScreenResultType) -> None:
Note:
If the requested or the callback are `None` this will be a no-op.
"""
if self.future is not None:
self.future.set_result(result)
if self.await_screen is not None:
self.await_screen.set_result(result)
if self.requester is not None and self.callback is not None:
self.requester.call_next(self.callback, result)
self.callback = None
Expand Down Expand Up @@ -1166,17 +1196,19 @@ def _push_result_callback(
self,
requester: MessagePump,
callback: ScreenResultCallbackType[ScreenResultType] | None,
future: asyncio.Future[ScreenResultType | None] | None = None,
await_screen: AwaitScreen[ScreenResultType] | None = None,
) -> None:
"""Add a result callback to the screen.

Args:
requester: The object requesting the callback.
callback: The callback.
future: A Future to hold the result.
await_screen: An AwaitScreen to hold the result.
"""
self._result_callbacks.append(
ResultCallback[Optional[ScreenResultType]](requester, callback, future)
ResultCallback[Optional[ScreenResultType]](
requester, callback, await_screen
)
)

async def _message_loop_exit(self) -> None:
Expand Down
Loading