Skip to content

Instantly share code, notes, and snippets.

@ZGainsforth
Last active July 7, 2023 17:24
Show Gist options
  • Save ZGainsforth/9b9171d7e8b01faad07907c018d7d42c to your computer and use it in GitHub Desktop.
Save ZGainsforth/9b9171d7e8b01faad07907c018d7d42c to your computer and use it in GitHub Desktop.
FFT an HRTEM image stack, Radially integrate it and then fit the integration with a background and gaussians.
import numpy as np
import hyperspy.api as hs
from scipy.ndimage import map_coordinates
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from tifffile import imwrite, TiffWriter, TiffFile
from lmfit import Model, CompositeModel, Parameters
from lmfit.models import GaussianModel, QuadraticModel, PseudoVoigtModel
from functools import reduce
import operator
def cart2pol(x, y):
r = np.hypot(x, y)
t = np.arctan2(y, x)
return t, r
# from https://stackoverflow.com/questions/21242011/most-efficient-way-to-calculate-radial-profile
def radial_profile(data, center):
# Create a grid of indices
y, x = np.indices((data.shape))
assert not data.shape[0] % 2 and not data.shape[1] % 2, 'Input shape is not even so pixel alignment will be off. radial_profile expects the center of the image to be between the center four pixels.'
# Calculate the distance of each point in the grid to the center
# The center is shifted by half a pixel, as it is between the center four pixels
r = np.sqrt((x - center[0]) ** 2 + (y - center[1]) ** 2)
r = r.astype(int)
# Bin the pixel values by their distance to the center
tbin = np.bincount(r.ravel(), data.ravel())
# Count the number of pixels at each distance
nr = np.bincount(r.ravel())
# Calculate the average pixel value at each distance, replacing NaN values with zero
radialprofile = np.nan_to_num(tbin / nr)
# Create a corresponding array of distances
radius = np.arange(0, np.max(r) + 1, 1)
# Return the radius and radial profile
return radius, radialprofile
# Truncate the radial profile to within a specific d-spacing range.
def truncate_profile(radius, radialprofile, rmin, rmax):
radialprofile = radialprofile[radius<rmax]
radius = radius[radius<rmax]
radialprofile = radialprofile[radius>rmin]
radius = radius[radius>rmin]
return radius, radialprofile
def print_and_write(message, file):
print(message)
print(message, file=file)
def integrate_one_dm3(filename, truncateangstroms):
# Load the HRTEM image
f = hs.load(filename)
# print(f.axes_manager['x'].scale) # value of HRTEM image pixel units (e.g. 0.01)
# print(f.axes_manager['x'].units) # gives the units of the HRTEM image pixels (e.g. nm).
img = f.data.astype('float32')
assert img.shape[0] == img.shape[1], 'The FFT for this code assumes a square image.'
# Perform FFT and shift the zero frequency component to the center
fft_img = np.fft.fftshift(np.fft.fft2(img))
# Compute magnitude spectrum
fft_image = np.abs(fft_img)
# Radially integrate the power spectrum
center = fft_image.shape[0]//2, fft_image.shape[1]//2
radius, radialprofile = radial_profile(fft_image, center)
# np.savetxt(f'{filename} profile pixels.txt', np.stack((radius, radialprofile), axis=1))
# Convert the radial units to d-spacing
r = float(f.axes_manager['x'].scale)*img.shape[0] / radius # Convert 1/px to nm.
assert f.axes_manager['x'].units == 'nm', 'Image units are not nm. Is this an HRTEM image?'
# np.savetxt(f'{filename} profile invnm.txt', np.stack((r, radialprofile), axis=1))
r *= 10 # convert nm to A.
r = np.nan_to_num(r)
np.savetxt(f'{filename} profile invA.txt', np.stack((r, radialprofile), axis=1))
r, radialprofile = truncate_profile(r, radialprofile, truncateangstroms[0], truncateangstroms[1])
# params, params_covariance = curve_fit(model, r, radialprofile, p0=params_init, bounds=(lower_bounds, upper_bounds))
# print_fit_parameters_and_uncertainties(params, params_covariance, filename=f'{filename} radial integration fit.txt')
return f, fft_image, r, radialprofile#, params, params_covariance
def fit_one_profile(background_model, peak_model, model, radialprofile, r, lock_background=False):
# Make parameters for the model and set initial guesses
params = background_model.make_params()
# Fit Background
bkg_result = background_model.fit(radialprofile, params, x=r)
params.update(bkg_result.params)
# Fit Peaks with fixed background
for name, param in params.items():
if "bkg" in name: # If this is a background parameter...
param.set(vary=False) # Fix its value
peak_params = peak_model.make_params()
peak_params.update(params)
peak_result = peak_model.fit(radialprofile, peak_params, x=r)
# Update params with the result of the peak fit
params.update(peak_result.params)
if lock_background == False:
# Full Fit (background + peaks)
for name, param in params.items():
param.set(vary=True) # Allow all parameters to vary
fit_result = model.fit(radialprofile, params, x=r)
return fit_result
def show_fit(r, radialprofile, fit_result, filename):
# Get curves for the components and background
components = fit_result.eval_components(params=fit_result.params)
background = sum(components[comp_name] for comp_name in components if 'bkg_' in comp_name)
# Plotting the data, the fit and the individual components
fig, (ax1, ax2) = plt.subplots(2, 1) # create subplots
fig.set_size_inches(10, 6)
ax1.set_xlabel('$\mathrm{\\AA}$', fontsize=18)
ax1.set_ylabel('Integrated Intensity', fontsize=18)
ax1.plot(r, radialprofile, color='blue', linewidth=4)
ax1.plot(r, fit_result.best_fit, color='red', linewidth=4, alpha=0.8)
ax1.plot(r, background, color='orange', linewidth=4, alpha=0.8)
ax1.tick_params(axis='both', labelsize=12) # increase the font size of the ticks
ax1.legend(['Data', 'Model', 'Background'], fontsize=12, loc='upper left')
ax1.tick_params(axis='y')
ax2.set_xlabel('$\mathrm{\\AA}$', fontsize=18)
ax2.set_ylabel('Integrated Intensity\n - background', fontsize=18)
ax2.plot(r, radialprofile - background, color='blue', linewidth=4)
ax2.plot(r, fit_result.best_fit - background, color='red', linewidth=4, alpha=0.8)
legend_labels = ['Data-background', 'Model']
for i, comp_name in enumerate(components):
if 'peak' in comp_name:
ax2.plot(r, components[comp_name], color='orange', linewidth=4, alpha=0.8)
label = f'{type(fit_result.model.components[i]).__name__} at {fit_result.params[comp_name + "center"].value:.2f} '+'$\mathrm{\\AA}$'
legend_labels.append(label)
ax2.legend(legend_labels, fontsize=12, loc='upper left', ncol=2)
ax2.tick_params(axis='both', labelsize=12) # increase the font size of the ticks
ax1.set_title(filename, fontsize=18)
fig.tight_layout()
plt.savefig(f'{filename} radial integration fit.png', dpi=300)
# Output the text information about the fit (parameter values and uncertainties).
with open(f'{filename} fit parameters.txt', 'w') as f:
print_and_write(fit_result.fit_report(), f)
def ome_to_resolution_cm(metadata):
match metadata['PhysicalSizeXUnit']:
case 'A' | 'Å' | '1/A' | '1/Å':
scale = 1e8
case 'nm' | '1/nm':
scale = 1e7
case 'um' | 'µm' | '1/um' | '1/µm':
scale = 1e4
case 'mm':
scale = 10
case 'cm':
scale = 1
case 'm':
scale = 0.01
xval = scale/metadata['PhysicalSizeX']
yval = scale/metadata['PhysicalSizeY']
return (xval, yval)
def write_ome_tif_image(fileName, img, metadata):
ome_metadata={
'axes': 'YX',
'PixelType': 'float32',
'BigEndian': False,
'SizeX': img.shape[0],
'SizeY': img.shape[1],
'PhysicalSizeX': metadata['PhysicalSizeX'], # Pixels/unit
'PhysicalSizeXUnit': metadata['PhysicalSizeXUnit'],
'PhysicalSizeY': metadata['PhysicalSizeY'], # Pixels/unit
'PhysicalSizeYUnit': metadata['PhysicalSizeYUnit'],
'ranges': [0.0,1.0],
}
resolution = ome_to_resolution_cm(ome_metadata)
with TiffWriter(f'{fileName}.ome.tif') as tif:
mean = np.mean(img)
std = np.std(img)
minValTag = (280, # 280=MinSampleValue TIFF tag. See https://www.loc.gov/preservation/digital/formats/content/tiff_tags.shtml
11, # dtype float32
1, # one value in the tag.
mean-std, # What the value is.
False, # Write it to the first page of the Tiff only.
)
maxValTag = (281, 11, 1, mean+std, False) # MaxSampleValue TIFF tag.
tif.write(img, photometric='minisblack', metadata=ome_metadata, resolution=resolution, resolutionunit='CENTIMETER', extratags=[minValTag, maxValTag])
if __name__ == '__main__':
filenames = [f'{x:04d}' for x in range(71,76)]
# Range over which to consider the fit. We don't want to go to close to the center of the FFT or it's really nonlinear and will mess up the background fit.
truncateangstroms = [1.0, 10.0]
lock_background = False # True means fit background and then peaks separately -- False = fit all at last step.
# Instantiate the models for the background
background = []
bkg_name = f'bkg_scatter_'
background.append(QuadraticModel(prefix=bkg_name))
background[-1].set_param_hint(f'{bkg_name}a', value=43.03543)
background[-1].set_param_hint(f'{bkg_name}b', value=-14899.0078)
background[-1].set_param_hint(f'{bkg_name}c', value=-50482.6120)
bkg_name = f'bkg_ctf_'
background.append(PseudoVoigtModel(prefix=bkg_name))
background[-1].set_param_hint(f'{bkg_name}amplitude', min=1.0, max=1e9)
background[-1].set_param_hint(f'{bkg_name}center', min=5.0, max=10.0)
background[-1].set_param_hint(f'{bkg_name}sigma', min=0.5, max=10.0)
background[-1].set_param_hint(f'{bkg_name}fraction', min=0, max=1)
bkg_name = f'bkg_ctf2_'
background.append(GaussianModel(prefix=bkg_name))
background[-1].set_param_hint(f'{bkg_name}amplitude', min=1.0, max=1e9)
background[-1].set_param_hint(f'{bkg_name}center', min=5.0, max=10.0)
background[-1].set_param_hint(f'{bkg_name}sigma', min=0.5, max=10.0)
# Instantiate models for the peaks which are not part of the background.
peaks = []
peakname = f'peak{len(peaks)+1}_'
peaks.append(GaussianModel(prefix=peakname))
peaks[-1].set_param_hint(f'{peakname}amplitude', min=1.0, max=1e6)
peaks[-1].set_param_hint(f'{peakname}center', min=1.9, max=2.5)
peaks[-1].set_param_hint(f'{peakname}sigma', min=0.01, max=0.5)
peakname = f'peak{len(peaks)+1}_'
peaks.append(GaussianModel(prefix=peakname))
peaks[-1].set_param_hint(f'{peakname}amplitude', min=1.0, max=1e6)
peaks[-1].set_param_hint(f'{peakname}center', min=3.5, max=4.5)
peaks[-1].set_param_hint(f'{peakname}sigma', min=0.01, max=0.5)
# peakname = f'peak{len(peaks)+1}_'
# peaks.append(GaussianModel(prefix=peakname))
# peaks[-1].set_param_hint(f'{peakname}amplitude', min=1.0, max=1e6)
# peaks[-1].set_param_hint(f'{peakname}center', min=5.0, max=7.0)
# peaks[-1].set_param_hint(f'{peakname}sigma', min=0.01, max=1.5)
# Make parameters for the model and set initial guesses
# background_model = bkg_scatter + bkg_ctf
background_model = reduce(operator.add, background)
peak_model = reduce(operator.add, peaks)
model = background_model + peak_model
radialprofiles = []
fft_images = []
for filename in filenames:
print(f'Fitting {filename}.')
image, fft_image, r, radialprofile = integrate_one_dm3(filename + ".dm3", truncateangstroms)
fft_images.append(fft_image)
fit_result = fit_one_profile(background_model, peak_model, model, radialprofile, r, lock_background)
show_fit(r, radialprofile, fit_result, filename)
radialprofiles.append(radialprofile)
# Take the average of the radial profiles
radialprofile_avg = np.mean(radialprofiles, axis=0)
# # Fit the average radial profile
fit_result = fit_one_profile(background_model, peak_model, model, radialprofile_avg, r, lock_background)
# Plot the fit for the average radial profile
show_fit(r, radialprofile_avg, fit_result, f'{filenames[0]}-{filenames[-1]}')
# Create sum image
sum_image = np.sum(fft_images, axis=0)
# Plot the FFT image.
plt.figure()
magmean = np.mean(np.log10(sum_image))
magstd = np.std(np.log10(sum_image))
plt.imshow(np.log10(sum_image), vmin=magmean-magstd*2, vmax=magmean+magstd*4)
plt.title(f'{filenames[0]}-{filenames[-1]} FFT magnitude sum log')
metadata = {
"PhysicalSizeX": 1/float(image.axes_manager['x'].scale),
"PhysicalSizeXUnit": f"1/{image.axes_manager['x'].units}",
"PhysicalSizeY": 1/float(image.axes_manager['y'].scale),
"PhysicalSizeYUnit": f"1/{image.axes_manager['y'].units}",
}
write_ome_tif_image(f'{filenames[0]}-{filenames[-1]} FFT magnitude sum log', np.log10(sum_image).astype('float32'), metadata)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment