Source code for proxyprotocol.v1
from __future__ import annotations
from ipaddress import IPv4Address, IPv6Address
from typing import Sequence
from . import ProxyProtocolWantRead, ProxyProtocol, ProxyProtocolSyntaxError, \
ProxyProtocolIncompleteError
from .result import is_ipv4, is_ipv6, is_unknown, ProxyResult, \
ProxyResultUnknown, ProxyResultIPv4, ProxyResultIPv6
[docs]
class ProxyProtocolV1(ProxyProtocol):
"""Implements version 1 of the PROXY protocol."""
__slots__: Sequence[str] = []
[docs]
def is_valid(self, signature: bytes) -> bool:
return signature[0:6] == b'PROXY '
[docs]
def unpack(self, data: bytes) -> ProxyResult:
if data[-1:] != b'\n':
want_read = ProxyProtocolWantRead(want_line=True)
raise ProxyProtocolIncompleteError(want_read)
return self.unpack_line(data)
[docs]
def unpack_line(self, data: bytes) -> ProxyResult:
"""Parse the PROXY protocol v1 header line.
Args:
data: The bytestring to parse.
"""
if data[0:6] != b'PROXY ' or data[-2:] != b'\r\n':
raise ProxyProtocolSyntaxError(
'Invalid proxy protocol v1 signature')
line = bytes(data[6:-2])
parts = line.split(b' ')
family_string = parts[0]
if family_string == b'UNKNOWN':
return ProxyResultUnknown()
elif len(parts) != 5:
raise ProxyProtocolSyntaxError(
'Invalid proxy protocol header format')
elif family_string == b'TCP4':
source_addr4 = (self._get_ip4(parts[1]), self._get_port(parts[3]))
dest_addr4 = (self._get_ip4(parts[2]), self._get_port(parts[4]))
return ProxyResultIPv4(source_addr4, dest_addr4)
elif family_string == b'TCP6':
source_addr6 = (self._get_ip6(parts[1]), self._get_port(parts[3]))
dest_addr6 = (self._get_ip6(parts[2]), self._get_port(parts[4]))
return ProxyResultIPv6(source_addr6, dest_addr6)
else:
raise ProxyProtocolSyntaxError(
'Invalid proxy protocol address family')
def _get_ip4(self, ip_string: bytes) -> IPv4Address:
return IPv4Address(ip_string.decode('ascii'))
def _get_ip6(self, ip_string: bytes) -> IPv6Address:
return IPv6Address(ip_string.decode('ascii'))
def _get_port(self, port_string: bytes) -> int:
port_num = int(port_string)
if port_num < 0 or port_num > 65535:
raise ValueError(port_num)
return port_num
[docs]
def pack(self, result: ProxyResult) -> bytes:
if not result.proxied:
raise ValueError('proxied must be True in v1')
family_b = self._pack_family(result)
if is_ipv4(result) or is_ipv6(result):
source_ip: bytes = result.peername[0].encode('ascii')
source_port: bytes = str(result.peername[1]).encode('ascii')
dest_ip: bytes = result.sockname[0].encode('ascii')
dest_port: bytes = str(result.sockname[1]).encode('ascii')
else:
source_ip = b''
source_port = b''
dest_ip = b''
dest_port = b''
return b'PROXY %b %b %b %b %b\r\n' % \
(family_b, source_ip, dest_ip, source_port, dest_port)
def _pack_family(self, result: ProxyResult) -> bytes:
if is_ipv4(result):
return b'TCP4'
elif is_ipv6(result):
return b'TCP6'
elif is_unknown(result):
return b'UNKNOWN'
else:
raise KeyError(type)