260 lines
7.6 KiB
Python
260 lines
7.6 KiB
Python
# Copyright 2022 Amethyst Reese
|
|
# Licensed under the MIT license
|
|
|
|
import asyncio
|
|
from unittest import TestCase
|
|
|
|
import aioitertools as ait
|
|
import aioitertools.asyncio as aio
|
|
from .helpers import async_test
|
|
|
|
slist = ["A", "B", "C"]
|
|
srange = range(3)
|
|
|
|
|
|
class AsyncioTest(TestCase):
|
|
def test_import(self):
|
|
self.assertEqual(ait.asyncio, aio)
|
|
|
|
@async_test
|
|
async def test_as_completed(self):
|
|
async def sleepy(number, duration):
|
|
await asyncio.sleep(duration)
|
|
return number
|
|
|
|
pairs = [(1, 0.3), (2, 0.1), (3, 0.5)]
|
|
expected = [2, 1, 3]
|
|
|
|
futures = [sleepy(*pair) for pair in pairs]
|
|
results = await ait.list(aio.as_completed(futures))
|
|
self.assertEqual(results, expected)
|
|
|
|
futures = [sleepy(*pair) for pair in pairs]
|
|
results = []
|
|
async for value in aio.as_completed(futures):
|
|
results.append(value)
|
|
self.assertEqual(results, expected)
|
|
|
|
@async_test
|
|
async def test_as_completed_timeout(self):
|
|
calls = [(1.0,), (0.1,)]
|
|
|
|
futures = [asyncio.sleep(*args) for args in calls]
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
await ait.list(aio.as_completed(futures, timeout=0.5))
|
|
|
|
futures = [asyncio.sleep(*args) for args in calls]
|
|
results = 0
|
|
with self.assertRaises(asyncio.TimeoutError):
|
|
async for _ in aio.as_completed(futures, timeout=0.5):
|
|
results += 1
|
|
self.assertEqual(results, 1)
|
|
|
|
@async_test
|
|
async def test_as_generated(self):
|
|
async def gen():
|
|
for i in range(10):
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
|
|
gens = [gen(), gen(), gen()]
|
|
expected = list(range(10)) * 3
|
|
results = []
|
|
async for value in aio.as_generated(gens):
|
|
results.append(value)
|
|
self.assertEqual(30, len(results))
|
|
self.assertListEqual(sorted(expected), sorted(results))
|
|
|
|
@async_test
|
|
async def test_as_generated_exception(self):
|
|
async def gen1():
|
|
for i in range(3):
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
raise Exception("fake")
|
|
|
|
async def gen2():
|
|
for i in range(10):
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
|
|
gens = [gen1(), gen2()]
|
|
results = []
|
|
with self.assertRaisesRegex(Exception, "fake"):
|
|
async for value in aio.as_generated(gens):
|
|
results.append(value)
|
|
self.assertNotIn(10, results)
|
|
|
|
@async_test
|
|
async def test_as_generated_return_exception(self):
|
|
async def gen1():
|
|
for i in range(3):
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
raise Exception("fake")
|
|
|
|
async def gen2():
|
|
for i in range(10):
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
|
|
gens = [gen1(), gen2()]
|
|
expected = list(range(3)) + list(range(10))
|
|
errors = []
|
|
results = []
|
|
async for value in aio.as_generated(gens, return_exceptions=True):
|
|
if isinstance(value, Exception):
|
|
errors.append(value)
|
|
else:
|
|
results.append(value)
|
|
self.assertListEqual(sorted(expected), sorted(results))
|
|
self.assertEqual(1, len(errors))
|
|
self.assertIsInstance(errors[0], Exception)
|
|
|
|
@async_test
|
|
async def test_as_generated_task_cancelled(self):
|
|
async def gen(max: int = 10):
|
|
for i in range(5):
|
|
if i > max:
|
|
raise asyncio.CancelledError
|
|
yield i
|
|
await asyncio.sleep(0)
|
|
|
|
gens = [gen(2), gen()]
|
|
expected = list(range(3)) + list(range(5))
|
|
results = []
|
|
async for value in aio.as_generated(gens):
|
|
results.append(value)
|
|
self.assertListEqual(sorted(expected), sorted(results))
|
|
|
|
@async_test
|
|
async def test_as_generated_cancelled(self):
|
|
async def gen():
|
|
for i in range(5):
|
|
yield i
|
|
await asyncio.sleep(0.1)
|
|
|
|
expected = [0, 0, 1, 1]
|
|
results = []
|
|
|
|
async def foo():
|
|
gens = [gen(), gen()]
|
|
async for value in aio.as_generated(gens):
|
|
results.append(value)
|
|
return results
|
|
|
|
task = asyncio.ensure_future(foo())
|
|
await asyncio.sleep(0.15)
|
|
task.cancel()
|
|
await task
|
|
|
|
self.assertListEqual(sorted(expected), sorted(results))
|
|
|
|
@async_test
|
|
async def test_gather_input_types(self):
|
|
async def fn(arg):
|
|
await asyncio.sleep(0.001)
|
|
return arg
|
|
|
|
fns = [fn(1), asyncio.ensure_future(fn(2))]
|
|
if hasattr(asyncio, "create_task"):
|
|
# 3.7 only
|
|
fns.append(asyncio.create_task(fn(3)))
|
|
else:
|
|
fns.append(fn(3))
|
|
|
|
result = await aio.gather(*fns)
|
|
self.assertEqual([1, 2, 3], result)
|
|
|
|
@async_test
|
|
async def test_gather_limited(self):
|
|
max_counter = 0
|
|
counter = 0
|
|
|
|
async def fn(arg):
|
|
nonlocal counter, max_counter
|
|
counter += 1
|
|
max_counter = max(max_counter, counter)
|
|
await asyncio.sleep(0.001)
|
|
counter -= 1
|
|
return arg
|
|
|
|
# Limit of 2
|
|
result = await aio.gather(*[fn(i) for i in range(10)], limit=2)
|
|
self.assertEqual(2, max_counter)
|
|
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], result)
|
|
|
|
# No limit
|
|
result = await aio.gather(*[fn(i) for i in range(10)])
|
|
self.assertEqual(
|
|
10, max_counter
|
|
) # TODO: on a loaded machine this might be less?
|
|
self.assertEqual([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], result)
|
|
|
|
@async_test
|
|
async def test_gather_limited_dupes(self):
|
|
async def fn(arg):
|
|
await asyncio.sleep(0.001)
|
|
return arg
|
|
|
|
f = fn(1)
|
|
g = fn(2)
|
|
result = await aio.gather(f, f, f, g, f, g, limit=2)
|
|
self.assertEqual([1, 1, 1, 2, 1, 2], result)
|
|
|
|
f = fn(1)
|
|
g = fn(2)
|
|
result = await aio.gather(f, f, f, g, f, g)
|
|
self.assertEqual([1, 1, 1, 2, 1, 2], result)
|
|
|
|
@async_test
|
|
async def test_gather_with_exceptions(self):
|
|
class MyException(Exception):
|
|
pass
|
|
|
|
async def fn(arg, fail=False):
|
|
await asyncio.sleep(arg)
|
|
if fail:
|
|
raise MyException(arg)
|
|
return arg
|
|
|
|
with self.assertRaises(MyException):
|
|
await aio.gather(fn(0.002, fail=True), fn(0.001))
|
|
|
|
result = await aio.gather(
|
|
fn(0.002, fail=True), fn(0.001), return_exceptions=True
|
|
)
|
|
self.assertEqual(result[1], 0.001)
|
|
self.assertIsInstance(result[0], MyException)
|
|
|
|
@async_test
|
|
async def test_gather_cancel(self):
|
|
cancelled = False
|
|
started = False
|
|
|
|
async def _fn():
|
|
nonlocal started, cancelled
|
|
try:
|
|
started = True
|
|
await asyncio.sleep(10) # might as well be forever
|
|
except asyncio.CancelledError:
|
|
nonlocal cancelled
|
|
cancelled = True
|
|
raise
|
|
|
|
async def _gather():
|
|
await aio.gather(_fn())
|
|
|
|
if hasattr(asyncio, "create_task"):
|
|
# 3.7+ only
|
|
task = asyncio.create_task(_gather())
|
|
else:
|
|
task = asyncio.ensure_future(_gather())
|
|
# to insure the gather actually runs
|
|
await asyncio.sleep(0)
|
|
task.cancel()
|
|
with self.assertRaises(asyncio.CancelledError):
|
|
await task
|
|
self.assertTrue(started)
|
|
self.assertTrue(cancelled)
|