373 lines
10 KiB
Python
373 lines
10 KiB
Python
# Copyright 2022 Amethyst Reese
|
|
# Licensed under the MIT license
|
|
|
|
import asyncio
|
|
from collections.abc import AsyncIterator
|
|
from unittest import TestCase
|
|
|
|
import aioitertools as ait
|
|
from .helpers import async_test
|
|
|
|
slist = ["A", "B", "C"]
|
|
srange = range(3)
|
|
srange1 = range(1, 4)
|
|
srange0 = range(1)
|
|
|
|
|
|
class BuiltinsTest(TestCase):
|
|
|
|
# aioitertools.all()
|
|
|
|
@async_test
|
|
async def test_all_list(self):
|
|
self.assertTrue(await ait.all([True, 1, "string"]))
|
|
self.assertFalse(await ait.all([True, 0, "string"]))
|
|
|
|
@async_test
|
|
async def test_all_range(self):
|
|
self.assertTrue(await ait.all(srange1))
|
|
self.assertFalse(await ait.all(srange))
|
|
|
|
@async_test
|
|
async def test_all_generator(self):
|
|
self.assertTrue(await ait.all(x for x in srange1))
|
|
self.assertFalse(await ait.all(x for x in srange))
|
|
|
|
@async_test
|
|
async def test_all_async_generator(self):
|
|
self.assertTrue(await ait.all(ait.iter(srange1)))
|
|
self.assertFalse(await ait.all(ait.iter(srange)))
|
|
|
|
# aioitertools.any()
|
|
|
|
@async_test
|
|
async def test_any_list(self):
|
|
self.assertTrue(await ait.any([False, 1, ""]))
|
|
self.assertFalse(await ait.any([False, 0, ""]))
|
|
|
|
@async_test
|
|
async def test_any_range(self):
|
|
self.assertTrue(await ait.any(srange))
|
|
self.assertTrue(await ait.any(srange1))
|
|
self.assertFalse(await ait.any(srange0))
|
|
|
|
@async_test
|
|
async def test_any_generator(self):
|
|
self.assertTrue(await ait.any(x for x in srange))
|
|
self.assertTrue(await ait.any(x for x in srange1))
|
|
self.assertFalse(await ait.any(x for x in srange0))
|
|
|
|
@async_test
|
|
async def test_any_async_generator(self):
|
|
self.assertTrue(await ait.any(ait.iter(srange)))
|
|
self.assertTrue(await ait.any(ait.iter(srange1)))
|
|
self.assertFalse(await ait.any(ait.iter(srange0)))
|
|
|
|
# aioitertools.iter()
|
|
|
|
@async_test
|
|
async def test_iter_list(self):
|
|
it = ait.iter(slist)
|
|
self.assertIsInstance(it, AsyncIterator)
|
|
idx = 0
|
|
async for item in it:
|
|
self.assertEqual(item, slist[idx])
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_iter_range(self):
|
|
it = ait.iter(srange)
|
|
self.assertIsInstance(it, AsyncIterator)
|
|
idx = 0
|
|
async for item in it:
|
|
self.assertEqual(item, srange[idx])
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_iter_iterable(self):
|
|
sentinel = object()
|
|
|
|
class async_iterable:
|
|
def __aiter__(self):
|
|
return sentinel
|
|
|
|
aiter = async_iterable()
|
|
self.assertEqual(ait.iter(aiter), sentinel)
|
|
|
|
@async_test
|
|
async def test_iter_iterator(self):
|
|
sentinel = object()
|
|
|
|
class async_iterator:
|
|
def __aiter__(self):
|
|
return sentinel
|
|
|
|
def __anext__(self):
|
|
return sentinel
|
|
|
|
aiter = async_iterator()
|
|
self.assertEqual(ait.iter(aiter), aiter)
|
|
|
|
@async_test
|
|
async def test_iter_async_generator(self):
|
|
async def async_gen():
|
|
yield 1
|
|
yield 2
|
|
|
|
agen = async_gen()
|
|
self.assertEqual(ait.iter(agen), agen)
|
|
|
|
# aioitertools.next()
|
|
|
|
@async_test
|
|
async def test_next_list(self):
|
|
it = ait.iter(slist)
|
|
self.assertEqual(await ait.next(it), "A")
|
|
self.assertEqual(await ait.next(it), "B")
|
|
self.assertEqual(await ait.next(it), "C")
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await ait.next(it)
|
|
|
|
@async_test
|
|
async def test_next_range(self):
|
|
it = ait.iter(srange)
|
|
self.assertEqual(await ait.next(it), 0)
|
|
self.assertEqual(await ait.next(it), 1)
|
|
self.assertEqual(await ait.next(it), 2)
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await ait.next(it)
|
|
|
|
@async_test
|
|
async def test_next_iterable(self):
|
|
class async_iter:
|
|
def __init__(self):
|
|
self.index = 0
|
|
|
|
def __aiter__(self):
|
|
return self
|
|
|
|
def __anext__(self):
|
|
if self.index > 2:
|
|
raise StopAsyncIteration()
|
|
return self.fake_next()
|
|
|
|
async def fake_next(self):
|
|
value = slist[self.index]
|
|
self.index += 1
|
|
return value
|
|
|
|
it = ait.iter(async_iter())
|
|
self.assertEqual(await ait.next(it), "A")
|
|
self.assertEqual(await ait.next(it), "B")
|
|
self.assertEqual(await ait.next(it), "C")
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await ait.next(it)
|
|
|
|
it = iter(slist)
|
|
self.assertEqual(await ait.next(it), "A")
|
|
self.assertEqual(await ait.next(it), "B")
|
|
self.assertEqual(await ait.next(it), "C")
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await ait.next(it)
|
|
|
|
@async_test
|
|
async def test_next_async_generator(self):
|
|
async def async_gen():
|
|
for item in slist:
|
|
yield item
|
|
|
|
it = ait.iter(async_gen())
|
|
self.assertEqual(await ait.next(it), "A")
|
|
self.assertEqual(await ait.next(it), "B")
|
|
self.assertEqual(await ait.next(it), "C")
|
|
with self.assertRaises(StopAsyncIteration):
|
|
await ait.next(it)
|
|
|
|
@async_test
|
|
async def test_next_default_iterable(self):
|
|
it = iter(["A"])
|
|
|
|
self.assertEqual(await ait.next(it, "?"), "A")
|
|
# End of iteration
|
|
self.assertEqual(await ait.next(it, "?"), "?")
|
|
|
|
@async_test
|
|
async def test_next_default_async_iterable(self):
|
|
it = ait.iter(["A"])
|
|
self.assertEqual(await ait.next(it, "?"), "A")
|
|
# End of iteration
|
|
self.assertEqual(await ait.next(it, "?"), "?")
|
|
|
|
# aioitertools.list()
|
|
|
|
@async_test
|
|
async def test_list(self):
|
|
self.assertEqual(await ait.list(ait.iter(slist)), slist)
|
|
|
|
@async_test
|
|
async def test_tuple(self):
|
|
self.assertEqual(await ait.tuple(ait.iter(slist)), tuple(slist))
|
|
|
|
# aioitertools.set()
|
|
|
|
@async_test
|
|
async def test_set(self):
|
|
self.assertEqual(await ait.set(ait.iter(slist)), set(slist))
|
|
|
|
# aioitertools.enumerate()
|
|
|
|
@async_test
|
|
async def test_enumerate(self):
|
|
async for index, value in ait.enumerate(slist):
|
|
self.assertEqual(value, slist[index])
|
|
|
|
@async_test
|
|
async def test_enumerate_start(self):
|
|
async for index, value in ait.enumerate(slist, 4):
|
|
self.assertEqual(value, slist[index - 4])
|
|
|
|
# aioitertools.map()
|
|
|
|
@async_test
|
|
async def test_map_function_list(self):
|
|
idx = 0
|
|
async for value in ait.map(str.lower, slist):
|
|
self.assertEqual(value, slist[idx].lower())
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_map_function_async_generator(self):
|
|
async def gen():
|
|
for item in slist:
|
|
yield item
|
|
|
|
idx = 0
|
|
async for value in ait.map(str.lower, gen()):
|
|
self.assertEqual(value, slist[idx].lower())
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_map_coroutine_list(self):
|
|
async def double(x):
|
|
await asyncio.sleep(0.0001)
|
|
return x * 2
|
|
|
|
idx = 0
|
|
async for value in ait.map(double, slist):
|
|
self.assertEqual(value, slist[idx] * 2)
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_map_coroutine_generator(self):
|
|
async def gen():
|
|
for item in slist:
|
|
yield item
|
|
|
|
async def double(x):
|
|
await asyncio.sleep(0.0001)
|
|
return x * 2
|
|
|
|
idx = 0
|
|
async for value in ait.map(double, gen()):
|
|
self.assertEqual(value, slist[idx] * 2)
|
|
idx += 1
|
|
|
|
# aioitertools.max()
|
|
|
|
@async_test
|
|
async def test_max_basic(self):
|
|
async def gen():
|
|
for item in slist:
|
|
yield item
|
|
|
|
self.assertEqual(await ait.max(gen()), "C")
|
|
self.assertEqual(await ait.max(range(4)), 3)
|
|
|
|
with self.assertRaisesRegex(ValueError, "iterable is empty"):
|
|
await ait.max([])
|
|
|
|
with self.assertRaisesRegex(ValueError, "kwarg .+ not supported"):
|
|
await ait.max(None, foo="foo")
|
|
|
|
@async_test
|
|
async def test_max_default(self):
|
|
self.assertEqual(await ait.max(range(2), default="x"), 1)
|
|
self.assertEqual(await ait.max([], default="x"), "x")
|
|
self.assertEqual(await ait.max([], default=None), None)
|
|
|
|
@async_test
|
|
async def test_max_key(self):
|
|
words = ["star", "buzz", "guard"]
|
|
|
|
def reverse(s):
|
|
return s[::-1]
|
|
|
|
self.assertEqual(reverse("python"), "nohtyp")
|
|
|
|
self.assertEqual(await ait.max(words), "star")
|
|
self.assertEqual(await ait.max(words, key=reverse), "buzz")
|
|
|
|
# aioitertools.min()
|
|
|
|
@async_test
|
|
async def test_min_basic(self):
|
|
async def gen():
|
|
for item in slist:
|
|
yield item
|
|
|
|
self.assertEqual(await ait.min(gen()), "A")
|
|
self.assertEqual(await ait.min(range(4)), 0)
|
|
|
|
with self.assertRaisesRegex(ValueError, "iterable is empty"):
|
|
await ait.min([])
|
|
|
|
with self.assertRaisesRegex(ValueError, "kwarg .+ not supported"):
|
|
await ait.min(None, foo="foo")
|
|
|
|
@async_test
|
|
async def test_min_default(self):
|
|
self.assertEqual(await ait.min(range(2), default="x"), 0)
|
|
self.assertEqual(await ait.min([], default="x"), "x")
|
|
self.assertEqual(await ait.min([], default=None), None)
|
|
|
|
@async_test
|
|
async def test_min_key(self):
|
|
words = ["star", "buzz", "guard"]
|
|
|
|
def reverse(s):
|
|
return s[::-1]
|
|
|
|
self.assertEqual(reverse("python"), "nohtyp")
|
|
|
|
self.assertEqual(await ait.min(words), "buzz")
|
|
self.assertEqual(await ait.min(words, key=reverse), "guard")
|
|
|
|
# aioitertools.sum()
|
|
|
|
@async_test
|
|
async def test_sum_range_default(self):
|
|
self.assertEqual(await ait.sum(srange), sum(srange))
|
|
|
|
@async_test
|
|
async def test_sum_list_string(self):
|
|
self.assertEqual(await ait.sum(slist, "foo"), "fooABC")
|
|
|
|
# aioitertools.zip()
|
|
|
|
@async_test
|
|
async def test_zip_equal(self):
|
|
idx = 0
|
|
async for a, b in ait.zip(slist, srange):
|
|
self.assertEqual(a, slist[idx])
|
|
self.assertEqual(b, srange[idx])
|
|
idx += 1
|
|
|
|
@async_test
|
|
async def test_zip_shortest(self):
|
|
short = ["a", "b", "c"]
|
|
long = [0, 1, 2, 3, 5]
|
|
|
|
result = await ait.list(ait.zip(short, long))
|
|
expected = [("a", 0), ("b", 1), ("c", 2)]
|
|
self.assertListEqual(expected, result)
|