Skip to content

Commit

Permalink
Add underlying_transport to WSTransport
Browse files Browse the repository at this point in the history
  • Loading branch information
taras committed Aug 17, 2024
1 parent 5f7b97f commit 5af87a1
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 125 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ The API follows the low-level `transport/protocol design from asyncio <https://d
It passes frames instead of messages to a user handler. A message can potentially consist of multiple frames but it is up to user to choose the best strategy for merging them.
Same principle applies for compression and flow control. User can implement their own strategies using the most appropriate tools.

That being said that the most common use-case is when messages and frames are the same, i.e. a message consists of only a single frame, and no compression is being used.
That being said the most common use-case is when messages and frames are the same, i.e. a message consists of only a single frame, and no compression is being used.

Getting started
===============
Expand Down
7 changes: 7 additions & 0 deletions docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,13 @@ Classes
.. autoclass:: WSTransport
:members:

.. py:attribute:: underlying_transport
:type: asyncio.Transport

Underlying TCP or SSL transport. Can be used to set buffer limits, check connection state, etc.

**Please don't use it to send data. Use only WSTransport.send_* methods to send frames.**

.. py:method:: send_reuse_external_buffer(WSMsgType msg_type, char* message, size_t message_size)
:param msg_type: Message type
Expand Down
20 changes: 8 additions & 12 deletions picows/picows.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -68,21 +68,14 @@ cdef class WSFrame:
cpdef bytes get_close_message(self)


cdef class WSFrameBuilder:
cdef:
MemoryBuffer _write_buf
bint is_client_side

cdef prepare_frame_in_external_buffer(self, WSMsgType msg_type, uint8_t* msg_ptr, size_t msg_length)
cpdef prepare_frame(self, WSMsgType msg_type, message)


cdef class WSTransport:
cdef:
object _transport #: Optional[asyncio.Transport]
readonly object underlying_transport #: asyncio.Transport

object _logger #: Logger
object _disconnected_future
WSFrameBuilder _frame_builder
object _disconnected_future #: asyncio.Future
MemoryBuffer _write_buf
bint _is_client_side

cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* message, size_t message_size)
cpdef send(self, WSMsgType msg_type, message)
Expand All @@ -95,6 +88,9 @@ cdef class WSTransport:
cdef send_http_handshake_response(self, bytes accept_val)
cdef mark_disconnected(self)

cdef bytes _prepare_frame_in_external_buffer(self, WSMsgType msg_type, uint8_t* msg_ptr, size_t msg_length)
cdef bytes _prepare_frame(self, WSMsgType msg_type, message)


cdef class WSListener:
cpdef on_ws_connected(self, WSTransport transport)
Expand Down
226 changes: 114 additions & 112 deletions picows/picows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ from typing import cast, Tuple, Optional, Callable, Union
cimport cython

from cpython.bytes cimport PyBytes_GET_SIZE, PyBytes_AS_STRING, PyBytes_FromStringAndSize, PyBytes_CheckExact
from cpython.bytearray cimport PyByteArray_AS_STRING, PyByteArray_GET_SIZE, PyByteArray_CheckExact
from cpython.memoryview cimport PyMemoryView_FromMemory
from cpython.mem cimport PyMem_Malloc, PyMem_Realloc, PyMem_Free
from cpython.buffer cimport PyBUF_WRITE, PyBUF_READ, PyBUF_SIMPLE, PyObject_GetBuffer, PyBuffer_Release
Expand Down Expand Up @@ -242,107 +243,6 @@ cdef class MemoryBuffer:
self.size = new_size


cdef class WSFrameBuilder:
def __init__(self, bint is_client_side):
self._write_buf = MemoryBuffer(1024)
self.is_client_side = is_client_side

cdef prepare_frame_in_external_buffer(self, WSMsgType msg_type, uint8_t* msg_ptr, size_t msg_length):
cdef:
# Just fin byte and msg_type
# No support for rsv/compression
uint8_t* header_ptr = msg_ptr
uint64_t extended_payload_length_64
uint32_t mask = <uint32_t> rand() if self.is_client_side else 0
uint16_t extended_payload_length_16
uint8_t first_byte = 0x80 | <uint8_t> msg_type
uint8_t second_byte = 0x80 if self.is_client_side else 0

if msg_length < 126:
header_ptr -= 2
header_ptr[0] = first_byte
header_ptr[1] = second_byte | <uint8_t>msg_length
elif msg_length < (1 << 16):
header_ptr -= 4
header_ptr[0] = first_byte
header_ptr[1] = second_byte | 126
extended_payload_length_16 = htons(<uint16_t> msg_length)
(<uint16_t*>(header_ptr + 2))[0] = extended_payload_length_16
else:
header_ptr -= 10
header_ptr[0] = first_byte
header_ptr[1] = second_byte | 127
extended_payload_length_64 = htobe64(<uint64_t> msg_length)
(<uint64_t*> (header_ptr + 2))[0] = extended_payload_length_64

if self.is_client_side:
_mask_payload(msg_ptr, msg_length, mask)

cdef Py_ssize_t total_length = msg_length + (msg_ptr - header_ptr)

return PyBytes_FromStringAndSize(<char*>header_ptr, total_length)
# return PyMemoryView_FromMemory(header_ptr, total_length, PyBUF_READ)

cpdef prepare_frame(self, WSMsgType msg_type, message):
"""Send a frame over the websocket with message as its payload."""
cdef:
Py_buffer msg_buffer
char* msg_ptr
Py_ssize_t msg_length

if message is None:
msg_ptr = b""
msg_length = 0
elif PyBytes_CheckExact(message):
# Just a small optimization for bytes type as the most used type for sending data
msg_ptr = PyBytes_AS_STRING(message)
msg_length = PyBytes_GET_SIZE(message)
else:
PyObject_GetBuffer(message, &msg_buffer, PyBUF_SIMPLE)
msg_ptr = <char*>msg_buffer.buf
msg_length = msg_buffer.len
# We can already release because we still keep the reference to the message
PyBuffer_Release(&msg_buffer)

cdef:
# Just fin byte and msg_type
# No support for rsv/compression
uint8_t first_byte = 0x80 | <uint8_t>msg_type
uint8_t second_byte = 0x80 if self.is_client_side else 0
uint32_t mask = <uint32_t>rand() if self.is_client_side else 0
uint16_t extended_payload_length_16
uint64_t extended_payload_length_64
Py_ssize_t payload_start_idx

self._write_buf.clear()
self._write_buf.push_back(first_byte)

if msg_length < 126:
second_byte |= <uint8_t>msg_length
self._write_buf.push_back(second_byte)
elif msg_length < (1 << 16):
second_byte |= 126
self._write_buf.push_back(second_byte)
extended_payload_length_16 = htons(<uint16_t>msg_length)
self._write_buf.append(<const char*>&extended_payload_length_16, 2)
else:
second_byte |= 127
extended_payload_length_64 = htobe64(<uint64_t>msg_length)
self._write_buf.push_back(second_byte)
self._write_buf.append(<const char*>&extended_payload_length_64, 8)

if self.is_client_side:
self._write_buf.append(<const char*>&mask, 4)
payload_start_idx = self._write_buf.size
self._write_buf.append(msg_ptr, msg_length)
_mask_payload(<uint8_t*>self._write_buf.data + payload_start_idx, msg_length, mask)
else:
self._write_buf.append(msg_ptr, msg_length)

# return PyMemoryView_FromMemory(<char*>&self._write_buf[0], self._write_buf.size(), PyBUF_READ)
return PyBytes_FromStringAndSize(self._write_buf.data, self._write_buf.size)


cdef class WSListener:
"""
Base class for user handlers.
Expand Down Expand Up @@ -397,14 +297,15 @@ cdef class WSListener:

cdef class WSTransport:
def __init__(self, bint is_client_side, underlying_transport, logger, loop):
self._transport = underlying_transport
self.underlying_transport = underlying_transport
self._logger = logger
self._disconnected_future = loop.create_future()
self._frame_builder = WSFrameBuilder(is_client_side)
self._write_buf = MemoryBuffer(1024)
self._is_client_side = is_client_side

cdef send_reuse_external_buffer(self, WSMsgType msg_type, char* message, size_t message_size):
frame = self._frame_builder.prepare_frame_in_external_buffer(msg_type, <uint8_t*>message, message_size)
self._transport.write(frame)
frame = self._prepare_frame_in_external_buffer(msg_type, <uint8_t*>message, message_size)
self.underlying_transport.write(frame)

cpdef send(self, WSMsgType msg_type, message):
"""
Expand All @@ -413,8 +314,8 @@ cdef class WSTransport:
Send a frame over websocket with a message as its payload.
"""
frame = self._frame_builder.prepare_frame(msg_type, message)
self._transport.write(frame)
frame = self._prepare_frame(msg_type, message)
self.underlying_transport.write(frame)

cpdef send_ping(self, message=None):
"""
Expand All @@ -441,7 +342,7 @@ cdef class WSTransport:
This method doesn't disconnect the underlying transport.
Does nothing if the underlying transport is already disconnected.
"""
if self._transport.is_closing():
if self.underlying_transport.is_closing():
return

cdef bytes close_payload = struct.pack("!H", <uint16_t>close_code)
Expand All @@ -455,9 +356,9 @@ cdef class WSTransport:
Immediately disconnect the underlying transport.
It is ok to call this method multiple times. It does nothing if the transport is already disconnected.
"""
if self._transport.is_closing():
if self.underlying_transport.is_closing():
return
self._transport.close()
self.underlying_transport.close()

async def wait_until_closed(self):
"""
Expand All @@ -475,7 +376,7 @@ cdef class WSTransport:
b"Sec-WebSocket-Version: 13\r\n"
b"Sec-WebSocket-Key: %b\r\n"
b"\r\n" % (ws_path, host_port, websocket_key_b64))
self._transport.write(initial_handshake)
self.underlying_transport.write(initial_handshake)

cdef send_http_handshake_response(self, bytes accept_val):
cdef bytes handshake_response = (b"HTTP/1.1 101 Switching Protocols\r\n"
Expand All @@ -485,12 +386,113 @@ cdef class WSTransport:
b"\r\n" % (accept_val,))

self._logger.log(PICOWS_DEBUG_LL, "Send upgrade response: %s", handshake_response)
self._transport.write(handshake_response)
self.underlying_transport.write(handshake_response)

cdef mark_disconnected(self):
if not self._disconnected_future.done():
self._disconnected_future.set_result(None)

cdef bytes _prepare_frame_in_external_buffer(self, WSMsgType msg_type, uint8_t* msg_ptr, size_t msg_length):
cdef:
# Just fin byte and msg_type
# No support for rsv/compression
uint8_t* header_ptr = msg_ptr
uint64_t extended_payload_length_64
uint32_t mask = <uint32_t> rand() if self._is_client_side else 0
uint16_t extended_payload_length_16
uint8_t first_byte = 0x80 | <uint8_t> msg_type
uint8_t second_byte = 0x80 if self._is_client_side else 0

if msg_length < 126:
header_ptr -= 2
header_ptr[0] = first_byte
header_ptr[1] = second_byte | <uint8_t>msg_length
elif msg_length < (1 << 16):
header_ptr -= 4
header_ptr[0] = first_byte
header_ptr[1] = second_byte | 126
extended_payload_length_16 = htons(<uint16_t> msg_length)
(<uint16_t*>(header_ptr + 2))[0] = extended_payload_length_16
else:
header_ptr -= 10
header_ptr[0] = first_byte
header_ptr[1] = second_byte | 127
extended_payload_length_64 = htobe64(<uint64_t> msg_length)
(<uint64_t*> (header_ptr + 2))[0] = extended_payload_length_64

if self._is_client_side:
_mask_payload(msg_ptr, msg_length, mask)

cdef Py_ssize_t total_length = msg_length + (msg_ptr - header_ptr)

return PyBytes_FromStringAndSize(<char*>header_ptr, total_length)

cdef bytes _prepare_frame(self, WSMsgType msg_type, message):
"""Send a frame over the websocket with message as its payload."""
cdef:
Py_buffer msg_buffer
char* msg_ptr
Py_ssize_t msg_length

if message is None:
msg_ptr = b""
msg_length = 0
elif PyBytes_CheckExact(message):
# Just a small optimization for bytes type as the most used type for sending data
msg_ptr = PyBytes_AS_STRING(message)
msg_length = PyBytes_GET_SIZE(message)
elif PyByteArray_CheckExact(message):
# Just a small optimization for bytes type as the most used type for sending data
msg_ptr = PyByteArray_AS_STRING(message)
msg_length = PyByteArray_GET_SIZE(message)
else:
PyObject_GetBuffer(message, &msg_buffer, PyBUF_SIMPLE)
msg_ptr = <char*>msg_buffer.buf
msg_length = msg_buffer.len
# We can already release because we still keep the reference to the message
PyBuffer_Release(&msg_buffer)

cdef:
# Just fin byte and msg_type
# No support for rsv/compression
uint8_t first_byte = 0x80 | <uint8_t>msg_type
uint8_t second_byte = 0x80 if self._is_client_side else 0
uint32_t mask = <uint32_t>rand() if self._is_client_side else 0
uint16_t extended_payload_length_16
uint64_t extended_payload_length_64
Py_ssize_t payload_start_idx

self._write_buf.clear()
self._write_buf.push_back(first_byte)

if msg_length < 126:
second_byte |= <uint8_t>msg_length
self._write_buf.push_back(second_byte)
elif msg_length < (1 << 16):
second_byte |= 126
self._write_buf.push_back(second_byte)
extended_payload_length_16 = htons(<uint16_t>msg_length)
self._write_buf.append(<const char*>&extended_payload_length_16, 2)
else:
second_byte |= 127
extended_payload_length_64 = htobe64(<uint64_t>msg_length)
self._write_buf.push_back(second_byte)
self._write_buf.append(<const char*>&extended_payload_length_64, 8)

if self._is_client_side:
self._write_buf.append(<const char*>&mask, 4)
payload_start_idx = self._write_buf.size
self._write_buf.append(msg_ptr, msg_length)
_mask_payload(<uint8_t*>self._write_buf.data + payload_start_idx, msg_length, mask)
else:
self._write_buf.append(msg_ptr, msg_length)

# Unfortunately we can't return a memoryview from write buffer like this
# return PyMemoryView_FromMemory(<char *> &self._write_buf[0], self._write_buf.size(), PyBUF_READ)
# because uvloop.Transport.write may delay sending and it doesn't copy the content of the buffer

return PyBytes_FromStringAndSize(self._write_buf.data, self._write_buf.size)


cdef class WSProtocol:
cdef:
Expand Down

0 comments on commit 5af87a1

Please sign in to comment.