增加环绕侦察场景适配
This commit is contained in:
Binary file not shown.
Binary file not shown.
@@ -34,7 +34,7 @@ from collections.abc import (
|
||||
from concurrent.futures import Future
|
||||
from contextlib import AbstractContextManager, suppress
|
||||
from contextvars import Context, copy_context
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial, wraps
|
||||
from inspect import (
|
||||
CORO_RUNNING,
|
||||
@@ -59,8 +59,6 @@ from typing import (
|
||||
)
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import sniffio
|
||||
|
||||
from .. import (
|
||||
CapacityLimiterStatistics,
|
||||
EventStatistics,
|
||||
@@ -68,7 +66,11 @@ from .. import (
|
||||
TaskInfo,
|
||||
abc,
|
||||
)
|
||||
from .._core._eventloop import claim_worker_thread, threadlocals
|
||||
from .._core._eventloop import (
|
||||
claim_worker_thread,
|
||||
set_current_async_library,
|
||||
threadlocals,
|
||||
)
|
||||
from .._core._exceptions import (
|
||||
BrokenResourceError,
|
||||
BusyResourceError,
|
||||
@@ -151,18 +153,18 @@ else:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException],
|
||||
exc_val: BaseException,
|
||||
exc_tb: TracebackType,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Shutdown and close event loop."""
|
||||
if self._state is not _State.INITIALIZED:
|
||||
loop = self._loop
|
||||
if self._state is not _State.INITIALIZED or loop is None:
|
||||
return
|
||||
try:
|
||||
loop = self._loop
|
||||
_cancel_all_tasks(loop)
|
||||
loop.run_until_complete(loop.shutdown_asyncgens())
|
||||
if hasattr(loop, "shutdown_default_executor"):
|
||||
@@ -801,6 +803,11 @@ class TaskGroup(abc.TaskGroup):
|
||||
task_status_future: asyncio.Future | None = None,
|
||||
) -> asyncio.Task:
|
||||
def task_done(_task: asyncio.Task) -> None:
|
||||
if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None:
|
||||
asyncio.future_discard_from_awaited_by(
|
||||
_task, self.cancel_scope._host_task
|
||||
)
|
||||
|
||||
task_state = _task_states[_task]
|
||||
assert task_state.cancel_scope is not None
|
||||
assert _task in task_state.cancel_scope._tasks
|
||||
@@ -882,6 +889,9 @@ class TaskGroup(abc.TaskGroup):
|
||||
)
|
||||
self.cancel_scope._tasks.add(task)
|
||||
self._tasks.add(task)
|
||||
if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None:
|
||||
asyncio.future_add_to_awaited_by(task, self.cancel_scope._host_task)
|
||||
|
||||
task.add_done_callback(task_done)
|
||||
return task
|
||||
|
||||
@@ -1005,29 +1015,6 @@ _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
|
||||
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
|
||||
|
||||
|
||||
class BlockingPortal(abc.BlockingPortal):
|
||||
def __new__(cls) -> BlockingPortal:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._loop = get_running_loop()
|
||||
|
||||
def _spawn_task_from_thread(
|
||||
self,
|
||||
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
|
||||
args: tuple[Unpack[PosArgsT]],
|
||||
kwargs: dict[str, Any],
|
||||
name: object,
|
||||
future: Future[T_Retval],
|
||||
) -> None:
|
||||
AsyncIOBackend.run_sync_from_thread(
|
||||
partial(self._task_group.start_soon, name=name),
|
||||
(self._call_func, func, args, kwargs, future),
|
||||
self._loop,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Subprocesses
|
||||
#
|
||||
@@ -1052,12 +1039,30 @@ class StreamReaderWrapper(abc.ByteReceiveStream):
|
||||
@dataclass(eq=False)
|
||||
class StreamWriterWrapper(abc.ByteSendStream):
|
||||
_stream: asyncio.StreamWriter
|
||||
_closed: bool = field(init=False, default=False)
|
||||
|
||||
async def send(self, item: bytes) -> None:
|
||||
self._stream.write(item)
|
||||
await self._stream.drain()
|
||||
await AsyncIOBackend.checkpoint_if_cancelled()
|
||||
stream_paused = self._stream._protocol._paused # type: ignore[attr-defined]
|
||||
try:
|
||||
self._stream.write(item)
|
||||
await self._stream.drain()
|
||||
except (ConnectionResetError, BrokenPipeError, RuntimeError) as exc:
|
||||
# If closed by us and/or the peer:
|
||||
# * on stdlib, drain() raises ConnectionResetError or BrokenPipeError
|
||||
# * on uvloop and Winloop, write() eventually starts raising RuntimeError
|
||||
if self._closed:
|
||||
raise ClosedResourceError from exc
|
||||
elif self._stream.is_closing():
|
||||
raise BrokenResourceError from exc
|
||||
|
||||
raise
|
||||
|
||||
if not stream_paused:
|
||||
await AsyncIOBackend.cancel_shielded_checkpoint()
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._closed = True
|
||||
self._stream.close()
|
||||
await AsyncIOBackend.checkpoint()
|
||||
|
||||
@@ -1125,7 +1130,7 @@ def _forcibly_shutdown_process_pool_on_exit(
|
||||
) -> None:
|
||||
"""
|
||||
Forcibly shuts down worker processes belonging to this event loop."""
|
||||
child_watcher: asyncio.AbstractChildWatcher | None = None
|
||||
child_watcher: asyncio.AbstractChildWatcher | None = None # type: ignore[name-defined]
|
||||
if sys.version_info < (3, 12):
|
||||
try:
|
||||
child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
|
||||
@@ -1133,7 +1138,7 @@ def _forcibly_shutdown_process_pool_on_exit(
|
||||
pass
|
||||
|
||||
# Close as much as possible (w/o async/await) to avoid warnings
|
||||
for process in workers:
|
||||
for process in workers.copy():
|
||||
if process.returncode is None:
|
||||
continue
|
||||
|
||||
@@ -1157,6 +1162,7 @@ async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
|
||||
try:
|
||||
await sleep(math.inf)
|
||||
except asyncio.CancelledError:
|
||||
workers = workers.copy()
|
||||
for process in workers:
|
||||
if process.returncode is None:
|
||||
process.kill()
|
||||
@@ -1599,8 +1605,8 @@ class UDPSocket(abc.UDPSocket):
|
||||
return self._transport.get_extra_info("socket")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._closed = True
|
||||
if not self._transport.is_closing():
|
||||
self._closed = True
|
||||
self._transport.close()
|
||||
|
||||
async def receive(self) -> tuple[bytes, IPSockAddrType]:
|
||||
@@ -1647,8 +1653,8 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
|
||||
return self._transport.get_extra_info("socket")
|
||||
|
||||
async def aclose(self) -> None:
|
||||
self._closed = True
|
||||
if not self._transport.is_closing():
|
||||
self._closed = True
|
||||
self._transport.close()
|
||||
|
||||
async def receive(self) -> bytes:
|
||||
@@ -1971,8 +1977,9 @@ class CapacityLimiter(BaseCapacityLimiter):
|
||||
def total_tokens(self, value: float) -> None:
|
||||
if not isinstance(value, int) and not math.isinf(value):
|
||||
raise TypeError("total_tokens must be an int or math.inf")
|
||||
if value < 1:
|
||||
raise ValueError("total_tokens must be >= 1")
|
||||
|
||||
if value < 0:
|
||||
raise ValueError("total_tokens must be >= 0")
|
||||
|
||||
waiters_to_notify = max(value - self._total_tokens, 0)
|
||||
self._total_tokens = value
|
||||
@@ -2158,9 +2165,14 @@ class TestRunner(abc.TestRunner):
|
||||
loop_factory: Callable[[], AbstractEventLoop] | None = None,
|
||||
) -> None:
|
||||
if use_uvloop and loop_factory is None:
|
||||
import uvloop
|
||||
if sys.platform != "win32":
|
||||
import uvloop
|
||||
|
||||
loop_factory = uvloop.new_event_loop
|
||||
loop_factory = uvloop.new_event_loop
|
||||
else:
|
||||
import winloop
|
||||
|
||||
loop_factory = winloop.new_event_loop
|
||||
|
||||
self._runner = Runner(debug=debug, loop_factory=loop_factory)
|
||||
self._exceptions: list[BaseException] = []
|
||||
@@ -2317,9 +2329,14 @@ class AsyncIOBackend(AsyncBackend):
|
||||
debug = options.get("debug", None)
|
||||
loop_factory = options.get("loop_factory", None)
|
||||
if loop_factory is None and options.get("use_uvloop", False):
|
||||
import uvloop
|
||||
if sys.platform != "win32":
|
||||
import uvloop
|
||||
|
||||
loop_factory = uvloop.new_event_loop
|
||||
loop_factory = uvloop.new_event_loop
|
||||
else:
|
||||
import winloop
|
||||
|
||||
loop_factory = winloop.new_event_loop
|
||||
|
||||
with Runner(debug=debug, loop_factory=loop_factory) as runner:
|
||||
return runner.run(wrapper())
|
||||
@@ -2475,7 +2492,7 @@ class AsyncIOBackend(AsyncBackend):
|
||||
expired_worker.stop()
|
||||
|
||||
context = copy_context()
|
||||
context.run(sniffio.current_async_library_cvar.set, None)
|
||||
context.run(set_current_async_library, None)
|
||||
if abandon_on_cancel or scope._parent_scope is None:
|
||||
worker_scope = scope
|
||||
else:
|
||||
@@ -2524,7 +2541,7 @@ class AsyncIOBackend(AsyncBackend):
|
||||
raise RunFinishedError
|
||||
|
||||
context = copy_context()
|
||||
context.run(sniffio.current_async_library_cvar.set, "asyncio")
|
||||
context.run(set_current_async_library, "asyncio")
|
||||
scope = getattr(threadlocals, "current_cancel_scope", None)
|
||||
f: concurrent.futures.Future[T_Retval] = context.run(
|
||||
asyncio.run_coroutine_threadsafe, task_wrapper(), loop=loop
|
||||
@@ -2541,7 +2558,7 @@ class AsyncIOBackend(AsyncBackend):
|
||||
@wraps(func)
|
||||
def wrapper() -> None:
|
||||
try:
|
||||
sniffio.current_async_library_cvar.set("asyncio")
|
||||
set_current_async_library("asyncio")
|
||||
f.set_result(func(*args))
|
||||
except BaseException as exc:
|
||||
f.set_exception(exc)
|
||||
@@ -2558,10 +2575,6 @@ class AsyncIOBackend(AsyncBackend):
|
||||
loop.call_soon_threadsafe(wrapper)
|
||||
return f.result()
|
||||
|
||||
@classmethod
|
||||
def create_blocking_portal(cls) -> abc.BlockingPortal:
|
||||
return BlockingPortal()
|
||||
|
||||
@classmethod
|
||||
async def open_process(
|
||||
cls,
|
||||
|
||||
@@ -17,10 +17,8 @@ from collections.abc import (
|
||||
Iterable,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures import Future
|
||||
from contextlib import AbstractContextManager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from io import IOBase
|
||||
from os import PathLike
|
||||
from signal import Signals
|
||||
@@ -224,38 +222,6 @@ class TaskGroup(abc.TaskGroup):
|
||||
return await self._nursery.start(func, *args, name=name)
|
||||
|
||||
|
||||
#
|
||||
# Threads
|
||||
#
|
||||
|
||||
|
||||
class BlockingPortal(abc.BlockingPortal):
|
||||
def __new__(cls) -> BlockingPortal:
|
||||
return object.__new__(cls)
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._token = trio.lowlevel.current_trio_token()
|
||||
|
||||
def _spawn_task_from_thread(
|
||||
self,
|
||||
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
|
||||
args: tuple[Unpack[PosArgsT]],
|
||||
kwargs: dict[str, Any],
|
||||
name: object,
|
||||
future: Future[T_Retval],
|
||||
) -> None:
|
||||
trio.from_thread.run_sync(
|
||||
partial(self._task_group.start_soon, name=name),
|
||||
self._call_func,
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
future,
|
||||
trio_token=self._token,
|
||||
)
|
||||
|
||||
|
||||
#
|
||||
# Subprocesses
|
||||
#
|
||||
@@ -1114,10 +1080,6 @@ class TrioBackend(AsyncBackend):
|
||||
except trio.RunFinishedError:
|
||||
raise RunFinishedError from None
|
||||
|
||||
@classmethod
|
||||
def create_blocking_portal(cls) -> abc.BlockingPortal:
|
||||
return BlockingPortal()
|
||||
|
||||
@classmethod
|
||||
async def open_process(
|
||||
cls,
|
||||
|
||||
Reference in New Issue
Block a user