Skip to content

Instantly share code, notes, and snippets.

@rcoder
Created March 10, 2010 22:02
Show Gist options
  • Save rcoder/328486 to your computer and use it in GitHub Desktop.
Save rcoder/328486 to your computer and use it in GitHub Desktop.
"""
TCP proxy which limits which commands may be forwarded to a MongoDB server
Use to enforce read/write/insert-only restrictions
Based on protocol docs from the MongoDB wiki:
http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol
"""
import struct
import optparse
from datetime import datetime
from twisted.internet import reactor
from twisted.protocols import portforward
# enumerate all protocol opcodes
OPCODES = {
'REPLY' : 1,
'MSG' : 1000,
'UPDATE' : 2001,
'INSERT' : 2002,
'GET_BY_OID' : 2003,
'QUERY' : 2004,
'GET_MORE' : 2005,
'DELETE' : 2006,
'KILL_CURSORS' : 2007
}
OPCODE_LABELS = dict(zip(OPCODES.values(), OPCODES.keys()))
ALL_OPCODES = OPCODES.keys()
# each message begins with a shared header struct:
# struct {
# int32 messageLength;
# int32 requestID;
# int32 responseTo;
# int32 opCode;
# }
# based on this structure, we can see that the opcode is encoded as a
# little-endian integer in each message payload, after the length header,
# request id, and 4-byte (one-word) zero pad
OPCODE_OFFSET = struct.calcsize('<iii')
def get_opcode(data):
"""
Get the MongoDB wire protocol opcode for a client request packet
"""
result = struct.unpack_from('<i', data, offset=OPCODE_OFFSET)
return result[0]
class MongoFilteringProxy(portforward.ProxyServer):
"""
TCP socket proxy which forwards all messages that do *not* contain
an opcode in the `excluded_opcodes` list
"""
def __init__(self, *args, **kwargs):
# by default, allow all opcodes
self.allowed_opcodes = set(ALL_OPCODES)
def handleReject(self, opcode, data):
"""
By default, simply prints a message including the blocked opcode
to stdout; override in a subclass to take other action (logging, etc.)
when a message is rejected
"""
peer = self.transport.getPeer()
print "[%s] %s rejected opcode %s" % (datetime.now().ctime(), peer.host, opcode)
def dataReceived(self, data):
opcode_number = get_opcode(data)
opcode = OPCODE_LABELS[opcode_number]
if opcode in self.factory.allowed_opcodes:
self.peer.transport.write(data)
else:
self.handleReject(opcode, data)
class MongoProxyFactory(portforward.ProxyFactory):
"""
This protocol provides configuration hooks for determining which MongoDB
opcodes will be passed through the proxy
"""
protocol = MongoFilteringProxy
allowed_opcodes = set(ALL_OPCODES)
def allow(self, opcodes):
"""
Set a whitelist of allowed opcodes (deny all others)
"""
self.allowed_opcodes = set(opcodes)
def deny(self, opcodes):
"""
Set a blacklist of disallowed opcodes (allow all others)
"""
self.allowed_opcodes = set(ALL_OPCODES) - set(opcodes)
if __name__ == '__main__':
parser = optparse.OptionParser()
parser.add_option('-d', '--dbhost', dest='hostname', default='localhost',
help="MongoDB hostname")
parser.add_option('-p', '--dbport', dest='port', default=27017, type='int',
help="MongoDB port")
parser.add_option('-l', '--localport', dest='localport', default=29017, type='int',
help="Local listener port")
parser.add_option('-m', '--mode', dest='mode', default='read',
help="filtering mode: read (read-only), insert (insert-only), full (full access) " +
"[default: %default]")
opts, args = parser.parse_args()
factory = MongoProxyFactory(opts.hostname, opts.port)
if opts.mode == 'full':
factory.allow(ALL_OPCODES)
elif opts.mode == 'insert':
factory.deny(('DELETE', 'UPDATE'))
elif opts.mode == 'read':
factory.deny(('DELETE', 'INSERT', 'UPDATE'))
else:
print >> sys.stderr, "invalid mode"
parser.print_help()
sys.exit(1)
print "Starting MongoDB proxy on port %i, connected to %s:%i" % \
(opts.localport, opts.hostname, opts.port)
opcodes = list(factory.allowed_opcodes)
opcodes.sort()
print "Running in '%s' mode. Allowed opcodes:\n %s" % \
(opts.mode, (' ').join(opcodes))
print "---"
reactor.listenTCP(opts.localport, factory)
reactor.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment