Last active
March 28, 2023 19:32
-
-
Save alisterburt/88133c823c70ef6f7375ff7db15320ad to your computer and use it in GitHub Desktop.
Rotational average 3D for Pranav
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from pathlib import Path | |
from typing import List, Sequence, Tuple | |
import einops | |
import mrcfile | |
import numpy as np | |
import torch | |
import typer | |
cli = typer.Typer(name='raps_3d', no_args_is_help=True, add_completion=False) | |
def rfft_shape_from_signal_shape(input_shape: Sequence[int]) -> Tuple[int]: | |
"""Get the output shape of an rfft on an input with input_shape.""" | |
rfft_shape = list(input_shape) | |
rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1) | |
return tuple(rfft_shape) | |
def fft_center( | |
grid_shape: Tuple[int, ...], fftshifted: bool, rfft: bool | |
) -> torch.Tensor: | |
"""Return the indices of the fftshifted DFT center.""" | |
fft_center = torch.zeros(size=(len(grid_shape),)) | |
grid_shape = torch.as_tensor(grid_shape).float() | |
if rfft is True: | |
grid_shape = torch.tensor(rfft_shape_from_signal_shape(grid_shape)) | |
if fftshifted is True: | |
fft_center = torch.divide(grid_shape, 2, rounding_mode='floor') | |
if rfft is True: | |
fft_center[-1] = 0 | |
return fft_center | |
def _indices_centered_on_dc_for_shifted_rfft( | |
rfft_shape: Sequence[int] | |
) -> torch.Tensor: | |
rfft_shape = torch.tensor(rfft_shape) | |
rfftn_dc_idx = torch.div(rfft_shape, 2, rounding_mode='floor') | |
rfftn_dc_idx[-1] = 0 | |
rfft_indices = torch.tensor(np.indices(rfft_shape)) # (c, (d), h, w) | |
rfft_indices = einops.rearrange(rfft_indices, 'c ... -> ... c') | |
return rfft_indices - rfftn_dc_idx | |
def _distance_from_dc_for_shifted_rfft(rfft_shape: Sequence[int]) -> torch.Tensor: | |
centered_indices = _indices_centered_on_dc_for_shifted_rfft(rfft_shape) | |
return einops.reduce(centered_indices ** 2, '... c -> ...', reduction='sum') ** 0.5 | |
def _indices_centered_on_dc_for_shifted_dft( | |
dft_shape: Sequence[int], rfft: bool | |
) -> torch.Tensor: | |
if rfft is True: | |
return _indices_centered_on_dc_for_shifted_rfft(dft_shape) | |
dft_indices = torch.tensor(np.indices(dft_shape)).float() | |
dft_indices = einops.rearrange(dft_indices, 'c ... -> ... c') | |
dc_idx = fft_center(dft_shape, fftshifted=True, rfft=False) | |
return dft_indices - dc_idx | |
def _distance_from_dc_for_shifted_dft( | |
dft_shape: Sequence[int], rfft: bool | |
) -> torch.Tensor: | |
idx = _indices_centered_on_dc_for_shifted_dft(dft_shape, rfft=rfft) | |
return einops.reduce(idx ** 2, '... c -> ...', reduction='sum') ** 0.5 | |
def indices_centered_on_dc_for_dft( | |
dft_shape: Sequence[int], rfft: bool, fftshifted: bool | |
) -> torch.Tensor: | |
dft_indices = _indices_centered_on_dc_for_shifted_dft(dft_shape, rfft=rfft) | |
dft_indices = einops.rearrange(dft_indices, '... c -> c ...') | |
if fftshifted is False: | |
dims_to_shift = tuple(torch.arange(start=-1 * len(dft_shape), end=0, step=1)) | |
dims_to_shift = dims_to_shift[:-1] if rfft is True else dims_to_shift | |
dft_indices = torch.fft.ifftshift(dft_indices, dim=dims_to_shift) | |
return einops.rearrange(dft_indices, 'c ... -> ... c') | |
def distance_from_dc_for_dft( | |
dft_shape: Sequence[int], rfft: bool, fftshifted: bool | |
) -> torch.Tensor: | |
idx = indices_centered_on_dc_for_dft(dft_shape, rfft=rfft, fftshifted=fftshifted) | |
return einops.reduce(idx ** 2, '... c -> ...', reduction='sum') ** 0.5 | |
def _find_shell_indices_1d( | |
distances: torch.Tensor, n_shells: int | |
) -> List[torch.Tensor]: | |
"""Find indices into a vector of distances for shells 1 unit apart.""" | |
sorted, sort_idx = torch.sort(distances, descending=False) | |
split_points = torch.linspace(start=0.5, end=n_shells - 0.5, steps=n_shells) | |
split_idx = torch.searchsorted(sorted, split_points) | |
return torch.tensor_split(sort_idx, split_idx)[:-1] | |
def _split_into_shells_3d( | |
image: torch.Tensor, n_shells: int, rfft: bool = False, fftshifted: bool = True | |
) -> List[torch.Tensor]: | |
d, h, w = image.shape[-3:] | |
distances = distance_from_dc_for_dft( | |
dft_shape=(d, h, w), rfft=rfft, fftshifted=fftshifted | |
) | |
distances = einops.rearrange(distances, 'd h w -> (d h w)') | |
per_shell_indices = _find_shell_indices_1d(distances, n_shells=n_shells) | |
image = einops.rearrange(image, '... d h w -> ... (d h w)') | |
shells = [ | |
image[..., shell_idx] | |
for shell_idx in per_shell_indices | |
] | |
return shells | |
def rotational_average_3d( | |
image: torch.Tensor, rfft: bool = False, fftshifted: bool = True | |
) -> torch.Tensor: | |
n_shells = image.shape[-3] // 2 | |
shells = _split_into_shells_3d( | |
image, n_shells=n_shells, rfft=rfft, fftshifted=fftshifted | |
) | |
means = [ | |
einops.reduce(shell, '... shell -> ...', reduction='mean') | |
for shell in shells | |
] | |
return einops.rearrange(means, 'shells ... -> ... shells') | |
@cli.command(no_args_is_help=True) | |
def main( | |
volume_file: Path = typer.Option(..., '--volume-file', '-i'), | |
output_file: Path = typer.Option(..., '--output-file', '-o', help='text file'), | |
): | |
volume = torch.tensor(mrcfile.read(volume_file)).float() | |
with mrcfile.open(volume_file, permissive=True, header_only=True) as mrc: | |
apix = float(mrc.voxel_size.x) | |
dft = torch.fft.fftn(volume, dim=(-3, -2, -1)).abs().square() | |
raps = rotational_average_3d(dft, rfft=True, fftshifted=False).numpy() | |
spectral_idx = np.arange(len(raps)) | |
nyquist_idx = (volume.shape[-1] // 2) - 1 | |
fraction_of_nyquist = spectral_idx / nyquist_idx | |
freqs = fraction_of_nyquist * (1 / (2 * apix)) | |
del volume | |
np.savetxt(output_file, raps) | |
typer.echo(f'file with data saved to {output_file}') | |
import matplotlib.pyplot as plt | |
fig, ax = plt.subplots() | |
ax.set( | |
title='log(power) vs spatial frequency', | |
xlabel='spatial frequency (1/Å)', | |
ylabel='log(power)', | |
) | |
ax.xaxis.set_major_formatter(lambda x, _: f'1/{1 / x:.2f}') | |
ax.plot(freqs, np.log(raps)) | |
plt.show() | |
if __name__ == '__main__': | |
cli() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment