-
-
Save mivade/97c2dc353a1bb460a1d44010df66e6d7 to your computer and use it in GitHub Desktop.
"""Simple demonstration of using ZMQ's Curve authentication. | |
This demo is adapted from the examples given in the `PyZMQ repository`__. Key | |
differences include: | |
* Using ``setsockopt`` to set Curve parameters instead of setting attributes | |
directly (help out your IDE!) | |
* Integration with ``asyncio`` | |
__ https://github.com/zeromq/pyzmq/tree/master/examples | |
""" | |
from abc import ABC, abstractmethod | |
import asyncio | |
from contextlib import AbstractContextManager | |
import os.path | |
from pathlib import Path | |
from tempfile import TemporaryDirectory | |
from typing import Dict, Optional, Tuple, Union | |
import zmq | |
from zmq.asyncio import Context | |
import zmq.auth | |
from zmq.auth.asyncio import AsyncioAuthenticator | |
def generate_keys(key_dir: str) -> Dict[str, Tuple[str, str]]: | |
"""Generate all public/private keys needed for this demo. | |
Parameters | |
---------- | |
key_dir | |
Directory to write keys to. | |
""" | |
s_pub, s_sec = zmq.auth.create_certificates(key_dir, "server") | |
c_pub, c_sec = zmq.auth.create_certificates(key_dir, "client") | |
return {"server": (s_pub, s_sec), "client": (c_pub, c_sec)} | |
class BaseServer(AbstractContextManager): | |
"""Base ZMQ server with authentication support. | |
Parameters | |
---------- | |
address | |
ZMQ address string to listen on. | |
secret_key_path | |
Path to the server secret key file. | |
client_key_dir | |
Path to the directory containing authorized client public keys. When | |
not given, accept connections from any client that knows the server's | |
public key. | |
ctx | |
A :class:`Context`. If not given, one will be created. | |
socket_type | |
Type of socket to create. If not given, ``zmq.REP`` will be used. | |
""" | |
def __init__( | |
self, | |
address: str, | |
secret_key_path: Union[str, Path], | |
client_key_dir: Optional[Union[str, Path]] = None, | |
ctx: Optional[Context] = None, | |
socket_type: Optional[int] = zmq.REP | |
): | |
if not address.startswith("tcp://"): | |
raise ValueError("CurveZMQ only works over TCP") | |
self.address = address | |
self.socket_type = socket_type | |
self.ctx = ctx or Context.instance() | |
self._secret_key_file = secret_key_path | |
assert os.path.isfile(self._secret_key_file) | |
if client_key_dir is not None: | |
self._client_key_dir = client_key_dir | |
assert os.path.isdir(self._client_key_dir) | |
else: | |
self._client_key_dir = None | |
auth_location = ( | |
str(client_key_dir) | |
if client_key_dir is not None | |
else zmq.auth.CURVE_ALLOW_ANY | |
) | |
# Configure the authenticator | |
self.auth = AsyncioAuthenticator(context=self.ctx) | |
self.auth.configure_curve(domain="*", location=auth_location) | |
self.auth.allow("127.0.0.1") | |
self.auth.start() | |
# Configure the listening socket | |
self.socket = self.ctx.socket(self.socket_type) | |
keys = zmq.auth.load_certificate(self._secret_key_file) | |
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0]) | |
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1]) | |
self.socket.setsockopt(zmq.CURVE_SERVER, True) | |
self.socket.bind(self.address) | |
def __exit__(self, *_exc): | |
self.auth.stop() | |
@abstractmethod | |
async def run(self): | |
"""Implement this method to send and/or receive messages.""" | |
class EchoServer(BaseServer): | |
"""A simple echoing service.""" | |
async def run(self): | |
with self: | |
while True: | |
msg = await self.socket.recv() | |
await self.socket.send(msg) | |
if msg == b"quit": | |
print("Server exiting upon request") | |
break | |
class BaseClient(ABC): | |
"""Base (possibly) authenticated client class. | |
address | |
ZMQ address for the server. | |
server_public_key_path | |
Path to the server's public key. | |
secret_key_path | |
Path to the client's secret key. | |
ctx | |
Optional :class:`Context`. | |
""" | |
def __init__( | |
self, | |
address: str, | |
server_public_key_path: Union[str, Path], | |
secret_key_path: Union[str, Path], | |
ctx: Optional[Context] = None, | |
): | |
self.ctx = ctx or Context.instance() | |
self.socket = self.ctx.socket(zmq.REQ) | |
self.address = address | |
# Configure client keys | |
keys = zmq.auth.load_certificate(secret_key_path) | |
self.socket.setsockopt(zmq.CURVE_PUBLICKEY, keys[0]) | |
self.socket.setsockopt(zmq.CURVE_SECRETKEY, keys[1]) | |
# Load the server public key and register with the socket | |
server_key, _ = zmq.auth.load_certificate(server_public_key_path) | |
self.socket.setsockopt(zmq.CURVE_SERVERKEY, server_key) | |
self.socket.connect(self.address) | |
@abstractmethod | |
async def run(self) -> None: | |
"""Implement this coroutine to communicate with the server.""" | |
class EchoClient(BaseClient): | |
"""A simple echo request client.""" | |
async def run(self) -> None: | |
for i in range(10): | |
await self.socket.send(f"Hello, world {i}".encode()) | |
result = await self.socket.recv() | |
print("Client received", result) | |
await asyncio.sleep(1) | |
await self.socket.send(b"quit") | |
await self.socket.recv() | |
async def main(): | |
import logging | |
# Set debug logging so we can see zmq.auth's logs | |
logging.basicConfig(level=logging.DEBUG) | |
with TemporaryDirectory() as tempdir: | |
address = "tcp://127.0.0.1:9999" | |
keys = generate_keys(tempdir) | |
server = EchoServer(address, keys["server"][1], tempdir) | |
client = EchoClient(address, keys["server"][0], keys["client"][1]) | |
await asyncio.gather(server.run(), client.run()) | |
if __name__ == "__main__": | |
loop = asyncio.get_event_loop() | |
loop.run_until_complete(main()) |
Both the server and client are run in main
. This was just a small demo to learn how to use ZMQ's curve authentication and wasn't meant for anything serious.
Thank you @mivade for the response and I think the idea of the demo is great! I thought that might be the case but when I was trying it I wasn't making it past await self.socket.send(f"Hello, world {i}".encode()) in EchoClient run function:
`class EchoClient(BaseClient):
"""A simple echo request client."""
async def run(self) -> None:
print('in Class EchoClient run')
for i in range(10):
print(f"Under foor loop: {i}")
await self.socket.send(f"Hello, world {i}".encode()) <--------------
result = await self.socket.recv()
print(f"Client received {result}")
await asyncio.sleep(1)
await self.socket.send(b"quit")
await self.socket.recv()`
~/src/zeromq-rpc-teststuff$ python3 zmq_auth.py
starting
async def main loop
tcp://127.0.0.1:9999
DEBUG:zmq.auth:Configure curve: *[/tmp/tmp2o5wk9x_]
DEBUG:zmq.auth:Allowing 127.0.0.1
DEBUG:zmq.auth:Starting
in Class BaseClient constructor
in Class EchoServer run
in Class EchoClient run
Under foor loop: 0 (appears to hang here)
Should I expect that to return and to have results printed out? I didn't seem to have much luck getting print/debug out of the server side either.
I am trying with: python3 --version
Python 3.8.10
(Pdb) bt
/usr/lib/python3.8/runpy.py(194)_run_module_as_main()
-> return _run_code(code, main_globals, None,
/usr/lib/python3.8/runpy.py(87)_run_code()
-> exec(code, run_globals)
/usr/lib/python3.8/pdb.py(1732)()
-> pdb.main()
/usr/lib/python3.8/pdb.py(1705)main()
-> pdb._runscript(mainpyfile)
/usr/lib/python3.8/pdb.py(1573)_runscript()
-> self.run(statement)
/usr/lib/python3.8/bdb.py(580)run()
-> exec(cmd, globals, locals)
(1)()
/home/js-dev/src/zeromq-rpc-teststuff/zmq_auth.py(208)()
-> loop.run_until_complete(main())
/usr/lib/python3.8/asyncio/base_events.py(603)run_until_complete()
-> self.run_forever()
/usr/lib/python3.8/asyncio/base_events.py(570)run_forever()
-> self._run_once()
/usr/lib/python3.8/asyncio/base_events.py(1823)_run_once()
-> event_list = self._selector.select(timeout)
/usr/lib/python3.8/selectors.py(468)select()
-> fd_event_list = self._selector.poll(timeout, max_ev)
/usr/lib/python3.8/pdb.py(189)sigint_handler()
-> def sigint_handler(self, signum, frame):
(Pdb)
I'm not sure what the issue is. I just tested it and it still works for me. One thing about ZMQ that can be frustrating is that errors can be a bit hard to track down since you'll often just see nothing happening. One thing I do in real world applications is poll with a timeout for sockets being ready to read so things don't get completely stuck.
Thank you @mivade I am going to try some different versions today. I am running the same code here I just added some prints for debugging when I realized it was hanging up. By chance could you tell me what python version you are running? If it is not too much trouble could you send a pip freeze output? Thank you for your example and help!
Interesting if I do the same python 3.8.10 and pip install zmq (version 0.0.0) it works: ~/src/rpc$ python3 zmq_auth.py
DEBUG:zmq.auth:Configure curve: [/tmp/tmp1xy9pios]
DEBUG:zmq.auth:Allowing 127.0.0.1
DEBUG:zmq.auth:Starting
DEBUG:zmq.auth:version: b'1.0', request_id: b'1', domain: '', address: '127.0.0.1', identity: b'', mechanism: b'CURVE'
DEBUG:zmq.auth:PASSED (allowed) address=127.0.0.1
DEBUG:zmq.auth:ALLOWED (CURVE) domain= client_key=b'U(>k2e:}:Y:KLufrI6Fz2dw&M-E:D0e?dBGzyX0P'
DEBUG:zmq.auth:ZAP reply code=b'200' text=b'OK'
Client received b'Hello, world 0'
Client received b'Hello, world 1'
Client received b'Hello, world 2'
Client received b'Hello, world 3'
Client received b'Hello, world 4'
Client received b'Hello, world 5'
Client received b'Hello, world 6'
Client received b'Hello, world 7'
Client received b'Hello, world 8'
Client received b'Hello, world 9'
Server exiting upon request
Thank you for helping me get off the ground!
How do you run run the client? I wasn't sure what this was based off of to look at that as an example.