Source code for async_wrapper.wait

from __future__ import annotations

from functools import partial
from threading import local
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar

from anyio import EndOfStream, Event, create_memory_object_stream, create_task_group
from typing_extensions import ParamSpec, Self, override

from async_wrapper.exception import PendingError

if TYPE_CHECKING:
    from types import TracebackType

    from anyio import EventStatistics
    from anyio.abc import CancelScope, TaskGroup
    from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream


__all__ = ["Waiter", "Completed", "wait_for"]

ValueT_co = TypeVar("ValueT_co", covariant=True)
ParamT = ParamSpec("ParamT")
Pending = local()


[docs]class Waiter(Event): """ wait wrapper Example: >>> import anyio >>> >>> from async_wrapper import Waiter >>> >>> >>> async def test() -> None: >>> print("test: start") >>> await anyio.sleep(1) >>> print("test: end") >>> >>> >>> async def test2(event: anyio.Event) -> None: >>> print("test2: start") >>> await event.wait() >>> print("test2: end") >>> >>> >>> async def main() -> None: >>> async with anyio.create_task_group() as task_group: >>> event = Waiter(test)(task_group) >>> task_group.start_soon(test2, event) >>> >>> >>> if __name__ == "__main__": >>> anyio.run(main) $ poetry run python main.py test: start test2: start test: end test2: end """ __slots__ = ("_event", "_func", "_args", "_kwargs") _event: Event def __init__( self, func: Callable[ParamT, Awaitable[Any]], *args: ParamT.args, **kwargs: ParamT.kwargs, ) -> None: self._func = func self._args = args self._kwargs = kwargs def __call__(self, task_group: TaskGroup, *, name: Any = None) -> Self: """start soon in task group""" task_group.start_soon( wait_for, self, partial(self._func, *self._args, **self._kwargs), name=name ) return self def __copy__(self) -> Self: return self.copy()
[docs] def copy(self, *args: Any, **kwargs: Any) -> Self: """ create new event Returns: new :obj:`Waiter` """ if not args: args = tuple(self._args) if not kwargs: kwargs = self._kwargs.copy() return Waiter(self._func, *args, **kwargs) # type: ignore
@override def __new__( cls, func: Callable[ParamT, Awaitable[Any]], *args: ParamT.args, **kwargs: ParamT.kwargs, ) -> Self: new = object.__new__(cls) new._event = super().__new__(cls) # noqa: SLF001 return new
[docs] @override def set(self) -> None: return self._event.set()
[docs] @override def is_set(self) -> bool: return self._event.is_set()
[docs] @override async def wait(self) -> None: return await self._event.wait()
[docs] @override def statistics(self) -> EventStatistics: return self._event.statistics()
[docs]class Completed: """ like :func:`asyncio.as_completed` Example: >>> from __future__ import annotations >>> >>> import anyio >>> >>> from async_wrapper import Completed >>> >>> >>> async def test( >>> x: int, >>> sleep: float, >>> result: list[int] | None = None, >>> ) -> int: >>> print(f"[{x}] test: start") >>> await anyio.sleep(sleep) >>> print(f"[{x}] test: end") >>> if result is not None: >>> result.append(x) >>> return x >>> >>> >>> async def main() -> None: >>> result: list[int] = [] >>> async with anyio.create_task_group() as task_group: >>> task_group.start_soon(test, 1, 1, result) >>> async with Completed(task_group) as completed: >>> completed.start_soon(None, test, 2, 0.2) >>> completed.start_soon(None, test, 3, 0.1) >>> completed.start_soon(None, test, 4, 0.3) >>> >>> result.extend([value async for value in completed]) >>> >>> assert result == [3, 2, 4, 1] >>> >>> result = [] >>> async with anyio.create_task_group() as task_group: >>> task_group.start_soon(test, 1, 1, result) >>> async with Completed() as completed: >>> completed.start_soon(task_group, test, 2, 0.2) >>> completed.start_soon(task_group, test, 3, 0.1) >>> completed.start_soon(task_group, test, 4, 0.3) >>> >>> result.extend([value async for value in completed]) >>> >>> assert result == [3, 2, 4, 1] >>> >>> >>> if __name__ == "__main__": >>> anyio.run(main) """ __slots__ = ("_events", "__setter", "__getter", "__task_group") def __init__(self, task_group: TaskGroup | None = None) -> None: self._events: dict[Waiter, MemoryObjectReceiveStream[Any]] = {} self.__setter: MemoryObjectSendStream[Waiter] | None = None self.__getter: MemoryObjectReceiveStream[Waiter] | None = None self.__task_group: TaskGroup | None = task_group @property def _is_active(self) -> bool: return self.__setter is not None @property def _setter(self) -> MemoryObjectSendStream[Waiter]: if self.__setter is None: raise PendingError("enter first") return self.__setter @property def _getter(self) -> MemoryObjectReceiveStream[Waiter]: if self.__getter is None: raise PendingError("enter first") return self.__getter def _task_group(self, task_group: TaskGroup | None) -> TaskGroup: if self.__task_group is None: if task_group is None: raise ValueError("there is no task group") return task_group if task_group is None: return self.__task_group if task_group is not self.__task_group: raise ValueError("diff task groups") return task_group
[docs] def start_soon( self, task_group: TaskGroup | None, func: Callable[..., Awaitable[Any]], *args: Any, name: Any = None, ) -> None: """ Start a coroutine in a task group, similar to :meth:`anyio.abc.TaskGroup.start_soon`. If a task group is already provided, the task_group parameter should be the same object. Args: task_group: An :class:`anyio.abc.TaskGroup`. Defaults to None. func: The target coroutine function. name: The name used in :meth:`anyio.abc.TaskGroup.start_soon`. Defaults to None. """ # noqa: D205 if not self._is_active: raise PendingError("enter first") task_group = self._task_group(task_group) waiter, getter = _create_waiter(func, *args) waiter(task_group, name=name) self._events[waiter] = getter
async def _shutdown(self) -> None: async with create_task_group() as task_group: task_group.start_soon(self._setter.aclose) task_group.start_soon(self._getter.aclose) for getter in self._events.values(): if not getter._closed: # noqa: SLF001 task_group.start_soon(getter.aclose) async def _anext(self) -> Any: if not self._events: raise EndOfStream async with create_task_group() as task_group: for event in self._events: task_group.start_soon( _wait_waiter, event, task_group.cancel_scope, self ) event = await self._getter.receive() getter = self._events.pop(event) async with getter: return await getter.receive() async def __aenter__(self) -> Self: self.__setter, self.__getter = create_memory_object_stream(1) return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: try: await self._shutdown() finally: self.__setter = None self.__getter = None def __aiter__(self) -> Self: if self.__setter is None: raise PendingError("enter first") return self async def __anext__(self) -> Any: try: return await self._anext() except EndOfStream as exc: raise StopAsyncIteration from exc
[docs]async def wait_for( event: Event | Iterable[Event], func: Callable[ParamT, Awaitable[ValueT_co]], *args: ParamT.args, **kwargs: ParamT.kwargs, ) -> ValueT_co: """ Wait for an event before executing an awaitable function. like :func:`asyncio.wait_for` Args: event: An :obj:`anyio.Event` or an iterable of events. func: An awaitable function to be executed. Returns: The result of the executed function. Example: >>> import anyio >>> >>> from async_wrapper import wait_for >>> >>> >>> async def test() -> None: >>> print("test: start") >>> await anyio.sleep(1) >>> print("test: end") >>> >>> >>> async def test2(event: anyio.Event) -> None: >>> print("test2: start") >>> await event.wait() >>> print("test2: end") >>> >>> >>> async def main() -> None: >>> event = anyio.Event() >>> async with anyio.create_task_group() as task_group: >>> task_group.start_soon(wait_for, event, test) >>> task_group.start_soon(test2, event) >>> >>> >>> if __name__ == "__main__": >>> anyio.run(main) $ poetry run python main.py test: start test2: start test: end test2: end """ event = set(event) if not isinstance(event, Event) else (event,) try: return await func(*args, **kwargs) finally: for sub in event: sub.set()
async def _wait_waiter( waiter: Waiter, scope: CancelScope, completed: Completed ) -> None: await waiter.wait() completed._setter.send_nowait(waiter) # noqa: SLF001 scope.cancel() async def _intercept_value( setter: MemoryObjectSendStream[Any], getter: MemoryObjectReceiveStream[Any], func: Callable[..., Awaitable[Any]], *args: Any, ) -> None: async with setter: result = await func(*args) if not getter._closed: # noqa: SLF001 await setter.send(result) def _create_waiter( func: Callable[..., Awaitable[Any]], *args: Any ) -> tuple[Waiter, MemoryObjectReceiveStream[Any]]: setter, getter = create_memory_object_stream(1) waiter = Waiter(_intercept_value, setter, getter, func, *args) return waiter, getter