from __future__ import annotations
from contextlib import AsyncExitStack
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Callable, Generic
from anyio import create_task_group as _create_task_group
from anyio.abc import TaskGroup as _TaskGroup
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar, override
from async_wrapper.task_group.value import SoonValue
if TYPE_CHECKING:
from collections.abc import Awaitable, Coroutine
from types import TracebackType
from anyio.abc import CancelScope, CapacityLimiter, Lock, Semaphore
_T = TypeVar("_T", infer_variance=True)
_P = ParamSpec("_P")
__all__ = ["TaskGroupWrapper", "create_task_group_wrapper"]
[docs]
class TaskGroupWrapper(_TaskGroup):
"""
wrap :class:`anyio.abc.TaskGroup`
Example:
.. code-block:: python
import anyio
from async_wrapper import TaskGroupWrapper
async def test(x: int) -> int:
await anyio.sleep(0.1)
return x
async def main() -> None:
async with anyio.create_task_group() as task_group:
async with TaskGroupWrapper(task_group) as tg:
func = tg.wrap(test)
soon_1 = func(1)
soon_2 = func(2)
assert soon_1.is_ready
assert soon_2.is_ready
assert soon_1.value == 1
assert soon_2.value == 2
if __name__ == "__main__":
anyio.run(main)
"""
__slots__ = ("_task_group", "_active_self")
def __init__(self, task_group: _TaskGroup) -> None:
self._task_group = task_group
self._active_self = False
@property
@override
def cancel_scope(self) -> CancelScope: # pyright: ignore[reportIncompatibleVariableOverride]
return self._task_group.cancel_scope
[docs]
@override
def start_soon(
self, func: Callable[..., Awaitable[Any]], *args: Any, name: Any = None
) -> None:
return self._task_group.start_soon(func, *args, name=name)
@override
async def start(
self, func: Callable[..., Awaitable[Any]], *args: Any, name: Any = None
) -> Any:
raise NotImplementedError
@override
async def __aenter__(self) -> Self:
if _is_active(self._task_group):
return self
await self._task_group.__aenter__()
self._active_self = True
return self
@override
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if self._active_self:
try:
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
finally:
self._active_self = False
return None
[docs]
def wrap(
self,
func: Callable[_P, Awaitable[_T]],
semaphore: Semaphore | None = None,
limiter: CapacityLimiter | None = None,
lock: Lock | None = None,
) -> SoonWrapper[_P, _T]:
"""
Wrap a function to be used within a wrapper.
The wrapped function will return a value shortly.
Args:
func: The target function to be wrapped.
semaphore: An :obj:`anyio.abc.Semaphore`. Defaults to None.
limiter: An :obj:`anyio.abc.CapacityLimiter`. Defaults to None.
lock: An :obj:`anyio.abc.Lock`. Defaults to None.
Returns:
The wrapped function.
"""
return SoonWrapper(func, self, semaphore=semaphore, limiter=limiter, lock=lock)
class SoonWrapper(Generic[_P, _T]):
"""wrapped func using in :class:`TaskGroupWrapper`"""
__slots__ = ("func", "task_group", "semaphore", "limiter", "lock", "_wrapped")
def __init__(
self,
func: Callable[_P, Awaitable[_T]],
task_group: _TaskGroup,
semaphore: Semaphore | None = None,
limiter: CapacityLimiter | None = None,
lock: Lock | None = None,
) -> None:
self.func = func
self.task_group = task_group
self.semaphore = semaphore
self.limiter = limiter
self.lock = lock
self._wrapped = None
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> SoonValue[_T]:
value: SoonValue[_T] = SoonValue()
wrapped = partial(self.wrapped, value, *args, **kwargs)
self.task_group.start_soon(wrapped)
return value
@property
def wrapped(
self,
) -> Callable[Concatenate[SoonValue[_T], _P], Coroutine[Any, Any, _T]]:
"""wrapped func using semaphore"""
if self._wrapped is not None:
return self._wrapped
@wraps(self.func)
async def wrapped(
value: SoonValue[_T], *args: _P.args, **kwargs: _P.kwargs
) -> _T:
async with AsyncExitStack() as stack:
if self.semaphore is not None:
await stack.enter_async_context(self.semaphore)
if self.limiter is not None:
await stack.enter_async_context(self.limiter)
if self.lock is not None:
await stack.enter_async_context(self.lock)
result = await self.func(*args, **kwargs)
value._value = result # noqa: SLF001
return result
raise RuntimeError("never") # pragma: no cover
self._wrapped = wrapped
return wrapped
def copy(
self,
semaphore: Semaphore | None = None,
limiter: CapacityLimiter | None = None,
lock: Lock | None = None,
) -> Self:
"""
Create a copy of this object.
Args:
semaphore: An :obj:`anyio.abc.Semaphore`.
If provided, it will overwrite the existing semaphore. Defaults to None.
limiter: An :obj:`anyio.abc.CapacityLimiter`.
If provided, it will overwrite the existing limiter. Defaults to None.
lock: An :obj:`anyio.abc.Lock`.
If provided, it will overwrite the existing lock. Defaults to None.
Returns:
A copy of this object with optional overwritten components.
"""
if semaphore is None:
semaphore = self.semaphore
if limiter is None:
limiter = self.limiter
if lock is None:
lock = self.lock
return SoonWrapper( # type: ignore
self.func, self.task_group, semaphore=semaphore, limiter=limiter, lock=lock
)
[docs]
def create_task_group_wrapper() -> TaskGroupWrapper:
"""
create new task group wrapper
Returns:
new :obj:`TaskGroupWrapper`
"""
return TaskGroupWrapper(_create_task_group())
def _is_active(task_group: _TaskGroup) -> bool:
# trio, asyncio
return task_group._active # type: ignore # noqa: SLF001