Last active
July 24, 2021 19:37
-
-
Save claws/7856874 to your computer and use it in GitHub Desktop.
An authentication module for pyzmq based on zauth from czmq. See the included test file for usage. To run the test function use: python test_authentication.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
An authentication module for pyzmq modelled on zauth from czmq. | |
The functions to read and generate certificates should be interoperable | |
with czmq's zcert's - though are not as fully featured. | |
''' | |
import datetime | |
import glob | |
import json | |
import os | |
from threading import Thread | |
import zmq | |
from zmq.utils import z85 | |
from zmq.eventloop.ioloop import IOLoop | |
from zmq.eventloop.zmqstream import ZMQStream | |
CURVE_ALLOW_ANY = '*' | |
_cert_secret_banner = """# **** Generated on {} by pyzmq **** | |
# ZeroMQ CURVE **Secret** Certificate | |
# DO NOT PROVIDE THIS FILE TO OTHER USERS nor change its permissions. | |
""" | |
_cert_public_banner = """# **** Generated on {} by pyzmq **** | |
# ZeroMQ CURVE Public Certificate | |
# Exchange securely, or use a secure mechanism to verify the contents | |
# of this file after exchange. Store public certificates in your home | |
# directory, in the .curve subdirectory. | |
""" | |
def create_certificates(key_dir, name, metadata=None): | |
''' | |
Create zcert-esque public and private certificate files. | |
Return the file paths to the public and secret certificate files. | |
''' | |
def write_key_file(key_filename, banner, public_key, secret_key=None, metadata=None): | |
""" Create a certificate file """ | |
with open(key_filename, 'w') as f: | |
f.write(banner.format(datetime.datetime.now())) | |
f.write('metadata\n') | |
if metadata and isinstance(metadata, dict): | |
for k, v in metadata.items(): | |
f.write(" {} = {}\n".format(k, v)) | |
f.write('curve\n') | |
f.write(" public-key = \"{}\"\n".format(public_key)) | |
if secret_key: | |
f.write(" secret-key = \"{}\"\n".format(secret_key)) | |
public_key, secret_key = zmq.curve_keypair() | |
base_filename = os.path.join(key_dir, name) | |
secret_key_file = "{}.key_secret".format(base_filename) | |
public_key_file = "{}.key".format(base_filename) | |
now = datetime.datetime.now() | |
write_key_file(public_key_file, | |
_cert_public_banner.format(now), | |
public_key) | |
write_key_file(secret_key_file, | |
_cert_secret_banner.format(now), | |
public_key, | |
secret_key=secret_key, | |
metadata=metadata) | |
return public_key_file, secret_key_file | |
def load_certificate(filename): | |
''' | |
Load a certificate specified by filename and return the public | |
and private keys read from the file. If the certificate file | |
only contains the public key then secret_key will be None. | |
''' | |
public_key = None | |
secret_key = None | |
if not os.path.exists(filename): | |
print "E: Invalid certificate file: {}".format(filename) | |
return public_key, secret_key | |
with open(filename, 'r') as f: | |
lines = f.readlines() | |
lines = filter(None, lines) | |
lines = filter(lambda x: not x.startswith('#'), lines) | |
lines = [x.strip() for x in lines] | |
for line in lines: | |
if line.startswith('public-key'): | |
public_key = line.split(" = ")[1].strip().replace('"', '') | |
if line.startswith('secret-key'): | |
secret_key = line.split(" = ")[1].strip().replace('"', '') | |
return public_key, secret_key | |
def load_certificates(location): | |
''' Load public keys from all certificates stored at location directory ''' | |
certs = {} | |
if os.path.isdir(location): | |
# Follow czmq pattern of public keys stored in *.key files. | |
glob_string = os.path.join(location, "*.key") | |
cert_files = glob.glob(glob_string) | |
for cert_file in cert_files: | |
public_key, _ = load_certificate(cert_file) | |
if public_key: | |
certs[public_key] = 'OK' | |
return certs | |
class AuthAgentThread(Thread): | |
''' Thread in which ZAP authentication actually happens ''' | |
def __init__(self, context, endpoint, verbose=False): | |
super(AuthAgentThread, self).__init__() | |
self.context = context | |
self.verbose = verbose | |
self.allow_any = False | |
self.zap = None | |
self.whitelist = [] | |
self.blacklist = [] | |
# passwords is a dict keyed by domain and contains values | |
# of dicts with username:password pairs. | |
self.passwords = {} | |
# certs is dict keyed by domain and contains values | |
# of dicts keyed by the public keys from the specified location. | |
self.certs = {} | |
# create a socket to communicate back to main thread. | |
self.pipe = context.socket(zmq.PAIR) | |
self.pipe.linger = 1 | |
self.pipe.connect(endpoint) | |
def run(self): | |
''' Start the Authentication Agent thread task ''' | |
# Create ZAP handler and get ready for requests | |
self.zap = self.context.socket(zmq.REP) | |
self.zap.linger = 1 | |
self.zap.bind("inproc://zeromq.zap.01") | |
poller = zmq.Poller() | |
poller.register(self.pipe, zmq.POLLIN) | |
poller.register(self.zap, zmq.POLLIN) | |
while True: | |
try: | |
socks = dict(poller.poll()) | |
except zmq.ZMQError: | |
break # interrupted | |
if self.pipe in socks and socks[self.pipe] == zmq.POLLIN: | |
terminate = self._handle_pipe() | |
if terminate: | |
break | |
if self.zap in socks and socks[self.zap] == zmq.POLLIN: | |
self._authenticate() | |
self.pipe.close() | |
self.zap.close() | |
def _send_zap_reply(self, sequence, status_code, status_text): | |
''' | |
Send a ZAP reply to the handler socket. | |
''' | |
uid = b"{}".format(os.getuid()) if status_code == 'OK' else b"" | |
metadata = b"" # not currently used | |
if self.verbose: | |
print "I: ZAP reply code={} text={}".format(status_code, status_text) | |
reply = [b"1.0", sequence, status_code, status_text, uid, metadata] | |
self.zap.send_multipart(reply) | |
def _handle_pipe(self): | |
''' | |
Handle a message from front-end API. | |
''' | |
terminate = False | |
# Get the whole message off the pipe in one go | |
msg = self.pipe.recv_multipart() | |
if msg is None: | |
terminate = True | |
return terminate | |
command = msg[0] | |
if self.verbose: | |
print "I: auth received API command {}".format(command) | |
if command == 'ALLOW': | |
address = msg[1] | |
if address not in self.whitelist: | |
self.whitelist.append(address) | |
self.pipe.send(b'OK') | |
elif command == 'DENY': | |
address = msg[1] | |
if address not in self.blacklist: | |
self.blacklist.append(address) | |
self.pipe.send(b'OK') | |
elif command == 'PLAIN': | |
domain = msg[1] | |
json_passwords = msg[2] | |
self.passwords[domain] = json.loads(json_passwords) | |
self.pipe.send(b'OK') | |
elif command == 'CURVE': | |
# For now we don't do anything with domains | |
domain = msg[1] | |
# If location is CURVE_ALLOW_ANY, allow all clients. Otherwise | |
# treat location as a directory that holds the certificates. | |
location = msg[2] | |
if location == CURVE_ALLOW_ANY: | |
self.allow_any = True | |
else: | |
self.allow_any = False | |
if os.path.isdir(location): | |
self.certs[domain] = load_certificates(location) | |
else: | |
if self.verbose: | |
print "E: Invalid CURVE certs location: {}".format(location) | |
self.pipe.send(b'OK') | |
elif command == 'VERBOSE': | |
enabled = msg[1] | |
self.verbose = enabled == '1' | |
self.pipe.send(b'OK') | |
elif command == 'TERMINATE': | |
terminate = True | |
self.pipe.send(b'OK') | |
else: | |
print "E: invalid auth command from API: {}".format(command) | |
return terminate | |
def _authenticate_plain(self, domain, username, password): | |
''' | |
Perform ZAP authentication check for PLAIN mechanism | |
''' | |
allowed = False | |
reason = b"" | |
if self.passwords: | |
# If no domain is not specified then use the default domain | |
if not domain: | |
domain = '*' | |
if domain in self.passwords: | |
if username in self.passwords[domain]: | |
if password == self.passwords[domain][username]: | |
allowed = True | |
else: | |
reason = b"Invalid password" | |
else: | |
reason = b"Invalid username" | |
else: | |
reason = b"Invalid domain" | |
if self.verbose: | |
status = "DENIED" | |
if allowed: | |
status = "ALLOWED" | |
print "I: {} (PLAIN) domain={} username={} password={}".format(status, | |
domain, username, password) | |
else: | |
print "I: {} {}".format(status, reason) | |
else: | |
reason = b"No passwords defined" | |
if self.verbose: | |
print "I: DENIED (PLAIN) {}".format(reason) | |
return allowed, reason | |
def _authenticate_curve(self, domain, client_key): | |
''' | |
Perform ZAP authentication check for CURVE mechanism | |
''' | |
allowed = False | |
reason = b"" | |
if self.allow_any: | |
allowed = True | |
reason = b"OK" | |
if self.verbose: | |
print "I: ALLOWED (CURVE allow any client)" | |
else: | |
# If no explicit domain is specified then use the default domain | |
if not domain: | |
domain = '*' | |
if domain in self.certs: | |
# The certs dict stores keys in z85 format, convert binary key to z85 text | |
z85_client_key = z85.encode(client_key) | |
if z85_client_key in self.certs[domain]: | |
allowed = True | |
reason = b"OK" | |
else: | |
reason = b"Unknown key" | |
if self.verbose: | |
status = "DENIED" | |
if allowed: | |
status = "ALLOWED" | |
print "I: {} (CURVE) domain={} client_key={}".format(status, | |
domain, z85_client_key) | |
else: | |
reason = b"Unknown domain" | |
return allowed, reason | |
def _authenticate(self): | |
''' | |
Perform ZAP authentication. | |
''' | |
msg = self.zap.recv_multipart() | |
if not msg: return | |
version, sequence, domain, address, identity, mechanism = msg[:6] | |
if (version != b"1.0"): | |
self._send_zap_reply(sequence, b"400", b"Invalid version") | |
return | |
if self.verbose: | |
print "version: {}".format(version) | |
print "sequence: {}".format(sequence) | |
print "domain: {}".format(domain) | |
print "address: {}".format(address) | |
print "identity: {}".format(identity) | |
print "mechanism: {}".format(mechanism) | |
# Check if address is explicitly whitelisted or blacklisted | |
allowed = False | |
denied = False | |
reason = b"NO ACCESS" | |
if self.whitelist: | |
if address in self.whitelist: | |
allowed = True | |
if self.verbose: | |
print "I: PASSED (whitelist) address={}".format(address) | |
else: | |
denied = True | |
reason = b"Address not in whitelist" | |
if self.verbose: | |
print "I: DENIED (not in whitelist) address={}".format(address) | |
elif self.blacklist: | |
if address in self.blacklist: | |
denied = True | |
reason = b"Address is blacklisted" | |
if self.verbose: | |
print "I: DENIED (blacklist) address={}".format(address) | |
else: | |
allowed = True | |
if self.verbose: | |
print "I: PASSED (not in blacklist) address={}".format(address) | |
# Mechanism-specific checks | |
if not denied: | |
if mechanism == b'NULL' and not allowed: | |
# For NULL, we allow if the address wasn't blacklisted | |
if self.verbose: | |
print "I: ALLOWED (NULL)" | |
allowed = True | |
elif mechanism == b'PLAIN': | |
# For PLAIN, even a whitelisted address must authenticate | |
username, password = msg[6:] | |
allowed, reason = self._authenticate_plain(domain, username, password) | |
elif mechanism == b'CURVE': | |
# For CURVE, even a whitelisted address must authenticate | |
key = msg[6] | |
allowed, reason = self._authenticate_curve(domain, key) | |
if allowed: | |
self._send_zap_reply(sequence, b"200", b"OK") | |
else: | |
self._send_zap_reply(sequence, b"400", reason) | |
class Authenticator(object): | |
''' | |
An Authenticator object takes over authentication for all incoming | |
connections in its context. | |
Note: | |
- libzmq provides four levels of security: default NULL (which zauth does | |
not see), and authenticated NULL, PLAIN, and CURVE, which zauth can see. | |
- until you add policies, all incoming NULL connections are allowed | |
(classic ZeroMQ behavior), and all PLAIN and CURVE connections are denied. | |
All work is done by a background thread, the "agent", which we talk | |
to over a pipe. This lets the agent do work asynchronously in the | |
background while our application does other things. This is invisible to | |
the caller, who sees a classic API. | |
''' | |
def __init__(self, context, verbose=False): | |
if zmq.zmq_version_info() < (4,0): | |
raise Exception("Security is only available in libzmq >= 4.0") | |
self.context = context | |
self.pipe = None | |
self.pipe_endpoint = "inproc://{}.inproc".format(id(self)) | |
self.thread = None | |
self.start(verbose) | |
def allow(self, address): | |
''' | |
Allow (whitelist) a single IP address. For NULL, all clients from this | |
address will be accepted. For PLAIN and CURVE, they will be allowed to | |
continue with authentication. You can call this method multiple times | |
to whitelist multiple IP addresses. If you whitelist a single address, | |
any non-whitelisted addresses are treated as blacklisted. | |
''' | |
self.pipe.send_multipart([b'ALLOW', address]) | |
def deny(self, address): | |
''' | |
Deny (blacklist) a single IP address. For all security mechanisms, this | |
rejects the connection without any further authentication. Use either a | |
whitelist, or a blacklist, not not both. If you define both a whitelist | |
and a blacklist, only the whitelist takes effect. | |
''' | |
self.pipe.send_multipart([b'DENY', address]) | |
def verbose(self, enabled): | |
''' | |
Enable verbose tracing of commands and activity. | |
''' | |
self.pipe.send_multipart([b'VERBOSE', b'1' if enabled else b'0']) | |
def configure_plain(self, domain='*', passwords=None): | |
''' | |
Configure PLAIN authentication for a given domain. PLAIN authentication | |
uses a plain-text password file. To cover all domains, use "*". | |
You can modify the password file at any time; it is reloaded automatically. | |
''' | |
if passwords: | |
if isinstance(passwords, dict): | |
passwords = json.dumps(passwords) | |
self.pipe.send_multipart([b'PLAIN', domain, passwords]) | |
def configure_curve(self, domain='*', location=None): | |
''' | |
Configure CURVE authentication for a given domain. CURVE authentication | |
uses a directory that holds all public client certificates, i.e. their | |
public keys. The certificates must be in zcert_save () format. | |
To cover all domains, use "*". | |
You can add and remove certificates in that directory at any time. | |
To allow all client keys without checking, specify CURVE_ALLOW_ANY for | |
the location. | |
''' | |
self.pipe.send_multipart([b'CURVE', domain, location]) | |
def start(self, verbose=False): | |
''' | |
Start performing ZAP authentication | |
''' | |
# create a socket to communicate with auth thread. | |
self.pipe = self.context.socket(zmq.PAIR) | |
self.pipe.linger = 1 | |
self.pipestream = ZMQStream(self.pipe, IOLoop.instance()) | |
self.pipestream.on_recv(self._on_message) | |
self.pipestream.bind(self.pipe_endpoint) | |
self.thread = AuthAgentThread(self.context, | |
self.pipe_endpoint, verbose=verbose) | |
self.thread.start() | |
def stop(self): | |
''' | |
Stop performing ZAP authentication | |
''' | |
if self.pipe: | |
self.pipe.send(b'TERMINATE') | |
if self.is_alive(): | |
self.thread.join() | |
self.thread = None | |
self.pipe.close() | |
self.pipe = None | |
self.pipestream = None | |
def is_alive(self): | |
''' Is the Auth thread currently running ? ''' | |
if self.thread and self.thread.is_alive(): | |
return True | |
return False | |
def __del__(self): | |
self.stop() | |
def _on_message(self, msg): | |
''' | |
Process a message from the Auth thread | |
''' | |
status = msg[0] | |
if status != b"OK": | |
print "E: status from auth thread indicates error: {}".format(status) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python2.6 | |
import os | |
import shutil | |
import zmq | |
import authentication | |
if __name__ == '__main__': | |
def can_connect(server, client): | |
result = False | |
iface = 'tcp://127.0.0.1' | |
port = server.bind_to_random_port(iface) | |
client.connect("%s:%i" % (iface, port)) | |
msg = ["Hello World"] | |
server.send_multipart(msg) | |
poller = zmq.Poller() | |
poller.register(client, zmq.POLLIN) | |
socks = dict(poller.poll(100)) | |
if client in socks and socks[client] == zmq.POLLIN: | |
rcvd_msg = client.recv_multipart() | |
result = rcvd_msg == msg | |
return result | |
try: | |
context = zmq.Context() | |
auth = authentication.Authenticator(context, verbose=True) | |
# A default NULL connection should always succeed, and not | |
# go through our authentication infrastructure at all. | |
print "Try connecting using NULL and no authentication enabled, connection should pass" | |
server = context.socket(zmq.PUSH) | |
client = context.socket(zmq.PULL) | |
assert can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
###################################################################### | |
# test NULL authentication | |
# When we set a domain on the server, will switch on authentication | |
# for NULL sockets, but with no policies, the client connection | |
# will still be allowed. | |
print "Try connecting using NULL and authentication enabled, connection should pass" | |
server = context.socket(zmq.PUSH) | |
server.zap_domain = 'global' | |
client = context.socket(zmq.PULL) | |
assert can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
# Blacklist 127.0.0.1, connection should fail | |
print "Blacklist 127.0.0.1, connection should fail" | |
auth.deny('127.0.0.1') | |
server = context.socket(zmq.PUSH) | |
server.zap_domain = 'global' | |
client = context.socket(zmq.PULL) | |
assert not can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
# Whitelist 127.0.0.1, which overrides the blacklist | |
print "Whitelist 127.0.0.1, which overrides the blacklist, connection should pass" | |
auth.allow('127.0.0.1') | |
server = context.socket(zmq.PUSH) | |
server.zap_domain = 'global' | |
client = context.socket(zmq.PULL) | |
assert can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
###################################################################### | |
# test PLAIN authentication | |
# attempt PLAIN authentication - without configuring server for PLAIN | |
print "Try PLAIN authentication - without configuring server, connection should fail" | |
server = context.socket(zmq.PUSH) | |
server.plain_server = True | |
client = context.socket(zmq.PULL) | |
client.plain_username = 'admin' | |
client.plain_password = 'Password' | |
assert not can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
# try PLAIN authentication | |
print "Try PLAIN authentication - with server configured, connection should pass" | |
server = context.socket(zmq.PUSH) | |
server.plain_server = True | |
client = context.socket(zmq.PULL) | |
client.plain_username = 'admin' | |
client.plain_password = 'Password' | |
auth.configure_plain(domain='*', passwords={'admin': 'Password'}) | |
assert can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
# attempt PLAIN using bogus credentials | |
print "Try PLAIN authentication - with bogus credentials, connection should fail" | |
server = context.socket(zmq.PUSH) | |
server.plain_server = True | |
client = context.socket(zmq.PULL) | |
client.plain_username = 'admin' | |
client.plain_password = 'Bogus' | |
assert not can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
###################################################################### | |
# test CURVE authentication | |
# Generate new CURVE keypairs for this test | |
print "Creating CURVE authentication certificates" | |
base_dir = os.path.dirname(os.path.abspath(__file__)) | |
keys_dir = os.path.join(base_dir, '.certs') | |
public_keys_dir = os.path.join(base_dir, '.certs_public') | |
secret_keys_dir = os.path.join(base_dir, '.certs_private') | |
if os.path.exists(keys_dir): | |
shutil.rmtree(keys_dir) | |
if os.path.exists(public_keys_dir): | |
shutil.rmtree(public_keys_dir) | |
if os.path.exists(secret_keys_dir): | |
shutil.rmtree(secret_keys_dir) | |
os.mkdir(keys_dir) | |
os.mkdir(public_keys_dir) | |
os.mkdir(secret_keys_dir) | |
# create new keys in .certs dir | |
server_public_file, server_secret_file = authentication.create_certificates(keys_dir, "server") | |
client_public_file, client_secret_file = authentication.create_certificates(keys_dir, "client") | |
# move keys to appropriate directories | |
for key_file in os.listdir(keys_dir): | |
if key_file.endswith(".key"): | |
shutil.move(os.path.join(keys_dir, key_file), | |
os.path.join(public_keys_dir, '.')) | |
# move secret keys to their own directory | |
for key_file in os.listdir(keys_dir): | |
if key_file.endswith(".key_secret"): | |
shutil.move(os.path.join(keys_dir, key_file), | |
os.path.join(secret_keys_dir, '.')) | |
server_secret_file = os.path.join(secret_keys_dir, "server.key_secret") | |
client_secret_file = os.path.join(secret_keys_dir, "client.key_secret") | |
server_public, server_secret = authentication.load_certificate(server_secret_file) | |
client_public, client_secret = authentication.load_certificate(client_secret_file) | |
# test without setting up any authentication | |
print "Try CURVE authentication - without configuring server, connection should fail" | |
server = context.socket(zmq.PUSH) | |
#server.curve_publickey = server_public | |
server.curve_secretkey = server_secret | |
server.curve_server = True | |
assert (server.mechanism == zmq.CURVE), "unexpected mechanism" | |
client = context.socket(zmq.PULL) | |
client.curve_serverkey = server_public | |
client.curve_publickey = client_public | |
client.curve_secretkey = client_secret | |
assert not can_connect(server, client), "expected connect == false" | |
client.close() | |
server.close() | |
print "" | |
# test CURVE_ALLOW_ANY | |
print "Try CURVE authentication - with server configured to CURVE_ALLOW_ANY, connection should pass" | |
auth.configure_curve(domain='*', location=authentication.CURVE_ALLOW_ANY) | |
server = context.socket(zmq.PUSH) | |
#server.curve_publickey = server_public | |
#assert (server.get(zmq.CURVE_PUBLICKEY) == server_public), "public key mismatch" | |
server.curve_secretkey = server_secret | |
#assert (server.get(zmq.CURVE_SECRETKEY) == server_secret), "secret key mismatch" | |
server.curve_server = True | |
assert (server.mechanism == zmq.CURVE), "unexpected mechanism" | |
assert server.get(zmq.CURVE_SERVER) == True, "expected CURVE server true" | |
client = context.socket(zmq.PULL) | |
#print "client_public: {}".format(client_public) | |
#print "client_secret: {}".format(client_secret) | |
#print "server_public: {}".format(server_public) | |
client.curve_serverkey = server_public | |
client.curve_publickey = client_public | |
client.curve_secretkey = client_secret | |
#print "checking can connect" | |
assert can_connect(server, client), "expected connect == true" | |
client.close() | |
server.close() | |
print "" | |
# Test full client authentication using certificates | |
print "Try CURVE authentication - with server configured, connection should pass" | |
auth.configure_curve(domain='*', location=public_keys_dir) | |
server = context.socket(zmq.PUSH) | |
server.curve_publickey = server_public | |
server.curve_secretkey = server_secret | |
server.curve_server = True | |
client = context.socket(zmq.PULL) | |
client.curve_publickey = client_public | |
client.curve_secretkey = client_secret | |
client.curve_serverkey = server_public | |
assert can_connect(server, client) | |
client.close() | |
server.close() | |
print "" | |
# Remove authenticator and check a normal connection works | |
auth.stop() | |
del auth | |
print "" | |
print "Try connecting using NULL and no authentication enabled, connection should pass" | |
server = context.socket(zmq.PUSH) | |
client = context.socket(zmq.PULL) | |
assert can_connect(server, client) | |
#client.close() | |
#server.close() | |
finally: | |
client.close() | |
server.close() | |
context.term() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment