Created
July 2, 2020 13:49
-
-
Save agronholm/6f193219072dff5c00993c8e67c9d602 to your computer and use it in GitHub Desktop.
AnyIO SMTP client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import socket | |
from dataclasses import dataclass, field | |
from email.headerregistry import Address | |
from email.message import EmailMessage | |
from email.utils import getaddresses, parseaddr | |
from ssl import SSLContext | |
from typing import Optional, Iterable, Callable, Union, List, Dict, Any | |
from anyio import connect_tcp, fail_after, BlockingPortal, start_blocking_portal | |
from anyio.abc import SocketStream | |
from .auth import SMTPAuthenticator | |
from .protocol import SMTPClientProtocol, SMTPResponse, ClientState, SMTPException | |
logger = logging.getLogger(__name__) | |
@dataclass | |
class AsyncSMTPClient: | |
""" | |
An example asynchronous SMTP client. | |
:param host: host name or IP address of the SMTP server | |
:param port: port on the SMTP server to connect to | |
:param connect_timeout: connection timeout (in seconds) | |
:param read_timeout: timeout for reading responses (in seconds) | |
:param domain: domain name to send to the server as part of the greeting message | |
:param ssl_context: SSL context to use for establishing TLS encrypted sessions | |
:param authenticator: authenticator to use for authenticating with the SMTP server | |
""" | |
host: str | |
port: int = 587 | |
connect_timeout: float = 30 | |
read_timeout: float = 60 | |
domain: str = field(default_factory=socket.gethostname) | |
ssl_context: Optional[SSLContext] = None | |
authenticator: Optional[SMTPAuthenticator] = None | |
_protocol: SMTPClientProtocol = field(init=False, default_factory=SMTPClientProtocol) | |
_stream: Optional[SocketStream] = field(init=False, default=None) | |
async def __aenter__(self): | |
await self.connect() | |
return self | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
await self.aclose() | |
async def connect(self) -> None: | |
if not self._stream: | |
async with fail_after(self.connect_timeout): | |
self._stream = await connect_tcp( | |
self.host, self.port, ssl_context=self.ssl_context, | |
tls_standard_compatible=False) | |
try: | |
await self._wait_response() | |
await self._send_command(self._protocol.send_greeting, self.domain) | |
# Do the TLS handshake if supported by the server | |
if 'STARTTLS' in self._protocol.extensions: | |
await self._send_command(self._protocol.start_tls) | |
await self._stream.start_tls() | |
# Send a new EHLO command to determine new capabilities | |
await self._send_command(self._protocol.send_greeting, self.domain) | |
# Use the authenticator if one was provided | |
if self.authenticator: | |
auth_gen = self.authenticator.authenticate() | |
try: | |
auth_data = await auth_gen.asend(None) | |
response = await self._send_command( | |
self._protocol.authenticate, self.authenticator.mechanism, auth_data) | |
while self._protocol.state is ClientState.authenticating: | |
auth_data = await auth_gen.asend(response.message) | |
self._protocol.send_authentication_data(auth_data) | |
await self._flush_output() | |
except StopAsyncIteration: | |
pass | |
finally: | |
await auth_gen.aclose() | |
except BaseException: | |
await self.aclose() | |
raise | |
async def aclose(self) -> None: | |
if self._stream: | |
try: | |
if self._protocol.state is not ClientState.finished: | |
await self._send_command(self._protocol.quit) | |
finally: | |
await self._stream.close() | |
self._stream = None | |
async def _wait_response(self) -> SMTPResponse: | |
while True: | |
if not self._stream: | |
raise SMTPException('Not connected') | |
if self._protocol.needs_incoming_data: | |
data = await self._stream.receive_some(65536) | |
logger.debug('Received: %s', data) | |
response = self._protocol.feed_bytes(data) | |
if response: | |
if response.is_error(): | |
response.raise_as_exception() | |
else: | |
return response | |
data = self._protocol.get_outgoing_data() | |
if data: | |
await self._stream.send_all(data) | |
logger.debug('Sent: %s', data) | |
async def _flush_output(self) -> None: | |
data = self._protocol.get_outgoing_data() | |
logger.debug('Sent: %s', data) | |
async with fail_after(self.read_timeout): | |
await self._stream.send_all(data) | |
async def _send_command(self, command: Callable, *args) -> SMTPResponse: | |
if not self._stream: | |
raise SMTPException('Not connected') | |
command(*args) | |
await self._flush_output() | |
return await self._wait_response() | |
async def send_message(self, message: EmailMessage, *, | |
sender: Union[str, Address, None] = None, | |
recipients: Optional[Iterable[str]] = None) -> SMTPResponse: | |
sender = sender or parseaddr(message.get('From'))[1] | |
await self._send_command(self._protocol.mail, sender) | |
if not recipients: | |
tos: List[str] = message.get_all('to', []) | |
ccs: List[str] = message.get_all('cc', []) | |
resent_tos: List[str] = message.get_all('resent-to', []) | |
resent_ccs: List[str] = message.get_all('resent-cc', []) | |
recipients = [email for name, email in | |
getaddresses(tos + ccs + resent_tos + resent_ccs)] | |
for recipient in recipients: | |
await self._send_command(self._protocol.recipient, recipient) | |
await self._send_command(self._protocol.start_data) | |
return await self._send_command(self._protocol.data, message) | |
class SyncSMTPClient: | |
def __init__(self, *args, async_backend: str = 'asyncio', | |
async_backend_options: Optional[Dict[str, Any]] = None, **kwargs): | |
self._async_backend = async_backend | |
self._async_backend_options = async_backend_options | |
self._async_client = AsyncSMTPClient(*args, **kwargs) | |
self._portal: Optional[BlockingPortal] = None | |
def __enter__(self): | |
self._portal = start_blocking_portal(self._async_backend, self._async_backend_options) | |
def __exit__(self): | |
self.close() | |
def connect(self) -> None: | |
self._portal.call(self._async_client.connect) | |
def close(self) -> None: | |
try: | |
self._portal.call(self._async_client.aclose) | |
finally: | |
self._portal.stop_from_external_thread() | |
self._portal = None | |
def send_message(self, message: EmailMessage, *, | |
sender: Union[str, Address, None] = None, | |
recipients: Optional[Iterable[str]] = None) -> SMTPResponse: | |
return self._portal.call(self._async_client.send_message, message, sender, recipients) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment