Source code for proxyprotocol.server


from __future__ import annotations

from urllib.parse import urlsplit, urlunsplit, parse_qs
from ssl import SSLContext, Purpose, VerifyMode, create_default_context
from typing import Optional
from typing_extensions import Final

from .. import ProxyProtocol
from ..version import ProxyProtocolVersion

__all__ = ['Address']


[docs] class Address: """Parses an address on the command-line. Valid examples include: * ``HOST`` * ``HOST:PORT`` * ``HOST:PORT?pp=v1`` * ``ssl://HOST:PORT`` (outbound addresses only) * ``ssl://HOST:PORT?cert=/path/to/cert.pem`` * ``ssl://HOST:PORT?cert=cert.pem&key=privkey.pem&verify=CERT_REQUIRED`` Args: addr: The address string to parse. server: True for server-side (listen) addresses. """ __slots__ = ['url', 'query', 'server', '_ssl'] def __init__(self, addr: str, *, server: bool = False) -> None: super().__init__() url = urlsplit(addr) if not url.scheme or not url.netloc: url = urlsplit('//' + addr) if url.query: query = parse_qs(url.query) else: query = {} self.url: Final = url self.query: Final = query self.server: Final = server self._ssl: Optional[SSLContext] = None @property def host(self) -> str: """The hostname parsed from the address.""" return self.url.hostname or '' @property def port(self) -> Optional[int]: """The port parsed from the address.""" return self.url.port or None @property def pp(self) -> ProxyProtocol: """The PROXY protocol implementation.""" pp_version = self.query.get('pp', [''])[-1] or 'detect' return ProxyProtocolVersion.get(pp_version) @property def ssl(self) -> Optional[SSLContext]: """The :class:`~ssl.SSLContext` to use on the address.""" if self.url.scheme == 'ssl': if self._ssl is None: if self.server: ssl = create_default_context(Purpose.CLIENT_AUTH) else: ssl = create_default_context(Purpose.SERVER_AUTH) if self.server or 'cert' in self.query: cert = self.query['cert'][-1] key = self.query.get('key', [''])[-1] or None ssl.load_cert_chain(cert, key) if 'verify' in self.query: ssl.verify_mode = VerifyMode[self.query['verify'][-1]] if 'cafile' in self.query or 'capath' in self.query: cafile = self.query.get('cafile', [''])[-1] or None capath = self.query.get('capath', [''])[-1] or None cadata = self.query.get('cadata', [''])[-1] or None ssl.load_verify_locations(cafile, capath, cadata) self._ssl = ssl return self._ssl else: return None def __str__(self) -> str: return urlunsplit(self.url)