Source code for proxyprotocol.reader
from __future__ import annotations
import asyncio
from asyncio import StreamReader, StreamWriter
from functools import partial
from typing import Any, Awaitable, Callable, Coroutine, Union
from typing_extensions import Final, TypeAlias
from uuid import uuid4
from . import ProxyProtocol, ProxyProtocolIncompleteError, \
ProxyProtocolWantRead
from .result import ProxyResult, ProxyResultUnknown
from .sock import SocketInfo
from .typing import StreamReaderProtocol
__all__ = ['ProxyProtocolReader']
_Callback: TypeAlias = Callable[
[StreamReader, StreamWriter], Awaitable[None]]
_WrappedCallback: TypeAlias = Callable[
[StreamReader, StreamWriter, SocketInfo], Coroutine[Any, Any, None]]
[docs]
class ProxyProtocolReader:
"""Read a PROXY protocol header from a stream.
Args:
pp: The PROXY protocol implementation.
"""
def __init__(self, pp: ProxyProtocol) -> None:
super().__init__()
self.pp: Final = pp
async def _handle_want(self, reader: StreamReaderProtocol,
want_read: ProxyProtocolWantRead) -> bytes:
if want_read.want_bytes is not None:
return await reader.readexactly(want_read.want_bytes)
elif want_read.want_line:
return await reader.readline()
raise ValueError('No conditions given to complete parsing')
[docs]
async def read(self, reader: StreamReaderProtocol) -> ProxyResult:
"""Read a complete PROXY protocol header from the input stream and
return the result.
Args:
reader: The input stream.
"""
data = bytearray()
want_read: ProxyProtocolWantRead
while True:
try:
with memoryview(data) as view:
return self.pp.unpack(view)
except ProxyProtocolIncompleteError as exc:
want_read = exc.want_read
data += await self._handle_want(reader, want_read)
[docs]
def get_callback(self, callback: _WrappedCallback,
timeout: Union[None, int, float] = 3) -> _Callback:
"""Get a callback object for use as the *client_connected_cb* argument
to :func:`asyncio.start_server`.
The returned callback will first read the PROXY protocol header before
starting the provided *callback* as a :class:`~asyncio.Task`. The
*callback* argument is similar to *client_connected_cb* but with an
additional positional argument -- the
:class:`~proxyprotocol.sock.SocketInfo` read from the header.
Args:
callback: Async function with arguments ``(reader, writer,
sock_info)`` called after successfully reading the header.
timeout: A timeout in seconds to allow for reading the header.
"""
return partial(self._read_then_call, callback, timeout)
async def _read_then_call(self, callback: _WrappedCallback,
timeout: Union[None, int, float],
reader: StreamReader, writer: StreamWriter) \
-> None:
try:
result = await asyncio.wait_for(self.read(reader), timeout)
except Exception as exc:
writer.close()
result = ProxyResultUnknown(exc)
sock_info = SocketInfo.get(writer, result, unique_id=uuid4().bytes)
asyncio.create_task(callback(reader, writer, sock_info))