Created
November 2, 2014 19:49
-
-
Save davebshow/6ac8fa577cc31e3b5d79 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
from itertools import chain | |
import networkx as nx | |
class ProjX(object): | |
"""Prototype implementation of a query/schema | |
manipulation language for NetworkX.""" | |
def __init__(self, graph, type_attr='type'): | |
for node in graph.nodes(): | |
graph.node[node]['visited'] = [] | |
self.graph = graph | |
self.type = type_attr | |
def _clear(self, nbunch): | |
"""Used to clear the visited attribute on node. | |
Needed to save space.""" | |
for node in nbunch: | |
self.graph.node[node]['visited'] = [] | |
def execute(self, query): | |
"""Main api.""" | |
verb, pred = query.split() | |
if verb.lower() == 'match': | |
output = self.match(pred) | |
elif verb.lower() == 'transfer': | |
output = self.transfer(pred) | |
elif verb.lower() == 'project': | |
output = self.project(pred) | |
else: | |
raise SyntaxError('Expected query to begin ' | |
'with "MATCH", "TRANSFER", or "PROJECT".') | |
# Get rid of the visited property used in traversal. | |
return output | |
def match(self, pred): | |
"""""" | |
paths = self._match(pred) | |
# Return a graph comprised by the matched nodes. | |
return self.build_subraph(paths) | |
def _match(self, pred): | |
"""Executes pred after MATCH.""" | |
# Get mode sequence and start node type. | |
statements = pred.split('-') | |
start_type = statements[0] | |
# Store the results of the upcoming traversals. | |
path_list = [] | |
for node, attrs in self.graph.nodes(data=True): | |
if attrs[self.type] == start_type: | |
# Traverse the graph using the mode sequence | |
# as a criteria for a valid path. | |
paths = self.traverse(node, statements[1:]) | |
path_list.append(paths) | |
paths = list(chain.from_iterable(path_list)) | |
return paths | |
def project(self, pred): | |
"""Project across a mode sequence.""" | |
paths = self._match(pred) | |
edges = [(p[0], p[-1]) for p in paths] | |
graph = nx.Graph(edges) | |
# Set the attributes. I wish there were a better | |
# way to do this??? | |
for node in graph.nodes(): | |
graph.node[node] = self.graph.node[node] | |
return graph | |
def transfer(self, pred): | |
"""Transfer the node attributes of a mode onto the | |
connected nodes of another mode through | |
a valid mode sequence.""" | |
# Copy the graph here -- just remove | |
# the source nodes after the transfers are | |
# finished. | |
graph = self.graph.copy() | |
paths = self._match(pred) | |
for path in paths: | |
# Node type to be transfered. | |
transfer_source = path[0] | |
# Node type to recieve transfered | |
# node attributes. | |
transfer_target = path[-1] | |
attrs = self.graph.node[transfer_source] | |
tp = attrs[self.type] | |
# Allow for attributes "slugs" to | |
# be created during transfer for nodes that | |
# take on attributes from muliplt transfered nodes. | |
att_counter = 1 | |
# Transfer the attributes to target nodes. | |
for k, v in attrs.items(): | |
if k not in [self.type, 'visited']: | |
attname = '{0}_{1}'.format(tp.lower(), k) | |
if (attname in graph.node[transfer_target] and | |
graph.node[transfer_target].get(attname, '') != v): | |
attname = '{0}{1}'.format(attname, att_counter) | |
att_counter += 1 | |
graph.node[transfer_target][attname] = v | |
if transfer_source in graph: | |
graph.remove_node(transfer_source) | |
return graph | |
def traverse(self, start, modes): | |
"""This is a controlled depth, depth first traversal | |
of a NetworkX graph. Criteria for searching depends on | |
a start node and a sequence of modes (node types). | |
From the start node, the traversal will visit nodes | |
that meet the mode sequence. It does not allow cycles | |
or backtracking along the same path. The whole class | |
runs off this function. Could be very memory inefficient | |
in very dense graph with 3 + mode queries.""" | |
# Initialize a stack to keep | |
# track of traversal progress. | |
stack = [start] | |
# Store all valid paths based on | |
# mode sequence. | |
paths = [] | |
# Keep track of visited nodes, later | |
# the visited list will be cleared | |
visited = set() | |
# The traversal will begin | |
# at the designated start point. | |
current = start | |
# Keep track depth from start node | |
# to watch for successful sequence match. | |
depth = 0 | |
# This is the len of a successful sequence. | |
max_depth = len(modes) | |
# When the stack runs out, all candidate | |
# nodes have been visited. | |
while len(stack) > 0: | |
# Traverse! | |
if depth < max_depth: | |
nbrs = set(self.graph[current]) | |
for nbr in nbrs: | |
attrs = self.graph.node[nbr] | |
# Here check candidate node validity. | |
# Make sure this path hasn't been checked already. | |
# Make sure it matches the mode sequence. | |
# Make sure it's not backtracking on same path. | |
if (current not in attrs['visited'] and | |
attrs[self.type] == modes[depth] and | |
nbr not in stack): | |
self.graph.node[nbr]['visited'].append(current) | |
visited.update([nbr]) | |
# Continue traversal at next depth. | |
current = nbr | |
stack.append(current) | |
depth += 1 | |
break | |
# If no valid nodes are available from | |
# this position, backtrack. | |
else: | |
stack.pop() | |
if len(stack) > 0: | |
current = stack[-1] | |
depth -= 1 | |
# If max depth reached, store the | |
# valid node sequence. | |
else: | |
paths.append(list(stack)) | |
# Backtrack and keep checking. | |
stack.pop() | |
current = stack[-1] | |
depth -= 1 | |
# Clear the visited attribute to prepare | |
# for next start node to begin traversal. | |
self._clear(visited) | |
return paths | |
def build_subraph(self, paths): | |
"""Just takes the paths returned by | |
_match and builds a graph.""" | |
g = nx.Graph() | |
for path in paths: | |
combined_paths = _combine_paths(path) | |
g.add_edges_from(combined_paths) | |
for node in g.nodes(): | |
g.node[node] = self.graph.node[node] | |
return g | |
def _combine_paths(path): | |
"""Turn path list into edge list.""" | |
edges = [] | |
for i, node in enumerate(path[1:]): | |
edges.append((path[i], node)) | |
return edges | |
def test_graph(): | |
g = nx.Graph([(1, 2), (1, 3), (1, 4), (1, 5), (2, 5), | |
(3, 9), (4, 9), (5, 6), (6, 7), (7, 8), | |
(10, 5), (11, 5)]) | |
g.node[1] = {'type': 'Person', 'name': 'davebshow'} | |
g.node[2] = {'type':'Institution', 'name': 'western'} | |
g.node[3] = {'type':'Publication', 'name': 'LLC'} | |
g.node[4] = {'type':'Publication', 'name': 'socialnet'} | |
g.node[5] = {'type':'City', 'name': 'london'} | |
g.node[6] = {'type':'Institution', 'name': 'stats canada'} | |
g.node[7] = {'type':'City', 'name': 'toronto'} | |
g.node[8] = {'type':'Person', 'name': 'adam'} | |
g.node[9] = {'type':'City', 'name': 'new york'} | |
g.node[10] = {'type':'Person', 'name': 'javi'} | |
g.node[11] = {'type':'Person', 'name': 'chong'} | |
return g | |
def labels(g): | |
labels_dict = {} | |
for node, attrs in g.nodes(data=True): | |
label = '' | |
for k, v in attrs.items(): | |
if k != 'visited': | |
label += '{0}: {1}\n'.format(k, v) | |
labels_dict[node] = label | |
return labels_dict | |
def colors(g): | |
colors_dict = {} | |
colors = [] | |
counter = 1 | |
for node, attrs in g.nodes(data=True): | |
if attrs['type'] not in colors_dict: | |
colors_dict[attrs['type']] = float(counter) | |
colors.append(float(counter)) | |
counter += 1 | |
else: | |
colors.append(colors_dict[attrs['type']]) | |
return colors |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment