from __future__ import annotations
import json
import zlib
from enum import IntEnum, IntFlag
from struct import Struct, error as struct_error
from typing import ClassVar, Any, Hashable, Optional, Iterator, Sequence, \
Mapping, Dict, List
from .checksum import crc32c as crc32c_checksum
from .typing import PeerCert
__all__ = ['Type', 'SSLClient', 'ExtType', 'TLV', 'ProxyProtocolTLV',
'ProxyProtocolSSLTLV', 'ProxyProtocolExtTLV']
[docs]
class Type(IntEnum):
"""The PROXY protocol TLV standard type values.
See Also:
:class:`ProxyProtocolTLV`
"""
PP2_TYPE_ALPN = 0x01
PP2_TYPE_AUTHORITY = 0x02
PP2_TYPE_CRC32C = 0x03
PP2_TYPE_NOOP = 0x04
PP2_TYPE_UNIQUE_ID = 0x05
PP2_TYPE_SSL = 0x20
PP2_TYPE_NETNS = 0x30
PP2_SUBTYPE_SSL_VERSION = 0x21
PP2_SUBTYPE_SSL_CN = 0x22
PP2_SUBTYPE_SSL_CIPHER = 0x23
PP2_SUBTYPE_SSL_SIG_ALG = 0x24
PP2_SUBTYPE_SSL_KEY_ALG = 0x25
PP2_TYPE_MIN_CUSTOM = 0xE0
PP2_TYPE_MAX_CUSTOM = 0xEF
PP2_TYPE_MIN_EXPERIMENT = 0xF0
PP2_TYPE_MAX_EXPERIMENT = 0xF7
PP2_TYPE_MIN_FUTURE = 0xF8
PP2_TYPE_MAX_FUTURE = 0xFF
[docs]
class SSLClient(IntFlag):
"""The PROXY protocol ``PP2_TYPE_SSL`` client flags.
See Also:
:class:`ProxyProtocolSSLTLV`
"""
PP2_CLIENT_SSL = 0x01
PP2_CLIENT_CERT_CONN = 0x02
PP2_CLIENT_CERT_SESS = 0x04
[docs]
class ExtType(IntEnum):
"""Non-standard extension TLV types.
See Also:
:class:`ProxyProtocolExtTLV`
"""
PP2_TYPE_EXT_COMPRESSION = 0x01
PP2_TYPE_EXT_SECRET_BITS = 0x02
PP2_TYPE_EXT_PEERCERT = 0x03
PP2_TYPE_EXT_DNSBL = 0x04
[docs]
class TLV(Mapping[int, bytes], Hashable):
"""Defines the basic parsing and structure of a PROXY protocol TLV vector.
The unpacked TLV values are available as dict-style keys of this object,
e.g. ``tlv[0xE2]``. To serialize back to a bytestring, use ``bytes(tlv)``.
Args:
data: TLV data to parse.
init: A mapping of types to values to initialize the TLV, such as
another :class:`TLV`.
"""
__slots__ = ['_tlv', '_frozen']
_fmt = Struct('!BH')
def __init__(self, data: bytes = b'',
init: Optional[Mapping[int, bytes]] = None) -> None:
super().__init__()
self._tlv = self._unpack(data)
if init is not None:
self._tlv.update(init)
self._frozen = self._freeze()
def _freeze(self) -> Hashable:
return frozenset(self._tlv.items())
def _unpack(self, data: bytes) -> Dict[int, bytes]:
index = 0
fmt = self._fmt
results: Dict[int, bytes] = {}
while len(data) >= index + fmt.size:
type_num, size = fmt.unpack_from(data, index)
index += fmt.size
results[type_num] = bytes(data[index:index + size])
index += size
return results
def _pack(self) -> bytes:
parts: List[bytes] = []
fmt = self._fmt
tlv = self._tlv
for type_num in range(0x00, 0x100):
val = tlv.get(type_num)
if val is not None:
parts.append(fmt.pack(type_num, len(val)))
parts.append(val)
return b''.join(parts)
@property
def size(self) -> int:
"""The size of the TLV when converted to bytes."""
cur_len = 0
fmt_size = self._fmt.size
tlv = self._tlv
for type_num in range(0x00, 0x100):
val = tlv.get(type_num)
if val is not None:
cur_len += fmt_size
cur_len += len(val)
return cur_len
def __bytes__(self) -> bytes:
return self._pack()
def __getitem__(self, type_num: int) -> bytes:
return self._tlv[type_num]
def __iter__(self) -> Iterator[int]:
return iter(self._tlv)
def __bool__(self) -> bool:
return bool(self._tlv)
def __len__(self) -> int:
return len(self._tlv)
def __hash__(self) -> int:
return hash(self._frozen)
def __eq__(self, other: Any) -> bool:
if isinstance(other, type(self)):
return self._frozen == other._frozen
return super().__eq__(other)
def __repr__(self) -> str:
arg = repr(bytes(self)) if self else ''
return f'{type(self).__name__}({arg})'
[docs]
class ProxyProtocolTLV(TLV):
"""Defines the TLV values that may be appended to a PROXY protocol header.
These values can provide additional information not stored in the address
data. Refer to the PROXY protocol spec for more information about each TLV.
Args:
data: TLV data to parse.
init: A mapping of types to values to initialize the TLV, such as
another :class:`TLV`.
"""
__slots__ = ['_auto_crc32c']
_crc32c_fmt = Struct('!L')
def __init__(self, data: bytes = b'',
init: Optional[Mapping[int, bytes]] = None, *,
alpn: Optional[bytes] = None,
authority: Optional[str] = None,
crc32c: Optional[int] = None,
unique_id: Optional[bytes] = None,
ssl: Optional[ProxyProtocolSSLTLV] = None,
netns: Optional[str] = None,
ext: Optional[ProxyProtocolExtTLV] = None,
auto_crc32c: bool = False) -> None:
results = dict(init or {})
if alpn is not None:
results[Type.PP2_TYPE_ALPN] = alpn
if authority is not None:
results[Type.PP2_TYPE_AUTHORITY] = authority.encode('utf-8')
if crc32c is not None:
results[Type.PP2_TYPE_CRC32C] = self._crc32c_fmt.pack(crc32c)
if unique_id is not None:
results[Type.PP2_TYPE_UNIQUE_ID] = unique_id
if ssl:
results[Type.PP2_TYPE_SSL] = bytes(ssl)
if netns is not None:
results[Type.PP2_TYPE_NETNS] = netns.encode('ascii')
if ext:
results[Type.PP2_TYPE_NOOP] = bytes(ext)
super().__init__(data, results)
self._auto_crc32c = auto_crc32c \
and crc32c is None \
and crc32c_checksum is not None
def _pack(self) -> bytes:
if self._auto_crc32c:
raise ValueError('Cannot convert to bytes with auto_crc32c=True')
return super()._pack()
@property
def _zero_crc32c(self) -> ProxyProtocolTLV:
return ProxyProtocolTLV(init=self, crc32c=0)
@property
def size(self) -> int:
if self.crc32c is None and self._auto_crc32c:
return self._zero_crc32c.size
else:
return super().size
def _compute_checksum(self, before: Sequence[bytes]) -> int:
assert crc32c_checksum is not None
crc = crc32c_checksum(b'')
for data in before:
crc = crc32c_checksum(data, crc)
return crc32c_checksum(bytes(self._zero_crc32c), crc)
[docs]
def with_checksum(self, *before: bytes) -> ProxyProtocolTLV:
"""Return a copy of the current TLV values with the :attr:`.crc32c`
checksum populated according to the rules in the PROXY protocol spec.
Args:
before: The data in the PROXY protocol header before the TLV, which
is included in the checksum.
"""
if not self._auto_crc32c:
return self
crc = self._compute_checksum(before)
return ProxyProtocolTLV(init=self, crc32c=crc)
[docs]
def verify_checksum(self, *before: bytes) -> bool:
"""Verifies the :attr:`.crc32c` checksum, if present, correctly matches
the expected value computed for the PROXY protocol header. If this
method returns False, the connection should likely be aborted.
Args:
before: The data in the PROXY protocol header before the TLV, which
is included in the checksum.
"""
if self.crc32c is None or crc32c_checksum is None:
return True
crc = self._compute_checksum(before)
return self.crc32c == crc
@property
def alpn(self) -> Optional[bytes]:
"""The ``PP2_TYPE_ALPN`` value."""
val = self.get(Type.PP2_TYPE_ALPN)
if val is not None:
return bytes(val)
return None
@property
def authority(self) -> Optional[str]:
"""The ``PP2_TYPE_AUTHORITY`` value."""
val = self.get(Type.PP2_TYPE_AUTHORITY)
if val is not None:
return str(val, 'utf-8')
return None
@property
def crc32c(self) -> Optional[int]:
"""The ``PP2_TYPE_CRC32C`` value."""
val = self.get(Type.PP2_TYPE_CRC32C)
if val is not None:
crc32c, = self._crc32c_fmt.unpack(val)
return int(crc32c)
return None
@property
def unique_id(self) -> bytes:
"""The ``PP2_TYPE_UNIQUE_ID`` value."""
val = self.get(Type.PP2_TYPE_UNIQUE_ID)
if val is not None:
return bytes(val)
return b''
@property
def ssl(self) -> ProxyProtocolSSLTLV:
"""The ``PP2_TYPE_SSL`` value."""
val = self.get(Type.PP2_TYPE_SSL)
if val is not None:
return ProxyProtocolSSLTLV(val)
return ProxyProtocolSSLTLV()
@property
def netns(self) -> Optional[str]:
"""The ``PP2_TYPE_NETNS`` value."""
val = self.get(Type.PP2_TYPE_NETNS)
if val is not None:
return str(val, 'ascii')
return None
@property
def ext(self) -> ProxyProtocolExtTLV:
"""The ``PP2_TYPE_NOOP`` value, possibly parsed as an extension TLV."""
val = self.get(Type.PP2_TYPE_NOOP)
if val is not None:
return ProxyProtocolExtTLV(val)
return ProxyProtocolExtTLV()
[docs]
class ProxyProtocolSSLTLV(TLV):
"""The ``PP2_TYPE_SSL`` TLV, which is prefixed with a struct containing
*client* and *verify* values, then follows with ``PP2_SUBTYPE_SSL_*`` TLVs.
Args:
data: TLV data to parse.
init: A mapping of types to values to initialize the TLV, such as
another :class:`TLV`.
"""
__slots__ = ['_client', '_verify']
_prefix_fmt = Struct('!BL')
def __init__(self, data: bytes = b'',
init: Optional[Mapping[int, bytes]] = None, *,
has_ssl: Optional[bool] = None,
has_cert_conn: Optional[bool] = None,
has_cert_sess: Optional[bool] = None,
verified: Optional[bool] = None,
version: Optional[str] = None,
cn: Optional[str] = None,
cipher: Optional[str] = None,
sig_alg: Optional[str] = None,
key_alg: Optional[str] = None) -> None:
self._client = 0
self._verify = 1
results = dict(init or {})
if version is not None:
results[Type.PP2_SUBTYPE_SSL_VERSION] = version.encode('ascii')
if cn is not None:
results[Type.PP2_SUBTYPE_SSL_CN] = cn.encode('utf-8')
if cipher is not None:
results[Type.PP2_SUBTYPE_SSL_CIPHER] = cipher.encode('ascii')
if sig_alg is not None:
results[Type.PP2_SUBTYPE_SSL_SIG_ALG] = sig_alg.encode('ascii')
if key_alg is not None:
results[Type.PP2_SUBTYPE_SSL_KEY_ALG] = key_alg.encode('ascii')
super().__init__(data, results)
if has_ssl is True:
self._client |= SSLClient.PP2_CLIENT_SSL
elif has_ssl is False:
self._client &= ~SSLClient.PP2_CLIENT_SSL
if has_cert_conn is True:
self._client |= SSLClient.PP2_CLIENT_CERT_CONN
elif has_cert_conn is False:
self._client &= ~SSLClient.PP2_CLIENT_CERT_CONN
if has_cert_sess is True:
self._client |= SSLClient.PP2_CLIENT_CERT_SESS
elif has_cert_sess is False:
self._client &= ~SSLClient.PP2_CLIENT_CERT_SESS
if verified is not None:
self._verify = int(not verified)
def _unpack(self, data: bytes) -> Dict[int, bytes]:
try:
self._client, self._verify = \
self._prefix_fmt.unpack_from(data, 0)
except struct_error:
pass
return super()._unpack(data[self._prefix_fmt.size:])
def _pack(self) -> bytes:
prefix = self._prefix_fmt.pack(self.client, self.verify)
return prefix + super()._pack()
def __bool__(self) -> bool:
return super().__bool__() or bool(self.client) or self.verified
def __hash__(self) -> int:
return hash((self._frozen, self._client, self._verify))
def __eq__(self, other: Any) -> bool:
if isinstance(other, type(self)):
self_cmp = (self._frozen, self._client, self._verify)
other_cmp = (self._frozen, self._client, self._verify)
return self_cmp == other_cmp
return super().__eq__(other)
@property
def client(self) -> int:
"""The client field in the ``PP2_TYPE_SSL`` value."""
return self._client
@property
def verify(self) -> int:
"""The verify field in the ``PP2_TYPE_SSL`` value."""
return self._verify
@property
def has_ssl(self) -> bool:
"""True if the ``PP2_CLIENT_SSL`` flag was set."""
return self.client & SSLClient.PP2_CLIENT_SSL != 0
@property
def has_cert_conn(self) -> bool:
"""True if the ``PP2_CLIENT_CERT_CONN`` flag was set."""
return self.client & SSLClient.PP2_CLIENT_CERT_CONN != 0
@property
def has_cert_sess(self) -> bool:
"""True if the ``PP2_CLIENT_CERT_SESS`` flag was set."""
return self.client & SSLClient.PP2_CLIENT_CERT_SESS != 0
@property
def verified(self) -> bool:
"""True if the client provided a certificate that was successfully
verified.
"""
return self.verify == 0
@property
def version(self) -> Optional[str]:
"""The ``PP2_SUBTYPE_SSL_VERSION`` value."""
val = self.get(Type.PP2_SUBTYPE_SSL_VERSION)
if val is not None:
return str(val, 'ascii')
return None
@property
def cn(self) -> Optional[str]:
"""The ``PP2_SUBTYPE_SSL_CN`` value."""
val = self.get(Type.PP2_SUBTYPE_SSL_CN)
if val is not None:
return str(val, 'utf-8')
return None
@property
def cipher(self) -> Optional[str]:
"""The ``PP2_SUBTYPE_SSL_CIPHER`` value."""
val = self.get(Type.PP2_SUBTYPE_SSL_CIPHER)
if val is not None:
return str(val, 'ascii')
return None
@property
def sig_alg(self) -> Optional[str]:
"""The ``PP2_SUBTYPE_SSL_SIG_ALG`` value."""
val = self.get(Type.PP2_SUBTYPE_SSL_SIG_ALG)
if val is not None:
return str(val, 'ascii')
return None
@property
def key_alg(self) -> Optional[str]:
"""The ``PP2_SUBTYPE_SSL_KEY_ALG`` value."""
val = self.get(Type.PP2_SUBTYPE_SSL_KEY_ALG)
if val is not None:
return str(val, 'ascii')
return None
[docs]
class ProxyProtocolExtTLV(TLV):
"""Non-standard extension TLV, which is hidden inside a
:attr:`~Type.PP2_TYPE_NOOP` and must start with :attr:`.MAGIC_PREFIX`.
Args:
data: TLV data to parse.
init: A mapping of types to values to initialize the TLV, such as
another :class:`TLV`.
"""
#: The :attr:`~Type.PP2_TYPE_NOOP` value must begin with this byte sequence
#: to be parsed as a :class:`ProxyProtocolExtTLV`.
MAGIC_PREFIX: ClassVar[bytes] = b'\x88\x1b\x79\xc1\xce\x96\x85\xb0'
_secret_bits_fmt = Struct('!H')
def __init__(self, data: bytes = b'',
init: Optional[Mapping[int, bytes]] = None, *,
compression: Optional[str] = None,
secret_bits: Optional[int] = None,
peercert: Optional[PeerCert] = None,
dnsbl: Optional[str] = None) -> None:
results = dict(init or {})
if compression is not None:
val = compression.encode('ascii')
results[ExtType.PP2_TYPE_EXT_COMPRESSION] = val
if secret_bits is not None:
val = self._secret_bits_fmt.pack(secret_bits)
results[ExtType.PP2_TYPE_EXT_SECRET_BITS] = val
if peercert is not None:
val = zlib.compress(json.dumps(peercert).encode('ascii'))
results[ExtType.PP2_TYPE_EXT_PEERCERT] = val
if dnsbl is not None:
val = dnsbl.encode('utf-8')
results[ExtType.PP2_TYPE_EXT_DNSBL] = val
super().__init__(data, results)
def _unpack(self, data: bytes) -> Dict[int, bytes]:
magic_prefix = self.MAGIC_PREFIX
magic_prefix_len = len(magic_prefix)
if data[0:magic_prefix_len] != magic_prefix:
return {}
return super()._unpack(data[magic_prefix_len:])
def _pack(self) -> bytes:
return self.MAGIC_PREFIX + super()._pack()
@property
def compression(self) -> Optional[str]:
"""The :attr:`~ExtType.PP2_TYPE_EXT_COMPRESSION` value. This is used by
the :attr:`~proxyprotocol.sock.SocketInfo.compression` value.
"""
val = self.get(ExtType.PP2_TYPE_EXT_COMPRESSION)
if val is not None:
return str(val, 'ascii')
return None
@property
def secret_bits(self) -> Optional[int]:
"""The :attr:`~ExtType.PP2_TYPE_EXT_SECRET_BITS` value. This is used to
populate the third member of the
:attr:`~proxyprotocol.sock.SocketInfo.cipher` tuple.
"""
val = self.get(ExtType.PP2_TYPE_EXT_SECRET_BITS)
if val is not None:
secret_bits, = self._secret_bits_fmt.unpack(val)
return int(secret_bits)
return None
@property
def peercert(self) -> Optional[PeerCert]:
"""The :attr:`~ExtType.PP2_TYPE_EXT_PEERCERT` value. This is used by
the :attr:`~proxyprotocol.sock.SocketInfo.peercert` value.
"""
val = self.get(ExtType.PP2_TYPE_EXT_PEERCERT)
if val is not None:
decompressed = zlib.decompress(val)
ret: PeerCert = json.loads(decompressed)
return ret
return None
@property
def dnsbl(self) -> Optional[str]:
"""The :attr:`~ExtType.PP2_TYPE_EXT_DNSBL` value. This is the hostname
or other identifier that reports a status or reputation of the
connecting IP address.
"""
val = self.get(ExtType.PP2_TYPE_EXT_DNSBL)
if val is not None:
return str(val, 'utf-8')
return None