Last active
March 21, 2019 13:50
-
-
Save ipashchenko/90e1220b6517e301f7a8fe1c8001bf65 to your computer and use it in GitHub Desktop.
Get skeleton of image
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import operator | |
import string | |
import copy | |
import itertools | |
import numpy as np | |
import networkx as nx | |
import scipy.ndimage as nd | |
import image_ops | |
from from_fits import create_image_from_fits_file | |
from skimage.morphology import medial_axis | |
from scipy import nanmean | |
import matplotlib.pyplot as plt | |
# Create 4 to 8-connected elements to use with binary hit-or-miss | |
struct1 = np.array([[1, 0, 0], | |
[0, 1, 1], | |
[0, 0, 0]]) | |
struct2 = np.array([[0, 0, 1], | |
[1, 1, 0], | |
[0, 0, 0]]) | |
# Next check the three elements which will be double counted | |
check1 = np.array([[1, 1, 0, 0], | |
[0, 0, 1, 1]]) | |
check2 = np.array([[0, 0, 1, 1], | |
[1, 1, 0, 0]]) | |
check3 = np.array([[1, 1, 0], | |
[0, 0, 1], | |
[0, 0, 1]]) | |
def eight_con(): | |
return np.ones((3, 3)) | |
def _fix_small_holes(mask_array, rel_size=0.1): | |
''' | |
Helper function to remove only small holes within a masked region. | |
Parameters | |
---------- | |
mask_array : numpy.ndarray | |
Array containing the masked region. | |
rel_size : float, optional | |
If < 1.0, sets the minimum size a hole must be relative to the area | |
of the mask. Otherwise, this is the maximum number of pixels the hole | |
must have to be deleted. | |
Returns | |
------- | |
mask_array : numpy.ndarray | |
Altered array. | |
''' | |
if rel_size <= 0.0: | |
raise ValueError("rel_size must be positive.") | |
elif rel_size > 1.0: | |
pixel_flag = True | |
else: | |
pixel_flag = False | |
# Find the region area | |
reg_area = len(np.where(mask_array == 1)[0]) | |
# Label the holes | |
holes = np.logical_not(mask_array).astype(float) | |
lab_holes, n_holes = nd.label(holes, eight_con()) | |
# If no holes, return | |
if n_holes == 1: | |
return mask_array | |
# Ignore area outside of the region. | |
out_label = lab_holes[0, 0] | |
# Set size to be just larger than the region. Thus it can never be | |
# deleted. | |
holes[np.where(lab_holes == out_label)] = reg_area + 1. | |
# Sum up the regions and find holes smaller than the threshold. | |
sums = nd.sum(holes, lab_holes, range(1, n_holes + 1)) | |
if pixel_flag: # Use number of pixels | |
delete_holes = np.where(sums < rel_size)[0] | |
else: # Use relative size of holes. | |
delete_holes = np.where(sums / reg_area < rel_size)[0] | |
# Return if there is nothing to delete. | |
if delete_holes == []: | |
return mask_array | |
# Add one to take into account 0 in list if object label 1. | |
delete_holes += 1 | |
for label in delete_holes: | |
mask_array[np.where(lab_holes == label)] = 1 | |
return mask_array | |
def isolateregions(binary_array, size_threshold=0, pad_size=5, | |
fill_hole=False, rel_size=0.1, morph_smooth=False): | |
''' | |
Labels regions in a boolean array and returns individual arrays for each | |
region. Regions below a threshold can optionlly be removed. Small holes | |
may also be filled in. | |
Parameters | |
---------- | |
binary_array : numpy.ndarray | |
A binary array of regions. | |
size_threshold : int, optional | |
Sets the pixel size on the size of regions. | |
pad_size : int, optional | |
Padding to be added to the individual arrays. | |
fill_hole : int, optional | |
Enables hole filling. | |
rel_size : float or int, optional | |
If < 1.0, sets the minimum size a hole must be relative to the area | |
of the mask. Otherwise, this is the maximum number of pixels the hole | |
must have to be deleted. | |
morph_smooth : bool, optional | |
Morphologically smooth the image using a binar opening and closing. | |
Returns | |
------- | |
output_arrays : list | |
Regions separated into individual arrays. | |
num : int | |
Number of filaments | |
corners : list | |
Contains the indices where each skeleton array was taken from | |
the original. | |
''' | |
output_arrays = [] | |
corners = [] | |
# Label skeletons | |
labels, num = nd.label(binary_array, eight_con()) | |
# Remove skeletons which have fewer pixels than the threshold. | |
if size_threshold != 0: | |
sums = nd.sum(binary_array, labels, range(1, num + 1)) | |
remove_fils = np.where(sums <= size_threshold)[0] | |
for lab in remove_fils: | |
binary_array[np.where(labels == lab + 1)] = 0 | |
# Relabel after deleting short skeletons. | |
labels, num = nd.label(binary_array, eight_con()) | |
# Split each skeleton into its own array. | |
for n in range(1, num + 1): | |
x, y = np.where(labels == n) | |
# Make an array shaped to the skeletons size and padded on each edge | |
# the +1 is because, e.g., range(0, 5) only has 5 elements, but the | |
# indices we're using are range(0, 6) | |
shapes = (x.max() - x.min() + 2 * pad_size, | |
y.max() - y.min() + 2 * pad_size) | |
eachfil = np.zeros(shapes) | |
eachfil[x - x.min() + pad_size, y - y.min() + pad_size] = 1 | |
# Fill in small holes | |
if fill_hole: | |
eachfil = _fix_small_holes(eachfil, rel_size=rel_size) | |
if morph_smooth: | |
eachfil = nd.binary_opening(eachfil, np.ones((3, 3))) | |
eachfil = nd.binary_closing(eachfil, np.ones((3, 3))) | |
output_arrays.append(eachfil) | |
# Keep the coordinates from the original image | |
lower = (x.min() - pad_size, y.min() - pad_size) | |
upper = (x.max() + pad_size + 1, y.max() + pad_size + 1) | |
corners.append([lower, upper]) | |
return output_arrays, num, corners | |
def shifter(l, n): | |
return l[n:] + l[:n] | |
def distance(x, x1, y, y1): | |
return np.sqrt((x - x1) ** 2.0 + (y - y1) ** 2.0) | |
def find_filpix(branches, labelfil, final=True): | |
''' | |
Identifies the types of pixels in the given skeletons. Identification is | |
based on the connectivity of the pixel. | |
Parameters | |
---------- | |
branches : list | |
Contains the number of branches in each skeleton. | |
labelfil : list | |
Contains the arrays of each skeleton. | |
final : bool, optional | |
If true, corner points, intersections, and body points are all | |
labeled as a body point for use when the skeletons have already | |
been cleaned. | |
Returns | |
------- | |
fila_pts : list | |
All points on the body of each skeleton. | |
inters : list | |
All points associated with an intersection in each skeleton. | |
labelfil : list | |
Contains the arrays of each skeleton where all intersections | |
have been removed. | |
endpts_return : list | |
The end points of each branch of each skeleton. | |
''' | |
initslices = [] | |
initlist = [] | |
shiftlist = [] | |
sublist = [] | |
endpts = [] | |
blockpts = [] | |
bodypts = [] | |
slices = [] | |
vallist = [] | |
shiftvallist = [] | |
cornerpts = [] | |
subvallist = [] | |
subslist = [] | |
pix = [] | |
filpix = [] | |
intertemps = [] | |
fila_pts = [] | |
inters = [] | |
repeat = [] | |
temp_group = [] | |
all_pts = [] | |
pairs = [] | |
endpts_return = [] | |
for k in range(1, branches + 1): | |
x, y = np.where(labelfil == k) | |
# pixel_slices = np.empty((len(x)+1,8)) | |
for i in range(len(x)): | |
if x[i] < labelfil.shape[0] - 1 and y[i] < labelfil.shape[1] - 1: | |
pix.append((x[i], y[i])) | |
initslices.append(np.array([[labelfil[x[i] - 1, y[i] + 1], | |
labelfil[x[i], y[i] + 1], | |
labelfil[x[i] + 1, y[i] + 1]], | |
[labelfil[x[i] - 1, y[i]], 0, | |
labelfil[x[i] + 1, y[i]]], | |
[labelfil[x[i] - 1, y[i] - 1], | |
labelfil[x[i], y[i] - 1], | |
labelfil[x[i] + 1, y[i] - 1]]])) | |
filpix.append(pix) | |
slices.append(initslices) | |
initslices = [] | |
pix = [] | |
for i in range(len(slices)): | |
for k in range(len(slices[i])): | |
initlist.append([slices[i][k][0, 0], | |
slices[i][k][0, 1], | |
slices[i][k][0, 2], | |
slices[i][k][1, 2], | |
slices[i][k][2, 2], | |
slices[i][k][2, 1], | |
slices[i][k][2, 0], | |
slices[i][k][1, 0]]) | |
vallist.append(initlist) | |
initlist = [] | |
for i in range(len(slices)): | |
for k in range(len(slices[i])): | |
shiftlist.append(shifter(vallist[i][k], 1)) | |
shiftvallist.append(shiftlist) | |
shiftlist = [] | |
for k in range(len(slices)): | |
for i in range(len(vallist[k])): | |
for j in range(8): | |
sublist.append( | |
int(vallist[k][i][j]) - int(shiftvallist[k][i][j])) | |
subslist.append(sublist) | |
sublist = [] | |
subvallist.append(subslist) | |
subslist = [] | |
# x represents the subtracted list (step-ups) and y is the values of the | |
# surrounding pixels. The categories of pixels are ENDPTS (x<=1), | |
# BODYPTS (x=2,y=2),CORNERPTS (x=2,y=3),BLOCKPTS (x=3,y>=4), and | |
# INTERPTS (x>=3). | |
# A cornerpt is [*,0,0] (*s) associated with an intersection, | |
# but their exclusion from | |
# [1,*,0] the intersection keeps eight-connectivity, they are included | |
# [0,1,0] intersections for this reason. | |
# A blockpt is [1,0,1] They are typically found in a group of four, | |
# where all four | |
# [0,*,*] constitute a single intersection. | |
# [1,*,*] | |
# The "final" designation is used when finding the final branch lengths. | |
# At this point, blockpts and cornerpts should be eliminated. | |
for k in range(branches): | |
for l in range(len(filpix[k])): | |
x = [j for j, y in enumerate(subvallist[k][l]) if y == k + 1] | |
y = [j for j, z in enumerate(vallist[k][l]) if z == k + 1] | |
if len(x) <= 1: | |
endpts.append(filpix[k][l]) | |
endpts_return.append(filpix[k][l]) | |
elif len(x) == 2: | |
if final: | |
bodypts.append(filpix[k][l]) | |
else: | |
if len(y) == 2: | |
bodypts.append(filpix[k][l]) | |
elif len(y) == 3: | |
cornerpts.append(filpix[k][l]) | |
elif len(y) >= 4: | |
blockpts.append(filpix[k][l]) | |
elif len(x) >= 3: | |
intertemps.append(filpix[k][l]) | |
endpts = list(set(endpts)) | |
bodypts = list(set(bodypts)) | |
dups = set(endpts) & set(bodypts) | |
if len(dups) > 0: | |
for i in dups: | |
bodypts.remove(i) | |
# Cornerpts without a partner diagonally attached can be included as a | |
# bodypt. | |
if len(cornerpts) > 0: | |
deleted_cornerpts = [] | |
for i, j in zip(cornerpts, cornerpts): | |
if i != j: | |
if distance(i[0], j[0], i[1], j[1]) == np.sqrt(2.0): | |
proximity = [(i[0], i[1] - 1), | |
(i[0], i[1] + 1), | |
(i[0] - 1, i[1]), | |
(i[0] + 1, i[1]), | |
(i[0] - 1, i[1] + 1), | |
(i[0] + 1, i[1] + 1), | |
(i[0] - 1, i[1] - 1), | |
(i[0] + 1, i[1] - 1)] | |
match = set(intertemps) & set(proximity) | |
if len(match) == 1: | |
pairs.append([i, j]) | |
deleted_cornerpts.append(i) | |
deleted_cornerpts.append(j) | |
cornerpts = list(set(cornerpts).difference(set(deleted_cornerpts))) | |
if len(cornerpts) > 0: | |
for l in cornerpts: | |
proximity = [(l[0], l[1] - 1), | |
(l[0], l[1] + 1), | |
(l[0] - 1, l[1]), | |
(l[0] + 1, l[1]), | |
(l[0] - 1, l[1] + 1), | |
(l[0] + 1, l[1] + 1), | |
(l[0] - 1, l[1] - 1), | |
(l[0] + 1, l[1] - 1)] | |
match = set(intertemps) & set(proximity) | |
if len(match) == 1: | |
intertemps.append(l) | |
fila_pts.append(endpts + bodypts) | |
else: | |
fila_pts.append(endpts + bodypts + [l]) | |
# cornerpts.remove(l) | |
else: | |
fila_pts.append(endpts + bodypts) | |
# Reset lists | |
cornerpts = [] | |
endpts = [] | |
bodypts = [] | |
if len(pairs) > 0: | |
for i in range(len(pairs)): | |
for j in pairs[i]: | |
all_pts.append(j) | |
if len(blockpts) > 0: | |
for i in blockpts: | |
all_pts.append(i) | |
if len(intertemps) > 0: | |
for i in intertemps: | |
all_pts.append(i) | |
# Pairs of cornerpts, blockpts, and interpts are combined into an | |
# array. If there is eight connectivity between them, they are labelled | |
# as a single intersection. | |
arr = np.zeros((labelfil.shape)) | |
for z in all_pts: | |
labelfil[z[0], z[1]] = 0 | |
arr[z[0], z[1]] = 1 | |
lab, nums = nd.label(arr, eight_con()) | |
for k in range(1, nums + 1): | |
objs_pix = np.where(lab == k) | |
for l in range(len(objs_pix[0])): | |
temp_group.append((objs_pix[0][l], objs_pix[1][l])) | |
inters.append(temp_group) | |
temp_group = [] | |
for i in range(len(inters) - 1): | |
if inters[i] == inters[i + 1]: | |
repeat.append(inters[i]) | |
for i in repeat: | |
inters.remove(i) | |
return fila_pts, inters, labelfil, endpts_return | |
def pix_identify(isolatefilarr, num): | |
''' | |
This function is essentially a wrapper on find_filpix. It returns the | |
outputs of find_filpix in the form that are used during the analysis. | |
Parameters | |
---------- | |
isolatefilarr : list | |
Contains individual arrays of each skeleton. | |
num : int | |
The number of skeletons. | |
Returns | |
------- | |
interpts : list | |
Contains lists of all intersections points in each skeleton. | |
hubs : list | |
Contains the number of intersections in each filament. This is | |
useful for identifying those with no intersections as their analysis | |
is straight-forward. | |
ends : list | |
Contains the positions of all end points in each skeleton. | |
filbranches : list | |
Contains the number of branches in each skeleton. | |
labelisofil : list | |
Contains individual arrays for each skeleton where the | |
branches are labeled and the intersections have been removed. | |
''' | |
interpts = [] | |
hubs = [] | |
ends = [] | |
filbranches = [] | |
labelisofil = [] | |
for n in range(num): | |
funcreturn = find_filpix(1, isolatefilarr[n], final=False) | |
interpts.append(funcreturn[1]) | |
hubs.append(len(funcreturn[1])) | |
isolatefilarr.pop(n) | |
isolatefilarr.insert(n, funcreturn[2]) | |
ends.append(funcreturn[3]) | |
label_branch, num_branch = nd.label(isolatefilarr[n], eight_con()) | |
filbranches.append(num_branch) | |
labelisofil.append(label_branch) | |
return interpts, hubs, ends, filbranches, labelisofil | |
def skeleton_length(skeleton): | |
''' | |
Length finding via morphological operators. We use the differences in | |
connectivity between 4 and 8-connected to split regions. Connections | |
between 4 and 8-connected regions are found using a series of hit-miss | |
operators. | |
The inputted skeleton MUST have no intersections otherwise the returned | |
length will not be correct! | |
Parameters | |
---------- | |
skeleton : numpy.ndarray | |
Array containing the skeleton. | |
Returns | |
------- | |
length : float | |
Length of the skeleton. | |
''' | |
# 4-connected labels | |
four_labels = nd.label(skeleton)[0] | |
four_sizes = nd.sum(skeleton, four_labels, range(np.max(four_labels) + 1)) | |
# Lengths is the number of pixels minus number of objects with more | |
# than 1 pixel. | |
four_length = np.sum( | |
four_sizes[four_sizes > 1]) - len(four_sizes[four_sizes > 1]) | |
# Find pixels which a 4-connected and subtract them off the skeleton | |
four_objects = np.where(four_sizes > 1)[0] | |
skel_copy = copy.copy(skeleton) | |
for val in four_objects: | |
skel_copy[np.where(four_labels == val)] = 0 | |
# Remaining pixels are only 8-connected | |
# Lengths is same as before, multiplied by sqrt(2) | |
eight_labels = nd.label(skel_copy, eight_con())[0] | |
eight_sizes = nd.sum( | |
skel_copy, eight_labels, range(np.max(eight_labels) + 1)) | |
eight_length = ( | |
np.sum(eight_sizes) - np.max(eight_labels)) * np.sqrt(2) | |
# If there are no 4-connected pixels, we don't need the hit-miss portion. | |
if four_length == 0.0: | |
conn_length = 0.0 | |
else: | |
store = np.zeros(skeleton.shape) | |
# Loop through the 4 rotations of the structuring elements | |
for k in range(0, 4): | |
hm1 = nd.binary_hit_or_miss( | |
skeleton, structure1=np.rot90(struct1, k=k)) | |
store += hm1 | |
hm2 = nd.binary_hit_or_miss( | |
skeleton, structure1=np.rot90(struct2, k=k)) | |
store += hm2 | |
hm_check3 = nd.binary_hit_or_miss( | |
skeleton, structure1=np.rot90(check3, k=k)) | |
store -= hm_check3 | |
if k <= 1: | |
hm_check1 = nd.binary_hit_or_miss( | |
skeleton, structure1=np.rot90(check1, k=k)) | |
store -= hm_check1 | |
hm_check2 = nd.binary_hit_or_miss( | |
skeleton, structure1=np.rot90(check2, k=k)) | |
store -= hm_check2 | |
conn_length = np.sqrt(2) * \ | |
np.sum(np.sum(store, axis=1), axis=0) # hits | |
return conn_length + eight_length + four_length | |
def init_lengths(labelisofil, filbranches, array_offsets, img): | |
''' | |
This is a wrapper on fil_length for running on the branches of the | |
skeletons. | |
Parameters | |
---------- | |
labelisofil : list | |
Contains individual arrays for each skeleton where the | |
branches are labeled and the intersections have been removed. | |
filbranches : list | |
Contains the number of branches in each skeleton. | |
array_offsets : List | |
The indices of where each filament array fits in the | |
original image. | |
img : numpy.ndarray | |
Original image. | |
Returns | |
------- | |
branch_properties: dict | |
Contains the lengths and intensities of the branches. | |
Keys are *length* and *intensity*. | |
''' | |
num = len(labelisofil) | |
# Initialize Lists | |
lengths = [] | |
av_branch_intensity = [] | |
for n in range(num): | |
leng = [] | |
av_intensity = [] | |
label_copy = copy.copy(labelisofil[n]) | |
objects = nd.find_objects(label_copy) | |
for obj in objects: | |
# Scale the branch array to the branch size | |
branch_array = label_copy[obj] | |
# Find the skeleton points and set those to 1 | |
branch_pts = np.where(branch_array > 0) | |
branch_array[branch_pts] = 1 | |
# Now find the length on the branch | |
branch_length = skeleton_length(branch_array) | |
if branch_length == 0.0: | |
# For use in longest path algorithm, will be set to zero for | |
# final analysis | |
branch_length = 0.5 | |
leng.append(branch_length) | |
# Now let's find the average intensity along each branch | |
# Get the offsets from the original array and | |
# add on the offset the branch array introduces. | |
x_offset = obj[0].start + array_offsets[n][0][0] | |
y_offset = obj[1].start + array_offsets[n][0][1] | |
av_intensity.append( | |
nanmean([img[x + x_offset, y + y_offset] | |
for x, y in zip(*branch_pts) | |
if np.isfinite(img[x + x_offset, y + y_offset]) and | |
not img[x + x_offset, y + y_offset] < 0.0])) | |
lengths.append(leng) | |
av_branch_intensity.append(av_intensity) | |
branch_properties = { | |
"length": lengths, "intensity": av_branch_intensity} | |
return branch_properties | |
def product_gen(n): | |
for r in itertools.count(1): | |
for i in itertools.product(n, repeat=r): | |
yield "".join(i) | |
def pre_graph(labelisofil, branch_properties, interpts, ends): | |
''' | |
This function converts the skeletons into a graph object compatible with | |
networkx. The graphs have nodes corresponding to end and | |
intersection points and edges defining the connectivity as the branches | |
with the weights set to the branch length. | |
Parameters | |
---------- | |
labelisofil : list | |
Contains individual arrays for each skeleton where the | |
branches are labeled and the intersections have been removed. | |
branch_properties : dict | |
Contains the lengths and intensities of all branches. | |
interpts : list | |
Contains the pixels which belong to each intersection. | |
ends : list | |
Contains the end pixels for each skeleton. | |
Returns | |
------- | |
end_nodes : list | |
Contains the nodes corresponding to end points. | |
inter_nodes : list | |
Contains the nodes corresponding to intersection points. | |
edge_list : list | |
Contains the connectivity information for the graphs. | |
nodes : list | |
A complete list of all of the nodes. The other nodes lists have | |
been separated as they are labeled differently. | |
''' | |
num = len(labelisofil) | |
end_nodes = [] | |
inter_nodes = [] | |
nodes = [] | |
edge_list = [] | |
def path_weighting(idx, length, intensity, w=0.5): | |
''' | |
Relative weighting for the shortest path algorithm using the branch | |
lengths and the average intensity along the branch. | |
''' | |
if w > 1.0 or w < 0.0: | |
raise ValueError( | |
"Relative weighting w must be between 0.0 and 1.0.") | |
return (1 - w) * (length[idx] / np.sum(length)) + \ | |
w * (intensity[idx] / np.sum(intensity)) | |
lengths = branch_properties["length"] | |
branch_intensity = branch_properties["intensity"] | |
for n in range(num): | |
inter_nodes_temp = [] | |
# Create end_nodes, which contains lengths, and nodes, which we will | |
# later add in the intersections | |
end_nodes.append([(labelisofil[n][i[0], i[1]], | |
path_weighting(int(labelisofil[n][i[0], i[1]] - 1), | |
lengths[n], | |
branch_intensity[n]), | |
lengths[n][int(labelisofil[n][i[0], i[1]] - 1)], | |
branch_intensity[n][int(labelisofil[n][i[0], i[1]] - 1)]) | |
for i in ends[n]]) | |
nodes.append([labelisofil[n][i[0], i[1]] for i in ends[n]]) | |
# Intersection nodes are given by the intersections points of the filament. | |
# They are labeled alphabetically (if len(interpts[n])>26, | |
# subsequent labels are AA,AB,...). | |
# The branch labels attached to each intersection are included for future | |
# use. | |
for intersec in interpts[n]: | |
uniqs = [] | |
for i in intersec: # Intersections can contain multiple pixels | |
int_arr = np.array([[labelisofil[n][i[0] - 1, i[1] + 1], | |
labelisofil[n][i[0], i[1] + 1], | |
labelisofil[n][i[0] + 1, i[1] + 1]], | |
[labelisofil[n][i[0] - 1, i[1]], 0, | |
labelisofil[n][i[0] + 1, i[1]]], | |
[labelisofil[n][i[0] - 1, i[1] - 1], | |
labelisofil[n][i[0], i[1] - 1], | |
labelisofil[n][i[0] + 1, i[1] - 1]]]).astype(int) | |
for x in np.unique(int_arr[np.nonzero(int_arr)]): | |
uniqs.append((x, | |
path_weighting(x - 1, lengths[n], | |
branch_intensity[n]), | |
lengths[n][x - 1], | |
branch_intensity[n][x - 1])) | |
# Intersections with multiple pixels can give the same branches. | |
# Get rid of duplicates | |
uniqs = list(set(uniqs)) | |
inter_nodes_temp.append(uniqs) | |
# Add the intersection labels. Also append those to nodes | |
inter_nodes.append( | |
zip(product_gen(string.ascii_uppercase), inter_nodes_temp)) | |
for alpha, node in zip(product_gen(string.ascii_uppercase), | |
inter_nodes_temp): | |
nodes[n].append(alpha) | |
# Edges are created from the information contained in the nodes. | |
edge_list_temp = [] | |
for i, inters in enumerate(inter_nodes[n]): | |
end_match = list(set(inters[1]) & set(end_nodes[n])) | |
for k in end_match: | |
edge_list_temp.append((inters[0], k[0], k)) | |
for j, inters_2 in enumerate(inter_nodes[n]): | |
if i != j: | |
match = list(set(inters[1]) & set(inters_2[1])) | |
new_edge = None | |
if len(match) == 1: | |
new_edge = (inters[0], inters_2[0], match[0]) | |
elif len(match) > 1: | |
multi = [match[l][1] for l in range(len(match))] | |
keep = multi.index(min(multi)) | |
new_edge = (inters[0], inters_2[0], match[keep]) | |
if new_edge is not None: | |
if not (new_edge[1], new_edge[0], new_edge[2]) in edge_list_temp \ | |
and new_edge not in edge_list_temp: | |
edge_list_temp.append(new_edge) | |
# Remove duplicated edges between intersections | |
edge_list.append(edge_list_temp) | |
return edge_list, nodes | |
def try_mkdir(name): | |
''' | |
Checks if a folder exists, and makes it if it doesn't | |
''' | |
if not os.path.isdir(os.path.join(os.getcwd(), name)): | |
os.mkdir(os.path.join(os.getcwd(), name)) | |
def longest_path(edge_list, nodes, verbose=False, | |
skeleton_arrays=None, save_png=False, save_name=None): | |
''' | |
Takes the output of pre_graph and runs the shortest path algorithm. | |
Parameters | |
---------- | |
edge_list : list | |
Contains the connectivity information for the graphs. | |
nodes : list | |
A complete list of all of the nodes. The other nodes lists have | |
been separated as they are labeled differently. | |
verbose : bool, optional | |
If True, enables the plotting of the graph. | |
skeleton_arrays : list, optional | |
List of the skeleton arrays. Required when verbose=True. | |
save_png : bool, optional | |
Saves the plot made in verbose mode. Disabled by default. | |
save_name : str, optional | |
For use when ``save_png`` is enabled. | |
**MUST be specified when ``save_png`` is enabled.** | |
Returns | |
------- | |
max_path : list | |
Contains the paths corresponding to the longest lengths for | |
each skeleton. | |
extremum : list | |
Contains the starting and ending points of max_path | |
''' | |
num = len(nodes) | |
# Initialize lists | |
max_path = [] | |
extremum = [] | |
graphs = [] | |
for n in range(num): | |
G = nx.Graph() | |
G.add_nodes_from(nodes[n]) | |
for i in edge_list[n]: | |
G.add_edge(i[0], i[1], weight=i[2][1]) | |
paths = nx.shortest_path_length(G, weight='weight') | |
# Fix new API | |
new_paths = dict() | |
for path in paths: | |
new_paths[path[0]] = path[1] | |
values = [] | |
node_extrema = [] | |
for i in new_paths.keys(): | |
j = max(new_paths[i].items(), key=operator.itemgetter(1)) | |
node_extrema.append((j[0], i)) | |
values.append(j[1]) | |
start, finish = node_extrema[values.index(max(values))] | |
extremum.append([start, finish]) | |
max_path.append(nx.shortest_path(G, start, finish)) | |
graphs.append(G) | |
if verbose or save_png: | |
if not skeleton_arrays: | |
Warning("Must input skeleton arrays if verbose or save_png is" | |
" enabled. No plots will be created.") | |
elif save_png and save_name is None: | |
Warning("Must give a save_name when save_png is enabled. No" | |
" plots will be created.") | |
else: | |
# Check if skeleton_arrays is a list | |
assert isinstance(skeleton_arrays, list) | |
import matplotlib.pyplot as p | |
if verbose: | |
print("Filament: %s / %s" % (n + 1, num)) | |
p.subplot(1, 2, 1) | |
p.imshow(skeleton_arrays[n], interpolation="nearest", | |
origin="lower") | |
p.subplot(1, 2, 2) | |
elist = [(u, v) for (u, v, d) in G.edges(data=True)] | |
pos = nx.spring_layout(G) | |
nx.draw_networkx_nodes(G, pos, node_size=200) | |
nx.draw_networkx_edges(G, pos, edgelist=elist, width=2) | |
nx.draw_networkx_labels( | |
G, pos, font_size=10, font_family='sans-serif') | |
p.axis('off') | |
if save_png: | |
try_mkdir(save_name) | |
p.savefig(os.path.join(save_name, | |
save_name + "_longest_path_" + str(n) + ".png")) | |
if verbose: | |
p.show() | |
p.clf() | |
return max_path, extremum, graphs | |
def prune_graph(G, nodes, edge_list, max_path, labelisofil, branch_properties, | |
length_thresh, relintens_thresh=0.2): | |
''' | |
Function to remove unnecessary branches, while maintaining connectivity | |
in the graph. Also updates edge_list, nodes, branch_lengths and | |
filbranches. | |
Parameters | |
---------- | |
G : list | |
Contains the networkx Graph objects. | |
nodes : list | |
A complete list of all of the nodes. The other nodes lists have | |
been separated as they are labeled differently. | |
edge_list : list | |
Contains the connectivity information for the graphs. | |
max_path : list | |
Contains the paths corresponding to the longest lengths for | |
each skeleton. | |
labelisofil : list | |
Contains individual arrays for each skeleton where the | |
branches are labeled and the intersections have been removed. | |
branch_properties : dict | |
Contains the lengths and intensities of all branches. | |
length_thresh : int or float | |
Minimum length a branch must be to be kept. Can be overridden if the | |
branch is bright relative to the entire skeleton. | |
relintens_thresh : float between 0 and 1, optional. | |
Threshold for how bright the branch must be relative to the entire | |
skeleton. Can be overridden by length. | |
Returns | |
------- | |
labelisofil : list | |
Updated from input. | |
edge_list : list | |
Updated from input. | |
nodes : list | |
Updated from input. | |
branch_properties : dict | |
Updated from input. | |
''' | |
num = len(labelisofil) | |
for n in range(num): | |
degree = G[n].degree() | |
# To make compatible with new iterator API of networkx | |
new_degree = dict() | |
for d in degree: | |
new_degree[d[0]] = d[1] | |
single_connect = [key for key in new_degree.keys() if new_degree[key] == 1] | |
delete_candidate = list( | |
(set(nodes[n]) - set(max_path[n])) & set(single_connect)) | |
if not delete_candidate: # Nothing to delete! | |
continue | |
edge_candidates = [edge for edge in edge_list[n] if edge[ | |
0] in delete_candidate or edge[1] in delete_candidate] | |
intensities = [edge[2][3] for edge in edge_list[n]] | |
for edge in edge_candidates: | |
# In the odd case where a loop meets at the same intersection, | |
# ensure that edge is kept. | |
if isinstance(edge[0], str) & isinstance(edge[1], str): | |
continue | |
# If its too short and relatively not as intense, delete it | |
length = edge[2][2] | |
av_intensity = edge[2][3] | |
if length < length_thresh \ | |
and (av_intensity / np.sum(intensities)) < relintens_thresh: | |
edge_pts = np.where(labelisofil[n] == edge[2][0]) | |
labelisofil[n][edge_pts] = 0 | |
edge_list[n].remove(edge) | |
nodes[n].remove(edge[1]) | |
branch_properties["length"][n].remove(length) | |
branch_properties["intensity"][n].remove(av_intensity) | |
branch_properties["number"][n] -= 1 | |
return labelisofil, edge_list, nodes, branch_properties | |
def extremum_pts(labelisofil, extremum, ends): | |
''' | |
This function returns the the farthest extents of each filament. This | |
is useful for determining how well the shortest path algorithm has worked. | |
Parameters | |
---------- | |
labelisofil : list | |
Contains individual arrays for each skeleton. | |
extremum : list | |
Contains the extents as determined by the shortest | |
path algorithm. | |
ends : list | |
Contains the positions of each end point in eahch filament. | |
Returns | |
------- | |
extrem_pts : list | |
Contains the indices of the extremum points. | |
''' | |
num = len(labelisofil) | |
extrem_pts = [] | |
for n in range(num): | |
per_fil = [] | |
for i, j in ends[n]: | |
if labelisofil[n][i, j] == extremum[n][0] or labelisofil[n][i, j] == extremum[n][1]: | |
per_fil.append([i, j]) | |
extrem_pts.append(per_fil) | |
return extrem_pts | |
def main_length(max_path, edge_list, labelisofil, interpts, branch_lengths, | |
img_scale, verbose=False, save_png=False, save_name=None): | |
''' | |
Wraps previous functionality together for all of the skeletons in the | |
image. To find the overall length for each skeleton, intersections are | |
added back in, and any extraneous pixels they bring with them are deleted. | |
Parameters | |
---------- | |
max_path : list | |
Contains the paths corresponding to the longest lengths for | |
each skeleton. | |
edge_list : list | |
Contains the connectivity information for the graphs. | |
labelisofil : list | |
Contains individual arrays for each skeleton where the | |
branches are labeled and the intersections have been removed. | |
interpts : list | |
Contains the pixels which belong to each intersection. | |
branch_lengths : list | |
Lengths of individual branches in each skeleton. | |
img_scale : float | |
Conversion from pixel to physical units. | |
verbose : bool, optional | |
Returns plots of the longest path skeletons. | |
save_png : bool, optional | |
Saves the plot made in verbose mode. Disabled by default. | |
save_name : str, optional | |
For use when ``save_png`` is enabled. | |
**MUST be specified when ``save_png`` is enabled.** | |
Returns | |
------- | |
main_lengths : list | |
Lengths of the skeletons. | |
longpath_arrays : list | |
Arrays of the longest paths in the skeletons. | |
''' | |
main_lengths = [] | |
longpath_arrays = [] | |
for num, (path, edges, inters, skel_arr, lengths) in \ | |
enumerate(zip(max_path, edge_list, interpts, labelisofil, | |
branch_lengths)): | |
if len(path) == 1: | |
main_lengths.append(lengths[0] * img_scale) | |
skeleton = skel_arr # for viewing purposes when verbose | |
else: | |
skeleton = np.zeros(skel_arr.shape) | |
# Add edges along longest path | |
good_edge_list = [(path[i], path[i + 1]) | |
for i in range(len(path) - 1)] | |
# Find the branches along the longest path. | |
for i in good_edge_list: | |
for j in edges: | |
if (i[0] == j[0] and i[1] == j[1]) or \ | |
(i[0] == j[1] and i[1] == j[0]): | |
label = j[2][0] | |
skeleton[np.where(skel_arr == label)] = 1 | |
# Add intersections along longest path | |
intersec_pts = [] | |
for label in path: | |
try: | |
label = int(label) | |
except ValueError: | |
pass | |
if not isinstance(label, int): | |
k = 1 | |
while zip(product_gen(string.ascii_uppercase), | |
[1] * k)[-1][0] != label: | |
k += 1 | |
intersec_pts.extend(inters[k - 1]) | |
skeleton[zip(*inters[k - 1])] = 2 | |
# Remove unnecessary pixels | |
count = 0 | |
while True: | |
for pt in intersec_pts: | |
# If we have already eliminated the point, continue | |
if skeleton[pt] == 0: | |
continue | |
skeleton[pt] = 0 | |
lab_try, n = nd.label(skeleton, eight_con()) | |
if n > 1: | |
skeleton[pt] = 1 | |
else: | |
count += 1 | |
if count == 0: | |
break | |
count = 0 | |
main_lengths.append(skeleton_length(skeleton) * img_scale) | |
longpath_arrays.append(skeleton.astype(int)) | |
if verbose or save_png: | |
if save_png and save_name is None: | |
Warning("Must give a save_name when save_png is enabled. No" | |
" plots will be created.") | |
import matplotlib.pyplot as p | |
if verbose: | |
print("Filament: %s / %s" % (num + 1, len(labelisofil))) | |
p.subplot(121) | |
p.imshow(skeleton, origin='lower', interpolation="nearest") | |
p.subplot(122) | |
p.imshow(labelisofil[num], origin='lower', | |
interpolation="nearest") | |
if save_png: | |
try_mkdir(save_name) | |
p.savefig(os.path.join(save_name, | |
save_name + "_main_length_" + str(num) + ".png")) | |
if verbose: | |
p.show() | |
p.clf() | |
return main_lengths, longpath_arrays | |
def find_extran(branches, labelfil): | |
''' | |
Identify pixels that are not necessary to keep the connectivity of the | |
skeleton. It uses the same labeling process as find_filpix. Extraneous | |
pixels tend to be those from former intersections, whose attached branch | |
was eliminated in the cleaning process. | |
Parameters | |
---------- | |
branches : list | |
Contains the number of branches in each skeleton. | |
labelfil : list | |
Contains arrays of the labeled versions of each skeleton. | |
Returns | |
------- | |
labelfil : list | |
Contains the updated labeled arrays with extraneous pieces | |
removed. | |
''' | |
initslices = [] | |
initlist = [] | |
shiftlist = [] | |
sublist = [] | |
extran = [] | |
slices = [] | |
vallist = [] | |
shiftvallist = [] | |
subvallist = [] | |
subslist = [] | |
pix = [] | |
filpix = [] | |
for k in range(1, branches + 1): | |
x, y = np.where(labelfil == k) | |
for i in range(len(x)): | |
if x[i] < labelfil.shape[0] - 1 and y[i] < labelfil.shape[1] - 1: | |
pix.append((x[i], y[i])) | |
initslices.append(np.array([[labelfil[x[i] - 1, y[i] + 1], | |
labelfil[x[i], y[i] + 1], | |
labelfil[x[i] + 1, y[i] + 1]], | |
[labelfil[x[i] - 1, y[i]], 0, | |
labelfil[x[i] + 1, y[i]]], | |
[labelfil[x[i] - 1, y[i] - 1], | |
labelfil[x[i], y[i] - 1], | |
labelfil[x[i] + 1, y[i] - 1]]])) | |
filpix.append(pix) | |
slices.append(initslices) | |
initslices = [] | |
pix = [] | |
for i in range(len(slices)): | |
for k in range(len(slices[i])): | |
initlist.append([slices[i][k][0, 0], | |
slices[i][k][0, 1], | |
slices[i][k][0, 2], | |
slices[i][k][1, 2], | |
slices[i][k][2, 2], | |
slices[i][k][2, 1], | |
slices[i][k][2, 0], | |
slices[i][k][1, 0]]) | |
vallist.append(initlist) | |
initlist = [] | |
for i in range(len(slices)): | |
for k in range(len(slices[i])): | |
shiftlist.append(shifter(vallist[i][k], 1)) | |
shiftvallist.append(shiftlist) | |
shiftlist = [] | |
for k in range(len(slices)): | |
for i in range(len(vallist[k])): | |
for j in range(8): | |
sublist.append( | |
int(vallist[k][i][j]) - int(shiftvallist[k][i][j])) | |
subslist.append(sublist) | |
sublist = [] | |
subvallist.append(subslist) | |
subslist = [] | |
for k in range(len(slices)): | |
for l in range(len(filpix[k])): | |
x = [j for j, y in enumerate(subvallist[k][l]) if y == k + 1] | |
y = [j for j, z in enumerate(vallist[k][l]) if z == k + 1] | |
if len(x) == 0: | |
labelfil[filpix[k][l][0], filpix[k][l][1]] = 0 | |
if len(x) == 1: | |
if len(y) >= 2: | |
extran.append(filpix[k][l]) | |
labelfil[filpix[k][l][0], filpix[k][l][1]] = 0 | |
# if len(extran) >= 2: | |
# for i in extran: | |
# for j in extran: | |
# if i != j: | |
# if distance(i[0], j[0], i[1], j[1]) == np.sqrt(2.0): | |
# proximity = [(i[0], i[1] - 1), | |
# (i[0], i[1] + 1), | |
# (i[0] - 1, i[1]), | |
# (i[0] + 1, i[1]), | |
# (i[0] - 1, i[1] + 1), | |
# (i[0] + 1, i[1] + 1), | |
# (i[0] - 1, i[1] - 1), | |
# (i[0] + 1, i[1] - 1)] | |
# match = set(filpix[k]) & set(proximity) | |
# if len(match) > 0: | |
# for z in match: | |
# labelfil[z[0], z[1]] = 0 | |
return labelfil | |
def in_ipynb(): | |
try: | |
cfg = get_ipython().config | |
if cfg['IPKernelApp']['parent_appname'] == 'ipython-notebook': | |
return True | |
else: | |
return False | |
except NameError: | |
return False | |
def make_final_skeletons(labelisofil, inters, verbose=False, save_png=False, | |
save_name=None): | |
''' | |
Creates the final skeletons outputted by the algorithm. | |
Parameters | |
---------- | |
labelisofil : list | |
List of labeled skeletons. | |
inters : list | |
Positions of the intersections in each skeleton. | |
verbose : bool, optional | |
Enables plotting of the final skeleton. | |
save_png : bool, optional | |
Saves the plot made in verbose mode. Disabled by default. | |
save_name : str, optional | |
For use when ``save_png`` is enabled. | |
**MUST be specified when ``save_png`` is enabled.** | |
Returns | |
------- | |
filament_arrays : list | |
List of the final skeletons. | |
''' | |
filament_arrays = [] | |
for n, (skel_array, intersec) in enumerate(zip(labelisofil, inters)): | |
copy_array = np.zeros(skel_array.shape, dtype=int) | |
for inter in intersec: | |
for pts in inter: | |
x, y = pts | |
copy_array[x, y] = 1 | |
copy_array[np.where(skel_array >= 1)] = 1 | |
cleaned_array = find_extran(1, copy_array) | |
filament_arrays.append(cleaned_array) | |
if verbose or save_png: | |
if save_png and save_name is None: | |
Warning("Must give a save_name when save_png is enabled. No" | |
" plots will be created.") | |
plt.clf() | |
plt.imshow(cleaned_array, origin='lower', interpolation='nearest') | |
if save_png: | |
try_mkdir(save_name) | |
plt.savefig(os.path.join(save_name, | |
save_name+"_final_skeleton_"+str(n)+".png")) | |
if verbose: | |
plt.show() | |
if in_ipynb(): | |
plt.clf() | |
return filament_arrays | |
def recombine_skeletons(skeletons, offsets, orig_size, pad_size, | |
verbose=False): | |
''' | |
Takes a list of skeleton arrays and combines them back into | |
the original array. | |
Parameters | |
---------- | |
skeletons : list | |
Arrays of each skeleton. | |
offsets : list | |
Coordinates where the skeleton arrays have been sliced from the | |
image. | |
orig_size : tuple | |
Size of the image. | |
pad_size : int | |
Size of the array padding. | |
verbose : bool, optional | |
Enables printing when a skeleton array needs to be resized to fit | |
into the image. | |
Returns | |
------- | |
master_array : numpy.ndarray | |
Contains all skeletons placed in their original positions in the image. | |
''' | |
num = len(skeletons) | |
master_array = np.zeros(orig_size) | |
for n in range(num): | |
x_off, y_off = offsets[n][0] # These are the coordinates of the bottom | |
# left in the master array. | |
x_top, y_top = offsets[n][1] | |
# Now check if padding will put the array outside of the original array | |
# size | |
excess_x_top = x_top - orig_size[0] | |
excess_y_top = y_top - orig_size[1] | |
copy_skeleton = copy.copy(skeletons[n]) | |
size_change_flag = False | |
if excess_x_top > 0: | |
copy_skeleton = copy_skeleton[:-excess_x_top, :] | |
size_change_flag = True | |
if excess_y_top > 0: | |
copy_skeleton = copy_skeleton[:, :-excess_y_top] | |
size_change_flag = True | |
if x_off < 0: | |
copy_skeleton = copy_skeleton[-x_off:, :] | |
x_off = 0 | |
size_change_flag = True | |
if y_off < 0: | |
copy_skeleton = copy_skeleton[:, -y_off:] | |
y_off = 0 | |
size_change_flag = True | |
if verbose & size_change_flag: | |
print("REDUCED FILAMENT %s/%s TO FIT IN ORIGINAL ARRAY" % (n, num)) | |
x, y = np.where(copy_skeleton >= 1) | |
for i in range(len(x)): | |
master_array[x[i] + x_off, y[i] + y_off] = 1 | |
return master_array | |
if __name__ == "__main__": | |
# 2D numpy array with image | |
image = None | |
# rms of the image | |
rms = None | |
data = image.copy() | |
from scipy.ndimage.filters import gaussian_filter | |
data = gaussian_filter(data, 5) | |
mask = data < 3. * rms | |
data[mask] = 0 | |
data[~mask] = 1 | |
skel, distance = medial_axis(data, return_distance=True) | |
dist_on_skel = distance * skel | |
# Plot area and skeleton | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True, | |
subplot_kw={'adjustable': 'box-forced'}) | |
ax1.imshow(data, cmap=plt.cm.gray, interpolation='nearest') | |
ax1.axis('off') | |
ax2.imshow(dist_on_skel, cmap=plt.cm.spectral, interpolation='nearest') | |
ax2.contour(data, [0.5], colors='w') | |
ax2.axis('off') | |
fig.tight_layout() | |
plt.show() | |
fig.savefig('skeleton_orig.png') | |
plt.close() | |
isolated_filaments, num, offsets = isolateregions(skel) | |
interpts, hubs, ends, filbranches, labeled_fil_arrays =\ | |
pix_identify(isolated_filaments, num) | |
branch_properties = init_lengths(labeled_fil_arrays, filbranches, offsets, data) | |
branch_properties["number"] = filbranches | |
edge_list, nodes = pre_graph(labeled_fil_arrays, branch_properties, interpts, | |
ends) | |
max_path, extremum, G = longest_path(edge_list, nodes, verbose=True, | |
save_png=False, | |
skeleton_arrays=labeled_fil_arrays) | |
updated_lists = prune_graph(G, nodes, edge_list, max_path, labeled_fil_arrays, | |
branch_properties, length_thresh=20, | |
relintens_thresh=0.1) | |
labeled_fil_arrays, edge_list, nodes, branch_properties = updated_lists | |
filament_extents = extremum_pts(labeled_fil_arrays, extremum, ends) | |
length_output = main_length(max_path, edge_list, labeled_fil_arrays, interpts, | |
branch_properties["length"], 1, verbose=True) | |
filament_arrays = {} | |
lengths, filament_arrays["long path"] = length_output | |
lengths = np.asarray(lengths) | |
filament_arrays["final"] = make_final_skeletons(labeled_fil_arrays, interpts, | |
verbose=True) | |
skeleton = recombine_skeletons(filament_arrays["final"], offsets, data.shape, | |
0, verbose=True) | |
skeleton_longpath = recombine_skeletons(filament_arrays["long path"], offsets, | |
data.shape, 1) | |
skeleton_longpath_dist = skeleton_longpath * distance | |
# Plot area and skeleton | |
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharex=True, sharey=True, | |
subplot_kw={'adjustable': 'box-forced'}) | |
ax1.imshow(data, cmap=plt.cm.gray, interpolation='nearest') | |
ax1.axis('off') | |
ax2.imshow(skeleton_longpath_dist, cmap=plt.cm.spectral, | |
interpolation='nearest') | |
ax2.contour(data, [0.5], colors='w') | |
ax2.axis('off') | |
fig.tight_layout() | |
plt.savefig('skeleton.png') | |
plt.show() | |
plt.close() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment