Last active
August 29, 2015 14:00
-
-
Save gregswift/11401891 to your computer and use it in GitHub Desktop.
Standalone implementation of enhancement to ansible to add wait_for state=drained
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import socket | |
import datetime | |
import time | |
import sys | |
import re | |
import binascii | |
HAS_PSUTIL = False | |
try: | |
import psutil | |
HAS_PSUTIL = True | |
# just because we can import it on Linux doesn't mean we will use it | |
except ImportError: | |
pass | |
def load_platform_subclass(cls, *args, **kwargs): | |
''' | |
used by modules like User to have different implementations based on detected platform. See User | |
module for an example. | |
''' | |
this_platform = get_platform() | |
distribution = get_distribution() | |
subclass = None | |
# get the most specific superclass for this platform | |
if distribution is not None: | |
for sc in cls.__subclasses__(): | |
if sc.distribution is not None and sc.distribution == distribution and sc.platform == this_platform: | |
subclass = sc | |
if subclass is None: | |
for sc in cls.__subclasses__(): | |
if sc.platform == this_platform and sc.distribution is None: | |
subclass = sc | |
if subclass is None: | |
subclass = cls | |
return super(cls, subclass).__new__(subclass) | |
class TCPConnectionInfo(object): | |
""" | |
This is a generic TCP Connection Info strategy class that relies | |
on the psutil module, which is not ideal for targets, but necessary | |
for cross platform support. | |
A subclass may wish to override some or all of these methods. | |
- get_exclude_ips() | |
- get_active_connections() | |
All subclasses MUST define platform and distribution (which may be None). | |
""" | |
platform = 'Generic' | |
distribution = None | |
match_all_ips = { | |
socket.AF_INET: '0.0.0.0', | |
socket.AF_INET6: '::', | |
} | |
connection_states = { | |
'01': 'ESTABLISHED', | |
'02': 'SYN_SENT', | |
'03': 'SYN_RECV', | |
'04': 'FIN_WAIT1', | |
'05': 'FIN_WAIT2', | |
'06': 'TIME_WAIT', | |
} | |
def __new__(cls, *args, **kwargs): | |
return load_platform_subclass(TCPConnectionInfo, args, kwargs) | |
def __init__(self, module): | |
self.module = module | |
self.name = module.params['name'] | |
(self.family, self.host) = _convert_host_to_ip(self.module.params['host']) | |
self.port = int(self.module.params['port']) | |
self.exclude_ips = self._get_exclude_ips() | |
if not HAS_PSUTIL: | |
module.fail_json(msg="psutil module required for wait_for") | |
def _get_exclude_ips(self): | |
exclude_hosts = self.module.params['exclude_hosts'].split(',')) | |
return [ _convert_host_to_hex(h) for h in exclude_hosts ] | |
def get_active_connections_count(self): | |
active_connections = 0 | |
for p in psutil.process_iter(): | |
connections = p.get_connections(kind='inet') | |
for conn in connections: | |
if conn.status not in self.connection_states.values(): | |
continue | |
(local_ip, local_port) = conn.local_address | |
if self.port == local_port and ip in [self.match_all_ips[self.family], local_ip]: | |
(remote_ip, remote_port) = conn.remote_address | |
if remote_hex_ip not in self.exclude_ips: | |
active_connections += 1 | |
return active_connections | |
class LinuxTCPConnectionInfo(TCPConnectionInfo): | |
""" | |
This is a TCP Connection Info evaluation strategy class | |
that utilizes information from Linux's procfs. While less universal, | |
does allow Linux targets to not require an additional library. | |
""" | |
platform = 'Linux' | |
distribution = None | |
source_file = { | |
socket.AF_INET: '/proc/net/tcp', | |
socket.AF_INET6: '/proc/net/tcp6' | |
} | |
match_all_ips = { | |
socket.AF_INET: '00000000', | |
socket.AF_INET6: '00000000000000000000000000000000', | |
} | |
local_address_field = 1 | |
remote_address_field = 2 | |
connection_state_field = 3 | |
def __init__(self, module): | |
self.module = module | |
self.name = module.params['name'] | |
(self.family, self.host) = _convert_host_to_hex(module.params['host']) | |
self.port = "%0.4X" % int(module.params['port']) | |
self.exclude_ips = self._get_exclude_ips() | |
def _get_exclude_ips(self): | |
exclude_hosts = self.module.params['exclude_hosts'].split(',')) | |
return [ _convert_host_to_hex(h) for h in exclude_hosts ] | |
def get_active_connections_count(self, family): | |
active_connections = 0 | |
try: | |
f = open(self.source_file[family]) | |
except IOError: | |
pass | |
else: | |
for tcp_connection in f.readlines(): | |
tcp_connection = tcp_connection.strip().split(' ') | |
if tcp_connection[self.local_address_field] == 'local_address': | |
continue | |
if tcp_connection[self.connection_state_field] not in self.connection_states: | |
continue | |
(local_ip, local_port) = tcp_connection[self.local_address_field].split(':') | |
if self.port == local_port and ip in [self.match_all_ips[self.family], local_ip]: | |
(remote_ip, remote_port) = tcp_connection[self.remote_address_field].split(':') | |
if remote_hex_ip not in self.exclude_ips: | |
active_connections += 1 | |
f.close() | |
return active_connections | |
def _convert_host_to_ip(host): | |
""" | |
Perform forward DNS resolution on host, IP will give the same IP | |
Args: | |
host: String with either hostname, IPv4, or IPv6 address | |
Returns: | |
Tuple containing address family and IP | |
""" | |
addrinfo = socket.getaddrinfo(host, 80, 0, 0, socket.SOL_TCP)[0] | |
return (addrinfo[0], addrinfo[4][0]) | |
def _convert_host_to_hex(host): | |
""" | |
Convert the provided host to the format in /proc/net/tcp* | |
/proc/net/tcp uses little-endian four byte hex for ipv4 | |
/proc/net/tcp6 uses little-endian per 4B word for ipv6 | |
Args: | |
host: String with either hostname, IPv4, or IPv6 address | |
Returns: | |
Tuple containing address family and the little-endian converted host | |
""" | |
(family, ip) = _convert_host_to_ip(host) | |
hexed = binascii.hexlify(socket.inet_pton(family, ip).upper() | |
if family == socket.AF_INET: | |
hexed = _little_endian_convert_32bit(hexed) | |
elif family == socket.AF_INET6: | |
# xrange loops through each 8 character (4B) set in the 128bit total | |
hexed = "".join([ _little_endian_convert_32bit(hexed[x:x+8]) for x in xrange(0, 32, 8) ]) | |
return (family, hexed) | |
def _little_endian_convert_32bit(block): | |
""" | |
Convert to little-endian, effectively transposing | |
the order of the four byte word | |
12345678 -> 78563412 | |
Args: | |
block: String containing a 4 byte hex representation | |
Returns: | |
String containing the little-endian converted block | |
""" | |
# xrange starts at 6, and increments by -2 until it reaches -2 | |
# which lets us start at the end of the string block and work to the begining | |
return "".join([ block[x:x+2] for x in xrange(6, -2, -2) ]) | |
def wait_for_drain(host, port, exclude_hosts=None, timeout=30): | |
""" | |
@host can be any resolveable DNS or an IP | |
@port numerical representation only | |
@exclude_hosts array of host entries, entries same as @host | |
@timeout how long to wait for the drain | |
returns count of active connections | |
""" | |
start = datetime.datetime.now() | |
end = start + datetime.timedelta(seconds=timeout) | |
tcpconns = TCPConnectionInfo(module) | |
while datetime.datetime.now() < end: | |
active_connections = tcpconns.get_active_connections_count() | |
if active_connections == 0: | |
break | |
time.sleep(1) | |
else: | |
print "Timeout when waiting for %s:%s to drain" % (host, port) | |
return active_count | |
if __name__ == '__main__': | |
print "Found {0} connections to localhost:22".format(wait_for_drain('127.0.0.1',22)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment