Source code for blackhole.protocols

# -*- coding: utf-8 -*-

# (The MIT License)
#
# Copyright (c) 2013-2021 Kura
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the 'Software'), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

"""Communication protocols used by the worker and child processes."""


import asyncio
import logging

from .config import Config


__all__ = ("StreamReaderProtocol", "PING", "PONG")
"""Tuple all the things."""


logger = logging.getLogger("blackhole.protocols")


PING = b"x01"
"""Protocol message used by the worker and child processes to communicate."""

PONG = b"x02"
"""Protocol message used by the worker and child processes to communicate."""


class StreamReaderProtocol(asyncio.StreamReaderProtocol):
    """The class responsible for handling connections commands."""

    def __init__(self, clients, loop=None):
        """
        Initialise the protocol.

        :param list clients: A list of connected clients.
        :param loop: The event loop to use.
        :type loop: :py:obj:`None` or
                    :py:class:`syncio.unix_events._UnixSelectorEventLoop`
        """
        logger.debug("init")
        self.loop = loop if loop is not None else asyncio.get_event_loop()
        logger.debug("loop")
        super().__init__(
            asyncio.StreamReader(loop=self.loop),
            client_connected_cb=self._client_connected_cb,
            loop=self.loop,
        )
        logger.debug("super")
        self.clients = clients
        self.config = Config()
        logger.debug(self.config)
        # This is not a nice way to do this but, socket.getfqdn silently fails
        # and crashes inbound connections when called after os.fork
        self.fqdn = self.config.mailname

    def flags_from_transport(self):
        """Adapt internal flags for the transport in use."""
        # This has to be done here since passing it as part of init causes
        # flags to become garbled and mixed up. Artifact of loop.create_server
        sock = self.transport.get_extra_info("socket")
        # Ideally this would use transport.get_extra_info('sockname') but that
        # crashes the child process for some weird reason. Getting the socket
        # and interacting directly does not cause a crash, hence...
        sock_name = sock.getsockname()
        flags = self.config.flags_from_listener(sock_name[0], sock_name[1])
        if len(flags.keys()) > 0:
            self._flags = flags
            self._disable_dynamic_switching = True
            logger.debug("Flags enabled, disabling dynamic switching")
            logger.debug(f"Flags for this connection: {self._flags}")

    def _client_connected_cb(self, reader, writer):
        """
        Bind a stream reader and writer to the SMTP Protocol.

        :param asyncio.streams.StreamReader reader: An object for reading
                                                    incoming data.
        :param asyncio.streams.StreamWriter writer: An object for writing
                                                    outgoing data.
        """
        self._reader = reader
        self._writer = writer
        self.clients.append(writer)

    def connection_lost(self, exc):
        """
        Client connection is closed or lost.

        :param exc exc: Exception.
        """
        logger.debug("Peer disconnected")
        super().connection_lost(exc)
        self.connection_closed, self._connection_closed = True, True
        try:
            self.clients.remove(self._writer)
        except ValueError:
            pass

    async def wait(self):
        """
        Wait for data from the client.

        :returns: A line of received data.
        :rtype: :py:obj:`str`

        .. note::

           Also handles client timeouts if they wait too long before sending
           data. -- https://kura.gg/blackhole/configuration.html#timeout
        """
        while not self.connection_closed:
            try:
                line = await asyncio.wait_for(
                    self._reader.readline(),
                    self.config.timeout,
                )
            except asyncio.TimeoutError:
                await self.timeout()
                return None
            return line

    async def close(self):
        """Close the connection from the client."""
        logger.debug("Closing connection")
        if self._writer:
            try:
                self.clients.remove(self._writer)
            except ValueError:
                pass
            self._writer.close()
            await self._writer.drain()
        self._connection_closed = True

    async def push(self, msg):
        """
        Write a response message to the client.

        :param str msg: The message for the SMTP code
        """
        response = f"{msg}\r\n".encode("utf-8")
        logger.debug(f"SEND {response}")
        self._writer.write(response)
        await self._writer.drain()