Created
June 1, 2016 10:24
-
-
Save thomasaarholt/f85ef8360682a256de260d0343d77f83 to your computer and use it in GitHub Desktop.
Functions to remove high-energy cosmic ray spikes in EELS data
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def median_from_neighbors(energy_slice, x, y=""): | |
"""Takes the median of the value of the list neighbors""" | |
import numpy as np | |
return(np.median(find_cell_neighbors(energy_slice, x,y))) | |
def find_cell_neighbors(data, X, Y, r=1): | |
"""Finds value of neighbors of a index in the SI, excluding the centre and any values outside the edges of the SI""" | |
adjacent = [] | |
if Y == "": | |
# Line spectrum | |
for x in range(X-r, X+r+1): | |
if (x == X): # Do not include the spike itself | |
pass | |
elif (x < 0) or (x >= data.shape[0]): # Do not include indices outside the edges | |
pass | |
else: | |
adjacent.append(data[x]) | |
else: | |
# map | |
for x in range(X-r, X+r+1): | |
for y in range(Y-r, Y+r+1): | |
if (x == X) and (y == Y): # Do not include the spike itself | |
pass | |
elif (x < 0) or (x >= data.shape[0]): # Do not include indices outside the edges | |
pass | |
elif (y < 0) or (y >= data.shape[1]): # Do not include indices outside the edges | |
pass | |
else: | |
adjacent.append(data[x,y]) | |
return(adjacent) | |
def plot_spike_histogram(diff, diff_after): | |
"""Plots a single plot of the histogram of data differential before and after spikes removal""" | |
import matplotlib.pyplot as plt | |
import numpy as np | |
hist_before, bins_before = np.histogram(diff, bins="auto") | |
hist_after, bins_after = np.histogram(diff_after, bins="auto") | |
center_before = (bins_before[:-1] + bins_before[1:]) / 2 | |
center_after = (bins_after[:-1] + bins_after[1:]) / 2 | |
width = 5.0 # Width of column | |
plt.bar(center_before, hist_before, width=width, color = "red", label='before', log=True) | |
plt.bar(center_after, hist_after, width=width, color = "blue", label='after', log=True) | |
ax = plt.gca() | |
(ymin, ymax) = ax.get_ylim() | |
ax.set_ylim([1e-1,ymax]) | |
plt.legend() | |
plt.title("Histogram of signal derivative") | |
plt.xlabel("Derivative Magnitude") | |
plt.ylabel("Counts") | |
return | |
def remove_spikes(s=None, nMAD = 8, plot_difference=False): | |
"""Removes spikes outside nMad Median Absolute Deviations of the median of the differential of the spectrum""" | |
from statsmodels.robust import scale | |
import numpy as np | |
shape = s.data.shape | |
if len(shape) == 1: | |
# Case single spectrum | |
raise AttributeError("Remove spikes requires a Spectrum Image with non-zero navigation dimension") | |
elif len(shape) == 2: | |
# Case Line profile | |
print("Recognised as Line Profile") | |
diff = np.diff(s.data, axis=0) | |
threshold = nMAD * scale.mad(diff.flatten()) # MAD - nMad is the number of deviations away from the median are included | |
print("Gradient threshold is " + str(threshold)) | |
spike_positions = [] | |
positive = diff > threshold # Position of any inclining spikes | |
(x,e) = np.nonzero(positive) # Get index of spikes | |
x += 1 # Position of spike is the array index ahead of the gradient | |
for i in range(len(x)): | |
spike_positions.append([x[i], e[i]]) | |
negative = diff < -threshold # Position of any declining spikes | |
(x,e) = np.nonzero(negative) # Get index of spikes | |
for i in range(len(x)): | |
spike_positions.append([x[i], e[i]]) | |
print("Found " + str(len(spike_positions)) + " spikes!") | |
for (x,e) in spike_positions: | |
s.data[x,e] = median_from_neighbors(s.data[:,e], x) | |
if plot_difference == True: | |
diff_after = np.diff(s.data, axis=0) | |
plot_spike_histogram(diff, diff_after) | |
elif len(shape) == 3: | |
# Case EELS Map | |
print("Recognised as 2D Spectrum Image") | |
diff = np.diff(s.data, axis=1) # Get differential across data | |
threshold = nMAD * scale.mad(diff.flatten()) # MAD | |
print("Gradient threshold is " + str(threshold)) | |
positive = diff > threshold # Position of any inclining spikes | |
negative = diff < -threshold # Position of any declining spikes | |
spike_positions = [] | |
(x,y,e) = np.nonzero(positive) # Get index of spikes | |
y += 1 # Position of spike is the array index ahead of the gradient | |
for i in range(len(x)): | |
spike_positions.append([x[i], y[i], e[i]]) | |
(x,y,e) = np.nonzero(negative) # Get index of spikes | |
for i in range(len(x)): | |
spike_positions.append([x[i], y[i], e[i]]) | |
print("Found " + str(len(spike_positions)) + " spikes!") | |
for (x,y,e) in spike_positions: | |
# Spike intensity replaced by median of neighbors | |
s.data[x,y,e] = median_from_neighbors(s.data[:,:,e], x,y) | |
if plot_difference == True: | |
diff_after = np.diff(s.data, axis=1) | |
plot_spike_histogram(diff, diff_after) | |
else: | |
print("The signal shape does not match an expected value (X, S or X, Y, S). Signal data shape is " + str(s.data.shape)) | |
return s |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment