Skip to content

Instantly share code, notes, and snippets.

@claws
Last active July 24, 2021 19:37
Show Gist options
  • Save claws/7856874 to your computer and use it in GitHub Desktop.
Save claws/7856874 to your computer and use it in GitHub Desktop.
An authentication module for pyzmq based on zauth from czmq. See the included test file for usage. To run the test function use: python test_authentication.py
'''
An authentication module for pyzmq modelled on zauth from czmq.
The functions to read and generate certificates should be interoperable
with czmq's zcert's - though are not as fully featured.
'''
import datetime
import glob
import json
import os
from threading import Thread
import zmq
from zmq.utils import z85
from zmq.eventloop.ioloop import IOLoop
from zmq.eventloop.zmqstream import ZMQStream
CURVE_ALLOW_ANY = '*'
_cert_secret_banner = """# **** Generated on {} by pyzmq ****
# ZeroMQ CURVE **Secret** Certificate
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions.
"""
_cert_public_banner = """# **** Generated on {} by pyzmq ****
# ZeroMQ CURVE Public Certificate
# Exchange securely, or use a secure mechanism to verify the contents
# of this file after exchange. Store public certificates in your home
# directory, in the .curve subdirectory.
"""
def create_certificates(key_dir, name, metadata=None):
'''
Create zcert-esque public and private certificate files.
Return the file paths to the public and secret certificate files.
'''
def write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None):
""" Create a certificate file """
with open(key_filename, 'w') as f:
f.write(banner.format(datetime.datetime.now()))
f.write('metadata\n')
if metadata and isinstance(metadata, dict):
for k, v in metadata.items():
f.write(" {} = {}\n".format(k, v))
f.write('curve\n')
f.write(" public-key = \"{}\"\n".format(public_key))
if secret_key:
f.write(" secret-key = \"{}\"\n".format(secret_key))
public_key, secret_key = zmq.curve_keypair()
base_filename = os.path.join(key_dir, name)
secret_key_file = "{}.key_secret".format(base_filename)
public_key_file = "{}.key".format(base_filename)
now = datetime.datetime.now()
write_key_file(public_key_file,
_cert_public_banner.format(now),
public_key)
write_key_file(secret_key_file,
_cert_secret_banner.format(now),
public_key,
secret_key=secret_key,
metadata=metadata)
return public_key_file, secret_key_file
def load_certificate(filename):
'''
Load a certificate specified by filename and return the public
and private keys read from the file. If the certificate file
only contains the public key then secret_key will be None.
'''
public_key = None
secret_key = None
if not os.path.exists(filename):
print "E: Invalid certificate file: {}".format(filename)
return public_key, secret_key
with open(filename, 'r') as f:
lines = f.readlines()
lines = filter(None, lines)
lines = filter(lambda x: not x.startswith('#'), lines)
lines = [x.strip() for x in lines]
for line in lines:
if line.startswith('public-key'):
public_key = line.split(" = ")[1].strip().replace('"', '')
if line.startswith('secret-key'):
secret_key = line.split(" = ")[1].strip().replace('"', '')
return public_key, secret_key
def load_certificates(location):
''' Load public keys from all certificates stored at location directory '''
certs = {}
if os.path.isdir(location):
# Follow czmq pattern of public keys stored in *.key files.
glob_string = os.path.join(location, "*.key")
cert_files = glob.glob(glob_string)
for cert_file in cert_files:
public_key, _ = load_certificate(cert_file)
if public_key:
certs[public_key] = 'OK'
return certs
class AuthAgentThread(Thread):
''' Thread in which ZAP authentication actually happens '''
def __init__(self, context, endpoint, verbose=False):
super(AuthAgentThread, self).__init__()
self.context = context
self.verbose = verbose
self.allow_any = False
self.zap = None
self.whitelist = []
self.blacklist = []
# passwords is a dict keyed by domain and contains values
# of dicts with username:password pairs.
self.passwords = {}
# certs is dict keyed by domain and contains values
# of dicts keyed by the public keys from the specified location.
self.certs = {}
# create a socket to communicate back to main thread.
self.pipe = context.socket(zmq.PAIR)
self.pipe.linger = 1
self.pipe.connect(endpoint)
def run(self):
''' Start the Authentication Agent thread task '''
# Create ZAP handler and get ready for requests
self.zap = self.context.socket(zmq.REP)
self.zap.linger = 1
self.zap.bind("inproc://zeromq.zap.01")
poller = zmq.Poller()
poller.register(self.pipe, zmq.POLLIN)
poller.register(self.zap, zmq.POLLIN)
while True:
try:
socks = dict(poller.poll())
except zmq.ZMQError:
break # interrupted
if self.pipe in socks and socks[self.pipe] == zmq.POLLIN:
terminate = self._handle_pipe()
if terminate:
break
if self.zap in socks and socks[self.zap] == zmq.POLLIN:
self._authenticate()
self.pipe.close()
self.zap.close()
def _send_zap_reply(self, sequence, status_code, status_text):
'''
Send a ZAP reply to the handler socket.
'''
uid = b"{}".format(os.getuid()) if status_code == 'OK' else b""
metadata = b"" # not currently used
if self.verbose:
print "I: ZAP reply code={} text={}".format(status_code, status_text)
reply = [b"1.0", sequence, status_code, status_text, uid, metadata]
self.zap.send_multipart(reply)
def _handle_pipe(self):
'''
Handle a message from front-end API.
'''
terminate = False
# Get the whole message off the pipe in one go
msg = self.pipe.recv_multipart()
if msg is None:
terminate = True
return terminate
command = msg[0]
if self.verbose:
print "I: auth received API command {}".format(command)
if command == 'ALLOW':
address = msg[1]
if address not in self.whitelist:
self.whitelist.append(address)
self.pipe.send(b'OK')
elif command == 'DENY':
address = msg[1]
if address not in self.blacklist:
self.blacklist.append(address)
self.pipe.send(b'OK')
elif command == 'PLAIN':
domain = msg[1]
json_passwords = msg[2]
self.passwords[domain] = json.loads(json_passwords)
self.pipe.send(b'OK')
elif command == 'CURVE':
# For now we don't do anything with domains
domain = msg[1]
# If location is CURVE_ALLOW_ANY, allow all clients. Otherwise
# treat location as a directory that holds the certificates.
location = msg[2]
if location == CURVE_ALLOW_ANY:
self.allow_any = True
else:
self.allow_any = False
if os.path.isdir(location):
self.certs[domain] = load_certificates(location)
else:
if self.verbose:
print "E: Invalid CURVE certs location: {}".format(location)
self.pipe.send(b'OK')
elif command == 'VERBOSE':
enabled = msg[1]
self.verbose = enabled == '1'
self.pipe.send(b'OK')
elif command == 'TERMINATE':
terminate = True
self.pipe.send(b'OK')
else:
print "E: invalid auth command from API: {}".format(command)
return terminate
def _authenticate_plain(self, domain, username, password):
'''
Perform ZAP authentication check for PLAIN mechanism
'''
allowed = False
reason = b""
if self.passwords:
# If no domain is not specified then use the default domain
if not domain:
domain = '*'
if domain in self.passwords:
if username in self.passwords[domain]:
if password == self.passwords[domain][username]:
allowed = True
else:
reason = b"Invalid password"
else:
reason = b"Invalid username"
else:
reason = b"Invalid domain"
if self.verbose:
status = "DENIED"
if allowed:
status = "ALLOWED"
print "I: {} (PLAIN) domain={} username={} password={}".format(status,
domain, username, password)
else:
print "I: {} {}".format(status, reason)
else:
reason = b"No passwords defined"
if self.verbose:
print "I: DENIED (PLAIN) {}".format(reason)
return allowed, reason
def _authenticate_curve(self, domain, client_key):
'''
Perform ZAP authentication check for CURVE mechanism
'''
allowed = False
reason = b""
if self.allow_any:
allowed = True
reason = b"OK"
if self.verbose:
print "I: ALLOWED (CURVE allow any client)"
else:
# If no explicit domain is specified then use the default domain
if not domain:
domain = '*'
if domain in self.certs:
# The certs dict stores keys in z85 format, convert binary key to z85 text
z85_client_key = z85.encode(client_key)
if z85_client_key in self.certs[domain]:
allowed = True
reason = b"OK"
else:
reason = b"Unknown key"
if self.verbose:
status = "DENIED"
if allowed:
status = "ALLOWED"
print "I: {} (CURVE) domain={} client_key={}".format(status,
domain, z85_client_key)
else:
reason = b"Unknown domain"
return allowed, reason
def _authenticate(self):
'''
Perform ZAP authentication.
'''
msg = self.zap.recv_multipart()
if not msg: return
version, sequence, domain, address, identity, mechanism = msg[:6]
if (version != b"1.0"):
self._send_zap_reply(sequence, b"400", b"Invalid version")
return
if self.verbose:
print "version: {}".format(version)
print "sequence: {}".format(sequence)
print "domain: {}".format(domain)
print "address: {}".format(address)
print "identity: {}".format(identity)
print "mechanism: {}".format(mechanism)
# Check if address is explicitly whitelisted or blacklisted
allowed = False
denied = False
reason = b"NO ACCESS"
if self.whitelist:
if address in self.whitelist:
allowed = True
if self.verbose:
print "I: PASSED (whitelist) address={}".format(address)
else:
denied = True
reason = b"Address not in whitelist"
if self.verbose:
print "I: DENIED (not in whitelist) address={}".format(address)
elif self.blacklist:
if address in self.blacklist:
denied = True
reason = b"Address is blacklisted"
if self.verbose:
print "I: DENIED (blacklist) address={}".format(address)
else:
allowed = True
if self.verbose:
print "I: PASSED (not in blacklist) address={}".format(address)
# Mechanism-specific checks
if not denied:
if mechanism == b'NULL' and not allowed:
# For NULL, we allow if the address wasn't blacklisted
if self.verbose:
print "I: ALLOWED (NULL)"
allowed = True
elif mechanism == b'PLAIN':
# For PLAIN, even a whitelisted address must authenticate
username, password = msg[6:]
allowed, reason = self._authenticate_plain(domain, username, password)
elif mechanism == b'CURVE':
# For CURVE, even a whitelisted address must authenticate
key = msg[6]
allowed, reason = self._authenticate_curve(domain, key)
if allowed:
self._send_zap_reply(sequence, b"200", b"OK")
else:
self._send_zap_reply(sequence, b"400", reason)
class Authenticator(object):
'''
An Authenticator object takes over authentication for all incoming
connections in its context.
Note:
- libzmq provides four levels of security: default NULL (which zauth does
not see), and authenticated NULL, PLAIN, and CURVE, which zauth can see.
- until you add policies, all incoming NULL connections are allowed
(classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied.
All work is done by a background thread, the "agent", which we talk
to over a pipe. This lets the agent do work asynchronously in the
background while our application does other things. This is invisible to
the caller, who sees a classic API.
'''
def __init__(self, context, verbose=False):
if zmq.zmq_version_info() < (4,0):
raise Exception("Security is only available in libzmq >= 4.0")
self.context = context
self.pipe = None
self.pipe_endpoint = "inproc://{}.inproc".format(id(self))
self.thread = None
self.start(verbose)
def allow(self, address):
'''
Allow (whitelist) a single IP address. For NULL, all clients from this
address will be accepted. For PLAIN and CURVE, they will be allowed to
continue with authentication. You can call this method multiple times
to whitelist multiple IP addresses. If you whitelist a single address,
any non-whitelisted addresses are treated as blacklisted.
'''
self.pipe.send_multipart([b'ALLOW', address])
def deny(self, address):
'''
Deny (blacklist) a single IP address. For all security mechanisms, this
rejects the connection without any further authentication. Use either a
whitelist, or a blacklist, not not both. If you define both a whitelist
and a blacklist, only the whitelist takes effect.
'''
self.pipe.send_multipart([b'DENY', address])
def verbose(self, enabled):
'''
Enable verbose tracing of commands and activity.
'''
self.pipe.send_multipart([b'VERBOSE', b'1' if enabled else b'0'])
def configure_plain(self, domain='*', passwords=None):
'''
Configure PLAIN authentication for a given domain. PLAIN authentication
uses a plain-text password file. To cover all domains, use "*".
You can modify the password file at any time; it is reloaded automatically.
'''
if passwords:
if isinstance(passwords, dict):
passwords = json.dumps(passwords)
self.pipe.send_multipart([b'PLAIN', domain, passwords])
def configure_curve(self, domain='*', location=None):
'''
Configure CURVE authentication for a given domain. CURVE authentication
uses a directory that holds all public client certificates, i.e. their
public keys. The certificates must be in zcert_save () format.
To cover all domains, use "*".
You can add and remove certificates in that directory at any time.
To allow all client keys without checking, specify CURVE_ALLOW_ANY for
the location.
'''
self.pipe.send_multipart([b'CURVE', domain, location])
def start(self, verbose=False):
'''
Start performing ZAP authentication
'''
# create a socket to communicate with auth thread.
self.pipe = self.context.socket(zmq.PAIR)
self.pipe.linger = 1
self.pipestream = ZMQStream(self.pipe, IOLoop.instance())
self.pipestream.on_recv(self._on_message)
self.pipestream.bind(self.pipe_endpoint)
self.thread = AuthAgentThread(self.context,
self.pipe_endpoint, verbose=verbose)
self.thread.start()
def stop(self):
'''
Stop performing ZAP authentication
'''
if self.pipe:
self.pipe.send(b'TERMINATE')
if self.is_alive():
self.thread.join()
self.thread = None
self.pipe.close()
self.pipe = None
self.pipestream = None
def is_alive(self):
''' Is the Auth thread currently running ? '''
if self.thread and self.thread.is_alive():
return True
return False
def __del__(self):
self.stop()
def _on_message(self, msg):
'''
Process a message from the Auth thread
'''
status = msg[0]
if status != b"OK":
print "E: status from auth thread indicates error: {}".format(status)
#!/usr/bin/python2.6
import os
import shutil
import zmq
import authentication
if __name__ == '__main__':
def can_connect(server, client):
result = False
iface = 'tcp://127.0.0.1'
port = server.bind_to_random_port(iface)
client.connect("%s:%i" % (iface, port))
msg = ["Hello World"]
server.send_multipart(msg)
poller = zmq.Poller()
poller.register(client, zmq.POLLIN)
socks = dict(poller.poll(100))
if client in socks and socks[client] == zmq.POLLIN:
rcvd_msg = client.recv_multipart()
result = rcvd_msg == msg
return result
try:
context = zmq.Context()
auth = authentication.Authenticator(context, verbose=True)
# A default NULL connection should always succeed, and not
# go through our authentication infrastructure at all.
print "Try connecting using NULL and no authentication enabled, connection should pass"
server = context.socket(zmq.PUSH)
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
######################################################################
# test NULL authentication
# When we set a domain on the server, will switch on authentication
# for NULL sockets, but with no policies, the client connection
# will still be allowed.
print "Try connecting using NULL and authentication enabled, connection should pass"
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
# Blacklist 127.0.0.1, connection should fail
print "Blacklist 127.0.0.1, connection should fail"
auth.deny('127.0.0.1')
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert not can_connect(server, client)
client.close()
server.close()
print ""
# Whitelist 127.0.0.1, which overrides the blacklist
print "Whitelist 127.0.0.1, which overrides the blacklist, connection should pass"
auth.allow('127.0.0.1')
server = context.socket(zmq.PUSH)
server.zap_domain = 'global'
client = context.socket(zmq.PULL)
assert can_connect(server, client)
client.close()
server.close()
print ""
######################################################################
# test PLAIN authentication
# attempt PLAIN authentication - without configuring server for PLAIN
print "Try PLAIN authentication - without configuring server, connection should fail"
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Password'
assert not can_connect(server, client)
client.close()
server.close()
print ""
# try PLAIN authentication
print "Try PLAIN authentication - with server configured, connection should pass"
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Password'
auth.configure_plain(domain='*', passwords={'admin': 'Password'})
assert can_connect(server, client)
client.close()
server.close()
print ""
# attempt PLAIN using bogus credentials
print "Try PLAIN authentication - with bogus credentials, connection should fail"
server = context.socket(zmq.PUSH)
server.plain_server = True
client = context.socket(zmq.PULL)
client.plain_username = 'admin'
client.plain_password = 'Bogus'
assert not can_connect(server, client)
client.close()
server.close()
print ""
######################################################################
# test CURVE authentication
# Generate new CURVE keypairs for this test
print "Creating CURVE authentication certificates"
base_dir = os.path.dirname(os.path.abspath(__file__))
keys_dir = os.path.join(base_dir, '.certs')
public_keys_dir = os.path.join(base_dir, '.certs_public')
secret_keys_dir = os.path.join(base_dir, '.certs_private')
if os.path.exists(keys_dir):
shutil.rmtree(keys_dir)
if os.path.exists(public_keys_dir):
shutil.rmtree(public_keys_dir)
if os.path.exists(secret_keys_dir):
shutil.rmtree(secret_keys_dir)
os.mkdir(keys_dir)
os.mkdir(public_keys_dir)
os.mkdir(secret_keys_dir)
# create new keys in .certs dir
server_public_file, server_secret_file = authentication.create_certificates(keys_dir, "server")
client_public_file, client_secret_file = authentication.create_certificates(keys_dir, "client")
# move keys to appropriate directories
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(public_keys_dir, '.'))
# move secret keys to their own directory
for key_file in os.listdir(keys_dir):
if key_file.endswith(".key_secret"):
shutil.move(os.path.join(keys_dir, key_file),
os.path.join(secret_keys_dir, '.'))
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret")
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret")
server_public, server_secret = authentication.load_certificate(server_secret_file)
client_public, client_secret = authentication.load_certificate(client_secret_file)
# test without setting up any authentication
print "Try CURVE authentication - without configuring server, connection should fail"
server = context.socket(zmq.PUSH)
#server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
assert (server.mechanism == zmq.CURVE), "unexpected mechanism"
client = context.socket(zmq.PULL)
client.curve_serverkey = server_public
client.curve_publickey = client_public
client.curve_secretkey = client_secret
assert not can_connect(server, client), "expected connect == false"
client.close()
server.close()
print ""
# test CURVE_ALLOW_ANY
print "Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass"
auth.configure_curve(domain='*', location=authentication.CURVE_ALLOW_ANY)
server = context.socket(zmq.PUSH)
#server.curve_publickey = server_public
#assert (server.get(zmq.CURVE_PUBLICKEY) == server_public), "public key mismatch"
server.curve_secretkey = server_secret
#assert (server.get(zmq.CURVE_SECRETKEY) == server_secret), "secret key mismatch"
server.curve_server = True
assert (server.mechanism == zmq.CURVE), "unexpected mechanism"
assert server.get(zmq.CURVE_SERVER) == True, "expected CURVE server true"
client = context.socket(zmq.PULL)
#print "client_public: {}".format(client_public)
#print "client_secret: {}".format(client_secret)
#print "server_public: {}".format(server_public)
client.curve_serverkey = server_public
client.curve_publickey = client_public
client.curve_secretkey = client_secret
#print "checking can connect"
assert can_connect(server, client), "expected connect == true"
client.close()
server.close()
print ""
# Test full client authentication using certificates
print "Try CURVE authentication - with server configured, connection should pass"
auth.configure_curve(domain='*', location=public_keys_dir)
server = context.socket(zmq.PUSH)
server.curve_publickey = server_public
server.curve_secretkey = server_secret
server.curve_server = True
client = context.socket(zmq.PULL)
client.curve_publickey = client_public
client.curve_secretkey = client_secret
client.curve_serverkey = server_public
assert can_connect(server, client)
client.close()
server.close()
print ""
# Remove authenticator and check a normal connection works
auth.stop()
del auth
print ""
print "Try connecting using NULL and no authentication enabled, connection should pass"
server = context.socket(zmq.PUSH)
client = context.socket(zmq.PULL)
assert can_connect(server, client)
#client.close()
#server.close()
finally:
client.close()
server.close()
context.term()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment