Source code for swimprotocol.udp.protocol
from __future__ import annotations
import asyncio
from asyncio import Queue, Protocol, DatagramProtocol
from concurrent.futures import ThreadPoolExecutor
from typing import Final, Optional
from .pack import UdpPack
from ..packet import Packet
from ..tasks import TaskOwner
__all__ = ['BaseProtocol', 'UdpProtocol', 'TcpProtocol']
[docs]class BaseProtocol(TaskOwner):
"""Base class of :class:`UdpProtocol` and :class:`TcpProtocol`. Each will
call :meth:`.handle_packet` upon receipt of a full packet.
Args:
thread_pool: A thread pool for CPU-heavy operations.
"""
def __init__(self, thread_pool: ThreadPoolExecutor, udp_pack: UdpPack,
recv_queue: Queue[Packet]) -> None:
super().__init__()
self.thread_pool: Final = thread_pool
self.udp_pack: Final = udp_pack
self.recv_queue: Final = recv_queue
[docs] async def handle_packet(self, data: bytes) -> None:
"""Parse the *data* into a packet and push it onto the worker
:attr:`~swimprotocol.worker.Worker.recv_queue`.
Args:
data: The bytes representing a packet to be parsed by
:class:`~swimprotocol.udp.pack.UdpPack`.
"""
loop = asyncio.get_running_loop()
packet = await loop.run_in_executor(
self.thread_pool, self.udp_pack.unpack, data)
if packet is None:
return
await self.recv_queue.put(packet)
[docs]class UdpProtocol(BaseProtocol, DatagramProtocol):
"""Implements :class:`~asyncio.DatagramProtocol` to receive SWIM protocol
packets by UDP.
Each packet received is passed directly to :meth:`.handle_packet`.
"""
[docs] def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
self.run_subtask(self.handle_packet(data))
[docs]class TcpProtocol(BaseProtocol, Protocol):
"""Implements :class:`~asyncio.Protocol` to receive SWIM protocol packets
by TCP.
All data received is accumulated until the connection is closed, with the
result treated as a complete packet and sent to :meth:`.handle_packet`.
"""
def __init__(self, thread_pool: ThreadPoolExecutor, udp_pack: UdpPack,
recv_queue: Queue[Packet]) -> None:
super().__init__(thread_pool, udp_pack, recv_queue)
self._buf = bytearray()
[docs] def data_received(self, data: bytes) -> None:
self._buf += data
[docs] def connection_lost(self, exc: Optional[Exception]) -> None:
data = self._buf
self._buf = bytearray()
if exc is not None:
return
self.run_subtask(self.handle_packet(data))