Skip to content

Instantly share code, notes, and snippets.

@tcuthbert
Last active May 25, 2016 16:30
Show Gist options
  • Save tcuthbert/3adcddb56f7f81f45a2408e9591633b2 to your computer and use it in GitHub Desktop.
Save tcuthbert/3adcddb56f7f81f45a2408e9591633b2 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
"""
The heart and soul of Trigger, NetDevices is an abstract interface to network
device metadata and ACL associations.
Parses :setting:`NETDEVICES_SOURCE` and makes available a dictionary of
`~trigger.netdevices.NetDevice` objects, which is keyed by the FQDN of every
network device.
Other interfaces are non-public.
Example::
>>> from trigger.netdevices import NetDevices
>>> nd = NetDevices()
>>> dev = nd['test1-abc.net.aol.com']
>>> dev.vendor, dev.make
(<Vendor: Juniper>, 'MX960-BASE-AC')
>>> dev.bounce.next_ok('green')
datetime.datetime(2010, 4, 9, 9, 0, tzinfo=<UTC>)
"""
__author__ = 'Jathan McCollum, Eileen Tschetter, Mark Thomas, Michael Shields'
__maintainer__ = 'Jathan McCollum'
__email__ = '[email protected]'
__copyright__ = 'Copyright 2006-2013, AOL Inc.; 2013 Salesforce.com'
__version__ = '2.3.2'
# Imports
import copy
import itertools
import os
import re
import sys
import time
from twisted.python import log
from twisted.internet.protocol import Factory
from twisted.internet import reactor
from twisted.internet import defer
from trigger.conf import settings
from trigger.utils import network, parse_node_port
from trigger.utils.url import parse_url
from trigger.twister2 import generate_endpoint, TriggerEndpointClientFactory, IoslikeSendExpect
from trigger import changemgmt, exceptions, rancid
from UserDict import DictMixin
from crochet import setup, run_in_reactor, wait_for
import xml.etree.cElementTree as ET
from . import loader
try:
from trigger.acl.db import AclsDB
except ImportError:
log.msg("ACLs database could not be loaded; Loading without ACL support")
settings.WITH_ACLS = False
# Constants
JUNIPER_COMMIT = ET.Element('commit-configuration')
JUNIPER_COMMIT_FULL = copy.copy(JUNIPER_COMMIT)
ET.SubElement(JUNIPER_COMMIT_FULL, 'full')
# Exports
__all__ = ['device_match', 'NetDevice', 'NetDevices', 'Vendor']
# Functions
def _munge_source_data(data_source=settings.NETDEVICES_SOURCE):
"""
Read the source data in the specified format, parse it, and return a
:param data_source:
Absolute path to source data file
"""
log.msg('LOADING FROM: ', data_source)
kwargs = parse_url(data_source)
path = kwargs.pop('path')
return loader.load_metadata(path, **kwargs)
def _populate(netdevices, data_source, production_only, with_acls):
"""
Populates the NetDevices with NetDevice objects.
Abstracted from within NetDevices to prevent accidental repopulation of NetDevice
objects.
"""
#start = time.time()
device_data = _munge_source_data(data_source=data_source)
# Populate AclsDB if `with_acls` is set
if with_acls:
log.msg("NetDevices ACL associations: ENABLED")
aclsdb = AclsDB()
else:
log.msg("NetDevices ACL associations: DISABLED")
aclsdb = None
# Populate `netdevices` dictionary with `NetDevice` objects!
for obj in device_data:
dev = NetDevice(data=obj, with_acls=aclsdb)
# Only return devices with adminStatus of 'PRODUCTION' unless
# `production_only` is True
if dev.adminStatus.upper() != 'PRODUCTION' and production_only:
log.msg(
'[%s] Skipping: adminStatus not PRODUCTION' % dev.nodeName
)
continue
# These checks should be done on generation of netdevices.xml.
# Skip empty nodenames
if dev.nodeName is None:
continue
# Add to dict
netdevices[dev.nodeName] = dev
#end = time.time()
#print 'Took %f seconds' % (end - start)
def device_match(name, production_only=True):
"""
Return a matching :class:`~trigger.netdevices.NetDevice` object based on
partial name. Return `None` if no match or if multiple matches is
cancelled::
>>> device_match('test')
2 possible matches found for 'test':
[ 1] test1-abc.net.aol.com
[ 2] test2-abc.net.aol.com
[ 0] Exit
Enter a device number: 2
<NetDevice: test2-abc.net.aol.com>
If there is only a single match, that device object is returned without
a prompt::
>>> device_match('fw')
Matched 'fw1-xyz.net.aol.com'.
<NetDevice: fw1-xyz.net.aol.com>
"""
match = None
nd = NetDevices(production_only)
try:
match = nd.find(name)
except KeyError:
matches = nd.search(name)
if matches:
if len(matches) == 1:
single = matches[0]
print "Matched '%s'." % single
return single
print "%d possible matches found for '%s':" % (len(matches), name)
matches.sort()
for num, shortname in enumerate(matches):
print ' [%s] %s' % (str(num+1).rjust(2), shortname)
print ' [ 0] Exit\n'
choice = input('Enter a device number: ') - 1
match = None if choice < 0 else matches[choice]
log.msg('Choice: %s' % choice)
log.msg('You chose: %s' % match)
else:
print "No matches for '%s'." % name
return match
# Classes
class NetDevice(object):
"""
An object that represents a distinct network device and its metadata.
Almost all of the attributes are populated by
`~trigger.netdevices._populate()` and are mostly dependent upon the source
data. This is prone to implementation problems and should be revisited in
the long-run as there are certain fields that are baked into the core
functionality of Trigger.
Users usually won't create these objects directly! Rely instead upon
`~trigger.netdevice.NetDevices` to do this for you.
"""
def __init__(self, data=None, with_acls=None):
# Here comes all of the bare minimum set of attributes a NetDevice
# object needs for basic functionality within the existing suite.
# Hostname
self.nodeName = None
self.nodePort = None
# Hardware Info
self.deviceType = None
self.make = None
self.manufacturer = settings.FALLBACK_MANUFACTURER
self.vendor = None
self.model = None
self.serialNumber = None
# Administrivia
self.adminStatus = settings.DEFAULT_ADMIN_STATUS
self.assetID = None
self.budgetCode = None
self.budgetName = None
self.enablePW = None
self.owningTeam = None
self.owner = None
self.onCallName = None
self.operationStatus = None
self.lastUpdate = None
self.lifecycleStatus = None
self.projectName = None
# Location
self.site = None
self.room = None
self.coordinate = None
# If `data` has been passed, use it to update our attributes
if data is not None:
self._populate_data(data)
# Set node remote port based on "hostname:port" as nodeName
self._set_node_port()
# Cleanup the attributes (strip whitespace, lowercase values, etc.)
self._cleanup_attributes()
# Map the manufacturer name to a Vendor object that has extra sauce
if self.manufacturer is not None:
self.vendor = vendor_factory(self.manufacturer)
# Use the vendor to populate the deviceType if it's not set already
if self.deviceType is None:
self._populate_deviceType()
# ACLs (defaults to empty sets)
self.explicit_acls = self.implicit_acls = self.acls = self.bulk_acls = set()
if with_acls:
log.msg('[%s] Populating ACLs' % self.nodeName)
self._populate_acls(aclsdb=with_acls)
# Bind the correct execute/connect methods based on deviceType
self._bind_dynamic_methods()
# Set the correct command(s) to run on startup based on deviceType
self.startup_commands = self._set_startup_commands()
# Assign the configuration commit commands (e.g. 'write memory')
self.commit_commands = self._set_commit_commands()
# Determine whether we require an async pty SSH channel
self.requires_async_pty = self._set_requires_async_pty()
# Set the correct line-ending per vendor
self.delimiter = self._set_delimiter()
# Set initial endpoint state
self._connected = False
self._endpoint = None
self.results = Results2()
def _populate_data(self, data):
"""
Populate the custom attribute data
:param data:
An iterable of key/value pairs
"""
self.__dict__.update(data) # Better hope this is a dict!
def _cleanup_attributes(self):
"""Perform various cleanup actions. Abstracted for customization."""
# Lowercase the nodeName for completeness.
if self.nodeName is not None:
self.nodeName = self.nodeName.lower()
if self.deviceType is not None:
self.deviceType = self.deviceType.upper()
# Make sure the password is bytes not unicode
if self.enablePW is not None:
self.enablePW = str(self.enablePW)
# Cleanup whitespace from owning team
if self.owningTeam is not None:
self.owningTeam = self.owningTeam.strip()
# Map deviceStatus to adminStatus when data source is RANCID
if hasattr(self, 'deviceStatus'):
STATUS_MAP = {
'up': 'PRODUCTION',
'down': 'NON-PRODUCTION',
}
self.adminStatus = STATUS_MAP.get(self.deviceStatus, STATUS_MAP['up'])
def _set_node_port(self):
"""Set the freakin' TCP port"""
# If nodename is set, try to parse out a nodePort
if self.nodeName is not None:
nodeport_info = parse_node_port(self.nodeName)
nodeName, nodePort = nodeport_info
# If the nodeName differs, use it to replace the one we parsed
if nodeName != self.nodeName:
self.nodeName = nodeName
# If the port isn't set, set it
if nodePort is not None:
self.nodePort = nodePort
return None
# Make sure the port is an integer if it's not None
if self.nodePort is not None and isinstance(self.nodePort, basestring):
self.nodePort = int(self.nodePort)
def _populate_deviceType(self):
"""Try to make a guess what the device type is"""
self.deviceType = settings.DEFAULT_TYPES.get(self.vendor.name,
settings.FALLBACK_TYPE)
def _set_requires_async_pty(self):
"""
Set whether a device requires an async pty (see:
`~trigger.twister.TriggerSSHAsyncPtyChannel`).
"""
RULES = (
self.vendor in ('a10', 'arista', 'aruba', 'cisco', 'force10'),
self.is_brocade_vdx(),
)
return any(RULES)
def _set_delimiter(self):
"""
Set the delimiter to use for line-endings.
"""
default = '\n'
delimiter_map = {
'force10': '\r\n',
}
delimiter = delimiter_map.get(self.vendor.name, default)
return delimiter
def _set_startup_commands(self):
"""
Set the commands to run at startup. For now they are just ones to
disable pagination.
"""
def disable_paging_brocade():
"""Brocade commands differ by platform."""
if self.is_brocade_vdx():
return ['terminal length 0']
else:
return ['skip-page-display']
def disable_paging_cisco():
"""Cisco ASA commands differ from IOS"""
if self.is_cisco_asa():
return ['terminal pager 0']
else:
return default
# Commands used to disable paging.
default = ['terminal length 0']
paging_map = {
'a10': default,
'arista': default,
'aruba': ['no paging'], # v6.2.x this is not necessary
'brocade': disable_paging_brocade(), # See comments above
'cisco': disable_paging_cisco(),
'citrix': ['set cli mode page off'],
'dell': ['terminal datadump'],
'f5': ['modify cli preference pager disabled'],
'force10': default,
'foundry': ['skip-page-display'],
'juniper': ['set cli screen-length 0'],
'mrv': ['no pause'],
'netscreen': ['set console page 0'],
'paloalto': ['set cli scripting-mode on', 'set cli pager off'],
}
cmds = paging_map.get(self.vendor.name)
if self.is_netscreen():
cmds = paging_map['netscreen']
if cmds is not None:
return cmds
return []
def _set_commit_commands(self):
"""
Return the proper "commit" command. (e.g. write mem, etc.)
"""
if self.is_ioslike():
return self._ioslike_commit()
elif self.is_netscaler() or self.is_netscreen():
return ['save config']
elif self.vendor == 'juniper':
return self._juniper_commit()
elif self.vendor == 'paloalto':
return ['commit']
elif self.vendor == 'pica8':
return ['commit']
elif self.vendor == 'mrv':
return ['save configuration flash']
elif self.vendor == 'f5':
return ['save sys config']
else:
return []
def _ioslike_commit(self):
"""
Return proper 'write memory' command for IOS-like devices.
"""
if self.is_brocade_vdx() or self.vendor == 'dell':
return ['copy running-config startup-config', 'y']
elif self.is_cisco_nexus():
return ['copy running-config startup-config']
else:
return ['write memory']
def _juniper_commit(self, fields=settings.JUNIPER_FULL_COMMIT_FIELDS):
"""
Return proper ``commit-configuration`` element for a Juniper
device.
"""
default = [JUNIPER_COMMIT]
if not fields:
return default
# Either it's a normal "commit-configuration"
for attr, val in fields.iteritems():
if not getattr(self, attr) == val:
return default
# Or it's a "commit-configuration full"
return [JUNIPER_COMMIT_FULL]
def _bind_dynamic_methods(self):
"""
Bind dynamic methods to the instance. Currently does these:
+ Dynamically bind ~trigger.twister.excute` to .execute()
+ Dynamically bind ~trigger.twister.connect` to .connect()
Note that these both rely on the value of the ``vendor`` attribute.
"""
from trigger import twister
self.execute = twister.execute.__get__(self, self.__class__)
self.connect = twister.connect.__get__(self, self.__class__)
def _populate_acls(self, aclsdb=None):
"""
Populate the associated ACLs for this device.
:param aclsdb:
An `~trigger.acl.db.AclsDB` object.
"""
if not aclsdb:
return None
acls_dict = aclsdb.get_acl_dict(self)
self.explicit_acls = acls_dict['explicit']
self.implicit_acls = acls_dict['implicit']
self.acls = acls_dict['all']
def __str__(self):
return self.nodeName
def __repr__(self):
return "<NetDevice: %s>" % self.nodeName
def __cmp__(self, other):
if self.nodeName > other.nodeName:
return 1
elif self.nodeName < other.nodeName:
return -1
else:
return 0
@property
def bounce(self):
return changemgmt.bounce(self)
@property
def shortName(self):
return self.nodeName.split('.', 1)[0]
@property
def os(self):
vendor_mapping = settings.TEXTFSM_VENDOR_MAPPINGS
try:
oss = vendor_mapping[self.vendor]
if self.operatingSystem.lower() in oss:
return "{0}_{1}".format(self.vendor, self.operatingSystem.lower())
except:
log.msg("""Unable to find template for given device.
Check to see if your netdevices object has the 'platform' key.
Otherwise template does not exist.""")
return None
def _get_endpoint(self, *args):
endpoint = generate_endpoint(self).wait()
factory = TriggerEndpointClientFactory()
factory.protocol = IoslikeSendExpect
self._factory = factory # Track this for later?
# FIXME(jathan): prompt_pattern could move back to protocol?
prompt = re.compile(settings.IOSLIKE_PROMPT_PAT)
proto = endpoint.connect(factory, prompt_pattern=prompt)
self._proto = proto # Track this for later, too.
return proto
def open(self):
def inject_net_device_into_protocol(proto):
"""Now we're only injecting connection for use later."""
self._conn = proto.transport.conn
# proto.net_device = self
# proto.startup_commands = copy.copy(self.startup_commands)
return proto
self._endpoint = self._get_endpoint()
self.d = self._endpoint.addCallback(
inject_net_device_into_protocol
)
# This should be validated somehow
self._connected = True
return True
def close(self):
def disconnect(proto):
proto.transport.loseConnection()
return proto
if self._endpoint is None:
raise ValueError("Endpoint has not been instantiated.")
self._endpoint.addCallback(
disconnect
)
self._connected = False
return
def __enter__(self):
self.open()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
def get_results(self):
self._results = []
while len(self._results) != len(self.commands):
pass
return self._results
def run_commands(self, commands):
from trigger.twister2 import TriggerSSHShellClientEndpointBase
factory = TriggerEndpointClientFactory()
factory.protocol = IoslikeSendExpect
# Here's where we're using self._connect injected on .open()
# ep = TriggerSSHShellClientEndpointBase.existingConnection(self._conn)
proto = self._proto
# prompt = re.compile(settings.IOSLIKE_PROMPT_PAT)
# proto = ep.connect(factory, prompt_pattern=prompt)
def inject_commands_into_protocol(proto):
proto.add_commands(commands)
proto.todo.append(defer.Deferred())
return proto
proto = proto.addCallback(
inject_commands_into_protocol
)
# results = Results(d, commands)
# results = Results2(d, commands)
func = self.results.add(proto, commands)
return func
@property
def connected(self):
return self._connected
def allowable(self, action, when=None):
"""
Return whether it's okay to perform the specified ``action``.
False means a bounce window conflict. For now ``'load-acl'`` is the
only valid action and moratorium status is not checked.
:param action:
The action to check.
:param when:
A datetime object.
"""
assert action == 'load-acl'
return self.bounce.status(when) == changemgmt.BounceStatus('green')
def next_ok(self, action, when=None):
"""
Return the next time at or after the specified time (default now)
that it will be ok to perform the specified action.
:param action:
The action to check.
:param when:
A datetime object.
"""
assert action == 'load-acl'
return self.bounce.next_ok(changemgmt.BounceStatus('green'), when)
def is_router(self):
"""Am I a router?"""
return self.deviceType == 'ROUTER'
def is_switch(self):
"""Am I a switch?"""
return self.deviceType == 'SWITCH'
def is_firewall(self):
"""Am I a firewall?"""
return self.deviceType == 'FIREWALL'
def is_netscaler(self):
"""Am I a NetScaler?"""
return all([self.is_switch(), self.vendor=='citrix'])
def is_pica8(self):
"""Am I a Pica8?"""
## This is only really needed because pica8
## doesn't have a global command to disable paging
## so we need to do some special magic.
return all([self.vendor=='pica8'])
def is_netscreen(self):
"""Am I a NetScreen running ScreenOS?"""
# Are we even a firewall?
if not self.is_firewall():
return False
# If vendor or make is netscreen, automatically True
make_netscreen = self.make is not None and self.make.lower() == 'netscreen'
if self.vendor == 'netscreen' or make_netscreen:
return True
# Final check: Are we made by Juniper and an SSG? This requires that
# make or model is populated and has the word 'ssg' in it. This still
# fails if it's an SSG running JunOS, but this is not an edge case we
# can easily support at this time.
is_ssg = (
(self.model is not None and 'ssg' in self.model.lower()) or
(self.make is not None and 'ssg' in self.make.lower())
)
return self.vendor == 'juniper' and is_ssg
def is_ioslike(self):
"""
Am I an IOS-like device (as determined by :setting:`IOSLIKE_VENDORS`)?
"""
return self.vendor in settings.IOSLIKE_VENDORS
def is_brocade_vdx(self):
"""
Am I a Brocade VDX switch?
This is used to account for the disparity between the Brocade FCX
switches (which behave like Foundry devices) and the Brocade VDX
switches (which behave differently from classic Foundry devices).
"""
if hasattr(self, '_is_brocade_vdx'):
return self._is_brocade_vdx
if not (self.vendor == 'brocade' and self.is_switch()):
self._is_brocade_vdx = False
return False
if self.make is not None:
self._is_brocade_vdx = 'vdx' in self.make.lower()
return self._is_brocade_vdx
def is_cisco_asa(self):
"""
Am I a Cisco ASA Firewall?
This is used to account for slight differences in the commands that
may be used between Cisco's ASA and IOS platforms. Cisco ASA is still
very IOS-like, but there are still several gotcha's between the
platforms.
Will return True if vendor is Cisco and platform is Firewall. This
is to allow operability if using .csv NetDevices and pretty safe to
assume considering ASA (was PIX) are Cisco's flagship(if not only)
Firewalls.
"""
if hasattr(self, '_is_cisco_asa'):
return self._is_cisco_asa
if not (self.vendor == 'cisco' and self.is_firewall()):
self._is_cisco_asa = False
return False
if self.make is not None:
self._is_cisco_asa = 'asa' in self.make.lower()
self._is_cisco_asa = self.vendor == 'cisco' and self.is_firewall()
return self._is_cisco_asa
def is_cisco_nexus(self):
"""
Am I a Cisco Nexus device?
"""
words = (self.make, self.model)
patterns = ('n.k', 'nexus') # Patterns to match
pairs = itertools.product(patterns, words)
for pat, word in pairs:
if word and re.search(pat, word.lower()):
return True
return False
def _ssh_enabled(self, disabled_mapping):
"""Check whether vendor/type is enabled against the given mapping."""
disabled_types = disabled_mapping.get(self.vendor.name, [])
return self.deviceType not in disabled_types
def has_ssh(self):
"""Am I even listening on SSH?"""
return network.test_ssh(self.nodeName)
def _can_ssh(self, method):
"""
Am I enabled to use SSH for the given method in Trigger settings, and
if so do I even have SSH?
:param method: One of ('pty', 'async')
"""
METHOD_MAP = {
'pty': settings.SSH_PTY_DISABLED,
'async': settings.SSH_ASYNC_DISABLED,
}
assert method in METHOD_MAP
method_enabled = self._ssh_enabled(METHOD_MAP[method])
return method_enabled and self.has_ssh()
def can_ssh_async(self):
"""Am I enabled to use SSH async?"""
return self._can_ssh('async')
def can_ssh_pty(self):
"""Am I enabled to use SSH pty?"""
return self._can_ssh('pty')
def is_reachable(self):
"""Do I respond to a ping?"""
return network.ping(self.nodeName)
def dump(self):
"""Prints details for a device."""
dev = self
print
print '\tHostname: ', dev.nodeName
print '\tOwning Org.: ', dev.owner
print '\tOwning Team: ', dev.owningTeam
print '\tOnCall Team: ', dev.onCallName
print
print '\tVendor: ', '%s (%s)' % (dev.vendor.title, dev.manufacturer)
#print '\tManufacturer: ', dev.manufacturer
print '\tMake: ', dev.make
print '\tModel: ', dev.model
print '\tType: ', dev.deviceType
print '\tLocation: ', dev.site, dev.room, dev.coordinate
print
print '\tProject: ', dev.projectName
print '\tSerial: ', dev.serialNumber
print '\tAsset Tag: ', dev.assetID
print '\tBudget Code: ', '%s (%s)' % (dev.budgetCode, dev.budgetName)
print
print '\tAdmin Status: ', dev.adminStatus
print '\tLifecycle Status: ', dev.lifecycleStatus
print '\tOperation Status: ', dev.operationStatus
print '\tLast Updated: ', dev.lastUpdate
print
class Vendor(object):
"""
Map a manufacturer name to Trigger's canonical name.
Given a manufacturer name like 'CISCO SYSTEMS', this will attempt to map it
to the canonical vendor name specified in ``settings.VENDOR_MAP``. If this
can't be done, attempt to split the name up ('CISCO, 'SYSTEMS') and see if
any of the words map. An exception is raised as a last resort.
This exposes a normalized name that can be used in the event of a
multi-word canonical name.
"""
def __init__(self, manufacturer=None):
"""
:param manufacturer:
The literal or "internal" name for a vendor that is to be mapped to
its canonical name.
"""
if manufacturer is None:
raise SyntaxError('You must specify a `manufacturer` name')
self.manufacturer = manufacturer
self.name = self.determine_vendor(manufacturer)
self.title = self.name.title()
self.prompt_pattern = self._get_prompt_pattern(self.name)
def determine_vendor(self, manufacturer):
"""Try to turn the provided vendor name into the cname."""
vendor = settings.VENDOR_MAP.get(manufacturer)
if vendor is None:
mparts = [w for w in manufacturer.lower().split()]
for word in mparts:
if word in settings.SUPPORTED_VENDORS:
vendor = word
break
else:
# Safe fallback to first word
vendor = mparts[0]
return vendor
def _get_prompt_pattern(self, vendor, prompt_patterns=None):
"""
Map the vendor name to the appropriate ``prompt_pattern`` defined in
:setting:`PROMPT_PATTERNS`.
"""
if prompt_patterns is None:
prompt_patterns = settings.PROMPT_PATTERNS
# Try to get it by vendor
pat = prompt_patterns.get(vendor)
if pat is not None:
return pat
# Try to map it by IOS-like vendors...
if vendor in settings.IOSLIKE_VENDORS:
return settings.IOSLIKE_PROMPT_PAT
# Or fall back to the default
return settings.DEFAULT_PROMPT_PAT
@property
def normalized(self):
"""Return the normalized name for the vendor."""
return self.name.replace(' ', '_').lower()
def __str__(self):
return self.name
def __repr__(self):
return '<%s: %s>' % (self.__class__.__name__, self.title)
def __eq__(self, other):
return self.name.__eq__(Vendor(str(other)).name)
def __contains__(self, other):
return self.name.__contains__(Vendor(str(other)).name)
def __hash__(self):
return hash(self.name)
def lower(self):
return self.normalized
_vendor_registry = {}
def vendor_factory(vendor_name):
"""
Given a full name of a vendor, retrieve or create the canonical
`~trigger.netdevices.Vendor` object.
Vendor instances are cached to improve startup speed.
:param vendor_name:
The vendor's full manufacturer name (e.g. 'CISCO SYSTEMS')
"""
return _vendor_registry.setdefault(vendor_name, Vendor(vendor_name))
class NetDevices(DictMixin):
"""
Returns an immutable Singleton dictionary of
`~trigger.netdevices.NetDevice` objects.
By default it will only return devices for which
``adminStatus=='PRODUCTION'``.
There are hardly any use cases where ``NON-PRODUCTION`` devices are needed,
and it can cause real bugs of two sorts:
1. trying to contact unreachable devices and reporting spurious failures,
2. hot spares with the same ``nodeName``.
You may override this by passing ``production_only=False``.
"""
_Singleton = None
class _actual(object):
"""
This is the real class that stays active upon instantiation. All
attributes are inherited by NetDevices from this object. This means you
do NOT reference ``_actual`` itself, and instead call the methods from
the parent object.
Right::
>>> nd = NetDevices()
>>> nd.search('fw')
[<NetDevice: fw1-xyz.net.aol.com>]
Wrong::
>>> nd._actual.search('fw')
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: unbound method match() must be called with _actual
instance as first argument (got str instance instead)
"""
def __init__(self, production_only, with_acls):
self._dict = {}
_populate(netdevices=self._dict,
data_source=settings.NETDEVICES_SOURCE,
production_only=production_only, with_acls=with_acls)
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, item):
return item in self._dict
def keys(self):
return self._dict.keys()
def values(self):
return self._dict.values()
def find(self, key):
"""
Return either the exact nodename, or a unique dot-delimited
prefix. For example, if there is a node 'test1-abc.net.aol.com',
then any of find('test1-abc') or find('test1-abc.net') or
find('test1-abc.net.aol.com') will match, but not find('test1').
:param string key: Hostname prefix to find.
:returns: NetDevice object
"""
key = key.lower()
if key in self:
return self[key]
matches = [x for x in self.keys() if x.startswith(key + '.')]
if matches:
return self[matches[0]]
raise KeyError(key)
def all(self):
"""Returns all NetDevice objects."""
return self.values()
def search(self, token, field='nodeName'):
"""
Returns a list of NetDevice objects where other is in
``dev.nodeName``. The getattr call in the search will allow a
``AttributeError`` from a bogus field lookup so that you
don't get an empty list thinking you performed a legit query.
For example, this::
>>> field = 'bacon'
>>> [x for x in nd.all() if 'ash' in getattr(x, field)]
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
AttributeError: 'NetDevice' object has no attribute 'bacon'
Is better than this::
>>> [x for x in nd.all() if 'ash' in getattr(x, field, '')]
[]
Because then you know that 'bacon' isn't a field you can search on.
:param string token: Token to search match on in @field
:param string field: The field to match on when searching
:returns: List of NetDevice objects
"""
# We could actually just make this call match() to make this
# case-insensitive as well. But we won't yet because of possible
# implications in outside dependencies.
#return self.match(**{field:token})
return [x for x in self.all() if token in getattr(x, field)]
def match(self, **kwargs):
"""
Attempt to match values to all keys in @kwargs by dynamically
building a list comprehension. Will throw errors if the keys don't
match legit NetDevice attributes.
Keys and values are case IN-senstitive. Matches against non-string
values will FAIL.
Example by reference::
>>> nd = NetDevices()
>>> myargs = {'onCallName':'Data Center', 'model':'FCSLB'}
>>> mydevices = nd(**myargs)
Example by keyword arguments::
>>> mydevices = nd(oncallname='data center', model='fcslb')
:returns: List of NetDevice objects
"""
all_field_names = getattr(self, '_all_field_names', {})
# Cache the field names the first time .match() is called.
if not all_field_names:
# Merge in field_names from every NetDevice
for dev in self.all():
dev_fields = ((f.lower(), f) for f in dev.__dict__)
all_field_names.update(dev_fields)
self._all_field_names = all_field_names
# An iterator so we can filtering functionally
devices = iter(self.all())
def map_attr(attr):
"""Helper function for lower-to-regular attribute mapping."""
return self._all_field_names[attr.lower()]
# Use list comp. to keep filtering out the devices.
for attr, val in kwargs.iteritems():
attr = map_attr(attr)
val = str(val).lower()
devices = [
d for d in devices if (
val in str(getattr(d, attr, '')).lower()
)
]
return devices
def get_devices_by_type(self, devtype):
"""
Returns a list of NetDevice objects with deviceType matching type.
Known deviceTypes: ['FIREWALL', 'ROUTER', 'SWITCH']
"""
return [x for x in self._dict.values() if x.deviceType == devtype]
def list_switches(self):
"""Returns a list of NetDevice objects with deviceType of SWITCH"""
return self.get_devices_by_type('SWITCH')
def list_routers(self):
"""Returns a list of NetDevice objects with deviceType of ROUTER"""
return self.get_devices_by_type('ROUTER')
def list_firewalls(self):
"""Returns a list of NetDevice objects with deviceType of FIREWALL"""
return self.get_devices_by_type('FIREWALL')
def __init__(self, production_only=True, with_acls=None):
"""
:param production_only:
Whether to require devices to have ``adminStatus=='PRODUCTION'``.
:param with_acls:
Whether to load ACL associations (requires Redis). Defaults to whatever
is specified in settings.WITH_ACLS
"""
if with_acls is None:
with_acls = settings.WITH_ACLS
classobj = self.__class__
if classobj._Singleton is None:
classobj._Singleton = classobj._actual(production_only=production_only,
with_acls=with_acls)
def __getattr__(self, attr):
return getattr(self.__class__._Singleton, attr)
def __setattr__(self, attr, value):
return setattr(self.__class__._Singleton, attr, value)
class Results2(object):
"""Container object for ND `result` objects."""
def __init__(self):
self._proto = None
self._commands = []
def _generate_new_result(self):
return Result()
def result(self, results=None, commands=None):
"""This property is meant to be overloaded as a partial function
that returns the correct dissection of self._results.
"""
try:
clen = len(commands)
rv = results[0:clen]
rv.reverse()
return (commands, rv)
except:
return None
def add(self, d, commands):
from functools import partial
from copy import copy
self._proto = d.result
self._commands = self._commands + commands
# We need to block here to give the protocol time for the first deferred.
while self._proto.done == []:
pass
# Get first results state
d = self._proto.done.pop(0)
results = copy(d.result)
# Return a partial function state at the time results was newest.
func = partial(self.result,
results=list(reversed(results)),
commands=commands)
return func
class Result(object):
"""Results object returned by persistant shell commands"""
def __init__(self, proto, commands):
self._d = d
self._commands = commands
self._result = None
@property
def results(self):
try:
if self._d.result:
self._result = self._d.result.results
return self._result
except:
pass
def __repr__(self):
return '<Result: %r>' % self.results
# class Results(object):
# """Results object returned by persistant shell commands"""
# def __init__(self, d, dd, commands):
# self._d = d
# self._dd = dd
# self._commands = commands
# self._ready = False
# self._getter = None
# @property
# def ready(self):
# try:
# # Unknown whether this is threadsafe
# # self._getter = getattr(self._d.result, 'get_results_map')
# self._getter = getattr(self._d.result, 'get_results')
# self._ready = True
# except Exception as e:
# log.msg(">>> RESULTS NOT READY YET << ")
# return self._ready
# @property
# def results(self):
# if self.ready:
# self._results = self._getter(self._commands, self._dd)
# return self._getter(self._commands)
# -*- coding: utf-8 -*-
"""
Login and basic command-line interaction support using the Twisted asynchronous
I/O framework. The Trigger Twister is just like the Mersenne Twister, except
not at all.
"""
import fcntl
import os
import re
import signal
import struct
import sys
import tty
from copy import copy
from twisted.conch.ssh import session
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.endpoints import (SSHCommandClientEndpoint,
_NewConnectionHelper,
_ExistingConnectionHelper,
_CommandTransport, TCP4ClientEndpoint,
connectProtocol)
from twisted.internet import defer, protocol, reactor, threads
from twisted.internet.task import LoopingCall
from twisted.protocols.policies import TimeoutMixin
from twisted.python import log
from trigger.conf import settings
from trigger import tacacsrc, exceptions
from trigger.twister import is_awaiting_confirmation, has_ioslike_error
from trigger import tacacsrc
from trigger.utils import hash_list
from twisted.internet import reactor
from crochet import wait_for, run_in_reactor, setup, EventLoop
setup()
@run_in_reactor
def generate_endpoint(device):
creds = tacacsrc.get_device_password(device.nodeName)
return TriggerSSHShellClientEndpointBase.newConnection(
reactor, creds.username, device, password=creds.password
)
class SSHSessionAddress(object):
def __init__(self, server, username, command):
self.server = server
self.username = username
self.command = command
class _TriggerShellChannel(SSHChannel):
name = b'session'
def __init__(self, creator, command, protocolFactory, commandConnected, incremental,
with_errors, prompt_pattern, timeout, command_interval):
SSHChannel.__init__(self)
self._creator = creator
self._protocolFactory = protocolFactory
self._command = command
self._commandConnected = commandConnected
self.incremental = incremental
self.with_errors = with_errors
self.prompt = prompt_pattern
self.timeout = timeout
self.command_interval = command_interval
self._reason = None
def openFailed(self, reason):
"""
"""
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
"""
"""
pr = session.packRequest_pty_req(os.environ['TERM'],
self._get_window_size(), '')
self.conn.sendRequest(self, 'pty-req', pr)
command = self.conn.sendRequest(
self, 'shell', '', wantReply=True)
# signal.signal(signal.SIGWINCH, self._window_resized)
command.addCallbacks(self._execSuccess, self._execFailure)
def _window_resized(self, *args):
"""Triggered when the terminal is rezied."""
win_size = self._get_window_size()
new_size = win_size[1], win_size[0], win_size[2], win_size[3]
self.conn.sendRequest(self, 'window-change',
struct.pack('!4L', *new_size))
def _get_window_size(self):
"""Measure the terminal."""
stdin_fileno = sys.stdin.fileno()
winsz = fcntl.ioctl(stdin_fileno, tty.TIOCGWINSZ, '12345678')
return struct.unpack('4H', winsz)
def _execFailure(self, reason):
"""
"""
self._commandConnected.errback(reason)
def _execSuccess(self, ignored):
"""
"""
self._protocol = self._protocolFactory.buildProtocol(
SSHSessionAddress(
self.conn.transport.transport.getPeer(),
self.conn.transport.creator.username,
self._command
))
self._bind_protocol_data()
self._protocol.makeConnection(self)
self._commandConnected.callback(self._protocol)
def _bind_protocol_data(self):
# This was a string before, now it's a NetDevice.
self._protocol.device = self.conn.transport.creator.device or None
# FIXME(jathan): Is this potentially non-thread-safe?
self._protocol.startup_commands = copy(
self._protocol.device.startup_commands
)
self._protocol.incremental = self.incremental or None
self._protocol.prompt = self.prompt or None
self._protocol.with_errors = self.with_errors or None
self._protocol.timeout = self.timeout or None
self._protocol.command_interval = self.command_interval or None
def dataReceived(self, data):
self._protocol.dataReceived(data)
# SSHChannel.dataReceived(self, data)
class _TriggerSessionTransport(_CommandTransport):
def verifyHostKey(self, hostKey, fingerprint):
hostname = self.creator.hostname
ip = self.transport.getPeer().host
self._state = b'SECURING'
return defer.succeed(1)
class _NewTriggerConnectionHelperBase(_NewConnectionHelper):
"""
Return object used for establishing an async session rather than executing
a single command.
"""
def __init__(self, reactor, device, port, username, keys, password,
agentEndpoint, knownHosts, ui):
self.reactor = reactor
self.device = device
self.hostname = device.nodeName
self.port = port
self.username = username
self.keys = keys
self.password = password
self.agentEndpoint = agentEndpoint
if knownHosts is None:
knownHosts = self._knownHosts()
self.knownHosts = knownHosts
self.ui = ui
def secureConnection(self):
protocol = _TriggerSessionTransport(self)
ready = protocol.connectionReady
sshClient = TCP4ClientEndpoint(self.reactor, self.hostname, self.port)
d = connectProtocol(sshClient, protocol)
d.addCallback(lambda ignored: ready)
return d
class TriggerEndpointClientFactory(protocol.Factory):
"""
Factory for all clients. Subclass me.
"""
def __init__(self, creds=None, init_commands=None):
self.creds = tacacsrc.validate_credentials(creds)
self.results = []
self.err = None
# Setup and run the initial commands
if init_commands is None:
init_commands = [] # We need this to be a list
self.init_commands = init_commands
log.msg('INITIAL COMMANDS: %r' % self.init_commands, debug=True)
self.initialized = False
def clientConnectionFailed(self, connector, reason):
"""Do this when the connection fails."""
log.msg('Client connection failed. Reason: %s' % reason)
self.d.errback(reason)
def clientConnectionLost(self, connector, reason):
"""Do this when the connection is lost."""
log.msg('Client connection lost. Reason: %s' % reason)
if self.err:
log.msg('Got err: %r' % self.err)
# log.err(self.err)
self.d.errback(self.err)
else:
log.msg('Got results: %r' % self.results)
self.d.callback(self.results)
def stopFactory(self):
# IF we're out of channels, shut it down!
log.msg('All done!')
def _init_commands(self, protocol):
"""
Execute any initial commands specified.
:param protocol: A Protocol instance (e.g. action) to which to write
the commands.
"""
if not self.initialized:
log.msg('Not initialized, sending init commands', debug=True)
for next_init in self.init_commands:
log.msg('Sending: %r' % next_init, debug=True)
protocol.write(next_init + '\r\n')
else:
self.initialized = True
def connection_success(self, conn, transport):
log.msg('Connection success.')
self.conn = conn
self.transport = transport
log.msg('Connection information: %s' % self.transport)
class TriggerSSHShellClientEndpointBase(SSHCommandClientEndpoint):
"""
Base class for SSH endpoints.
Subclass me when you want to create a new ssh client.
"""
@classmethod
def newConnection(cls, reactor, username, device, keys=None, password=None,
port=22, agentEndpoint=None, knownHosts=None, ui=None):
helper = _NewTriggerConnectionHelperBase(
reactor, device, port, username, keys, password, agentEndpoint,
knownHosts, ui
)
return cls(helper)
@classmethod
def existingConnection(cls, connection):
"""Overload stock existinConnection to not require ``commands``."""
helper = _ExistingConnectionHelper(connection)
return cls(helper)
def __init__(self, creator):
self._creator = creator
def _executeCommand(self, connection, protocolFactory, command, incremental,
with_errors, prompt_pattern, timeout, command_interval):
commandConnected = defer.Deferred()
def disconnectOnFailure(passthrough):
# Close the connection immediately in case of cancellation, since
# that implies user wants it gone immediately (e.g. a timeout):
immediate = passthrough.check(CancelledError)
self._creator.cleanupConnection(connection, immediate)
return passthrough
commandConnected.addErrback(disconnectOnFailure)
channel = _TriggerShellChannel(
self._creator, command, protocolFactory, commandConnected, incremental,
with_errors, prompt_pattern, timeout, command_interval)
connection.openChannel(channel)
self.connected = True
return commandConnected
def connect(self, factory, command='', incremental=None,
with_errors=None, prompt_pattern=None, timeout=0,
command_interval=1):
d = self._creator.secureConnection()
d.addCallback(self._executeCommand, factory, command, incremental,
with_errors, prompt_pattern, timeout, command_interval)
return d
class IoslikeSendExpect(protocol.Protocol, TimeoutMixin):
"""
Action for use with TriggerTelnet as a state machine.
Take a list of commands, and send them to the device until we run out or
one errors. Wait for a prompt after each.
"""
def __init__(self):
self.device = None
self.commands = []
self.commanditer = iter(self.commands)
self.connected = False
self.disconnect = False
self.initialized = False
self.locked = False
self.startup_commands = []
# FIXME(tom) This sux and should be set by trigger settings
self.timeout = 10
self.todo = []
self.done = []
def connectionMade(self):
"""Do this when we connect."""
self.connected = True
self.finished = defer.Deferred()
self.setTimeout(self.timeout)
self.results = self.factory.results = []
self.data = ''
log.msg('[%s] connectionMade, data: %r' % (self.device, self.data))
# self.factory._init_commands(self)
def connectionLost(self, reason):
self.finished.callback(None)
# Don't call _send_next, since we expect to see a prompt, which
# will kick off initialization.
def add_commands(self, commands):
while self.locked is True:
pass
self.commands = commands
self.commanditer = iter(commands)
self.locked = True
self._send_next()
return True
def dataReceived(self, bytes):
"""Do this when we get data."""
log.msg('[%s] BYTES: %r' % (self.device, bytes))
self.data += bytes
# See if the prompt matches, and if it doesn't, see if it is waiting
# for more input (like a [y/n]) prompt), and continue, otherwise return
# None
m = self.prompt.search(self.data)
if not m:
# If the prompt confirms set the index to the matched bytes,
if is_awaiting_confirmation(self.data):
log.msg('[%s] Got confirmation prompt: %r' % (self.device,
self.data))
prompt_idx = self.data.find(bytes)
else:
return None
else:
# Or just use the matched regex object...
prompt_idx = m.start()
result = self.data[:prompt_idx]
# Trim off the echoed-back command. This should *not* be necessary
# since the telnet session is in WONT ECHO. This is confirmed with
# a packet trace, and running self.transport.dont(ECHO) from
# connectionMade() returns an AlreadyDisabled error. What's up?
log.msg('[%s] result BEFORE: %r' % (self.device, result))
result = result[result.find('\n')+1:]
log.msg('[%s] result AFTER: %r' % (self.device, result))
if self.initialized:
self.results.append(result)
if has_ioslike_error(result) and not self.with_errors:
log.msg('[%s] Command failed: %r' % (self.device, result))
self.factory.err = exceptions.IoslikeCommandFailure(result)
self.transport.loseConnection()
else:
if self.command_interval:
log.msg('[%s] Waiting %s seconds before sending next command' %
(self.device, self.command_interval))
reactor.callLater(self.command_interval, self._send_next)
def _send_next(self):
"""Send the next command in the stack."""
self.data = ''
self.resetTimeout()
if not self.initialized:
log.msg('[%s] Not initialized, sending startup commands' %
self.device)
if self.startup_commands:
next_init = self.startup_commands.pop(0)
log.msg('[%s] Sending initialize command: %r' % (self.device,
next_init))
self.transport.write(next_init.strip() + self.device.delimiter)
return None
else:
log.msg('[%s] Successfully initialized for command execution' %
self.device)
self.initialized = True
if self.incremental:
self.incremental(self.results)
try:
next_command = self.commanditer.next()
except StopIteration:
log.msg('[%s] No more commands to send, moving on...' %
self.device)
if self.todo:
d = self.todo.pop(0)
d.addCallback(lambda none: self.results)
d.callback(None)
self.done.append(d)
self.locked = False
return None
if next_command is None:
self.results.append(None)
self._send_next()
else:
log.msg('[%s] Sending command %r' % (self.device, next_command))
self.transport.write(next_command + '\n')
def timeoutConnection(self):
"""Do this when we timeout."""
log.msg('[%s] Timed out while sending commands' % self.device)
self.factory.err = exceptions.CommandTimeout('Timed out while '
'sending commands')
self.transport.loseConnection()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment