Skip to content

Instantly share code, notes, and snippets.

@amcgregor
Created July 12, 2011 14:14
Show Gist options
  • Save amcgregor/1078067 to your computer and use it in GitHub Desktop.
Save amcgregor/1078067 to your computer and use it in GitHub Desktop.
A blog-ish post about bitfield data design in Python.

Enumerated Bitfield Data Type

If it’s worth doing once, it’s worth writing a system to do it.

So I’m exploring writing a DNS server in Python, and while there are a number of solutions for reading and writing DNS data in a variety of formats (such as BIND configuration files, over-the-wire encoding, etc.) I learn best by doing, not by using someone else’s code. Calling it “not invented here” is naïve at best, so let’s get started with the first over-engineered bit I’ve written.

Flags

DNS is a binary protocol, which is a seriously good thing compared to protocols such as SMTP, NNTP, POP, and IMAP. It makes rather extensive use of bit masks to represent flags, thus we’ll need a method to encode and decode these bit masks, and a convenient way to display them in a human-readable way. Also attached to this Gist is a copy of the flags.py file from the dnspython package, a quite complete and mature package by any standard. Unfortunately, it followed in twisted.names’ footsteps. If you examine that file, you’ll notice that it defines everything several times; constants to represent the bits, a dictionary to map names to bits, then, separately, inverse mappings, then separate functions to encode and decode the bit fields.

This is horrible, horrible duplication. So to start, I needed an attribute-access dictionary; one is already part of my marrow suite, so I’ll base my Flags class on that:

class Flags(marrow.util.bunch.Bunch):

Now the first thing we’ll need is the inverse mapping. Since the meaning of the bits can not change during runtime, it’s safe to do so in the __init__ method:

    def __init__(self, *args, **kw):
        super(Flags, self).__init__(*args, **kw)
        
        self.__dict__['inverse'] = dict(zip(self.itervalues(), self.iterkeys()))

To prevent modification later, we disable __setitem__ and __setattr__:

    def __setitem__(self, name, value):
        raise RuntimeError("Flags can not be defined at runtime.")
    
    __setattr__ = __setitem__

We then need to be able to encode (turn a textual representation into binary) and decode (the reverse). First, encoding:

    def encode(self, text):
        return reduce(operator.__or__, (self[hunk.strip().upper()] for hunk in text.split()), 0)

The above may look slightly insane, but it works like a hot damn. All it does is look up each flag value (the bits) and OR them together. Now for decoding, which is a little harder:

    def decode(self, value):
        return ' '.join(v for k, v in reversed(sorted(zip(self.itervalues(), self.iterkeys()))) if value & k)

This returns an ordered string containing the flag names. It purposefully does not use the self.inverse mapping because it needs the results sorted.

See the second attachment (02-marrow-dns-flags.py) for the complete implementation and example flag sets.

# Copyright (C) 2001-2007, 2009, 2010 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
"""DNS Message Flags."""
# Standard DNS flags
QR = 0x8000
AA = 0x0400
TC = 0x0200
RD = 0x0100
RA = 0x0080
AD = 0x0020
CD = 0x0010
# EDNS flags
DO = 0x8000
_by_text = {
'QR' : QR,
'AA' : AA,
'TC' : TC,
'RD' : RD,
'RA' : RA,
'AD' : AD,
'CD' : CD
}
_edns_by_text = {
'DO' : DO
}
# We construct the inverse mappings programmatically to ensure that we
# cannot make any mistakes (e.g. omissions, cut-and-paste errors) that
# would cause the mappings not to be true inverses.
_by_value = dict([(y, x) for x, y in _by_text.iteritems()])
_edns_by_value = dict([(y, x) for x, y in _edns_by_text.iteritems()])
def _order_flags(table):
order = list(table.iteritems())
order.sort()
order.reverse()
return order
_flags_order = _order_flags(_by_value)
_edns_flags_order = _order_flags(_edns_by_value)
def _from_text(text, table):
flags = 0
tokens = text.split()
for t in tokens:
flags = flags | table[t.upper()]
return flags
def _to_text(flags, table, order):
text_flags = []
for k, v in order:
if flags & k != 0:
text_flags.append(v)
return ' '.join(text_flags)
def from_text(text):
"""Convert a space-separated list of flag text values into a flags
value.
@rtype: int"""
return _from_text(text, _by_text)
def to_text(flags):
"""Convert a flags value into a space-separated list of flag text
values.
@rtype: string"""
return _to_text(flags, _by_value, _flags_order)
def edns_from_text(text):
"""Convert a space-separated list of EDNS flag text values into a EDNS
flags value.
@rtype: int"""
return _from_text(text, _edns_by_text)
def edns_to_text(flags):
"""Convert an EDNS flags value into a space-separated list of EDNS flag
text values.
@rtype: string"""
return _to_text(flags, _edns_by_value, _edns_flags_order)
# encoding: utf-8
"""DNS flag constants."""
from __future__ import unicode_literals
import operator
from marrow.util.bunch import Bunch
__all__ = []
class Flags(Bunch):
def __init__(self, *args, **kw):
super(Flags, self).__init__(*args, **kw)
self.__dict__['inverse'] = dict(zip(self.itervalues(), self.iterkeys()))
def __setitem__(self, name, value):
raise RuntimeError("Flags can not be defined at runtime.")
__setattr__ = __setitem__
def encode(self, text):
return reduce(operator.__or__, (self[hunk.strip().upper()] for hunk in text.split()), 0)
def decode(self, value):
return ' '.join(v for k, v in reversed(sorted(zip(self.itervalues(), self.iterkeys()))) if value & k)
standard = Flags(
QR = 0x8000,
AA = 0x0400,
TC = 0x0200,
RD = 0x0100,
RA = 0x0080,
AD = 0x0020,
CD = 0x0010
)
edns = Flags(
DO = 0x8000
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment