Created
February 26, 2018 18:40
-
-
Save astrofrog/34cb0f84d2a57ba0ae64150dd42caf3e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from qtpy.QtCore import Qt, QSignal | |
from qtpy.QtGui import QPainter, QTransform, QPen | |
from qtpy.QtWidgets import (QGraphicsView, QGraphicsScene, QApplication, | |
QGraphicsTextItem, QGraphicsEllipseItem, | |
QGraphicsLineItem) | |
from glue.utils.qt import mpl_to_qt4_color, qt4_to_mpl_color | |
PI = 3.14256 | |
TWOPI = PI * 2 | |
COLOR_SELECTED = (0.2, 0.9, 0.2) | |
COLOR_DIRECT = (0.6, 0.9, 0.6) | |
COLOR_INDIRECT = (0.6, 0.9, 0.9) | |
COLOR_DISCONNECTED = (0.9, 0.6, 0.6) | |
def get_pen(color, linewidth=1): | |
color = mpl_to_qt4_color(color) | |
return QPen(color, linewidth, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin) | |
class Edge(QGraphicsLineItem): | |
def __init__(self, node_source, node_dest): | |
self.linewidth = 3 | |
self.node_source = node_source | |
self.node_dest = node_dest | |
super(Edge, self).__init__(0, 0, 1, 1) | |
self.setZValue(5) | |
self.color = '0.5' | |
def update_position(self): | |
x0, y0 = self.node_source.node_position | |
x1, y1 = self.node_dest.node_position | |
self.setLine(x0, y0, x1, y1) | |
@property | |
def color(self): | |
return qt4_to_mpl_color(self.pen().color()) | |
@color.setter | |
def color(self, value): | |
self.setPen(get_pen(value, self.linewidth)) | |
def add_to_scene(self, scene): | |
scene.addItem(self) | |
def contains(self, point): | |
x0 = point.x() | |
y0 = point.y() | |
x1, y1 = self.node_source.node_position | |
x2, y2 = self.node_dest.node_position | |
dot = (x0 - x1) * (x2 - x1) + (y0 - y1) * (y2 - y1) | |
vec = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) | |
frac = dot / vec | |
if frac < 0 or frac > 1: | |
return False | |
x3 = x1 + (x2 - x1) * frac | |
y3 = y1 + (y2 - y1) * frac | |
return np.hypot(x3 - x0, y3 - y0) < self.linewidth * 2 | |
class DataNode: | |
def __init__(self, data, radius=15): | |
self.data = data | |
self.radius = radius | |
# Add circular node | |
self.node = QGraphicsEllipseItem(-self.radius, -self.radius, | |
2 * self.radius, 2 * self.radius) | |
# Add text label | |
self.label = QGraphicsTextItem(data.label) | |
font = self.label.font() | |
font.setPointSize(8) | |
self.label.setFont(font) | |
# Add line between label and node | |
self.line1 = QGraphicsLineItem(0, 0, 1, 1) | |
self.line2 = QGraphicsLineItem(0, 0, 1, 1) | |
self.node.setZValue(20) | |
self.label.setZValue(10) | |
self.line1.setZValue(10) | |
self.line2.setZValue(10) | |
self.line1.setPen(get_pen('0.5')) | |
self.line2.setPen(get_pen('0.5')) | |
self.color = '0.8' | |
def contains(self, point): | |
x, y = self.node_position | |
return np.hypot(point.x() - x, point.y() - y) <= self.radius | |
def update(self): | |
self.node.update() | |
def add_to_scene(self, scene): | |
scene.addItem(self.node) | |
scene.addItem(self.label) | |
scene.addItem(self.line1) | |
scene.addItem(self.line2) | |
@property | |
def node_position(self): | |
pos = self.node.pos() | |
return pos.x(), pos.y() | |
@node_position.setter | |
def node_position(self, value): | |
self.node.setPos(value[0], value[1]) | |
self.update_lines() | |
@property | |
def label_position(self): | |
pos = self.label.pos() | |
return pos.x(), pos.y() | |
@label_position.setter | |
def label_position(self, value): | |
self.label.setPos(value[0], value[1]) | |
self.update_lines() | |
def update_lines(self): | |
x0, y0 = self.label_position | |
x2, y2 = self.node_position | |
x1 = 0.5 * (x0 + x2) | |
y1 = y0 | |
self.line1.setLine(x0, y0, x1, y1) | |
self.line2.setLine(x1, y1, x2, y2) | |
@property | |
def color(self): | |
return qt4_to_mpl_color(self.node.brush().color()) | |
@color.setter | |
def color(self, value): | |
self.node.setBrush(mpl_to_qt4_color(value)) | |
# | |
# @property | |
# def position(self): | |
# pos = self.pos() | |
# return pos.x(), pos.y() | |
# | |
# @position.setter | |
# def position(self, value): | |
# self.setPos(value[0], value[1]) | |
def get_connections(data_collection): | |
links = set() | |
for link in data_collection.links: | |
to_id = link.get_to_id() | |
for from_id in link.get_from_ids(): | |
data1 = from_id.parent | |
data2 = to_id.parent | |
if data1 is data2: | |
continue | |
if (data1, data2) not in links and (data2, data1) not in links: | |
links.add((data1, data2)) | |
return links | |
def layout_simple_circle(nodes, edges, center=None, radius=None): | |
# Place nodes around a circle | |
nodes = order_nodes_by_connections(nodes, edges) | |
for i, node in enumerate(nodes): | |
angle = 2 * np.pi * i / len(nodes) | |
nx = radius * np.cos(angle) + center[0] | |
ny = radius * np.sin(angle) + center[1] | |
node.node_position = nx, ny | |
def order_nodes_by_connections(nodes, edges): | |
search_nodes = list(nodes) | |
sorted_nodes = [] | |
while len(search_nodes) > 0: | |
lengths = [] | |
connections = [] | |
for node in search_nodes: | |
direct, indirect = find_connections(node, search_nodes, edges) | |
connections.append((indirect, direct)) | |
lengths.append((len(indirect), len(direct))) | |
m = max(lengths) | |
for i in range(len(lengths)): | |
if lengths[i] == m: | |
for node in connections[i][0] + connections[i][1]: | |
if node not in sorted_nodes: | |
sorted_nodes.append(node) | |
search_nodes = [node for node in nodes if node not in sorted_nodes] | |
return sorted_nodes | |
class DataGraphWidget(QGraphicsView): | |
selection_changed = QSignal() | |
def __init__(self, data_collection): | |
super(DataGraphWidget, self).__init__() | |
# Set up scene | |
self.scene = QGraphicsScene(self) | |
self.scene.setItemIndexMethod(QGraphicsScene.NoIndex) | |
self.scene.setSceneRect(0, 0, 800, 400) | |
self.setScene(self.scene) | |
# Get data and initialize nodes | |
self.nodes = dict((data, DataNode(data)) for data in data_collection) | |
# Get links and set up edges | |
self.edges = [Edge(self.nodes[data1], self.nodes[data2]) | |
for data1, data2 in get_connections(data_collection)] | |
# Figure out positions | |
layout_simple_circle(self.nodes.values(), self.edges, | |
center=(400, 200), radius=120) | |
# Update edge positions | |
for edge in self.edges: | |
edge.update_position() | |
# Set up labels | |
self.left_nodes = [node for node in self.nodes.values() if node.node_position[0] < 400] | |
self.left_nodes = sorted(self.left_nodes, key=lambda x: x.node_position[1], reverse=True) | |
self.right_nodes = [node for node in self.nodes.values() if node.node_position[0] > 400] | |
self.right_nodes = sorted(self.right_nodes, key=lambda x: x.node_position[1], reverse=True) | |
for i, node in enumerate(self.left_nodes): | |
y = 400 - (i + 1) / (len(self.left_nodes) + 1) * 400 | |
node.label_position = 200, y | |
for i, node in enumerate(self.right_nodes): | |
y = 400 - (i + 1) / (len(self.right_nodes) + 1) * 400 | |
node.label_position = 600, y | |
# Add nodes and edges to graph | |
for node in self.nodes.values(): | |
node.add_to_scene(self.scene) | |
for edge in self.edges: | |
edge.add_to_scene(self.scene) | |
self.setMinimumSize(800, 400) | |
self.setWindowTitle("Glue data graph") | |
self.setRenderHint(QPainter.Antialiasing) | |
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse) | |
self.setResizeAnchor(QGraphicsView.AnchorViewCenter) | |
self.text_adjusted = False | |
self.selected_edge = None | |
self.selected_node1 = None | |
self.selected_node2 = None | |
def paintEvent(self, event): | |
super(DataGraphWidget, self).paintEvent(event) | |
if not self.text_adjusted: | |
for node in self.nodes.values(): | |
width = node.label.boundingRect().width() | |
height = node.label.boundingRect().height() | |
transform = QTransform() | |
if node in self.left_nodes: | |
transform.translate(-width, -height / 2) | |
else: | |
transform.translate(0, -height / 2) | |
node.label.setTransform(transform) | |
self.text_adjusted = True | |
def find_object(self, event): | |
for obj in list(self.nodes.values()) + self.edges: | |
if obj.contains(event.localPos()): | |
return obj | |
def mouseMoveEvent(self, event): | |
# TODO: Don't update until the end | |
# TODO: Only select object on top | |
if self.selected_node1 is not None or self.selected_node2 is not None: | |
return | |
selected = self.find_object(event) | |
colors = {} | |
if isinstance(selected, DataNode): | |
colors[selected] = COLOR_DIRECT | |
direct, indirect = find_connections(selected, self.nodes.values(), self.edges) | |
for node in self.nodes.values(): | |
if node in direct: | |
colors[node] = COLOR_DIRECT | |
elif node in indirect: | |
colors[node] = COLOR_INDIRECT | |
else: | |
colors[node] = COLOR_DISCONNECTED | |
elif isinstance(selected, Edge): | |
colors[selected] = COLOR_DIRECT | |
colors[selected.node_source] = COLOR_DIRECT | |
colors[selected.node_dest] = COLOR_DIRECT | |
self.set_colors(colors) | |
def mousePressEvent(self, event): | |
# TODO: Don't update until the end | |
# TODO: Only select object on top | |
selected = self.find_object(event) | |
if isinstance(selected, DataNode): | |
if self.selected_node1 is None: | |
self.selected_node1 = selected | |
elif self.selected_node1 is selected: | |
self.selected_node1 = None | |
elif self.selected_node2 is selected: | |
self.selected_node2 = None | |
else: | |
self.selected_node2 = selected | |
for edge in self.edges: | |
if (edge.node_source is self.selected_node1 and edge.node_dest is self.selected_node2 or | |
edge.node_source is self.selected_node2 and edge.node_dest is self.selected_node1): | |
self.selected_edge = edge | |
break | |
else: | |
self.selected_edge = None | |
elif isinstance(selected, Edge): | |
if self.selected_edge is selected: | |
self.selected_edge = None | |
self.selected_node1 = None | |
self.selected_node2 = None | |
else: | |
self.selected_edge = selected | |
self.selected_node1 = selected.node_source | |
self.selected_node2 = selected.node_dest | |
colors = {} | |
if self.selected_edge is not None: | |
colors[self.selected_edge] = COLOR_SELECTED | |
if self.selected_node1 is not None: | |
colors[self.selected_node1] = COLOR_SELECTED | |
if self.selected_node2 is not None: | |
colors[self.selected_node2] = COLOR_SELECTED | |
self.set_colors(colors) | |
self.mouseMoveEvent(event) | |
def set_colors(self, colors): | |
for obj in list(self.nodes.values()) + self.edges: | |
default_color = '0.8' if isinstance(obj, DataNode) else '0.5' | |
obj.color = colors.get(obj, default_color) | |
obj.update() | |
def find_connections(node, nodes, edges): | |
direct = [node] | |
indirect = [] | |
current = direct | |
connected = [node] | |
changed = True | |
while changed: | |
changed = False | |
for edge in edges: | |
source = edge.node_source | |
dest = edge.node_dest | |
if source in connected and dest not in connected: | |
current.append(dest) | |
changed = True | |
if dest in connected and source not in connected: | |
current.append(source) | |
changed = True | |
current = indirect | |
connected.extend(current) | |
return direct, indirect | |
if __name__ == '__main__': | |
import sys | |
app = QApplication(sys.argv) | |
app.setAttribute(Qt.AA_UseHighDpiPixmaps) | |
from glue.core.state import load | |
dc = load('links.glu') | |
widget = DataGraphWidget(dc) | |
widget.show() | |
sys.exit(app.exec_()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment