Source code for proxyprotocol.detect


from __future__ import annotations

from . import ProxyProtocolWantRead, ProxyProtocol, \
    ProxyProtocolSyntaxError, ProxyProtocolIncompleteError
from .result import ProxyResult
from .v1 import ProxyProtocolV1
from .v2 import ProxyProtocolV2

__all__ = ['ProxyProtocolDetect']


[docs] class ProxyProtocolDetect(ProxyProtocol): """A PROXY protocol implementation that detects the version based on the first 8 bytes from the stream and passes it on to the version parser. This adds minimal overhead and *should* be used instead of a specific version. Args: versions: Override the default set of PROXY protocol implementations. """ __slots__ = ['versions'] def __init__(self, *versions: ProxyProtocol) -> None: super().__init__() self.versions = versions or [ProxyProtocolV2(), ProxyProtocolV1()]
[docs] def is_valid(self, signature: bytes) -> bool: return any(v.is_valid(signature) for v in self.versions)
[docs] def choose_version(self, signature: bytes) -> ProxyProtocol: """Choose the PROXY protocol version based on the 8-byte signature. Args: signature: The signature bytestring. """ for version in self.versions: if version.is_valid(signature): return version raise ProxyProtocolSyntaxError( 'Unrecognized proxy protocol version signature')
[docs] def unpack(self, data: bytes) -> ProxyResult: if len(data) < 8: want_read = ProxyProtocolWantRead(8 - len(data)) raise ProxyProtocolIncompleteError(want_read) pp = self.choose_version(data[0:8]) return pp.unpack(data)
[docs] def pack(self, result: ProxyResult) -> bytes: for version in self.versions: try: return version.pack(result) except (KeyError, ValueError): pass else: raise ValueError('Could not build PROXY protocol header')