Source code for async_wrapper.convert._sync.main

from __future__ import annotations

from concurrent.futures import ThreadPoolExecutor, wait
from contextvars import ContextVar
from functools import cached_property, partial, wraps
from importlib.util import find_spec
from typing import TYPE_CHECKING, Any, Generic, overload

import anyio
from sniffio import AsyncLibraryNotFoundError, current_async_library
from typing_extensions import ParamSpec, TypeVar

from async_wrapper.convert._sync.sqlalchemy import check_is_unset, run_sa_greenlet

if TYPE_CHECKING:
    from collections.abc import Awaitable, Callable

_T = TypeVar("_T", infer_variance=True)
_P = ParamSpec("_P")

__all__ = ["async_to_sync"]

current_async_lib_var = ContextVar("current_async_lib", default="asyncio")
use_uvloop_var = ContextVar("use_uvloop", default=False)
has_sqlalchemy = (
    find_spec("sqlalchemy") is not None and find_spec("greenlet") is not None
)


class Sync(Generic[_P, _T]):
    def __init__(self, func: Callable[_P, Awaitable[_T]]) -> None:
        self._func = func

    @cached_property
    def _wrapped(self) -> Callable[_P, _T]:
        sync_func = _as_sync(self._func)

        @wraps(self._func)
        def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
            if not _running_in_async_context():
                return _run(self._func, *args, **kwargs)

            backend = _get_current_backend()
            use_uvloop = _check_uvloop()

            with ThreadPoolExecutor(
                1, initializer=_init, initargs=(backend, use_uvloop)
            ) as pool:
                future = pool.submit(sync_func, *args, **kwargs)
                wait([future])
                return future.result()

        return inner

    def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T:
        return self._wrapped(*args, **kwargs)


@overload
def async_to_sync(
    func_or_awaitable: Callable[_P, Awaitable[_T]],
) -> Callable[_P, _T]: ...
@overload
def async_to_sync(func_or_awaitable: Awaitable[_T]) -> Callable[[], _T]: ...
@overload
def async_to_sync(
    func_or_awaitable: Callable[..., Awaitable[_T]] | Awaitable[_T],
) -> Callable[..., _T]: ...
[docs] def async_to_sync( func_or_awaitable: Callable[_P, Awaitable[_T]] | Awaitable[_T], ) -> Callable[_P, _T] | Callable[[], _T]: """ Convert an awaitable function or awaitable object to a synchronous function. If used within an asynchronous context, attempts to use the same backend. Defaults to asyncio. Args: func_or_awaitable: An awaitable function or awaitable object. Returns: A synchronous function. Example: .. code-block:: python import asyncio import time import anyio import sniffio from async_wrapper import async_to_sync @async_to_sync async def test(x: int) -> int: backend = sniffio.current_async_library() if backend == "asyncio": loop = asyncio.get_running_loop() print(backend, loop) else: print(backend) await anyio.sleep(1) return x def main() -> None: start = time.perf_counter() result = test(1) end = time.perf_counter() assert result == 1 assert end - start < 1.1 async def async_main() -> None: start = time.perf_counter() result = test(1) end = time.perf_counter() assert result == 1 assert end - start < 1.1 if __name__ == "__main__": main() anyio.run( async_main, backend="asyncio", backend_options={"use_uvloop": True}, ) anyio.run( async_main, backend="asyncio", backend_options={"use_uvloop": True}, ) anyio.run(async_main, backend="trio") """ if callable(func_or_awaitable): from async_wrapper.convert._async import Async if isinstance(func_or_awaitable, Async): return func_or_awaitable._func # noqa: SLF001 return Sync(func_or_awaitable) if has_sqlalchemy: result = run_sa_greenlet(func_or_awaitable) if not check_is_unset(result): return result # pyright: ignore[reportReturnType] awaitable_func = _awaitable_to_function(func_or_awaitable) return _async_func_to_sync(awaitable_func)
def _async_func_to_sync(func: Callable[_P, Awaitable[_T]]) -> Callable[_P, _T]: sync_func = _as_sync(func) @wraps(func) def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: if not _running_in_async_context(): return _run(func, *args, **kwargs) backend = _get_current_backend() use_uvloop = _check_uvloop() with ThreadPoolExecutor( 1, initializer=_init, initargs=(backend, use_uvloop) ) as pool: future = pool.submit(sync_func, *args, **kwargs) wait([future]) return future.result() return inner def _as_sync(func: Callable[_P, Awaitable[_T]]) -> Callable[_P, _T]: @wraps(func) def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T: return _run(func, *args, **kwargs) return inner def _run(func: Callable[_P, Awaitable[_T]], *args: _P.args, **kwargs: _P.kwargs) -> _T: backend = _get_current_backend() new_func = partial(func, *args, **kwargs) backend_options: dict[str, Any] = {} if backend == "asyncio": backend_options["use_uvloop"] = _check_uvloop() return anyio.run(new_func, backend=backend, backend_options=backend_options) def _check_uvloop() -> bool: if use_uvloop_var.get(): return True try: import uvloop except ImportError: # pragma: no cover return False import asyncio policy = asyncio.get_event_loop_policy() return isinstance(policy, uvloop.EventLoopPolicy) def _running_in_async_context() -> bool: try: current_async_library() except AsyncLibraryNotFoundError: return False return True def _get_current_backend() -> str: try: return current_async_library() except AsyncLibraryNotFoundError: return current_async_lib_var.get() def _init(backend: str, use_uvloop: bool) -> None: # noqa: FBT001 current_async_lib_var.set(backend) use_uvloop_var.set(use_uvloop) def _awaitable_to_function(value: Awaitable[_T]) -> Callable[[], Awaitable[_T]]: async def awaitable() -> _T: return await value return awaitable