Skip to content

Instantly share code, notes, and snippets.

@astrofrog
Created February 26, 2018 18:40
Show Gist options
  • Save astrofrog/34cb0f84d2a57ba0ae64150dd42caf3e to your computer and use it in GitHub Desktop.
Save astrofrog/34cb0f84d2a57ba0ae64150dd42caf3e to your computer and use it in GitHub Desktop.
import numpy as np
from qtpy.QtCore import Qt, QSignal
from qtpy.QtGui import QPainter, QTransform, QPen
from qtpy.QtWidgets import (QGraphicsView, QGraphicsScene, QApplication,
QGraphicsTextItem, QGraphicsEllipseItem,
QGraphicsLineItem)
from glue.utils.qt import mpl_to_qt4_color, qt4_to_mpl_color
PI = 3.14256
TWOPI = PI * 2
COLOR_SELECTED = (0.2, 0.9, 0.2)
COLOR_DIRECT = (0.6, 0.9, 0.6)
COLOR_INDIRECT = (0.6, 0.9, 0.9)
COLOR_DISCONNECTED = (0.9, 0.6, 0.6)
def get_pen(color, linewidth=1):
color = mpl_to_qt4_color(color)
return QPen(color, linewidth, Qt.SolidLine, Qt.RoundCap, Qt.RoundJoin)
class Edge(QGraphicsLineItem):
def __init__(self, node_source, node_dest):
self.linewidth = 3
self.node_source = node_source
self.node_dest = node_dest
super(Edge, self).__init__(0, 0, 1, 1)
self.setZValue(5)
self.color = '0.5'
def update_position(self):
x0, y0 = self.node_source.node_position
x1, y1 = self.node_dest.node_position
self.setLine(x0, y0, x1, y1)
@property
def color(self):
return qt4_to_mpl_color(self.pen().color())
@color.setter
def color(self, value):
self.setPen(get_pen(value, self.linewidth))
def add_to_scene(self, scene):
scene.addItem(self)
def contains(self, point):
x0 = point.x()
y0 = point.y()
x1, y1 = self.node_source.node_position
x2, y2 = self.node_dest.node_position
dot = (x0 - x1) * (x2 - x1) + (y0 - y1) * (y2 - y1)
vec = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1)
frac = dot / vec
if frac < 0 or frac > 1:
return False
x3 = x1 + (x2 - x1) * frac
y3 = y1 + (y2 - y1) * frac
return np.hypot(x3 - x0, y3 - y0) < self.linewidth * 2
class DataNode:
def __init__(self, data, radius=15):
self.data = data
self.radius = radius
# Add circular node
self.node = QGraphicsEllipseItem(-self.radius, -self.radius,
2 * self.radius, 2 * self.radius)
# Add text label
self.label = QGraphicsTextItem(data.label)
font = self.label.font()
font.setPointSize(8)
self.label.setFont(font)
# Add line between label and node
self.line1 = QGraphicsLineItem(0, 0, 1, 1)
self.line2 = QGraphicsLineItem(0, 0, 1, 1)
self.node.setZValue(20)
self.label.setZValue(10)
self.line1.setZValue(10)
self.line2.setZValue(10)
self.line1.setPen(get_pen('0.5'))
self.line2.setPen(get_pen('0.5'))
self.color = '0.8'
def contains(self, point):
x, y = self.node_position
return np.hypot(point.x() - x, point.y() - y) <= self.radius
def update(self):
self.node.update()
def add_to_scene(self, scene):
scene.addItem(self.node)
scene.addItem(self.label)
scene.addItem(self.line1)
scene.addItem(self.line2)
@property
def node_position(self):
pos = self.node.pos()
return pos.x(), pos.y()
@node_position.setter
def node_position(self, value):
self.node.setPos(value[0], value[1])
self.update_lines()
@property
def label_position(self):
pos = self.label.pos()
return pos.x(), pos.y()
@label_position.setter
def label_position(self, value):
self.label.setPos(value[0], value[1])
self.update_lines()
def update_lines(self):
x0, y0 = self.label_position
x2, y2 = self.node_position
x1 = 0.5 * (x0 + x2)
y1 = y0
self.line1.setLine(x0, y0, x1, y1)
self.line2.setLine(x1, y1, x2, y2)
@property
def color(self):
return qt4_to_mpl_color(self.node.brush().color())
@color.setter
def color(self, value):
self.node.setBrush(mpl_to_qt4_color(value))
#
# @property
# def position(self):
# pos = self.pos()
# return pos.x(), pos.y()
#
# @position.setter
# def position(self, value):
# self.setPos(value[0], value[1])
def get_connections(data_collection):
links = set()
for link in data_collection.links:
to_id = link.get_to_id()
for from_id in link.get_from_ids():
data1 = from_id.parent
data2 = to_id.parent
if data1 is data2:
continue
if (data1, data2) not in links and (data2, data1) not in links:
links.add((data1, data2))
return links
def layout_simple_circle(nodes, edges, center=None, radius=None):
# Place nodes around a circle
nodes = order_nodes_by_connections(nodes, edges)
for i, node in enumerate(nodes):
angle = 2 * np.pi * i / len(nodes)
nx = radius * np.cos(angle) + center[0]
ny = radius * np.sin(angle) + center[1]
node.node_position = nx, ny
def order_nodes_by_connections(nodes, edges):
search_nodes = list(nodes)
sorted_nodes = []
while len(search_nodes) > 0:
lengths = []
connections = []
for node in search_nodes:
direct, indirect = find_connections(node, search_nodes, edges)
connections.append((indirect, direct))
lengths.append((len(indirect), len(direct)))
m = max(lengths)
for i in range(len(lengths)):
if lengths[i] == m:
for node in connections[i][0] + connections[i][1]:
if node not in sorted_nodes:
sorted_nodes.append(node)
search_nodes = [node for node in nodes if node not in sorted_nodes]
return sorted_nodes
class DataGraphWidget(QGraphicsView):
selection_changed = QSignal()
def __init__(self, data_collection):
super(DataGraphWidget, self).__init__()
# Set up scene
self.scene = QGraphicsScene(self)
self.scene.setItemIndexMethod(QGraphicsScene.NoIndex)
self.scene.setSceneRect(0, 0, 800, 400)
self.setScene(self.scene)
# Get data and initialize nodes
self.nodes = dict((data, DataNode(data)) for data in data_collection)
# Get links and set up edges
self.edges = [Edge(self.nodes[data1], self.nodes[data2])
for data1, data2 in get_connections(data_collection)]
# Figure out positions
layout_simple_circle(self.nodes.values(), self.edges,
center=(400, 200), radius=120)
# Update edge positions
for edge in self.edges:
edge.update_position()
# Set up labels
self.left_nodes = [node for node in self.nodes.values() if node.node_position[0] < 400]
self.left_nodes = sorted(self.left_nodes, key=lambda x: x.node_position[1], reverse=True)
self.right_nodes = [node for node in self.nodes.values() if node.node_position[0] > 400]
self.right_nodes = sorted(self.right_nodes, key=lambda x: x.node_position[1], reverse=True)
for i, node in enumerate(self.left_nodes):
y = 400 - (i + 1) / (len(self.left_nodes) + 1) * 400
node.label_position = 200, y
for i, node in enumerate(self.right_nodes):
y = 400 - (i + 1) / (len(self.right_nodes) + 1) * 400
node.label_position = 600, y
# Add nodes and edges to graph
for node in self.nodes.values():
node.add_to_scene(self.scene)
for edge in self.edges:
edge.add_to_scene(self.scene)
self.setMinimumSize(800, 400)
self.setWindowTitle("Glue data graph")
self.setRenderHint(QPainter.Antialiasing)
self.setTransformationAnchor(QGraphicsView.AnchorUnderMouse)
self.setResizeAnchor(QGraphicsView.AnchorViewCenter)
self.text_adjusted = False
self.selected_edge = None
self.selected_node1 = None
self.selected_node2 = None
def paintEvent(self, event):
super(DataGraphWidget, self).paintEvent(event)
if not self.text_adjusted:
for node in self.nodes.values():
width = node.label.boundingRect().width()
height = node.label.boundingRect().height()
transform = QTransform()
if node in self.left_nodes:
transform.translate(-width, -height / 2)
else:
transform.translate(0, -height / 2)
node.label.setTransform(transform)
self.text_adjusted = True
def find_object(self, event):
for obj in list(self.nodes.values()) + self.edges:
if obj.contains(event.localPos()):
return obj
def mouseMoveEvent(self, event):
# TODO: Don't update until the end
# TODO: Only select object on top
if self.selected_node1 is not None or self.selected_node2 is not None:
return
selected = self.find_object(event)
colors = {}
if isinstance(selected, DataNode):
colors[selected] = COLOR_DIRECT
direct, indirect = find_connections(selected, self.nodes.values(), self.edges)
for node in self.nodes.values():
if node in direct:
colors[node] = COLOR_DIRECT
elif node in indirect:
colors[node] = COLOR_INDIRECT
else:
colors[node] = COLOR_DISCONNECTED
elif isinstance(selected, Edge):
colors[selected] = COLOR_DIRECT
colors[selected.node_source] = COLOR_DIRECT
colors[selected.node_dest] = COLOR_DIRECT
self.set_colors(colors)
def mousePressEvent(self, event):
# TODO: Don't update until the end
# TODO: Only select object on top
selected = self.find_object(event)
if isinstance(selected, DataNode):
if self.selected_node1 is None:
self.selected_node1 = selected
elif self.selected_node1 is selected:
self.selected_node1 = None
elif self.selected_node2 is selected:
self.selected_node2 = None
else:
self.selected_node2 = selected
for edge in self.edges:
if (edge.node_source is self.selected_node1 and edge.node_dest is self.selected_node2 or
edge.node_source is self.selected_node2 and edge.node_dest is self.selected_node1):
self.selected_edge = edge
break
else:
self.selected_edge = None
elif isinstance(selected, Edge):
if self.selected_edge is selected:
self.selected_edge = None
self.selected_node1 = None
self.selected_node2 = None
else:
self.selected_edge = selected
self.selected_node1 = selected.node_source
self.selected_node2 = selected.node_dest
colors = {}
if self.selected_edge is not None:
colors[self.selected_edge] = COLOR_SELECTED
if self.selected_node1 is not None:
colors[self.selected_node1] = COLOR_SELECTED
if self.selected_node2 is not None:
colors[self.selected_node2] = COLOR_SELECTED
self.set_colors(colors)
self.mouseMoveEvent(event)
def set_colors(self, colors):
for obj in list(self.nodes.values()) + self.edges:
default_color = '0.8' if isinstance(obj, DataNode) else '0.5'
obj.color = colors.get(obj, default_color)
obj.update()
def find_connections(node, nodes, edges):
direct = [node]
indirect = []
current = direct
connected = [node]
changed = True
while changed:
changed = False
for edge in edges:
source = edge.node_source
dest = edge.node_dest
if source in connected and dest not in connected:
current.append(dest)
changed = True
if dest in connected and source not in connected:
current.append(source)
changed = True
current = indirect
connected.extend(current)
return direct, indirect
if __name__ == '__main__':
import sys
app = QApplication(sys.argv)
app.setAttribute(Qt.AA_UseHighDpiPixmaps)
from glue.core.state import load
dc = load('links.glu')
widget = DataGraphWidget(dc)
widget.show()
sys.exit(app.exec_())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment