diff --git a/src/textual/app.py b/src/textual/app.py index 478593124e..cd946e31ea 100644 --- a/src/textual/app.py +++ b/src/textual/app.py @@ -121,6 +121,7 @@ ScreenResultCallbackType, ScreenResultType, SystemModalScreen, + AwaitScreen, ) from textual.signal import Signal from textual.theme import BUILTIN_THEMES, Theme, ThemeProvider @@ -2683,14 +2684,14 @@ def push_screen( screen: Screen[ScreenResultType] | str, callback: ScreenResultCallbackType[ScreenResultType] | None = None, wait_for_dismiss: Literal[True] = True, - ) -> asyncio.Future[ScreenResultType]: ... + ) -> AwaitScreen[ScreenResultType]: ... def push_screen( self, screen: Screen[ScreenResultType] | str, callback: ScreenResultCallbackType[ScreenResultType] | None = None, wait_for_dismiss: bool = False, - ) -> AwaitMount | asyncio.Future[ScreenResultType]: + ) -> AwaitMount | AwaitScreen[ScreenResultType]: """Push a new [screen](/guide/screens) on the screen stack, making it the current screen. Args: @@ -2703,7 +2704,7 @@ def push_screen( NoActiveWorker: If using `wait_for_dismiss` outside of a worker. Returns: - An optional awaitable that awaits the mounting of the screen and its children, or an asyncio Future + An optional awaitable that awaits the mounting of the screen and its children, or an awaitable to await the result of the screen. """ if not isinstance(screen, (Screen, str)): @@ -2711,14 +2712,6 @@ def push_screen( f"push_screen requires a Screen instance or str; not {screen!r}" ) - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # Mainly for testing, when push_screen isn't called in an async context - future: asyncio.Future[ScreenResultType] = asyncio.Future() - else: - future = loop.create_future() - if self._screen_stack: self.screen.post_message(events.ScreenSuspend()) self.screen.refresh() @@ -2728,7 +2721,8 @@ def push_screen( except LookupError: message_pump = self.app - next_screen._push_result_callback(message_pump, callback, future) + await_screen: AwaitScreen[ScreenResultType] = AwaitScreen() + next_screen._push_result_callback(message_pump, callback, await_screen) self._load_screen_css(next_screen) self._screen_stack.append(next_screen) next_screen.post_message(events.ScreenResume()) @@ -2740,7 +2734,7 @@ def push_screen( raise NoActiveWorker( "push_screen must be run from a worker when `wait_for_dismiss` is True" ) from None - return future + return await_screen else: return await_mount diff --git a/src/textual/screen.py b/src/textual/screen.py index 23517afaa1..5536ba1c9a 100644 --- a/src/textual/screen.py +++ b/src/textual/screen.py @@ -8,6 +8,7 @@ from __future__ import annotations +import enum import asyncio from functools import partial from operator import attrgetter @@ -23,6 +24,8 @@ Optional, TypeVar, Union, + Literal, + Generator, ) import rich.repr @@ -83,6 +86,33 @@ """Type of a screen result callback function.""" +class _Unset(enum.Enum): + UNSET = enum.auto() + + +class AwaitScreen(Generic[ScreenResultType]): + """ + An Awaitable object that resumes with the result of a Screen. + """ + + def __init__(self) -> None: + self._event = asyncio.Event() + self._result: ScreenResultType | Literal[_Unset.UNSET] = _Unset.UNSET + + async def wait(self) -> ScreenResultType: + await self._event.wait() + assert self._result is not _Unset.UNSET + return self._result + + def __await__(self) -> Generator[Any, Any, ScreenResultType]: + return self.wait().__await__() + + def set_result(self, result): + assert self._result is _Unset.UNSET + self._result = result + self._event.set() + + @rich.repr.auto class ResultCallback(Generic[ScreenResultType]): """Holds the details of a callback.""" @@ -91,21 +121,21 @@ def __init__( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, - future: asyncio.Future[ScreenResultType] | None = None, + await_screen: AwaitScreen[ScreenResultType] | None = None, ) -> None: """Initialise the result callback object. Args: requester: The object making a request for the callback. callback: The callback function. - future: A Future to hold the result. + await_screen: An AwaitScreen to hold the result. """ self.requester = requester """The object in the DOM that requested the callback.""" self.callback: ScreenResultCallbackType | None = callback """The callback function.""" - self.future = future - """A future for the result""" + self.await_screen = await_screen + """An AwaitScreen for the result""" def __call__(self, result: ScreenResultType) -> None: """Call the callback, passing the given result. @@ -116,8 +146,8 @@ def __call__(self, result: ScreenResultType) -> None: Note: If the requested or the callback are `None` this will be a no-op. """ - if self.future is not None: - self.future.set_result(result) + if self.await_screen is not None: + self.await_screen.set_result(result) if self.requester is not None and self.callback is not None: self.requester.call_next(self.callback, result) self.callback = None @@ -1166,17 +1196,19 @@ def _push_result_callback( self, requester: MessagePump, callback: ScreenResultCallbackType[ScreenResultType] | None, - future: asyncio.Future[ScreenResultType | None] | None = None, + await_screen: AwaitScreen[ScreenResultType] | None = None, ) -> None: """Add a result callback to the screen. Args: requester: The object requesting the callback. callback: The callback. - future: A Future to hold the result. + await_screen: An AwaitScreen to hold the result. """ self._result_callbacks.append( - ResultCallback[Optional[ScreenResultType]](requester, callback, future) + ResultCallback[Optional[ScreenResultType]]( + requester, callback, await_screen + ) ) async def _message_loop_exit(self) -> None: