from __future__ import annotations
import asyncio
import binascii
import logging
import re
import sys
from argparse import ArgumentParser
from asyncio import shield, StreamReader, StreamWriter, AbstractServer, \
CancelledError, TimeoutError
from base64 import b64encode, b64decode
from collections.abc import Awaitable, Iterable
from contextlib import closing, AsyncExitStack
from ssl import SSLError
from typing import TypeVar
from proxyprotocol.reader import ProxyProtocolReader
from proxyprotocol.sock import SocketInfo
from proxyprotocol.version import ProxyProtocolVersion
from pymap.concurrent import Event
from pymap.config import IMAPConfig
from pymap.context import subsystem, current_command, socket_info, \
connection_exit
from pymap.exceptions import ResponseError
from pymap.interfaces.backend import ServiceInterface
from pymap.interfaces.login import LoginInterface
from pymap.parsing.command import Command
from pymap.parsing.commands import Commands
from pymap.parsing.command.nonauth import AuthenticateCommand, StartTLSCommand
from pymap.parsing.command.select import IdleCommand
from pymap.parsing.response import ResponseContinuation, Response, \
ResponseCode, ResponseBad, ResponseNo, ResponseBye, ResponseOk, \
CommandResponse
from pymap.parsing.state import ParsingState, ParsingInterrupt, \
ExpectContinuation
from pymap.sockets import InheritedSockets
from pysasl.creds.server import ServerCredentials
from pysasl.exception import AuthenticationError
from pysasl.mechanism import ServerChallenge, ChallengeResponse
from .state import ConnectionState
__all__ = ['IMAPService', 'IMAPServer', 'IMAPConnection']
_Ret = TypeVar('_Ret')
_log = logging.getLogger(__name__)
[docs]
class IMAPService(ServiceInterface): # pragma: no cover
"""A pymap service implementing an IMAP server."""
[docs]
@classmethod
def add_arguments(cls, parser: ArgumentParser) -> None:
group = parser.add_argument_group('imap service')
group.add_argument('--host', metavar='IFACE', action='append',
help='the network interface to listen on')
group.add_argument('--port', metavar='NUM', default='143',
help='the port or service name to listen on')
group.add_argument('--cert', metavar='FILE', help='cert file for TLS')
group.add_argument('--key', metavar='FILE', help='key file for TLS')
if InheritedSockets.supports('systemd'):
group.add_argument('--systemd-sockets', action='store_const',
dest='inherited_sockets', const='systemd',
help='use systemd inherited sockets')
else:
parser.set_defaults(inherited_sockets=None)
group.add_argument('--no-tls', dest='tls', action='store_false',
help='disable TLS')
pp_choices = [v.name.lower() for v in ProxyProtocolVersion]
group.add_argument('--proxy-protocol', choices=pp_choices,
help='the PROXY protocol version string')
[docs]
async def start(self, stack: AsyncExitStack) -> None:
backend = self.backend
config = self.config
servers: list[AbstractServer] = []
imap_server = IMAPServer(backend.login, config)
pp_reader = ProxyProtocolReader(config.proxy_protocol)
imap_server_cb = pp_reader.get_callback(imap_server)
if config.args.inherited_sockets:
sockets = InheritedSockets.of(config.args.inherited_sockets).get()
if not sockets:
raise ValueError('No inherited sockets found')
for sock in sockets:
servers.append(await asyncio.start_server(
imap_server_cb, sock=sock))
else:
servers.append(await asyncio.start_server(
imap_server_cb, host=config.host, port=config.port))
for server in servers:
await stack.enter_async_context(server)
task = asyncio.create_task(server.serve_forever())
stack.callback(task.cancel)
[docs]
class IMAPServer:
"""Callable object that creates and runs :class:`IMAPConnection` objects
when :func:`asyncio.start_server` receives a new connection.
Args:
login: Login callback that takes authentication credentials and returns
a :class:`~pymap.interfaces.session.SessionInterface` object.
config: Settings to use for the IMAP server.
"""
__slots__ = ['commands', '_login', '_config']
def __init__(self, login: LoginInterface, config: IMAPConfig) -> None:
super().__init__()
self.commands = config.commands
self._login = login
self._config = config
async def __call__(self, reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
conn = IMAPConnection(self.commands, self._config,
reader, writer, sock_info)
state = ConnectionState(self._login, self._config)
async with AsyncExitStack() as stack:
connection_exit.set(stack)
stack.enter_context(closing(conn))
await conn.run(state)
[docs]
class IMAPConnection:
"""Runs a single IMAP connection from start to finish.
Args:
commands: Defines the IMAP commands available to the connection.
config: Settings to use for the IMAP connection.
reader: The input stream for the socket.
writer: The output stream for the socket.
"""
_lines = re.compile(r'\r?\n')
_literal_plus = re.compile(br'{(\d+)\+}\r?\n$')
__slots__ = ['commands', 'config', 'params', 'bad_command_limit',
'reader', 'writer', 'pp_reader', 'pp_result']
def __init__(self, commands: Commands, config: IMAPConfig,
reader: StreamReader, writer: StreamWriter,
sock_info: SocketInfo) -> None:
super().__init__()
self.commands = commands
self.config = config
self.params = config.parsing_params
self.bad_command_limit = config.bad_command_limit
self.reader = reader
self.writer = writer
socket_info.set(sock_info)
def close(self) -> None:
self.writer.close()
@classmethod
def _print(cls, log_format: str, output: str | bytes) -> None:
if _log.isEnabledFor(logging.DEBUG):
uid = socket_info.get().unique_id.hex()
if not isinstance(output, str):
output = str(output, 'utf-8', 'replace')
lines = cls._lines.split(output)
if not lines[-1]:
lines = lines[:-1]
for line in lines:
_log.debug(log_format, uid, line)
def _exec(self, future: Awaitable[_Ret]) -> Awaitable[_Ret]:
return subsystem.get().execute(future)
async def readline(self) -> memoryview:
buf = bytearray(await self.reader.readline())
while True:
if not buf.endswith(b'\n'):
raise EOFError()
elif buf.endswith(b'+}\n') or buf.endswith(b'+}\r\n'):
lit_plus = self._literal_plus.search(buf)
else:
lit_plus = None
if lit_plus:
literal_length = int(lit_plus.group(1))
buf += await self.reader.readexactly(literal_length)
buf += await self.reader.readline()
else:
self._print('%s -->| %s', buf)
return memoryview(buf)
async def read_continuation(self, literal_length: int) -> memoryview:
extra_literal = await self.reader.readexactly(literal_length)
self._print('%s -->| %s', extra_literal)
extra_line = await self.readline()
extra = extra_literal + bytes(extra_line)
return memoryview(extra)
async def authenticate(self, state: ConnectionState, mech_name: bytes) \
-> ServerCredentials | None:
mech = state.auth.get_server(mech_name)
if not mech:
return None
responses: list[ChallengeResponse] = []
while True:
try:
creds, final = mech.server_attempt(responses)
except ServerChallenge as chal:
chal_bytes = b64encode(chal.data)
cont = ResponseContinuation(chal_bytes)
await self.write_response(cont)
resp_bytes = bytes(await self.read_continuation(0))
if resp_bytes.rstrip(b'\r\n') == b'*':
raise AuthenticationError('Authentication canceled.') \
from None
try:
resp_dec = b64decode(resp_bytes)
except binascii.Error as exc:
raise AuthenticationError() from exc
else:
responses.append(ChallengeResponse(chal.data, resp_dec))
else:
if final is not None:
cont = ResponseContinuation(b64encode(final))
await self.write_response(cont)
await self.read_continuation(0)
return creds
async def _interrupt(self, state: ConnectionState,
interrupt: ParsingInterrupt,
continuations: list[memoryview]) -> None:
expected = interrupt.expected
if isinstance(expected, ExpectContinuation):
cont = ResponseContinuation(expected.message)
await self.write_response(cont)
ret = await self.read_continuation(expected.literal_length)
continuations.append(ret)
else:
raise TypeError(expected) from interrupt
async def read_command(self, state: ConnectionState) -> Command:
line = await self.readline()
conts: list[memoryview] = []
while True:
parsing_state = ParsingState(continuations=conts)
params = self.params.copy(parsing_state)
try:
cmd, _ = self.commands.parse(line, params)
except ParsingInterrupt as interrupt:
await self._interrupt(state, interrupt, conts)
else:
return cmd
async def read_idle_done(self, cmd: IdleCommand) -> bool:
buf = await self.read_continuation(0)
ok, _ = cmd.parse_done(buf)
return ok
async def write_response(self, resp: Response) -> None:
try:
await resp.async_write(self.writer)
await self.writer.drain()
except ConnectionError:
pass
else:
self._print('%s <--| %s', bytes(resp))
async def start_tls(self) -> None:
ssl_context = self.config.ssl_context
await self.writer.start_tls(ssl_context)
self._print('%s <->| %s', '<TLS handshake>')
async def send_error_disconnect(self) -> None:
_, exc, _ = sys.exc_info()
if isinstance(exc, CancelledError):
resp = ResponseBye(b'Server has closed the connection.',
ResponseCode.of(b'UNAVAILABLE'))
else:
resp = ResponseBye(b'Unhandled server error.',
ResponseCode.of(b'SERVERBUG'))
try:
await self.write_response(resp)
except IOError:
pass
async def write_updates(self, untagged: Iterable[Response]) -> None:
for resp in untagged:
await self.write_response(resp)
async def handle_updates(self, state: ConnectionState, done: Event,
cmd: IdleCommand) -> None:
while not done.is_set():
untagged = await self._exec(state.receive_updates(cmd, done))
await shield(self.write_updates(untagged))
async def idle(self, state: ConnectionState, cmd: IdleCommand) \
-> CommandResponse:
response = await self._exec(state.do_command(cmd))
if not isinstance(response, ResponseOk):
return response
await self.write_response(ResponseContinuation(b'Idling.'))
done = subsystem.get().new_event()
updates_task = asyncio.create_task(
self.handle_updates(state, done, cmd))
done_task = asyncio.create_task(self.read_idle_done(cmd))
updates_exc: Exception | None = None
done_exc: Exception | None = None
try:
ok = await done_task
except Exception as exc:
done_exc = exc
finally:
done.set()
try:
await updates_task
except Exception as exc:
updates_exc = exc
if updates_exc:
raise updates_exc
elif done_exc:
raise done_exc
elif not ok:
return ResponseBad(cmd.tag, b'Expected "DONE".')
else:
return response
[docs]
async def run(self, state: ConnectionState) -> None:
"""Start the socket communication with the IMAP greeting, and then
enter the command/response cycle.
Args:
state: Defines the interaction with the backend plugin.
"""
self._print('%s +++| %s', str(socket_info.get()))
try:
await self._run_state(state)
finally:
self._print('%s ---| %s', b'<disconnected>')
async def _run_state(self, state: ConnectionState) -> None:
bad_commands = 0
try:
greeting = await self._exec(state.do_greeting())
except ResponseError as exc:
resp = exc.get_response(b'*')
resp.condition = ResponseBye.condition
await self.write_response(resp)
return
else:
await self.write_response(greeting)
while True:
try:
cmd = await self.read_command(state)
except (ConnectionError, EOFError):
break
except CancelledError:
await self.send_error_disconnect()
break
except Exception:
await self.send_error_disconnect()
raise
else:
prev_cmd = current_command.set(cmd)
try:
if isinstance(cmd, AuthenticateCommand):
creds = await self.authenticate(state, cmd.mech_name)
response = await self._exec(
state.do_authenticate(cmd, creds))
elif isinstance(cmd, IdleCommand):
response = await self.idle(state, cmd)
else:
response = await self._exec(state.do_command(cmd))
except ResponseError as exc:
resp = exc.get_response(cmd.tag)
await self.write_response(resp)
if resp.is_terminal:
break
except AuthenticationError as exc:
msg = bytes(str(exc), 'utf-8', 'surrogateescape')
resp = ResponseBad(cmd.tag, msg)
await self.write_response(resp)
except TimeoutError:
resp = ResponseNo(cmd.tag, b'Operation timed out.',
ResponseCode.of(b'TIMEOUT'))
await self.write_response(resp)
except (CancelledError, ConnectionError, EOFError):
await self.send_error_disconnect()
break
except Exception:
await self.send_error_disconnect()
raise
else:
await self.write_response(response)
if response.is_bad:
bad_commands += 1
if self.bad_command_limit \
and bad_commands >= self.bad_command_limit:
msg = b'Too many errors, disconnecting.'
response.add_untagged(ResponseBye(msg))
else:
bad_commands = 0
if response.is_terminal:
break
if isinstance(cmd, StartTLSCommand) \
and isinstance(response, ResponseOk):
try:
await self.start_tls()
except ConnectionError:
break
except SSLError as exc:
self._print('%s <->| <TLS failure: %s>',
exc.reason)
return
finally:
await state.do_cleanup()
current_command.reset(prev_cmd)