from __future__ import annotations
import socket
from ipaddress import IPv4Address, IPv6Address
from socket import SocketKind
from struct import Struct
from typing import Optional, Sequence
from typing_extensions import Final
from . import ProxyProtocolWantRead, ProxyProtocol, ProxyProtocolSyntaxError, \
ProxyProtocolChecksumError, ProxyProtocolIncompleteError
from .result import is_ipv4, is_ipv6, is_unix, ProxyResultType, ProxyResult, \
ProxyResultLocal, ProxyResultUnknown, ProxyResultIPv4, ProxyResultIPv6, \
ProxyResultUnix
from .tlv import ProxyProtocolTLV
__all__ = ['ProxyProtocolV2Header', 'ProxyProtocolV2']
[docs]
class ProxyProtocolV2(ProxyProtocol):
"""Implements version 2 of the PROXY protocol."""
__slots__: Sequence[str] = []
_commands = [(0x00, 'local'),
(0x01, 'proxy')]
_types = [(0x00, ProxyResultType.LOCAL),
(0x00, ProxyResultType.UNKNOWN),
(0x10, ProxyResultType.IPv4),
(0x20, ProxyResultType.IPv6),
(0x30, ProxyResultType.UNIX)]
_protocols = [(0x00, None),
(0x01, socket.SOCK_STREAM),
(0x02, socket.SOCK_DGRAM)]
_commands_l = {left: right for left, right in _commands}
_commands_r = {right: left for left, right in _commands}
_types_l = {left: right for left, right in _types}
_types_r = {right: left for left, right in _types}
_protocols_l = {left: right for left, right in _protocols}
_protocols_r = {right: left for left, right in _protocols}
_header_format = Struct('!BBH')
_ipv4_format = Struct('!4s4sHH')
_ipv6_format = Struct('!16s16sHH')
_unix_format = Struct('!108s108s')
_tlv_format = Struct('!BH')
[docs]
def is_valid(self, signature: bytes) -> bool:
return signature[0:8] == b'\r\n\r\n\x00\r\nQ'
[docs]
def unpack(self, data: bytes) -> ProxyResult:
if len(data) < 16:
want_read = ProxyProtocolWantRead(16 - len(data))
raise ProxyProtocolIncompleteError(want_read)
header_data, data = data[0:16], data[16:]
header = self.unpack_header(header_data)
if len(data) < header.data_len:
want_read = ProxyProtocolWantRead(header.data_len - len(data))
raise ProxyProtocolIncompleteError(want_read)
return self.unpack_data(header, header_data, data)
[docs]
def unpack_data(self, header: ProxyProtocolV2Header,
header_data: bytes, data: bytes) \
-> ProxyResult:
"""Parse the address information read from the stream after the v2
header.
Args:
header: The version 2 header info.
header_data: The header bytestring.
data: The addresses bytestring to parse.
Raises:
:exc:`~proxyprotocol.ProxyProtocolChecksumError`
"""
if header.command not in ('local', 'proxy'):
raise ProxyProtocolSyntaxError('Invalid proxy protocol command')
result: ProxyResult
if header.command == 'local':
addr_data, tlv_data = b'', data
tlv = ProxyProtocolTLV(tlv_data)
result = ProxyResultLocal(tlv=tlv)
elif header.type == ProxyResultType.IPv4:
addr_len = self._ipv4_format.size
addr_data, tlv_data = data[:addr_len], data[addr_len:]
source_ip, dest_ip, source_port, dest_port = \
self._ipv4_format.unpack(addr_data)
source_addr4 = (IPv4Address(source_ip), source_port)
dest_addr4 = (IPv4Address(dest_ip), dest_port)
tlv = ProxyProtocolTLV(tlv_data)
result = ProxyResultIPv4(source_addr4, dest_addr4,
protocol=header.protocol, tlv=tlv)
elif header.type == ProxyResultType.IPv6:
addr_len = self._ipv6_format.size
addr_data, tlv_data = data[:addr_len], data[addr_len:]
source_ip, dest_ip, source_port, dest_port = \
self._ipv6_format.unpack(addr_data)
source_addr6 = (IPv6Address(source_ip), source_port)
dest_addr6 = (IPv6Address(dest_ip), dest_port)
tlv = ProxyProtocolTLV(tlv_data)
result = ProxyResultIPv6(source_addr6, dest_addr6,
protocol=header.protocol, tlv=tlv)
elif header.type == ProxyResultType.UNIX:
addr_len = self._unix_format.size
addr_data, tlv_data = data[:addr_len], data[addr_len:]
source_addr_b, dest_addr_b = self._unix_format.unpack(addr_data)
source_addru = source_addr_b.rstrip(b'\x00').decode('ascii')
dest_addru = dest_addr_b.rstrip(b'\x00').decode('ascii')
tlv = ProxyProtocolTLV(tlv_data)
result = ProxyResultUnix(source_addru, dest_addru,
protocol=header.protocol, tlv=tlv)
else:
return ProxyResultUnknown()
if not tlv.verify_checksum(header_data, addr_data):
raise ProxyProtocolChecksumError(result)
return result
[docs]
def pack(self, result: ProxyResult) -> bytes:
addresses = self._pack_addresses(result)
tlv = ProxyProtocolTLV(init=result.tlv, crc32c=None, auto_crc32c=True)
data_len = len(addresses) + tlv.size
header = self._pack_header(data_len, result)
tlv = tlv.with_checksum(header, addresses)
return header + addresses + bytes(tlv)
def _pack_header(self, data_len: int, result: ProxyResult) \
-> bytes:
command = 'proxy' if result.proxied else 'local'
type = result.type
protocol = result.protocol
byte_12 = 0x20 + self._commands_r[command]
byte_13 = self._types_r[type] + self._protocols_r[protocol]
return b'\r\n\r\n\x00\r\nQUIT\n%b' % \
self._header_format.pack(byte_12, byte_13, data_len)
def _pack_addresses(self, result: ProxyResult) -> bytes:
if is_ipv4(result):
source_ip = result.source[0].packed
source_port = result.source[1]
dest_ip = result.dest[0].packed
dest_port = result.dest[1]
return self._ipv4_format.pack(source_ip, dest_ip,
source_port, dest_port)
elif is_ipv6(result):
source_ip = result.source[0].packed
source_port = result.source[1]
dest_ip = result.dest[0].packed
dest_port = result.dest[1]
return self._ipv6_format.pack(source_ip, dest_ip,
source_port, dest_port)
elif is_unix(result):
source_b = result.source.encode('ascii')
dest_b = result.dest.encode('ascii')
return self._unix_format.pack(source_b, dest_b)
else:
return b''