Skip to content

Instantly share code, notes, and snippets.

@davebshow
Created November 2, 2014 19:49
Show Gist options
  • Save davebshow/6ac8fa577cc31e3b5d79 to your computer and use it in GitHub Desktop.
Save davebshow/6ac8fa577cc31e3b5d79 to your computer and use it in GitHub Desktop.
# -*- coding: utf-8 -*-
from itertools import chain
import networkx as nx
class ProjX(object):
"""Prototype implementation of a query/schema
manipulation language for NetworkX."""
def __init__(self, graph, type_attr='type'):
for node in graph.nodes():
graph.node[node]['visited'] = []
self.graph = graph
self.type = type_attr
def _clear(self, nbunch):
"""Used to clear the visited attribute on node.
Needed to save space."""
for node in nbunch:
self.graph.node[node]['visited'] = []
def execute(self, query):
"""Main api."""
verb, pred = query.split()
if verb.lower() == 'match':
output = self.match(pred)
elif verb.lower() == 'transfer':
output = self.transfer(pred)
elif verb.lower() == 'project':
output = self.project(pred)
else:
raise SyntaxError('Expected query to begin '
'with "MATCH", "TRANSFER", or "PROJECT".')
# Get rid of the visited property used in traversal.
return output
def match(self, pred):
""""""
paths = self._match(pred)
# Return a graph comprised by the matched nodes.
return self.build_subraph(paths)
def _match(self, pred):
"""Executes pred after MATCH."""
# Get mode sequence and start node type.
statements = pred.split('-')
start_type = statements[0]
# Store the results of the upcoming traversals.
path_list = []
for node, attrs in self.graph.nodes(data=True):
if attrs[self.type] == start_type:
# Traverse the graph using the mode sequence
# as a criteria for a valid path.
paths = self.traverse(node, statements[1:])
path_list.append(paths)
paths = list(chain.from_iterable(path_list))
return paths
def project(self, pred):
"""Project across a mode sequence."""
paths = self._match(pred)
edges = [(p[0], p[-1]) for p in paths]
graph = nx.Graph(edges)
# Set the attributes. I wish there were a better
# way to do this???
for node in graph.nodes():
graph.node[node] = self.graph.node[node]
return graph
def transfer(self, pred):
"""Transfer the node attributes of a mode onto the
connected nodes of another mode through
a valid mode sequence."""
# Copy the graph here -- just remove
# the source nodes after the transfers are
# finished.
graph = self.graph.copy()
paths = self._match(pred)
for path in paths:
# Node type to be transfered.
transfer_source = path[0]
# Node type to recieve transfered
# node attributes.
transfer_target = path[-1]
attrs = self.graph.node[transfer_source]
tp = attrs[self.type]
# Allow for attributes "slugs" to
# be created during transfer for nodes that
# take on attributes from muliplt transfered nodes.
att_counter = 1
# Transfer the attributes to target nodes.
for k, v in attrs.items():
if k not in [self.type, 'visited']:
attname = '{0}_{1}'.format(tp.lower(), k)
if (attname in graph.node[transfer_target] and
graph.node[transfer_target].get(attname, '') != v):
attname = '{0}{1}'.format(attname, att_counter)
att_counter += 1
graph.node[transfer_target][attname] = v
if transfer_source in graph:
graph.remove_node(transfer_source)
return graph
def traverse(self, start, modes):
"""This is a controlled depth, depth first traversal
of a NetworkX graph. Criteria for searching depends on
a start node and a sequence of modes (node types).
From the start node, the traversal will visit nodes
that meet the mode sequence. It does not allow cycles
or backtracking along the same path. The whole class
runs off this function. Could be very memory inefficient
in very dense graph with 3 + mode queries."""
# Initialize a stack to keep
# track of traversal progress.
stack = [start]
# Store all valid paths based on
# mode sequence.
paths = []
# Keep track of visited nodes, later
# the visited list will be cleared
visited = set()
# The traversal will begin
# at the designated start point.
current = start
# Keep track depth from start node
# to watch for successful sequence match.
depth = 0
# This is the len of a successful sequence.
max_depth = len(modes)
# When the stack runs out, all candidate
# nodes have been visited.
while len(stack) > 0:
# Traverse!
if depth < max_depth:
nbrs = set(self.graph[current])
for nbr in nbrs:
attrs = self.graph.node[nbr]
# Here check candidate node validity.
# Make sure this path hasn't been checked already.
# Make sure it matches the mode sequence.
# Make sure it's not backtracking on same path.
if (current not in attrs['visited'] and
attrs[self.type] == modes[depth] and
nbr not in stack):
self.graph.node[nbr]['visited'].append(current)
visited.update([nbr])
# Continue traversal at next depth.
current = nbr
stack.append(current)
depth += 1
break
# If no valid nodes are available from
# this position, backtrack.
else:
stack.pop()
if len(stack) > 0:
current = stack[-1]
depth -= 1
# If max depth reached, store the
# valid node sequence.
else:
paths.append(list(stack))
# Backtrack and keep checking.
stack.pop()
current = stack[-1]
depth -= 1
# Clear the visited attribute to prepare
# for next start node to begin traversal.
self._clear(visited)
return paths
def build_subraph(self, paths):
"""Just takes the paths returned by
_match and builds a graph."""
g = nx.Graph()
for path in paths:
combined_paths = _combine_paths(path)
g.add_edges_from(combined_paths)
for node in g.nodes():
g.node[node] = self.graph.node[node]
return g
def _combine_paths(path):
"""Turn path list into edge list."""
edges = []
for i, node in enumerate(path[1:]):
edges.append((path[i], node))
return edges
def test_graph():
g = nx.Graph([(1, 2), (1, 3), (1, 4), (1, 5), (2, 5),
(3, 9), (4, 9), (5, 6), (6, 7), (7, 8),
(10, 5), (11, 5)])
g.node[1] = {'type': 'Person', 'name': 'davebshow'}
g.node[2] = {'type':'Institution', 'name': 'western'}
g.node[3] = {'type':'Publication', 'name': 'LLC'}
g.node[4] = {'type':'Publication', 'name': 'socialnet'}
g.node[5] = {'type':'City', 'name': 'london'}
g.node[6] = {'type':'Institution', 'name': 'stats canada'}
g.node[7] = {'type':'City', 'name': 'toronto'}
g.node[8] = {'type':'Person', 'name': 'adam'}
g.node[9] = {'type':'City', 'name': 'new york'}
g.node[10] = {'type':'Person', 'name': 'javi'}
g.node[11] = {'type':'Person', 'name': 'chong'}
return g
def labels(g):
labels_dict = {}
for node, attrs in g.nodes(data=True):
label = ''
for k, v in attrs.items():
if k != 'visited':
label += '{0}: {1}\n'.format(k, v)
labels_dict[node] = label
return labels_dict
def colors(g):
colors_dict = {}
colors = []
counter = 1
for node, attrs in g.nodes(data=True):
if attrs['type'] not in colors_dict:
colors_dict[attrs['type']] = float(counter)
colors.append(float(counter))
counter += 1
else:
colors.append(colors_dict[attrs['type']])
return colors
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment