Source code for async_wrapper.task_group.task_group

from __future__ import annotations

from contextlib import AsyncExitStack
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Generic

from anyio import create_task_group as _create_task_group
from anyio.abc import TaskGroup as _TaskGroup
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar, override

from async_wrapper.task_group.value import SoonValue

if TYPE_CHECKING:
    from collections.abc import Awaitable, Coroutine
    from types import TracebackType

    from anyio.abc import CancelScope, CapacityLimiter, Lock, Semaphore

ValueT = TypeVar("ValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")

__all__ = ["TaskGroupWrapper", "create_task_group_wrapper"]


[docs] class TaskGroupWrapper(_TaskGroup): """ wrap :class:`anyio.abc.TaskGroup` Example: >>> import anyio >>> >>> from async_wrapper import TaskGroupWrapper >>> >>> >>> async def test(x: int) -> int: >>> await anyio.sleep(0.1) >>> return x >>> >>> >>> async def main() -> None: >>> async with anyio.create_task_group() as task_group: >>> async with TaskGroupWrapper(task_group) as tg: >>> func = tg.wrap(test) >>> soon_1 = func(1) >>> soon_2 = func(2) >>> >>> assert soon_1.is_ready >>> assert soon_2.is_ready >>> assert soon_1.value == 1 >>> assert soon_2.value == 2 >>> >>> >>> if __name__ == "__main__": >>> anyio.run(main) """ __slots__ = ("_task_group", "_active_self") def __init__(self, task_group: _TaskGroup) -> None: self._task_group = task_group self._active_self = False @property @override def cancel_scope(self) -> CancelScope: # pyright: ignore[reportIncompatibleVariableOverride] return self._task_group.cancel_scope
[docs] @override def start_soon( self, func: Callable[..., Awaitable[Any]], *args: Any, name: Any = None ) -> None: return self._task_group.start_soon(func, *args, name=name)
@override async def start( self, func: Callable[..., Awaitable[Any]], *args: Any, name: Any = None ) -> Any: raise NotImplementedError @override async def __aenter__(self) -> Self: if _is_active(self._task_group): return self await self._task_group.__aenter__() self._active_self = True return self @override async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: if self._active_self: try: return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) finally: self._active_self = False return None
[docs] def wrap( self, func: Callable[ParamT, Awaitable[ValueT]], semaphore: Semaphore | None = None, limiter: CapacityLimiter | None = None, lock: Lock | None = None, ) -> SoonWrapper[ParamT, ValueT]: """ Wrap a function to be used within a wrapper. The wrapped function will return a value shortly. Args: func: The target function to be wrapped. semaphore: An :obj:`anyio.abc.Semaphore`. Defaults to None. limiter: An :obj:`anyio.abc.CapacityLimiter`. Defaults to None. lock: An :obj:`anyio.abc.Lock`. Defaults to None. Returns: The wrapped function. """ return SoonWrapper(func, self, semaphore=semaphore, limiter=limiter, lock=lock)
class SoonWrapper(Generic[ParamT, ValueT]): """wrapped func using in :class:`TaskGroupWrapper`""" __slots__ = ("func", "task_group", "semaphore", "limiter", "lock", "_wrapped") def __init__( self, func: Callable[ParamT, Awaitable[ValueT]], task_group: _TaskGroup, semaphore: Semaphore | None = None, limiter: CapacityLimiter | None = None, lock: Lock | None = None, ) -> None: self.func = func self.task_group = task_group self.semaphore = semaphore self.limiter = limiter self.lock = lock self._wrapped = None def __call__( self, *args: ParamT.args, **kwargs: ParamT.kwargs ) -> SoonValue[ValueT]: value: SoonValue[ValueT] = SoonValue() wrapped = partial(self.wrapped, value, *args, **kwargs) self.task_group.start_soon(wrapped) return value @property def wrapped( self, ) -> Callable[Concatenate[SoonValue[ValueT], ParamT], Coroutine[Any, Any, ValueT]]: """wrapped func using semaphore""" if self._wrapped is not None: return self._wrapped @wraps(self.func) async def wrapped( value: SoonValue[ValueT], *args: ParamT.args, **kwargs: ParamT.kwargs ) -> ValueT: async with AsyncExitStack() as stack: if self.semaphore is not None: await stack.enter_async_context(self.semaphore) if self.limiter is not None: await stack.enter_async_context(self.limiter) if self.lock is not None: await stack.enter_async_context(self.lock) result = await self.func(*args, **kwargs) value._value = result # noqa: SLF001 return result raise RuntimeError("never") # pragma: no cover self._wrapped = wrapped return wrapped def copy( self, semaphore: Semaphore | None = None, limiter: CapacityLimiter | None = None, lock: Lock | None = None, ) -> Self: """ Create a copy of this object. Args: semaphore: An :obj:`anyio.abc.Semaphore`. If provided, it will overwrite the existing semaphore. Defaults to None. limiter: An :obj:`anyio.abc.CapacityLimiter`. If provided, it will overwrite the existing limiter. Defaults to None. lock: An :obj:`anyio.abc.Lock`. If provided, it will overwrite the existing lock. Defaults to None. Returns: A copy of this object with optional overwritten components. """ if semaphore is None: semaphore = self.semaphore if limiter is None: limiter = self.limiter if lock is None: lock = self.lock return SoonWrapper( # type: ignore self.func, self.task_group, semaphore=semaphore, limiter=limiter, lock=lock )
[docs] def create_task_group_wrapper() -> TaskGroupWrapper: """ create new task group wrapper Returns: new :obj:`TaskGroupWrapper` """ return TaskGroupWrapper(_create_task_group())
def _is_active(task_group: _TaskGroup) -> bool: # trio, asyncio return task_group._active # type: ignore # noqa: SLF001