Created
July 26, 2022 17:44
-
-
Save ctralie/979b420570dc6ea65dee2ab9f8a49705 to your computer and use it in GitHub Desktop.
Merge trees on time series with simplification
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Programmer: Chris Tralie | |
Purpose: To provide a basic ordered merge tree class for interval and circular domains, | |
along with methods to construct the merge tree from a time series, to plot it and | |
its associated persistence diagram, and to simplify the merge trees by persistence | |
""" | |
import numpy as np | |
import matplotlib.pyplot as plt | |
def plot_diagrams( | |
diagrams, | |
plot_only=None, | |
title=None, | |
xy_range=None, | |
labels=None, | |
markers=None, | |
sizes=None, | |
colors=None, | |
colormap="default", | |
ax_color=np.array([0.0, 0.0, 0.0]), | |
diagonal=True, | |
lifetime=False, | |
equal=True, | |
legend=True, | |
show=False, | |
ax=None | |
): | |
"""A helper function to plot persistence diagrams. | |
Parameters | |
---------- | |
diagrams: ndarray (n_pairs, 2) or list of diagrams | |
A diagram or list of diagrams. If diagram is a list of diagrams, | |
then plot all on the same plot using different colors. | |
plot_only: list of numeric | |
If specified, an array of only the diagrams that should be plotted. | |
title: string, default is None | |
If title is defined, add it as title of the plot. | |
xy_range: list of numeric [xmin, xmax, ymin, ymax] | |
User provided range of axes. This is useful for comparing | |
multiple persistence diagrams. | |
labels: string or list of strings | |
Legend labels for each diagram. | |
If none are specified, we use H_0, H_1, H_2,... by default. | |
markers: string or list of strings | |
Markers for each diagram | |
If none are specified, we use dots by default. | |
sizes: int or list of ints | |
Sizes of each marker | |
If none are specified, use 20 by default | |
colors: string or list of strings | |
Colors for each diagram | |
If none are specified, use the default sequence from matplotlib | |
colormap: string, default is 'default' | |
Any of matplotlib color palettes. | |
Some options are 'default', 'seaborn', 'sequential'. | |
See all available styles with | |
.. code:: python | |
import matplotlib as mpl | |
print(mpl.styles.available) | |
ax_color: any valid matplotlib color type. | |
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API. | |
diagonal: bool, default is True | |
Plot the diagonal x=y line. | |
lifetime: bool, default is False. If True, diagonal is turned to False. | |
Plot life time of each point instead of birth and death. | |
Essentially, visualize (x, y-x). | |
equal: bool, default is True. If True, plot axes equal | |
legend: bool, default is True | |
If true, show the legend. | |
show: bool, default is False | |
Call plt.show() after plotting. If you are using self.plot() as part | |
of a subplot, set show=False and call plt.show() only once at the end. | |
""" | |
ax = ax or plt.gca() | |
plt.style.use(colormap) | |
xlabel, ylabel = "Birth", "Death" | |
if not isinstance(diagrams, list): | |
# Must have diagrams as a list for processing downstream | |
diagrams = [diagrams] | |
if labels is None: | |
# Provide default labels for diagrams if using self.dgm_ | |
labels = ["$H_{{{}}}$".format(i) for i , _ in enumerate(diagrams)] | |
if markers is None: | |
markers = ["o"]*len(diagrams) | |
if sizes is None: | |
sizes = [20]*len(diagrams) | |
if colors is None: | |
colors = ["C{}".format(i) for i in range(len(diagrams))] | |
if plot_only: | |
diagrams = [diagrams[i] for i in plot_only] | |
labels = [labels[i] for i in plot_only] | |
if not isinstance(labels, list): | |
labels = [labels] * len(diagrams) | |
if not isinstance(markers, list): | |
markers = [markers]*len(diagrams) | |
if not isinstance(sizes, list): | |
sizes = [sizes]*len(diagrams) | |
if not isinstance(colors, list): | |
colors = [colors]*len(diagrams) | |
# Construct copy with proper type of each diagram | |
# so we can freely edit them. | |
diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams] | |
# find min and max of all visible diagrams | |
concat_dgms = np.concatenate(diagrams).flatten() | |
has_inf = np.any(np.isinf(concat_dgms)) | |
finite_dgms = concat_dgms[np.isfinite(concat_dgms)] | |
# clever bounding boxes of the diagram | |
if not xy_range: | |
# define bounds of diagram | |
ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms) | |
x_r = ax_max - ax_min | |
# Give plot a nice buffer on all sides. | |
# ax_range=0 when only one point, | |
buffer = 1 if xy_range == 0 else x_r / 5 | |
x_down = ax_min - buffer / 2 | |
x_up = ax_max + buffer | |
y_down, y_up = x_down, x_up | |
else: | |
x_down, x_up, y_down, y_up = xy_range | |
yr = y_up - y_down | |
if lifetime: | |
# Don't plot landscape and diagonal at the same time. | |
diagonal = False | |
# reset y axis so it doesn't go much below zero | |
y_down = -yr * 0.05 | |
y_up = y_down + yr | |
# set custom ylabel | |
ylabel = "Lifetime" | |
# set diagrams to be (x, y-x) | |
for dgm in diagrams: | |
dgm[:, 1] -= dgm[:, 0] | |
# plot horizon line | |
ax.plot([x_down, x_up], [0, 0], c=ax_color) | |
# Plot diagonal | |
if diagonal: | |
ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color) | |
# Plot inf line | |
if has_inf: | |
# put inf line slightly below top | |
b_inf = y_down + yr * 0.95 | |
ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$") | |
# convert each inf in each diagram with b_inf | |
for dgm in diagrams: | |
dgm[np.isinf(dgm)] = b_inf | |
# Plot each diagram | |
for dgm, label, marker, size, color in zip(diagrams, labels, markers, sizes, colors): | |
# plot persistence pairs | |
ax.scatter(dgm[:, 0], dgm[:, 1], size, c=color, label=label, marker=marker) | |
ax.set_xlabel(xlabel) | |
ax.set_ylabel(ylabel) | |
ax.set_xlim([x_down, x_up]) | |
ax.set_ylim([y_down, y_up]) | |
if equal: | |
ax.set_aspect('equal', 'box') | |
if title is not None: | |
ax.set_title(title) | |
if legend is True: | |
ax.legend(loc="lower right") | |
if show is True: | |
plt.show() | |
def poly_fit(X, xs, do_plot = False): | |
""" | |
Given a Nx2 array X of 2D coordinates, fit an N^th order polynomial | |
and evaluate it at the coordinates in xs. | |
This function assumes that all of the points have a unique X position | |
""" | |
x = X[:, 0] | |
y = X[:, 1] | |
N = X.shape[0] | |
A = np.zeros((N, N)) | |
for i in range(N): | |
A[:, i] = x**i | |
AInv = np.linalg.inv(A) | |
b = AInv.dot(y[:, None]) | |
M = xs.size | |
Y = np.zeros((M, 2)) | |
Y[:, 0] = xs | |
for i in range(N): | |
Y[:, 1] += b[i]*(xs**i) | |
if do_plot: | |
plt.plot(Y[:, 0], Y[:, 1], 'b') | |
plt.hold(True) | |
plt.scatter(X[:, 0], X[:, 1], 20, 'r') | |
plt.show() | |
return Y | |
def draw_curve(X, Y, linewidth): | |
""" | |
Draw a parabolic curve between two 2D points | |
Parameters | |
---------- | |
X: list of [x, y] | |
First point | |
Y: list of [x, y] | |
Second point | |
linewidth: int | |
Thickness of line | |
""" | |
if Y[1] < X[1]: | |
X, Y = Y, X | |
[x1, y1, x3, y3] = [X[0], X[1], Y[0], Y[1]] | |
x2 = 0.5*x1 + 0.5*x3 | |
y2 = 0.25*y1 + 0.75*y3 | |
xs = np.linspace(x1, x3, 50) | |
X = np.array([[x1, y1], [x2, y2], [x3, y3]]) | |
Y = poly_fit(X, xs, do_plot=False) | |
plt.plot(Y[:, 0], Y[:, 1], 'k', linewidth=linewidth) | |
class MergeNode(object): | |
def __init__(self, y, x=None): | |
""" | |
Parameters | |
---------- | |
y: float | |
Height of node | |
x: float | |
x position of node (optional) | |
""" | |
self.children = [] | |
self.x = x | |
self.y = y | |
self.idx = -1 # Inorder index | |
self.birth_death = [] | |
self.is_globalmin = False | |
def get_coords(self, use_inorder): | |
""" | |
Return a list of the [x, y] coordinates of this node | |
Parameters | |
---------- | |
use_inorder: boolean | |
If True, use the inorder coordinate for x. If false, | |
use a prespecified x coordinate if it exists | |
""" | |
coords = np.array([self.idx, self.y]) | |
if not use_inorder: | |
if self.x or self.x == 0: | |
coords[0] = self.x | |
return coords | |
def inorder(self, idx): | |
""" | |
Perform a generalized inorder traversal | |
NOTE: This will sort child nodes arbitrarily if | |
their x coordinates have not been specified | |
Parameters | |
idx: list[1] | |
A count, by reference | |
""" | |
for child in sorted(self.children+[self], key=lambda c: c.x): | |
if self == child: | |
self.idx = idx[0] | |
idx[0] += 1 | |
else: | |
child.inorder(idx) | |
def get_rep_timeseries(self, xs, ys, signs): | |
""" | |
Create a piecewise linear function that is | |
obtained from an inorder traversal of the y | |
coordinates of the nodes in this tree | |
Parameters | |
---------- | |
xs: list of float | |
X coordinates of time series that I'm building | |
ys: list of float | |
Time series that I'm building | |
signs: list of [-1, 1] | |
A parallel list indicating local min (-1) or local max (+1) | |
""" | |
if len(self.children) == 0: | |
xs.append(self.x) | |
ys.append(self.y) | |
signs.append(-1) | |
for i, child in enumerate(sorted(self.children, key=lambda c: c.x)): | |
child.get_rep_timeseries(xs, ys, signs) | |
if i < len(self.children)-1: | |
# Put the max in between every adjacent pair of children | |
xs.append(self.x) | |
ys.append(self.y) | |
signs.append(1) | |
def persistence_simplify(self, eps): | |
""" | |
Remove all leaves that are under a certain persistence threshold | |
Parameters | |
---------- | |
eps: Persistence threshold | |
""" | |
survived = True | |
if not self.is_globalmin and len(self.birth_death) == 2: # Leaf node | |
if self.birth_death[1] - self.birth_death[0] < eps: | |
survived = False | |
elif len(self.children) > 0: | |
self.children = [c for c in self.children if c.persistence_simplify(eps)] | |
if len(self.children) == 0: | |
survived = False | |
return survived | |
def delete_singletons(self): | |
""" | |
Delete nodes with a single child | |
""" | |
ret = self | |
if len(self.children) == 1: | |
ret = self.children[0].delete_singletons() | |
else: | |
for i, c in enumerate(self.children): | |
self.children[i] = c.delete_singletons() | |
return ret | |
def plot(self, use_inorder, params): | |
""" | |
Recursive helper method for plotting | |
Parameters | |
---------- | |
use_inorder: boolean | |
If True, use the inorder coordinate for x. If false, | |
use a prespecified x coordinate if it exists | |
params: dict { | |
offset: [x, y]: Offset by which to plot this | |
draw_curved: boolean: If true, draw parabolic curved lines between nodes | |
linewidth: int: How thick to draw the edges | |
pointsize: int: How big to draw the nodes | |
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes | |
} | |
""" | |
offset = np.array([0, 0]) if not 'offset' in params else params['offset'] | |
draw_curved = True if not 'draw_curved' in params else params['draw_curved'] | |
linewidth = 3 if not 'linewidth' in params else params['linewidth'] | |
pointsize = 200 if not 'pointsize' in params else params['pointsize'] | |
plot_birthdeath = False if not 'plot_birthdeath' in params else params['plot_birthdeath'] | |
X = np.array([self.x, self.y]) | |
X = self.get_coords(use_inorder) + offset | |
plt.scatter(X[0], X[1], pointsize, 'k') | |
if len(self.birth_death) > 0 and plot_birthdeath: | |
plt.text(X[0], X[1], "{:.2f}, {:.2f}".format(*self.birth_death), c='r') | |
for child in self.children: | |
Y = child.get_coords(use_inorder) + offset | |
if draw_curved: | |
draw_curve(X, Y, linewidth) | |
else: | |
plt.plot([X[0], Y[0]], [X[1], Y[1]], 'k', lineWidth=linewidth) | |
child.plot(use_inorder, params) | |
def unionfind_root(pointers, u): | |
""" | |
Union find root operation with path-compression | |
Parameters | |
---------- | |
pointers: list | |
A list of pointers to representative nodes | |
u: int | |
Index of the node to find | |
Returns | |
------- | |
Index of the representative of the component of u | |
""" | |
if not (pointers[u] == u): | |
pointers[u] = unionfind_root(pointers, pointers[u]) | |
return pointers[u] | |
def unionfind_union(pointers, u, v, idxorder): | |
""" | |
Union find "union" with early birth-based merging | |
(similar to rank-based merging...not sure if exactly the | |
same theoretical running time) | |
Parameters | |
---------- | |
pointers: list | |
A list of pointers to representative nodes | |
u: int | |
Index of first node | |
v: int | |
Index of the second node | |
idxorder: list | |
List of order in which each point shows up | |
""" | |
u = unionfind_root(pointers, u) | |
v = unionfind_root(pointers, v) | |
if u != v: | |
[ufirst, usecond] = [u, v] | |
if idxorder[v] < idxorder[u]: | |
[ufirst, usecond] = [v, u] | |
pointers[usecond] = ufirst | |
class MergeTree(object): | |
def __init__(self, x=np.array([])): | |
""" | |
Construct a new merge tree | |
Parameters | |
---------- | |
x: ndarray(N) | |
Time series with which to initialize a merge tree. | |
If left blank, initialize an empty merge tree. | |
""" | |
if x.size > 0: | |
self.init_from_timeseries(x) | |
else: | |
self.root = None | |
self.PD = np.array([[]]) | |
self.PDIdx = np.array([[]], dtype=int) | |
def get_rep_timeseries(self): | |
""" | |
Return a piecewise linear function that is | |
obtained from an inorder traversal of the y | |
coordinates of the nodes in this tree, as well as | |
a parallel array that indicates whether the points | |
are mins or maxes | |
Returns | |
------- | |
{ | |
xs: ndarray(N): Coordinates of time series | |
ys: ndarray(N): Time series representing piecewise linear function, | |
with as many samples as there are nodes in the tree, | |
signs: ndarray(N): A parallel array of signs | |
} | |
""" | |
ys = [] | |
xs = [] | |
signs = [] | |
if self.root: | |
self.root.get_rep_timeseries(xs, ys, signs) | |
return dict(xs=np.array(xs), ys=np.array(ys), signs=np.array(signs)) | |
def persistence_simplify(self, eps): | |
""" | |
Remove all leaves that are under a certain persistence threshold | |
Parameters | |
---------- | |
eps: Persistence threshold | |
""" | |
if self.root: | |
self.root.persistence_simplify(eps) | |
self.root.delete_singletons() | |
def plot(self, use_inorder, params={}): | |
""" | |
Draw this tree | |
Parameters | |
---------- | |
use_inorder: boolean | |
If True, use the inorder coordinate for x. If false, | |
use a prespecified x coordinate if it exists | |
params: dict { | |
offset: [x, y]: Offset by which to plot this | |
draw_curved: boolean: If true, draw parabolic curved lines between nodes | |
linewidth: int: How thick to draw the edges | |
pointsize: int: How big to draw the nodes | |
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes | |
} | |
""" | |
if self.root: | |
if use_inorder: | |
idx = [0] | |
self.root.inorder(idx) | |
self.root.plot(use_inorder, params) | |
def plot_with_pd(self, use_inorder, params={}): | |
""" | |
Draw this tree alongslide its persistence diagram | |
Parameters | |
---------- | |
use_inorder: boolean | |
If True, use the inorder coordinate for x. If false, | |
use a prespecified x coordinate if it exists | |
params: dict { | |
offset: [x, y]: Offset by which to plot this | |
draw_curved: boolean: If true, draw parabolic curved lines between nodes | |
linewidth: int: How thick to draw the edges | |
pointsize: int: How big to draw the nodes | |
plot_birthdeath: boolean: Whether to plot (birth, death) at leaf nodes | |
use_grid: boolean: Whether to draw grid lines | |
show_merge_xticks: Whether to show the x ticks for the merge tree | |
} | |
""" | |
if self.root: | |
use_grid = False if not 'use_grid' in params else params['use_grid'] | |
show_merge_xticks = False if not 'show_merge_xticks' in params else params['show_merge_xticks'] | |
yvals = np.sort(np.unique(self.get_rep_timeseries()['ys'])) | |
dy = yvals[-1] - yvals[0] | |
plt.subplot(121) | |
self.plot(use_inorder, params) | |
plt.gca().set_yticks(yvals) | |
plt.ylim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy) | |
if not show_merge_xticks: | |
plt.gca().set_xticks([]) | |
if use_grid: | |
plt.grid() | |
plt.subplot(122) | |
plot_diagrams([self.PD]) | |
plt.gca().set_yticks(np.unique(self.PD[:, 1])) | |
plt.ylim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy) | |
plt.gca().set_xticks(np.unique(self.PD[:, 0])) | |
plt.xlim(yvals[0]-0.1*dy, yvals[-1]+0.1*dy) | |
if use_grid: | |
plt.grid() | |
def init_from_timeseries(self, y, include_essential=False, circular=False): | |
""" | |
Uses union find to make a merge tree object from the time series x | |
(NOTE: This code is pretty general and could work to create merge trees | |
on any domain if the neighbor set was updated) | |
Parameters | |
---------- | |
y: ndarray(N) | |
1D array representing the time series | |
include_essential: bool | |
Whether to include the essential class | |
circular: boolean | |
Whether to assume that the domain wraps around circularly | |
Returns | |
------- | |
I: ndarray(N, 2) | |
H0 persistence diagram for this merge tree (also store locally | |
as a side effect) | |
""" | |
#Add points from the bottom up | |
N = len(y) | |
idx = np.argsort(y) | |
idxorder = np.zeros(N) | |
idxorder[idx] = np.arange(N) | |
pointers = np.arange(N) #Pointer to oldest indices | |
representatives = {} # Nodes that represent a connected component | |
leaves = {} # Leaf nodes | |
I = [] #Persistence diagram | |
IIdx = [] # Paired indices | |
for i in idx: # Go through each point in the time series in height order | |
neighbs = [] | |
#Find the oldest representatives of the neighbors that | |
#are already alive | |
for di in [-1, 1]: #Neighbor set is simply left/right | |
if circular or (i+di >= 0 and i+di < N): | |
idx = i + di | |
if circular: | |
idx = idx % N | |
if idxorder[idx] < idxorder[i]: | |
neighbs.append(unionfind_root(pointers, idx)) | |
if len(neighbs) == 0: | |
#If none of this point's neighbors are alive yet, this | |
#point will become alive with its own class | |
leaves[i] = MergeNode(y[i], i) | |
representatives[i] = leaves[i] | |
else: | |
#Find the oldest class, merge earlier classes with this class, | |
#and record the merge events and birth/death times | |
oldest_neighb = neighbs[np.argmin([idxorder[n] for n in neighbs])] | |
#No matter, what, the current node becomes part of the | |
#oldest class to which it is connected | |
unionfind_union(pointers, oldest_neighb, i, idxorder) | |
if len(neighbs) == 2: #A nontrivial merge | |
for n in neighbs: | |
if not (n == oldest_neighb): | |
#Create node and record persistence event if it's nontrivial | |
if y[i] > y[n]: | |
# Record persistence information | |
I.append([y[n], y[i]]) | |
IIdx.append([n, i]) | |
leaves[n].birth_death = (y[n], y[i]) | |
# Create new node | |
node = MergeNode(y[i], i) | |
self.root = node | |
left_right = [representatives[n] for n in neighbs] | |
if left_right[0].x > left_right[1].x: | |
left_right = left_right[::-1] | |
node.children = left_right | |
#Change the representative for this class to be the new node | |
representatives[oldest_neighb] = node | |
unionfind_union(pointers, oldest_neighb, n, idxorder) | |
#Add the essential class | |
leaves[np.argmin(y)].is_globalmin = True | |
if include_essential: | |
idx1 = np.argmin(y) | |
idx2 = np.argmax(y) | |
[b, d] = [y[idx1], y[idx2]] | |
I.append([b, d]) | |
IIdx.append([idx1, idx2]) | |
leaves[idx1].birth_death = (b, d) | |
self.PD = np.array(I) | |
self.PDIdx = np.array(IIdx, dtype=int) | |
return self.PD, self.PDIdx | |
if __name__ == '__main__': | |
circular=False | |
np.random.seed(0) | |
N = 200 | |
t = np.linspace(0.01, 0.98, N) | |
x = np.cos(2*np.pi*t*10) + t*10 | |
x += 0.3*np.random.randn(N) | |
MT = MergeTree() | |
MT.init_from_timeseries(x) | |
rg = [np.min(x), np.max(x)] | |
pad = 0.1*(rg[1]-rg[0]) | |
rg[0] -= pad | |
rg[1] += pad | |
fac = 0.6 | |
plt.figure(figsize=(fac*20, fac*6)) | |
for i, eps in enumerate(np.linspace(0, 1, 200)): | |
plt.clf() | |
plt.subplot(131) | |
MT.persistence_simplify(eps) | |
MT.plot(False, {'pointsize':10, 'linewidth':1}) | |
plt.title("Simplified $\\epsilon = {:.3f}$".format(eps)) | |
plt.ylim(rg) | |
plt.xlim([-1, x.size+1]) | |
plt.subplot(132) | |
PD = MT.PD | |
plot_diagrams(MT.PD, sizes=4) | |
plt.plot([rg[0], rg[1]], [rg[0]+eps, rg[1]+eps], c='C3', linestyle='--') | |
PDEps = PD[PD[:, 1]-PD[:, 0] < eps, :] | |
plt.scatter(PDEps[:, 0], PDEps[:, 1], marker='x', c='C3') | |
plt.xlim(rg) | |
plt.ylim(rg) | |
plt.title("Persistence diagram") | |
plt.subplot(133) | |
res = MT.get_rep_timeseries() | |
plt.plot(res['xs'], res['ys']) | |
plt.title("Representative Time Series") | |
plt.ylim(rg) | |
plt.xlim([-1, x.size+1]) | |
plt.savefig("MT{}.png".format(i)) |
Author
ctralie
commented
Jul 26, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment