-
-
Save jexio/510187240f99916ceb90a2855fe13855 to your computer and use it in GitHub Desktop.
ZeroMQ Curve authentication demo
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Simple demonstration of using ZMQ's Curve authentication. | |
This demo is adapted from the examples given in the `PyZMQ repository`__. Key | |
differences include: | |
* Using ``setsockopt`` to set Curve parameters instead of setting attributes | |
directly (help out your IDE!) | |
* Integration with ``asyncio`` | |
__ https://github.com/zeromq/pyzmq/tree/master/examples | |
""" | |
from abc import ABC, abstractmethod | |
import asyncio | |
from contextlib import AbstractContextManager | |
import os.path | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Dict, Optional, Tuple, Union | |
import zmq | |
from zmq.asyncio import Context | |
import zmq.auth | |
from zmq.auth.asyncio import AsyncioAuthenticator | |
def generate_keys(key_dir: str) -> Dict[str, Tuple[str, str]]: | |
"""Generate all public/private keys needed for this demo. | |
Parameters | |
---------- | |
key_dir | |
Directory to write keys to. | |
""" | |
s_pub, s_sec = zmq.auth.create_certificates(key_dir, "server") | |
c_pub, c_sec = zmq.auth.create_certificates(key_dir, "client") | |
return {"server": (s_pub, s_sec), "client": (c_pub, c_sec)} | |
class BaseServer(AbstractContextManager): | |
"""Base ZMQ server with authentication support. | |
Parameters | |
---------- | |
address | |
ZMQ address string to listen on. | |
secret_key_path | |
Path to the server secret key file. | |
client_key_dir | |
Path to the directory containing authorized client public keys. When | |
not given, accept connections from any client that knows the server's | |
public key. | |
ctx | |
A :class:`Context`. If not given, one will be created. | |
socket_type | |
Type of socket to create. If not given, ``zmq.REP`` will be used. | |
""" | |
def __init__( | |
self, | |
address: str, | |
secret_key_path: Union[str, Path], | |
client_key_dir: Optional[Union[str, Path]] = None, | |
ctx: Optional[Context] = None, | |
socket_type: Optional[int] = zmq.REP | |
): | |
if not address.startswith("tcp://"): | |
raise ValueError("CurveZMQ only works over TCP") | |
self.address = address | |
self.socket_type = socket_type | |
self.ctx = ctx or Context.instance() | |
self._secret_key_file = secret_key_path | |
assert os.path.isfile(self._secret_key_file) | |
if client_key_dir is not None: | |
self._client_key_dir = client_key_dir | |
assert os.path.isdir(self._client_key_dir) | |
else: | |
self._client_key_dir = None | |
auth_location = ( | |
str(client_key_dir) | |
if client_key_dir is not None | |
else zmq.auth.CURVE_ALLOW_ANY | |
) | |
# Configure the authenticator | |
self.auth = AsyncioAuthenticator(context=self.ctx) | |
self.auth.configure_curve(domain="*", location=auth_location) | |
self.auth.allow("127.0.0.1") | |
self.auth.start() | |
# Configure the listening socket | |
self.socket = self.ctx.socket(self.socket_type) | |
keys = zmq.auth.load_certificate(self._secret_key_file) | |
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0]) | |
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1]) | |
self.socket.setsockopt(zmq.CURVE_SERVER, True) | |
self.socket.bind(self.address) | |
def __exit__(self, *_exc): | |
self.auth.stop() | |
@abstractmethod | |
async def run(self): | |
"""Implement this method to send and/or receive messages.""" | |
class EchoServer(BaseServer): | |
"""A simple echoing service.""" | |
async def run(self): | |
with self: | |
while True: | |
msg = await self.socket.recv() | |
await self.socket.send(msg) | |
if msg == b"quit": | |
print("Server exiting upon request") | |
break | |
class BaseClient(ABC): | |
"""Base (possibly) authenticated client class. | |
address | |
ZMQ address for the server. | |
server_public_key_path | |
Path to the server's public key. | |
secret_key_path | |
Path to the client's secret key. | |
ctx | |
Optional :class:`Context`. | |
""" | |
def __init__( | |
self, | |
address: str, | |
server_public_key_path: Union[str, Path], | |
secret_key_path: Union[str, Path], | |
ctx: Optional[Context] = None, | |
): | |
self.ctx = ctx or Context.instance() | |
self.socket = self.ctx.socket(zmq.REQ) | |
self.address = address | |
# Configure client keys | |
keys = zmq.auth.load_certificate(secret_key_path) | |
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0]) | |
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1]) | |
# Load the server public key and register with the socket | |
server_key, _ = zmq.auth.load_certificate(server_public_key_path) | |
self.socket.setsockopt(zmq.CURVE_SERVERKEY, server_key) | |
self.socket.connect(self.address) | |
@abstractmethod | |
async def run(self) -> None: | |
"""Implement this coroutine to communicate with the server.""" | |
class EchoClient(BaseClient): | |
"""A simple echo request client.""" | |
async def run(self) -> None: | |
for i in range(10): | |
await self.socket.send(f"Hello, world {i}".encode()) | |
result = await self.socket.recv() | |
print("Client received", result) | |
await asyncio.sleep(1) | |
await self.socket.send(b"quit") | |
await self.socket.recv() | |
async def main(): | |
import logging | |
# Set debug logging so we can see zmq.auth's logs | |
logging.basicConfig(level=logging.DEBUG) | |
with TemporaryDirectory() as tempdir: | |
address = "tcp://127.0.0.1:9999" | |
keys = generate_keys(tempdir) | |
server = EchoServer(address, keys["server"][1], tempdir) | |
client = EchoClient(address, keys["server"][0], keys["client"][1]) | |
await asyncio.gather(server.run(), client.run()) | |
if __name__ == "__main__": | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment