增加环绕侦察场景适配

This commit is contained in:
2026-01-08 15:44:38 +08:00
parent 3eba1f962b
commit 10c5bb5a8a
5441 changed files with 40219 additions and 379695 deletions

View File

@@ -34,7 +34,7 @@ from collections.abc import (
from concurrent.futures import Future
from contextlib import AbstractContextManager, suppress
from contextvars import Context, copy_context
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial, wraps
from inspect import (
CORO_RUNNING,
@@ -59,8 +59,6 @@ from typing import (
)
from weakref import WeakKeyDictionary
import sniffio
from .. import (
CapacityLimiterStatistics,
EventStatistics,
@@ -68,7 +66,11 @@ from .. import (
TaskInfo,
abc,
)
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._eventloop import (
claim_worker_thread,
set_current_async_library,
threadlocals,
)
from .._core._exceptions import (
BrokenResourceError,
BusyResourceError,
@@ -151,18 +153,18 @@ else:
def __exit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.close()
def close(self) -> None:
"""Shutdown and close event loop."""
if self._state is not _State.INITIALIZED:
loop = self._loop
if self._state is not _State.INITIALIZED or loop is None:
return
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, "shutdown_default_executor"):
@@ -801,6 +803,11 @@ class TaskGroup(abc.TaskGroup):
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None:
asyncio.future_discard_from_awaited_by(
_task, self.cancel_scope._host_task
)
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
@@ -882,6 +889,9 @@ class TaskGroup(abc.TaskGroup):
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
if sys.version_info >= (3, 14) and self.cancel_scope._host_task is not None:
asyncio.future_add_to_awaited_by(task, self.cancel_scope._host_task)
task.add_done_callback(task_done)
return task
@@ -1005,29 +1015,6 @@ _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
class BlockingPortal(abc.BlockingPortal):
def __new__(cls) -> BlockingPortal:
return object.__new__(cls)
def __init__(self) -> None:
super().__init__()
self._loop = get_running_loop()
def _spawn_task_from_thread(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future[T_Retval],
) -> None:
AsyncIOBackend.run_sync_from_thread(
partial(self._task_group.start_soon, name=name),
(self._call_func, func, args, kwargs, future),
self._loop,
)
#
# Subprocesses
#
@@ -1052,12 +1039,30 @@ class StreamReaderWrapper(abc.ByteReceiveStream):
@dataclass(eq=False)
class StreamWriterWrapper(abc.ByteSendStream):
_stream: asyncio.StreamWriter
_closed: bool = field(init=False, default=False)
async def send(self, item: bytes) -> None:
self._stream.write(item)
await self._stream.drain()
await AsyncIOBackend.checkpoint_if_cancelled()
stream_paused = self._stream._protocol._paused # type: ignore[attr-defined]
try:
self._stream.write(item)
await self._stream.drain()
except (ConnectionResetError, BrokenPipeError, RuntimeError) as exc:
# If closed by us and/or the peer:
# * on stdlib, drain() raises ConnectionResetError or BrokenPipeError
# * on uvloop and Winloop, write() eventually starts raising RuntimeError
if self._closed:
raise ClosedResourceError from exc
elif self._stream.is_closing():
raise BrokenResourceError from exc
raise
if not stream_paused:
await AsyncIOBackend.cancel_shielded_checkpoint()
async def aclose(self) -> None:
self._closed = True
self._stream.close()
await AsyncIOBackend.checkpoint()
@@ -1125,7 +1130,7 @@ def _forcibly_shutdown_process_pool_on_exit(
) -> None:
"""
Forcibly shuts down worker processes belonging to this event loop."""
child_watcher: asyncio.AbstractChildWatcher | None = None
child_watcher: asyncio.AbstractChildWatcher | None = None # type: ignore[name-defined]
if sys.version_info < (3, 12):
try:
child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
@@ -1133,7 +1138,7 @@ def _forcibly_shutdown_process_pool_on_exit(
pass
# Close as much as possible (w/o async/await) to avoid warnings
for process in workers:
for process in workers.copy():
if process.returncode is None:
continue
@@ -1157,6 +1162,7 @@ async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
try:
await sleep(math.inf)
except asyncio.CancelledError:
workers = workers.copy()
for process in workers:
if process.returncode is None:
process.kill()
@@ -1599,8 +1605,8 @@ class UDPSocket(abc.UDPSocket):
return self._transport.get_extra_info("socket")
async def aclose(self) -> None:
self._closed = True
if not self._transport.is_closing():
self._closed = True
self._transport.close()
async def receive(self) -> tuple[bytes, IPSockAddrType]:
@@ -1647,8 +1653,8 @@ class ConnectedUDPSocket(abc.ConnectedUDPSocket):
return self._transport.get_extra_info("socket")
async def aclose(self) -> None:
self._closed = True
if not self._transport.is_closing():
self._closed = True
self._transport.close()
async def receive(self) -> bytes:
@@ -1971,8 +1977,9 @@ class CapacityLimiter(BaseCapacityLimiter):
def total_tokens(self, value: float) -> None:
if not isinstance(value, int) and not math.isinf(value):
raise TypeError("total_tokens must be an int or math.inf")
if value < 1:
raise ValueError("total_tokens must be >= 1")
if value < 0:
raise ValueError("total_tokens must be >= 0")
waiters_to_notify = max(value - self._total_tokens, 0)
self._total_tokens = value
@@ -2158,9 +2165,14 @@ class TestRunner(abc.TestRunner):
loop_factory: Callable[[], AbstractEventLoop] | None = None,
) -> None:
if use_uvloop and loop_factory is None:
import uvloop
if sys.platform != "win32":
import uvloop
loop_factory = uvloop.new_event_loop
loop_factory = uvloop.new_event_loop
else:
import winloop
loop_factory = winloop.new_event_loop
self._runner = Runner(debug=debug, loop_factory=loop_factory)
self._exceptions: list[BaseException] = []
@@ -2317,9 +2329,14 @@ class AsyncIOBackend(AsyncBackend):
debug = options.get("debug", None)
loop_factory = options.get("loop_factory", None)
if loop_factory is None and options.get("use_uvloop", False):
import uvloop
if sys.platform != "win32":
import uvloop
loop_factory = uvloop.new_event_loop
loop_factory = uvloop.new_event_loop
else:
import winloop
loop_factory = winloop.new_event_loop
with Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(wrapper())
@@ -2475,7 +2492,7 @@ class AsyncIOBackend(AsyncBackend):
expired_worker.stop()
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
context.run(set_current_async_library, None)
if abandon_on_cancel or scope._parent_scope is None:
worker_scope = scope
else:
@@ -2524,7 +2541,7 @@ class AsyncIOBackend(AsyncBackend):
raise RunFinishedError
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, "asyncio")
context.run(set_current_async_library, "asyncio")
scope = getattr(threadlocals, "current_cancel_scope", None)
f: concurrent.futures.Future[T_Retval] = context.run(
asyncio.run_coroutine_threadsafe, task_wrapper(), loop=loop
@@ -2541,7 +2558,7 @@ class AsyncIOBackend(AsyncBackend):
@wraps(func)
def wrapper() -> None:
try:
sniffio.current_async_library_cvar.set("asyncio")
set_current_async_library("asyncio")
f.set_result(func(*args))
except BaseException as exc:
f.set_exception(exc)
@@ -2558,10 +2575,6 @@ class AsyncIOBackend(AsyncBackend):
loop.call_soon_threadsafe(wrapper)
return f.result()
@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
return BlockingPortal()
@classmethod
async def open_process(
cls,

View File

@@ -17,10 +17,8 @@ from collections.abc import (
Iterable,
Sequence,
)
from concurrent.futures import Future
from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import partial
from io import IOBase
from os import PathLike
from signal import Signals
@@ -224,38 +222,6 @@ class TaskGroup(abc.TaskGroup):
return await self._nursery.start(func, *args, name=name)
#
# Threads
#
class BlockingPortal(abc.BlockingPortal):
def __new__(cls) -> BlockingPortal:
return object.__new__(cls)
def __init__(self) -> None:
super().__init__()
self._token = trio.lowlevel.current_trio_token()
def _spawn_task_from_thread(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future[T_Retval],
) -> None:
trio.from_thread.run_sync(
partial(self._task_group.start_soon, name=name),
self._call_func,
func,
args,
kwargs,
future,
trio_token=self._token,
)
#
# Subprocesses
#
@@ -1114,10 +1080,6 @@ class TrioBackend(AsyncBackend):
except trio.RunFinishedError:
raise RunFinishedError from None
@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
return BlockingPortal()
@classmethod
async def open_process(
cls,