254 lines
8.2 KiB
Python
254 lines
8.2 KiB
Python
# Copyright 2022 Amethyst Reese
|
|
# Licensed under the MIT license
|
|
|
|
"""
|
|
Friendlier version of asyncio standard library.
|
|
|
|
Provisional library. Must be imported as `aioitertools.asyncio`.
|
|
"""
|
|
|
|
import asyncio
|
|
import time
|
|
from collections.abc import AsyncGenerator, AsyncIterable, Awaitable, Iterable
|
|
from typing import Any, cast, Optional
|
|
|
|
from .builtins import iter as aiter, maybe_await
|
|
from .types import AnyIterable, AsyncIterator, MaybeAwaitable, T
|
|
|
|
|
|
async def as_completed(
|
|
aws: Iterable[Awaitable[T]],
|
|
*,
|
|
timeout: Optional[float] = None,
|
|
) -> AsyncIterator[T]:
|
|
"""
|
|
Run awaitables in `aws` concurrently, and yield results as they complete.
|
|
|
|
Unlike `asyncio.as_completed`, this yields actual results, and does not require
|
|
awaiting each item in the iterable.
|
|
|
|
Cancels all remaining awaitables if a timeout is given and the timeout threshold
|
|
is reached.
|
|
|
|
Example::
|
|
|
|
async for value in as_completed(futures):
|
|
... # use value immediately
|
|
|
|
"""
|
|
done: set[Awaitable[T]] = set()
|
|
pending: set[Awaitable[T]] = {asyncio.ensure_future(a) for a in aws}
|
|
remaining: Optional[float] = None
|
|
|
|
if timeout and timeout > 0:
|
|
threshold = time.time() + timeout
|
|
else:
|
|
timeout = None
|
|
|
|
while pending:
|
|
if timeout:
|
|
remaining = threshold - time.time()
|
|
if remaining <= 0:
|
|
for fut in pending:
|
|
if isinstance(fut, asyncio.Future):
|
|
fut.cancel()
|
|
else: # pragma: no cover
|
|
pass
|
|
raise asyncio.TimeoutError()
|
|
|
|
# asyncio.Future inherits from typing.Awaitable
|
|
# asyncio.wait takes Iterable[Union[Future, Generator, Awaitable]], but
|
|
# returns Tuple[Set[Future], Set[Future]. Because mypy doesn't like assigning
|
|
# these values to existing Set[Awaitable] or even Set[Union[Awaitable, Future]],
|
|
# we need to first cast the results to something that we can actually use
|
|
# asyncio.Future: https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/futures.pyi#L30
|
|
# asyncio.wait(): https://github.com/python/typeshed/blob/72ff7b94e534c610ddf8939bacbc55343e9465d2/stdlib/3/asyncio/tasks.pyi#L89
|
|
done, pending = cast(
|
|
tuple[set[Awaitable[T]], set[Awaitable[T]]],
|
|
await asyncio.wait(
|
|
pending,
|
|
timeout=remaining,
|
|
return_when=asyncio.FIRST_COMPLETED,
|
|
),
|
|
)
|
|
|
|
for item in done:
|
|
yield await item
|
|
|
|
|
|
async def as_generated(
|
|
iterables: Iterable[AsyncIterable[T]],
|
|
*,
|
|
return_exceptions: bool = False,
|
|
) -> AsyncIterable[T]:
|
|
"""
|
|
Yield results from one or more async iterables, in the order they are produced.
|
|
|
|
Like :func:`as_completed`, but for async iterators or generators instead of futures.
|
|
Creates a separate task to drain each iterable, and a single queue for results.
|
|
|
|
If ``return_exceptions`` is ``False``, then any exception will be raised, and
|
|
pending iterables and tasks will be cancelled, and async generators will be closed.
|
|
If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
|
|
and execution will continue until all iterables have been fully consumed.
|
|
|
|
Example::
|
|
|
|
async def generator(x):
|
|
for i in range(x):
|
|
yield i
|
|
|
|
gen1 = generator(10)
|
|
gen2 = generator(12)
|
|
|
|
async for value in as_generated([gen1, gen2]):
|
|
... # intermixed values yielded from gen1 and gen2
|
|
"""
|
|
|
|
exc_queue: asyncio.Queue[Exception] = asyncio.Queue()
|
|
queue: asyncio.Queue[T] = asyncio.Queue()
|
|
|
|
async def tailer(iter: AsyncIterable[T]) -> None:
|
|
try:
|
|
async for item in iter:
|
|
await queue.put(item)
|
|
except asyncio.CancelledError:
|
|
if isinstance(iter, AsyncGenerator): # pragma:nocover
|
|
await iter.aclose()
|
|
raise
|
|
except Exception as e:
|
|
await exc_queue.put(e)
|
|
|
|
tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
|
|
pending = set(tasks)
|
|
|
|
try:
|
|
while pending:
|
|
try:
|
|
exc = exc_queue.get_nowait()
|
|
if return_exceptions:
|
|
yield exc # type: ignore
|
|
else:
|
|
raise exc
|
|
except asyncio.QueueEmpty:
|
|
pass
|
|
|
|
try:
|
|
value = queue.get_nowait()
|
|
yield value
|
|
except asyncio.QueueEmpty:
|
|
for task in list(pending):
|
|
if task.done():
|
|
pending.remove(task)
|
|
await asyncio.sleep(0.001)
|
|
|
|
except (asyncio.CancelledError, GeneratorExit):
|
|
pass
|
|
|
|
finally:
|
|
for task in tasks:
|
|
if not task.done():
|
|
task.cancel()
|
|
|
|
for task in tasks:
|
|
try:
|
|
await task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
|
|
async def gather(
|
|
*args: Awaitable[T],
|
|
return_exceptions: bool = False,
|
|
limit: int = -1,
|
|
) -> list[Any]:
|
|
"""
|
|
Like asyncio.gather but with a limit on concurrency.
|
|
|
|
Note that all results are buffered.
|
|
|
|
If gather is cancelled all tasks that were internally created and still pending
|
|
will be cancelled as well.
|
|
|
|
Example::
|
|
|
|
futures = [some_coro(i) for i in range(10)]
|
|
|
|
results = await gather(*futures, limit=2)
|
|
"""
|
|
|
|
# For detecting input duplicates and reconciling them at the end
|
|
input_map: dict[Awaitable[T], list[int]] = {}
|
|
# This is keyed on what we'll get back from asyncio.wait
|
|
pos: dict[asyncio.Future[T], int] = {}
|
|
ret: list[Any] = [None] * len(args)
|
|
|
|
pending: set[asyncio.Future[T]] = set()
|
|
done: set[asyncio.Future[T]] = set()
|
|
|
|
next_arg = 0
|
|
|
|
while True:
|
|
while next_arg < len(args) and (limit == -1 or len(pending) < limit):
|
|
# We have to defer the creation of the Task as long as possible
|
|
# because once we do, it starts executing, regardless of what we
|
|
# have in the pending set.
|
|
if args[next_arg] in input_map:
|
|
input_map[args[next_arg]].append(next_arg)
|
|
else:
|
|
# We call ensure_future directly to ensure that we have a Task
|
|
# because the return value of asyncio.wait will be an implicit
|
|
# task otherwise, and we won't be able to know which input it
|
|
# corresponds to.
|
|
task: asyncio.Future[T] = asyncio.ensure_future(args[next_arg])
|
|
pending.add(task)
|
|
pos[task] = next_arg
|
|
input_map[args[next_arg]] = [next_arg]
|
|
next_arg += 1
|
|
|
|
# pending might be empty if the last items of args were dupes;
|
|
# asyncio.wait([]) will raise an exception.
|
|
if pending:
|
|
try:
|
|
done, pending = await asyncio.wait(
|
|
pending, return_when=asyncio.FIRST_COMPLETED
|
|
)
|
|
for x in done:
|
|
if return_exceptions and x.exception():
|
|
ret[pos[x]] = x.exception()
|
|
else:
|
|
ret[pos[x]] = x.result()
|
|
except asyncio.CancelledError:
|
|
# Since we created these tasks we should cancel them
|
|
for x in pending:
|
|
x.cancel()
|
|
# we insure that all tasks are cancelled before we raise
|
|
await asyncio.gather(*pending, return_exceptions=True)
|
|
raise
|
|
|
|
if not pending and next_arg == len(args):
|
|
break
|
|
|
|
for lst in input_map.values():
|
|
for i in range(1, len(lst)):
|
|
ret[lst[i]] = ret[lst[0]]
|
|
|
|
return ret
|
|
|
|
|
|
async def gather_iter(
|
|
itr: AnyIterable[MaybeAwaitable[T]],
|
|
return_exceptions: bool = False,
|
|
limit: int = -1,
|
|
) -> list[T]:
|
|
"""
|
|
Wrapper around gather to handle gathering an iterable instead of ``*args``.
|
|
|
|
Note that the iterable values don't have to be awaitable.
|
|
"""
|
|
return await gather(
|
|
*[maybe_await(i) async for i in aiter(itr)],
|
|
return_exceptions=return_exceptions,
|
|
limit=limit,
|
|
)
|