chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1 @@
|
||||
b01999d409b29bd916e067bc963d5f2d9ee63cfc9ae0bccb769910131417bf93 /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pxd
|
||||
@@ -0,0 +1 @@
|
||||
0478ceb55d0ed30ef1a7da742cd003449bc69a07cf9fdb06789bd2b347cbfffe /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/mask.pyx
|
||||
@@ -0,0 +1 @@
|
||||
9e5fe78ed0ebce5414d2b8e01868d90c1facc20b84d2d5ff6c23e86e44a155ae /home/runner/work/aiohttp/aiohttp/aiohttp/_websocket/reader_c.pxd
|
||||
@@ -0,0 +1 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,147 @@
|
||||
"""Helpers for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import functools
|
||||
import re
|
||||
from struct import Struct
|
||||
from typing import TYPE_CHECKING, Final, List, Optional, Pattern, Tuple
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
from .models import WSHandshakeError
|
||||
|
||||
UNPACK_LEN3 = Struct("!Q").unpack_from
|
||||
UNPACK_CLOSE_CODE = Struct("!H").unpack
|
||||
PACK_LEN1 = Struct("!BB").pack
|
||||
PACK_LEN2 = Struct("!BBH").pack
|
||||
PACK_LEN3 = Struct("!BBQ").pack
|
||||
PACK_CLOSE_CODE = Struct("!H").pack
|
||||
PACK_RANDBITS = Struct("!L").pack
|
||||
MSG_SIZE: Final[int] = 2**14
|
||||
MASK_LEN: Final[int] = 4
|
||||
|
||||
WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||
|
||||
|
||||
# Used by _websocket_mask_python
|
||||
@functools.lru_cache
|
||||
def _xor_table() -> List[bytes]:
|
||||
return [bytes(a ^ b for a in range(256)) for b in range(256)]
|
||||
|
||||
|
||||
def _websocket_mask_python(mask: bytes, data: bytearray) -> None:
|
||||
"""Websocket masking function.
|
||||
|
||||
`mask` is a `bytes` object of length 4; `data` is a `bytearray`
|
||||
object of any length. The contents of `data` are masked with `mask`,
|
||||
as specified in section 5.3 of RFC 6455.
|
||||
|
||||
Note that this function mutates the `data` argument.
|
||||
|
||||
This pure-python implementation may be replaced by an optimized
|
||||
version when available.
|
||||
|
||||
"""
|
||||
assert isinstance(data, bytearray), data
|
||||
assert len(mask) == 4, mask
|
||||
|
||||
if data:
|
||||
_XOR_TABLE = _xor_table()
|
||||
a, b, c, d = (_XOR_TABLE[n] for n in mask)
|
||||
data[::4] = data[::4].translate(a)
|
||||
data[1::4] = data[1::4].translate(b)
|
||||
data[2::4] = data[2::4].translate(c)
|
||||
data[3::4] = data[3::4].translate(d)
|
||||
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
else:
|
||||
try:
|
||||
from .mask import _websocket_mask_cython # type: ignore[import-not-found]
|
||||
|
||||
websocket_mask = _websocket_mask_cython
|
||||
except ImportError: # pragma: no cover
|
||||
websocket_mask = _websocket_mask_python
|
||||
|
||||
|
||||
_WS_EXT_RE: Final[Pattern[str]] = re.compile(
|
||||
r"^(?:;\s*(?:"
|
||||
r"(server_no_context_takeover)|"
|
||||
r"(client_no_context_takeover)|"
|
||||
r"(server_max_window_bits(?:=(\d+))?)|"
|
||||
r"(client_max_window_bits(?:=(\d+))?)))*$"
|
||||
)
|
||||
|
||||
_WS_EXT_RE_SPLIT: Final[Pattern[str]] = re.compile(r"permessage-deflate([^,]+)?")
|
||||
|
||||
|
||||
def ws_ext_parse(extstr: Optional[str], isserver: bool = False) -> Tuple[int, bool]:
|
||||
if not extstr:
|
||||
return 0, False
|
||||
|
||||
compress = 0
|
||||
notakeover = False
|
||||
for ext in _WS_EXT_RE_SPLIT.finditer(extstr):
|
||||
defext = ext.group(1)
|
||||
# Return compress = 15 when get `permessage-deflate`
|
||||
if not defext:
|
||||
compress = 15
|
||||
break
|
||||
match = _WS_EXT_RE.match(defext)
|
||||
if match:
|
||||
compress = 15
|
||||
if isserver:
|
||||
# Server never fail to detect compress handshake.
|
||||
# Server does not need to send max wbit to client
|
||||
if match.group(4):
|
||||
compress = int(match.group(4))
|
||||
# Group3 must match if group4 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# CONTINUE to next extension
|
||||
if compress > 15 or compress < 9:
|
||||
compress = 0
|
||||
continue
|
||||
if match.group(1):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
else:
|
||||
if match.group(6):
|
||||
compress = int(match.group(6))
|
||||
# Group5 must match if group6 matches
|
||||
# Compress wbit 8 does not support in zlib
|
||||
# If compress level not support,
|
||||
# FAIL the parse progress
|
||||
if compress > 15 or compress < 9:
|
||||
raise WSHandshakeError("Invalid window size")
|
||||
if match.group(2):
|
||||
notakeover = True
|
||||
# Ignore regex group 5 & 6 for client_max_window_bits
|
||||
break
|
||||
# Return Fail if client side and not match
|
||||
elif not isserver:
|
||||
raise WSHandshakeError("Extension for deflate not supported" + ext.group(1))
|
||||
|
||||
return compress, notakeover
|
||||
|
||||
|
||||
def ws_ext_gen(
|
||||
compress: int = 15, isserver: bool = False, server_notakeover: bool = False
|
||||
) -> str:
|
||||
# client_notakeover=False not used for server
|
||||
# compress wbit 8 does not support in zlib
|
||||
if compress < 9 or compress > 15:
|
||||
raise ValueError(
|
||||
"Compress wbits must between 9 and 15, zlib does not support wbits=8"
|
||||
)
|
||||
enabledext = ["permessage-deflate"]
|
||||
if not isserver:
|
||||
enabledext.append("client_max_window_bits")
|
||||
|
||||
if compress < 15:
|
||||
enabledext.append("server_max_window_bits=" + str(compress))
|
||||
if server_notakeover:
|
||||
enabledext.append("server_no_context_takeover")
|
||||
# if client_notakeover:
|
||||
# enabledext.append('client_no_context_takeover')
|
||||
return "; ".join(enabledext)
|
||||
Binary file not shown.
@@ -0,0 +1,3 @@
|
||||
"""Cython declarations for websocket masking."""
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data)
|
||||
@@ -0,0 +1,48 @@
|
||||
from cpython cimport PyBytes_AsString
|
||||
|
||||
|
||||
#from cpython cimport PyByteArray_AsString # cython still not exports that
|
||||
cdef extern from "Python.h":
|
||||
char* PyByteArray_AsString(bytearray ba) except NULL
|
||||
|
||||
from libc.stdint cimport uint32_t, uint64_t, uintmax_t
|
||||
|
||||
|
||||
cpdef void _websocket_mask_cython(bytes mask, bytearray data):
|
||||
"""Note, this function mutates its `data` argument
|
||||
"""
|
||||
cdef:
|
||||
Py_ssize_t data_len, i
|
||||
# bit operations on signed integers are implementation-specific
|
||||
unsigned char * in_buf
|
||||
const unsigned char * mask_buf
|
||||
uint32_t uint32_msk
|
||||
uint64_t uint64_msk
|
||||
|
||||
assert len(mask) == 4
|
||||
|
||||
data_len = len(data)
|
||||
in_buf = <unsigned char*>PyByteArray_AsString(data)
|
||||
mask_buf = <const unsigned char*>PyBytes_AsString(mask)
|
||||
uint32_msk = (<uint32_t*>mask_buf)[0]
|
||||
|
||||
# TODO: align in_data ptr to achieve even faster speeds
|
||||
# does it need in python ?! malloc() always aligns to sizeof(long) bytes
|
||||
|
||||
if sizeof(size_t) >= 8:
|
||||
uint64_msk = uint32_msk
|
||||
uint64_msk = (uint64_msk << 32) | uint32_msk
|
||||
|
||||
while data_len >= 8:
|
||||
(<uint64_t*>in_buf)[0] ^= uint64_msk
|
||||
in_buf += 8
|
||||
data_len -= 8
|
||||
|
||||
|
||||
while data_len >= 4:
|
||||
(<uint32_t*>in_buf)[0] ^= uint32_msk
|
||||
in_buf += 4
|
||||
data_len -= 4
|
||||
|
||||
for i in range(0, data_len):
|
||||
in_buf[i] ^= mask_buf[i]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Models for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import json
|
||||
from enum import IntEnum
|
||||
from typing import Any, Callable, Final, NamedTuple, Optional, cast
|
||||
|
||||
WS_DEFLATE_TRAILING: Final[bytes] = bytes([0x00, 0x00, 0xFF, 0xFF])
|
||||
|
||||
|
||||
class WSCloseCode(IntEnum):
|
||||
OK = 1000
|
||||
GOING_AWAY = 1001
|
||||
PROTOCOL_ERROR = 1002
|
||||
UNSUPPORTED_DATA = 1003
|
||||
ABNORMAL_CLOSURE = 1006
|
||||
INVALID_TEXT = 1007
|
||||
POLICY_VIOLATION = 1008
|
||||
MESSAGE_TOO_BIG = 1009
|
||||
MANDATORY_EXTENSION = 1010
|
||||
INTERNAL_ERROR = 1011
|
||||
SERVICE_RESTART = 1012
|
||||
TRY_AGAIN_LATER = 1013
|
||||
BAD_GATEWAY = 1014
|
||||
|
||||
|
||||
class WSMsgType(IntEnum):
|
||||
# websocket spec types
|
||||
CONTINUATION = 0x0
|
||||
TEXT = 0x1
|
||||
BINARY = 0x2
|
||||
PING = 0x9
|
||||
PONG = 0xA
|
||||
CLOSE = 0x8
|
||||
|
||||
# aiohttp specific types
|
||||
CLOSING = 0x100
|
||||
CLOSED = 0x101
|
||||
ERROR = 0x102
|
||||
|
||||
text = TEXT
|
||||
binary = BINARY
|
||||
ping = PING
|
||||
pong = PONG
|
||||
close = CLOSE
|
||||
closing = CLOSING
|
||||
closed = CLOSED
|
||||
error = ERROR
|
||||
|
||||
|
||||
class WSMessage(NamedTuple):
|
||||
type: WSMsgType
|
||||
# To type correctly, this would need some kind of tagged union for each type.
|
||||
data: Any
|
||||
extra: Optional[str]
|
||||
|
||||
def json(self, *, loads: Callable[[Any], Any] = json.loads) -> Any:
|
||||
"""Return parsed JSON data.
|
||||
|
||||
.. versionadded:: 0.22
|
||||
"""
|
||||
return loads(self.data)
|
||||
|
||||
|
||||
# Constructing the tuple directly to avoid the overhead of
|
||||
# the lambda and arg processing since NamedTuples are constructed
|
||||
# with a run time built lambda
|
||||
# https://github.com/python/cpython/blob/d83fcf8371f2f33c7797bc8f5423a8bca8c46e5c/Lib/collections/__init__.py#L441
|
||||
WS_CLOSED_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSED, None, None))
|
||||
WS_CLOSING_MESSAGE = tuple.__new__(WSMessage, (WSMsgType.CLOSING, None, None))
|
||||
|
||||
|
||||
class WebSocketError(Exception):
|
||||
"""WebSocket protocol parser error."""
|
||||
|
||||
def __init__(self, code: int, message: str) -> None:
|
||||
self.code = code
|
||||
super().__init__(code, message)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return cast(str, self.args[1])
|
||||
|
||||
|
||||
class WSHandshakeError(Exception):
|
||||
"""WebSocket protocol handshake error."""
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..helpers import NO_EXTENSIONS
|
||||
|
||||
if TYPE_CHECKING or NO_EXTENSIONS: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
else:
|
||||
try:
|
||||
from .reader_c import ( # type: ignore[import-not-found]
|
||||
WebSocketDataQueue as WebSocketDataQueueCython,
|
||||
WebSocketReader as WebSocketReaderCython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderCython
|
||||
WebSocketDataQueue = WebSocketDataQueueCython
|
||||
except ImportError: # pragma: no cover
|
||||
from .reader_py import (
|
||||
WebSocketDataQueue as WebSocketDataQueuePython,
|
||||
WebSocketReader as WebSocketReaderPython,
|
||||
)
|
||||
|
||||
WebSocketReader = WebSocketReaderPython
|
||||
WebSocketDataQueue = WebSocketDataQueuePython
|
||||
Binary file not shown.
@@ -0,0 +1,110 @@
|
||||
import cython
|
||||
|
||||
from .mask cimport _websocket_mask_cython as websocket_mask
|
||||
|
||||
|
||||
cdef unsigned int READ_HEADER
|
||||
cdef unsigned int READ_PAYLOAD_LENGTH
|
||||
cdef unsigned int READ_PAYLOAD_MASK
|
||||
cdef unsigned int READ_PAYLOAD
|
||||
|
||||
cdef int OP_CODE_NOT_SET
|
||||
cdef int OP_CODE_CONTINUATION
|
||||
cdef int OP_CODE_TEXT
|
||||
cdef int OP_CODE_BINARY
|
||||
cdef int OP_CODE_CLOSE
|
||||
cdef int OP_CODE_PING
|
||||
cdef int OP_CODE_PONG
|
||||
|
||||
cdef int COMPRESSED_NOT_SET
|
||||
cdef int COMPRESSED_FALSE
|
||||
cdef int COMPRESSED_TRUE
|
||||
|
||||
cdef object UNPACK_LEN3
|
||||
cdef object UNPACK_CLOSE_CODE
|
||||
cdef object TUPLE_NEW
|
||||
|
||||
cdef object WSMsgType
|
||||
cdef object WSMessage
|
||||
|
||||
cdef object WS_MSG_TYPE_TEXT
|
||||
cdef object WS_MSG_TYPE_BINARY
|
||||
|
||||
cdef set ALLOWED_CLOSE_CODES
|
||||
cdef set MESSAGE_TYPES_WITH_CONTENT
|
||||
|
||||
cdef tuple EMPTY_FRAME
|
||||
cdef tuple EMPTY_FRAME_ERROR
|
||||
|
||||
cdef class WebSocketDataQueue:
|
||||
|
||||
cdef unsigned int _size
|
||||
cdef public object _protocol
|
||||
cdef unsigned int _limit
|
||||
cdef object _loop
|
||||
cdef bint _eof
|
||||
cdef object _waiter
|
||||
cdef object _exception
|
||||
cdef public object _buffer
|
||||
cdef object _get_buffer
|
||||
cdef object _put_buffer
|
||||
|
||||
cdef void _release_waiter(self)
|
||||
|
||||
cpdef void feed_data(self, object data, unsigned int size)
|
||||
|
||||
@cython.locals(size="unsigned int")
|
||||
cdef _read_from_buffer(self)
|
||||
|
||||
cdef class WebSocketReader:
|
||||
|
||||
cdef WebSocketDataQueue queue
|
||||
cdef unsigned int _max_msg_size
|
||||
|
||||
cdef Exception _exc
|
||||
cdef bytearray _partial
|
||||
cdef unsigned int _state
|
||||
|
||||
cdef int _opcode
|
||||
cdef bint _frame_fin
|
||||
cdef int _frame_opcode
|
||||
cdef list _payload_fragments
|
||||
cdef Py_ssize_t _frame_payload_len
|
||||
|
||||
cdef bytes _tail
|
||||
cdef bint _has_mask
|
||||
cdef bytes _frame_mask
|
||||
cdef Py_ssize_t _payload_bytes_to_read
|
||||
cdef unsigned int _payload_len_flag
|
||||
cdef int _compressed
|
||||
cdef object _decompressobj
|
||||
cdef bint _compress
|
||||
|
||||
cpdef tuple feed_data(self, object data)
|
||||
|
||||
@cython.locals(
|
||||
is_continuation=bint,
|
||||
fin=bint,
|
||||
has_partial=bint,
|
||||
payload_merged=bytes,
|
||||
)
|
||||
cpdef void _handle_frame(self, bint fin, int opcode, object payload, int compressed) except *
|
||||
|
||||
@cython.locals(
|
||||
start_pos=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
length=Py_ssize_t,
|
||||
chunk_size=Py_ssize_t,
|
||||
chunk_len=Py_ssize_t,
|
||||
data_len=Py_ssize_t,
|
||||
data_cstr="const unsigned char *",
|
||||
first_byte="unsigned char",
|
||||
second_byte="unsigned char",
|
||||
f_start_pos=Py_ssize_t,
|
||||
f_end_pos=Py_ssize_t,
|
||||
has_mask=bint,
|
||||
fin=bint,
|
||||
had_fragments=Py_ssize_t,
|
||||
payload_bytearray=bytearray,
|
||||
)
|
||||
cpdef void _feed_data(self, bytes data) except *
|
||||
@@ -0,0 +1,476 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: Optional[asyncio.Future[None]] = None
|
||||
self._exception: Union[BaseException, None] = None
|
||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
|
||||
self._exc: Optional[Exception] = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: Optional[bytes] = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(
|
||||
self, data: Union[bytes, bytearray, memoryview]
|
||||
) -> Tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
payload: Union[bytes, bytearray],
|
||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(self._partial)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
if self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: Union[bytes, bytearray]
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(assembled_payload)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: Union[bytes, bytearray]
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,476 @@
|
||||
"""Reader for WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import builtins
|
||||
from collections import deque
|
||||
from typing import Deque, Final, Optional, Set, Tuple, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..compression_utils import ZLibDecompressor
|
||||
from ..helpers import _EXC_SENTINEL, set_exception
|
||||
from ..streams import EofStream
|
||||
from .helpers import UNPACK_CLOSE_CODE, UNPACK_LEN3, websocket_mask
|
||||
from .models import (
|
||||
WS_DEFLATE_TRAILING,
|
||||
WebSocketError,
|
||||
WSCloseCode,
|
||||
WSMessage,
|
||||
WSMsgType,
|
||||
)
|
||||
|
||||
ALLOWED_CLOSE_CODES: Final[Set[int]] = {int(i) for i in WSCloseCode}
|
||||
|
||||
# States for the reader, used to parse the WebSocket frame
|
||||
# integer values are used so they can be cythonized
|
||||
READ_HEADER = 1
|
||||
READ_PAYLOAD_LENGTH = 2
|
||||
READ_PAYLOAD_MASK = 3
|
||||
READ_PAYLOAD = 4
|
||||
|
||||
WS_MSG_TYPE_BINARY = WSMsgType.BINARY
|
||||
WS_MSG_TYPE_TEXT = WSMsgType.TEXT
|
||||
|
||||
# WSMsgType values unpacked so they can by cythonized to ints
|
||||
OP_CODE_NOT_SET = -1
|
||||
OP_CODE_CONTINUATION = WSMsgType.CONTINUATION.value
|
||||
OP_CODE_TEXT = WSMsgType.TEXT.value
|
||||
OP_CODE_BINARY = WSMsgType.BINARY.value
|
||||
OP_CODE_CLOSE = WSMsgType.CLOSE.value
|
||||
OP_CODE_PING = WSMsgType.PING.value
|
||||
OP_CODE_PONG = WSMsgType.PONG.value
|
||||
|
||||
EMPTY_FRAME_ERROR = (True, b"")
|
||||
EMPTY_FRAME = (False, b"")
|
||||
|
||||
COMPRESSED_NOT_SET = -1
|
||||
COMPRESSED_FALSE = 0
|
||||
COMPRESSED_TRUE = 1
|
||||
|
||||
TUPLE_NEW = tuple.__new__
|
||||
|
||||
cython_int = int # Typed to int in Python, but cython with use a signed int in the pxd
|
||||
|
||||
|
||||
class WebSocketDataQueue:
|
||||
"""WebSocketDataQueue resumes and pauses an underlying stream.
|
||||
|
||||
It is a destination for WebSocket data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
|
||||
) -> None:
|
||||
self._size = 0
|
||||
self._protocol = protocol
|
||||
self._limit = limit * 2
|
||||
self._loop = loop
|
||||
self._eof = False
|
||||
self._waiter: Optional[asyncio.Future[None]] = None
|
||||
self._exception: Union[BaseException, None] = None
|
||||
self._buffer: Deque[Tuple[WSMessage, int]] = deque()
|
||||
self._get_buffer = self._buffer.popleft
|
||||
self._put_buffer = self._buffer.append
|
||||
|
||||
def is_eof(self) -> bool:
|
||||
return self._eof
|
||||
|
||||
def exception(self) -> Optional[BaseException]:
|
||||
return self._exception
|
||||
|
||||
def set_exception(
|
||||
self,
|
||||
exc: BaseException,
|
||||
exc_cause: builtins.BaseException = _EXC_SENTINEL,
|
||||
) -> None:
|
||||
self._eof = True
|
||||
self._exception = exc
|
||||
if (waiter := self._waiter) is not None:
|
||||
self._waiter = None
|
||||
set_exception(waiter, exc, exc_cause)
|
||||
|
||||
def _release_waiter(self) -> None:
|
||||
if (waiter := self._waiter) is None:
|
||||
return
|
||||
self._waiter = None
|
||||
if not waiter.done():
|
||||
waiter.set_result(None)
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self._eof = True
|
||||
self._release_waiter()
|
||||
self._exception = None # Break cyclic references
|
||||
|
||||
def feed_data(self, data: "WSMessage", size: "cython_int") -> None:
|
||||
self._size += size
|
||||
self._put_buffer((data, size))
|
||||
self._release_waiter()
|
||||
if self._size > self._limit and not self._protocol._reading_paused:
|
||||
self._protocol.pause_reading()
|
||||
|
||||
async def read(self) -> WSMessage:
|
||||
if not self._buffer and not self._eof:
|
||||
assert not self._waiter
|
||||
self._waiter = self._loop.create_future()
|
||||
try:
|
||||
await self._waiter
|
||||
except (asyncio.CancelledError, asyncio.TimeoutError):
|
||||
self._waiter = None
|
||||
raise
|
||||
return self._read_from_buffer()
|
||||
|
||||
def _read_from_buffer(self) -> WSMessage:
|
||||
if self._buffer:
|
||||
data, size = self._get_buffer()
|
||||
self._size -= size
|
||||
if self._size < self._limit and self._protocol._reading_paused:
|
||||
self._protocol.resume_reading()
|
||||
return data
|
||||
if self._exception is not None:
|
||||
raise self._exception
|
||||
raise EofStream
|
||||
|
||||
|
||||
class WebSocketReader:
|
||||
def __init__(
|
||||
self, queue: WebSocketDataQueue, max_msg_size: int, compress: bool = True
|
||||
) -> None:
|
||||
self.queue = queue
|
||||
self._max_msg_size = max_msg_size
|
||||
|
||||
self._exc: Optional[Exception] = None
|
||||
self._partial = bytearray()
|
||||
self._state = READ_HEADER
|
||||
|
||||
self._opcode: int = OP_CODE_NOT_SET
|
||||
self._frame_fin = False
|
||||
self._frame_opcode: int = OP_CODE_NOT_SET
|
||||
self._payload_fragments: list[bytes] = []
|
||||
self._frame_payload_len = 0
|
||||
|
||||
self._tail: bytes = b""
|
||||
self._has_mask = False
|
||||
self._frame_mask: Optional[bytes] = None
|
||||
self._payload_bytes_to_read = 0
|
||||
self._payload_len_flag = 0
|
||||
self._compressed: int = COMPRESSED_NOT_SET
|
||||
self._decompressobj: Optional[ZLibDecompressor] = None
|
||||
self._compress = compress
|
||||
|
||||
def feed_eof(self) -> None:
|
||||
self.queue.feed_eof()
|
||||
|
||||
# data can be bytearray on Windows because proactor event loop uses bytearray
|
||||
# and asyncio types this to Union[bytes, bytearray, memoryview] so we need
|
||||
# coerce data to bytes if it is not
|
||||
def feed_data(
|
||||
self, data: Union[bytes, bytearray, memoryview]
|
||||
) -> Tuple[bool, bytes]:
|
||||
if type(data) is not bytes:
|
||||
data = bytes(data)
|
||||
|
||||
if self._exc is not None:
|
||||
return True, data
|
||||
|
||||
try:
|
||||
self._feed_data(data)
|
||||
except Exception as exc:
|
||||
self._exc = exc
|
||||
set_exception(self.queue, exc)
|
||||
return EMPTY_FRAME_ERROR
|
||||
|
||||
return EMPTY_FRAME
|
||||
|
||||
def _handle_frame(
|
||||
self,
|
||||
fin: bool,
|
||||
opcode: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
payload: Union[bytes, bytearray],
|
||||
compressed: Union[int, cython_int], # Union intended: Cython pxd uses C int
|
||||
) -> None:
|
||||
msg: WSMessage
|
||||
if opcode in {OP_CODE_TEXT, OP_CODE_BINARY, OP_CODE_CONTINUATION}:
|
||||
# load text/binary
|
||||
if not fin:
|
||||
# got partial frame payload
|
||||
if opcode != OP_CODE_CONTINUATION:
|
||||
self._opcode = opcode
|
||||
self._partial += payload
|
||||
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(self._partial)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
return
|
||||
|
||||
has_partial = bool(self._partial)
|
||||
if opcode == OP_CODE_CONTINUATION:
|
||||
if self._opcode == OP_CODE_NOT_SET:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Continuation frame for non started message",
|
||||
)
|
||||
opcode = self._opcode
|
||||
self._opcode = OP_CODE_NOT_SET
|
||||
# previous frame was non finished
|
||||
# we should get continuation opcode
|
||||
elif has_partial:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"The opcode in non-fin frame is expected "
|
||||
f"to be zero, got {opcode!r}",
|
||||
)
|
||||
|
||||
assembled_payload: Union[bytes, bytearray]
|
||||
if has_partial:
|
||||
assembled_payload = self._partial + payload
|
||||
self._partial.clear()
|
||||
else:
|
||||
assembled_payload = payload
|
||||
|
||||
if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Message size {len(assembled_payload)} "
|
||||
f"exceeds limit {self._max_msg_size}",
|
||||
)
|
||||
|
||||
# Decompress process must to be done after all packets
|
||||
# received.
|
||||
if compressed:
|
||||
if not self._decompressobj:
|
||||
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
|
||||
# XXX: It's possible that the zlib backend (isal is known to
|
||||
# do this, maybe others too?) will return max_length bytes,
|
||||
# but internally buffer more data such that the payload is
|
||||
# >max_length, so we return one extra byte and if we're able
|
||||
# to do that, then the message is too big.
|
||||
payload_merged = self._decompressobj.decompress_sync(
|
||||
assembled_payload + WS_DEFLATE_TRAILING,
|
||||
(
|
||||
self._max_msg_size + 1
|
||||
if self._max_msg_size
|
||||
else self._max_msg_size
|
||||
),
|
||||
)
|
||||
if self._max_msg_size and len(payload_merged) > self._max_msg_size:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.MESSAGE_TOO_BIG,
|
||||
f"Decompressed message exceeds size limit {self._max_msg_size}",
|
||||
)
|
||||
elif type(assembled_payload) is bytes:
|
||||
payload_merged = assembled_payload
|
||||
else:
|
||||
payload_merged = bytes(assembled_payload)
|
||||
|
||||
if opcode == OP_CODE_TEXT:
|
||||
try:
|
||||
text = payload_merged.decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
|
||||
# XXX: The Text and Binary messages here can be a performance
|
||||
# bottleneck, so we use tuple.__new__ to improve performance.
|
||||
# This is not type safe, but many tests should fail in
|
||||
# test_client_ws_functional.py if this is wrong.
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_TEXT, text, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
else:
|
||||
self.queue.feed_data(
|
||||
TUPLE_NEW(WSMessage, (WS_MSG_TYPE_BINARY, payload_merged, "")),
|
||||
len(payload_merged),
|
||||
)
|
||||
elif opcode == OP_CODE_CLOSE:
|
||||
if len(payload) >= 2:
|
||||
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
|
||||
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close code: {close_code}",
|
||||
)
|
||||
try:
|
||||
close_message = payload[2:].decode("utf-8")
|
||||
except UnicodeDecodeError as exc:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
|
||||
) from exc
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, close_code, close_message))
|
||||
elif payload:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
f"Invalid close frame: {fin} {opcode} {payload!r}",
|
||||
)
|
||||
else:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.CLOSE, 0, ""))
|
||||
|
||||
self.queue.feed_data(msg, 0)
|
||||
elif opcode == OP_CODE_PING:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PING, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
elif opcode == OP_CODE_PONG:
|
||||
msg = TUPLE_NEW(WSMessage, (WSMsgType.PONG, payload, ""))
|
||||
self.queue.feed_data(msg, len(payload))
|
||||
else:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
|
||||
)
|
||||
|
||||
def _feed_data(self, data: bytes) -> None:
|
||||
"""Return the next frame from the socket."""
|
||||
if self._tail:
|
||||
data, self._tail = self._tail + data, b""
|
||||
|
||||
start_pos: int = 0
|
||||
data_len = len(data)
|
||||
data_cstr = data
|
||||
|
||||
while True:
|
||||
# read header
|
||||
if self._state == READ_HEADER:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
|
||||
fin = (first_byte >> 7) & 1
|
||||
rsv1 = (first_byte >> 6) & 1
|
||||
rsv2 = (first_byte >> 5) & 1
|
||||
rsv3 = (first_byte >> 4) & 1
|
||||
opcode = first_byte & 0xF
|
||||
|
||||
# frame-fin = %x0 ; more frames of this message follow
|
||||
# / %x1 ; final frame of this message
|
||||
# frame-rsv1 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv2 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
# frame-rsv3 = %x0 ;
|
||||
# 1 bit, MUST be 0 unless negotiated otherwise
|
||||
#
|
||||
# Remove rsv1 from this test for deflate development
|
||||
if rsv2 or rsv3 or (rsv1 and not self._compress):
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
if opcode > 0x7 and fin == 0:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received fragmented control frame",
|
||||
)
|
||||
|
||||
has_mask = (second_byte >> 7) & 1
|
||||
length = second_byte & 0x7F
|
||||
|
||||
# Control frames MUST have a payload
|
||||
# length of 125 bytes or less
|
||||
if opcode > 0x7 and length > 125:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Control frame payload cannot be larger than 125 bytes",
|
||||
)
|
||||
|
||||
# Set compress status if last package is FIN
|
||||
# OR set compress status if this is first fragment
|
||||
# Raise error if not first fragment with rsv1 = 0x1
|
||||
if self._frame_fin or self._compressed == COMPRESSED_NOT_SET:
|
||||
self._compressed = COMPRESSED_TRUE if rsv1 else COMPRESSED_FALSE
|
||||
elif rsv1:
|
||||
raise WebSocketError(
|
||||
WSCloseCode.PROTOCOL_ERROR,
|
||||
"Received frame with non-zero reserved bits",
|
||||
)
|
||||
|
||||
self._frame_fin = bool(fin)
|
||||
self._frame_opcode = opcode
|
||||
self._has_mask = bool(has_mask)
|
||||
self._payload_len_flag = length
|
||||
self._state = READ_PAYLOAD_LENGTH
|
||||
|
||||
# read payload length
|
||||
if self._state == READ_PAYLOAD_LENGTH:
|
||||
len_flag = self._payload_len_flag
|
||||
if len_flag == 126:
|
||||
if data_len - start_pos < 2:
|
||||
break
|
||||
first_byte = data_cstr[start_pos]
|
||||
second_byte = data_cstr[start_pos + 1]
|
||||
start_pos += 2
|
||||
self._payload_bytes_to_read = first_byte << 8 | second_byte
|
||||
elif len_flag > 126:
|
||||
if data_len - start_pos < 8:
|
||||
break
|
||||
self._payload_bytes_to_read = UNPACK_LEN3(data, start_pos)[0]
|
||||
start_pos += 8
|
||||
else:
|
||||
self._payload_bytes_to_read = len_flag
|
||||
|
||||
self._state = READ_PAYLOAD_MASK if self._has_mask else READ_PAYLOAD
|
||||
|
||||
# read payload mask
|
||||
if self._state == READ_PAYLOAD_MASK:
|
||||
if data_len - start_pos < 4:
|
||||
break
|
||||
self._frame_mask = data_cstr[start_pos : start_pos + 4]
|
||||
start_pos += 4
|
||||
self._state = READ_PAYLOAD
|
||||
|
||||
if self._state == READ_PAYLOAD:
|
||||
chunk_len = data_len - start_pos
|
||||
if self._payload_bytes_to_read >= chunk_len:
|
||||
f_end_pos = data_len
|
||||
self._payload_bytes_to_read -= chunk_len
|
||||
else:
|
||||
f_end_pos = start_pos + self._payload_bytes_to_read
|
||||
self._payload_bytes_to_read = 0
|
||||
|
||||
had_fragments = self._frame_payload_len
|
||||
self._frame_payload_len += f_end_pos - start_pos
|
||||
f_start_pos = start_pos
|
||||
start_pos = f_end_pos
|
||||
|
||||
if self._payload_bytes_to_read != 0:
|
||||
# If we don't have a complete frame, we need to save the
|
||||
# data for the next call to feed_data.
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
break
|
||||
|
||||
payload: Union[bytes, bytearray]
|
||||
if had_fragments:
|
||||
# We have to join the payload fragments get the payload
|
||||
self._payload_fragments.append(data_cstr[f_start_pos:f_end_pos])
|
||||
if self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = bytearray(b"".join(self._payload_fragments))
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = b"".join(self._payload_fragments)
|
||||
self._payload_fragments.clear()
|
||||
elif self._has_mask:
|
||||
assert self._frame_mask is not None
|
||||
payload_bytearray = data_cstr[f_start_pos:f_end_pos] # type: ignore[assignment]
|
||||
if type(payload_bytearray) is not bytearray: # pragma: no branch
|
||||
# Cython will do the conversion for us
|
||||
# but we need to do it for Python and we
|
||||
# will always get here in Python
|
||||
payload_bytearray = bytearray(payload_bytearray)
|
||||
websocket_mask(self._frame_mask, payload_bytearray)
|
||||
payload = payload_bytearray
|
||||
else:
|
||||
payload = data_cstr[f_start_pos:f_end_pos]
|
||||
|
||||
self._handle_frame(
|
||||
self._frame_fin, self._frame_opcode, payload, self._compressed
|
||||
)
|
||||
self._frame_payload_len = 0
|
||||
self._state = READ_HEADER
|
||||
|
||||
# XXX: Cython needs slices to be bounded, so we can't omit the slice end here.
|
||||
self._tail = data_cstr[start_pos:data_len] if start_pos < data_len else b""
|
||||
@@ -0,0 +1,262 @@
|
||||
"""WebSocket protocol versions 13 and 8."""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import sys
|
||||
from functools import partial
|
||||
from typing import Final, Optional, Set, Union
|
||||
|
||||
from ..base_protocol import BaseProtocol
|
||||
from ..client_exceptions import ClientConnectionResetError
|
||||
from ..compression_utils import ZLibBackend, ZLibCompressor
|
||||
from .helpers import (
|
||||
MASK_LEN,
|
||||
MSG_SIZE,
|
||||
PACK_CLOSE_CODE,
|
||||
PACK_LEN1,
|
||||
PACK_LEN2,
|
||||
PACK_LEN3,
|
||||
PACK_RANDBITS,
|
||||
websocket_mask,
|
||||
)
|
||||
from .models import WS_DEFLATE_TRAILING, WSMsgType
|
||||
|
||||
DEFAULT_LIMIT: Final[int] = 2**16
|
||||
|
||||
# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
|
||||
# Control frames (ping, pong, close) are never compressed
|
||||
WS_CONTROL_FRAME_OPCODE: Final[int] = 8
|
||||
|
||||
# For websockets, keeping latency low is extremely important as implementations
|
||||
# generally expect to be able to send and receive messages quickly. We use a
|
||||
# larger chunk size to reduce the number of executor calls and avoid task
|
||||
# creation overhead, since both are significant sources of latency when chunks
|
||||
# are small. A size of 16KiB was chosen as a balance between avoiding task
|
||||
# overhead and not blocking the event loop too long with synchronous compression.
|
||||
|
||||
WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
|
||||
|
||||
|
||||
class WebSocketWriter:
|
||||
"""WebSocket writer.
|
||||
|
||||
The writer is responsible for sending messages to the client. It is
|
||||
created by the protocol when a connection is established. The writer
|
||||
should avoid implementing any application logic and should only be
|
||||
concerned with the low-level details of the WebSocket protocol.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
protocol: BaseProtocol,
|
||||
transport: asyncio.Transport,
|
||||
*,
|
||||
use_mask: bool = False,
|
||||
limit: int = DEFAULT_LIMIT,
|
||||
random: random.Random = random.Random(),
|
||||
compress: int = 0,
|
||||
notakeover: bool = False,
|
||||
) -> None:
|
||||
"""Initialize a WebSocket writer."""
|
||||
self.protocol = protocol
|
||||
self.transport = transport
|
||||
self.use_mask = use_mask
|
||||
self.get_random_bits = partial(random.getrandbits, 32)
|
||||
self.compress = compress
|
||||
self.notakeover = notakeover
|
||||
self._closing = False
|
||||
self._limit = limit
|
||||
self._output_size = 0
|
||||
self._compressobj: Optional[ZLibCompressor] = None
|
||||
self._send_lock = asyncio.Lock()
|
||||
self._background_tasks: Set[asyncio.Task[None]] = set()
|
||||
|
||||
async def send_frame(
|
||||
self, message: bytes, opcode: int, compress: Optional[int] = None
|
||||
) -> None:
|
||||
"""Send a frame over the websocket with message as its payload."""
|
||||
if self._closing and not (opcode & WSMsgType.CLOSE):
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
|
||||
# Non-compressed frames don't need lock or shield
|
||||
self._write_websocket_frame(message, opcode, 0)
|
||||
elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
|
||||
# Small compressed payloads - compress synchronously in event loop
|
||||
# We need the lock even though sync compression has no await points.
|
||||
# This prevents small frames from interleaving with large frames that
|
||||
# compress in the executor, avoiding compressor state corruption.
|
||||
async with self._send_lock:
|
||||
self._send_compressed_frame_sync(message, opcode, compress)
|
||||
else:
|
||||
# Large compressed frames need shield to prevent corruption
|
||||
# For large compressed frames, the entire compress+send
|
||||
# operation must be atomic. If cancelled after compression but
|
||||
# before send, the compressor state would be advanced but data
|
||||
# not sent, corrupting subsequent frames.
|
||||
# Create a task to shield from cancellation
|
||||
# The lock is acquired inside the shielded task so the entire
|
||||
# operation (lock + compress + send) completes atomically.
|
||||
# Use eager_start on Python 3.12+ to avoid scheduling overhead
|
||||
loop = asyncio.get_running_loop()
|
||||
coro = self._send_compressed_frame_async_locked(message, opcode, compress)
|
||||
if sys.version_info >= (3, 12):
|
||||
send_task = asyncio.Task(coro, loop=loop, eager_start=True)
|
||||
else:
|
||||
send_task = loop.create_task(coro)
|
||||
# Keep a strong reference to prevent garbage collection
|
||||
self._background_tasks.add(send_task)
|
||||
send_task.add_done_callback(self._background_tasks.discard)
|
||||
await asyncio.shield(send_task)
|
||||
|
||||
# It is safe to return control to the event loop when using compression
|
||||
# after this point as we have already sent or buffered all the data.
|
||||
# Once we have written output_size up to the limit, we call the
|
||||
# drain helper which waits for the transport to be ready to accept
|
||||
# more data. This is a flow control mechanism to prevent the buffer
|
||||
# from growing too large. The drain helper will return right away
|
||||
# if the writer is not paused.
|
||||
if self._output_size > self._limit:
|
||||
self._output_size = 0
|
||||
if self.protocol._paused:
|
||||
await self.protocol._drain_helper()
|
||||
|
||||
def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
|
||||
"""
|
||||
Write a websocket frame to the transport.
|
||||
|
||||
This method handles frame header construction, masking, and writing to transport.
|
||||
It does not handle compression or flow control - those are the responsibility
|
||||
of the caller.
|
||||
"""
|
||||
msg_length = len(message)
|
||||
|
||||
use_mask = self.use_mask
|
||||
mask_bit = 0x80 if use_mask else 0
|
||||
|
||||
# Depending on the message length, the header is assembled differently.
|
||||
# The first byte is reserved for the opcode and the RSV bits.
|
||||
first_byte = 0x80 | rsv | opcode
|
||||
if msg_length < 126:
|
||||
header = PACK_LEN1(first_byte, msg_length | mask_bit)
|
||||
header_len = 2
|
||||
elif msg_length < 65536:
|
||||
header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
|
||||
header_len = 4
|
||||
else:
|
||||
header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
|
||||
header_len = 10
|
||||
|
||||
if self.transport.is_closing():
|
||||
raise ClientConnectionResetError("Cannot write to closing transport")
|
||||
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
|
||||
# If we are using a mask, we need to generate it randomly
|
||||
# and apply it to the message before sending it. A mask is
|
||||
# a 32-bit value that is applied to the message using a
|
||||
# bitwise XOR operation. It is used to prevent certain types
|
||||
# of attacks on the websocket protocol. The mask is only used
|
||||
# when aiohttp is acting as a client. Servers do not use a mask.
|
||||
if use_mask:
|
||||
mask = PACK_RANDBITS(self.get_random_bits())
|
||||
message = bytearray(message)
|
||||
websocket_mask(mask, message)
|
||||
self.transport.write(header + mask + message)
|
||||
self._output_size += MASK_LEN
|
||||
elif msg_length > MSG_SIZE:
|
||||
self.transport.write(header)
|
||||
self.transport.write(message)
|
||||
else:
|
||||
self.transport.write(header + message)
|
||||
|
||||
self._output_size += header_len + msg_length
|
||||
|
||||
def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
|
||||
"""Get or create a compressor object for the given compression level."""
|
||||
if compress:
|
||||
# Do not set self._compress if compressing is for this frame
|
||||
return ZLibCompressor(
|
||||
level=ZLibBackend.Z_BEST_SPEED,
|
||||
wbits=-compress,
|
||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||||
)
|
||||
if not self._compressobj:
|
||||
self._compressobj = ZLibCompressor(
|
||||
level=ZLibBackend.Z_BEST_SPEED,
|
||||
wbits=-self.compress,
|
||||
max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
|
||||
)
|
||||
return self._compressobj
|
||||
|
||||
def _send_compressed_frame_sync(
|
||||
self, message: bytes, opcode: int, compress: Optional[int]
|
||||
) -> None:
|
||||
"""
|
||||
Synchronous send for small compressed frames.
|
||||
|
||||
This is used for small compressed payloads that compress synchronously in the event loop.
|
||||
Since there are no await points, this is inherently cancellation-safe.
|
||||
"""
|
||||
# RSV are the reserved bits in the frame header. They are used to
|
||||
# indicate that the frame is using an extension.
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
compressobj = self._get_compressor(compress)
|
||||
# (0x40) RSV1 is set for compressed frames
|
||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
||||
self._write_websocket_frame(
|
||||
(
|
||||
compressobj.compress_sync(message)
|
||||
+ compressobj.flush(
|
||||
ZLibBackend.Z_FULL_FLUSH
|
||||
if self.notakeover
|
||||
else ZLibBackend.Z_SYNC_FLUSH
|
||||
)
|
||||
).removesuffix(WS_DEFLATE_TRAILING),
|
||||
opcode,
|
||||
0x40,
|
||||
)
|
||||
|
||||
async def _send_compressed_frame_async_locked(
|
||||
self, message: bytes, opcode: int, compress: Optional[int]
|
||||
) -> None:
|
||||
"""
|
||||
Async send for large compressed frames with lock.
|
||||
|
||||
Acquires the lock and compresses large payloads asynchronously in
|
||||
the executor. The lock is held for the entire operation to ensure
|
||||
the compressor state is not corrupted by concurrent sends.
|
||||
|
||||
MUST be run shielded from cancellation. If cancelled after
|
||||
compression but before sending, the compressor state would be
|
||||
advanced but data not sent, corrupting subsequent frames.
|
||||
"""
|
||||
async with self._send_lock:
|
||||
# RSV are the reserved bits in the frame header. They are used to
|
||||
# indicate that the frame is using an extension.
|
||||
# https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
|
||||
compressobj = self._get_compressor(compress)
|
||||
# (0x40) RSV1 is set for compressed frames
|
||||
# https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
|
||||
self._write_websocket_frame(
|
||||
(
|
||||
await compressobj.compress(message)
|
||||
+ compressobj.flush(
|
||||
ZLibBackend.Z_FULL_FLUSH
|
||||
if self.notakeover
|
||||
else ZLibBackend.Z_SYNC_FLUSH
|
||||
)
|
||||
).removesuffix(WS_DEFLATE_TRAILING),
|
||||
opcode,
|
||||
0x40,
|
||||
)
|
||||
|
||||
async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
|
||||
"""Close the websocket, sending the specified code and message."""
|
||||
if isinstance(message, str):
|
||||
message = message.encode("utf-8")
|
||||
try:
|
||||
await self.send_frame(
|
||||
PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
|
||||
)
|
||||
finally:
|
||||
self._closing = True
|
||||
Reference in New Issue
Block a user