Created
May 28, 2025 17:48
-
-
Save alisterburt/1b2ae72a5b4607d4c365c0a81edd6baf to your computer and use it in GitHub Desktop.
ts simulation for marten
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Tuple | |
import mmdf | |
import numpy as np | |
import torch | |
from fast_histogram import histogramdd | |
import einops | |
from scipy.stats import special_ortho_group | |
from torch_affine_utils.transforms_3d import Rz, Ry, T | |
from torch_affine_utils import homogenise_coordinates | |
import torch.nn.functional as F | |
def load_atoms_zyx(filepath: str, pixel_spacing: float, ca_only: bool = True) -> np.ndarray: | |
"""Load centered atomic coordinates, optionally filter for CA atoms only.""" | |
df = mmdf.read(filepath) | |
if ca_only is True: | |
idx_ca = df['atom'] == 'CA' | |
df = df[idx_ca] | |
zyx = df[['z', 'y', 'x']].to_numpy() | |
# convert angstroms to pixels | |
zyx = zyx / pixel_spacing | |
# center atoms around [0, 0, 0] | |
zyx -= einops.reduce(zyx, 'b zyx -> zyx', reduction='mean') | |
return zyx | |
def generate_particle_orientations(n_particles: int) -> np.ndarray: | |
"""Generate random rotation matrices for particle orientations.""" | |
return special_ortho_group.rvs(dim=3, size=n_particles) | |
def generate_particle_positions( | |
n_particles: int, | |
volume_size: Tuple[int, int, int], | |
seed: int = 42 | |
) -> np.ndarray: | |
"""Generate random positions for particles within the volume.""" | |
d, h, w = volume_size | |
rng = np.random.default_rng(seed=seed) | |
# Generate positions for particles central 1/3 of depth dim | |
dl, du = d // 3, 2 * d // 3 | |
particle_positions = rng.uniform( | |
size=(n_particles, 3), | |
low=(dl, 0, 0), | |
high=(du, h, w) | |
) | |
return particle_positions | |
def rotate_and_place_particles( | |
atom_zyx: np.ndarray, # (n_atoms, 3) | |
rotation_matrices: np.ndarray, # (n_particles, 3, 3) | |
positions: np.ndarray # (n_particles, 3, 3) | |
) -> torch.Tensor: | |
"""Rotate and place all atoms for each particle orientation and position.""" | |
# setup for broadcasting | |
# target: (n_particles, n_atoms, 3, 1) | |
n_particles, n_atoms = len(rotation_matrices), len(atom_zyx) | |
rotation_matrices = einops.repeat(rotation_matrices, "p i j -> p a i j", a=n_atoms) | |
atom_zyx = einops.repeat(atom_zyx, "a zyx -> p a zyx 1", p=n_particles) | |
# rotate atoms and remove trailing dim with length 1 | |
atom_zyx = rotation_matrices @ atom_zyx | |
atom_zyx = einops.rearrange(atom_zyx, 'p a zyx 1 -> p a zyx') | |
# Translate rotated atoms to particle positions | |
atom_zyx += einops.repeat(positions, 'p zyx -> p a zyx', a=n_atoms) | |
return torch.tensor(atom_zyx).float() | |
def rasterize_2d( | |
coordinates_yx: np.ndarray, | |
image_shape: Tuple[int, int] | |
) -> np.ndarray: | |
"""Rasterize 2D coordinates into a 2D image.""" | |
h, w = image_shape | |
h_min, h_max = (-0.5, h - 1 + 0.5) | |
w_min, w_max = (-0.5, w - 1 + 0.5) | |
image = histogramdd( | |
sample=coordinates_yx, | |
bins=(h, w), | |
range=[[h_min, h_max], [w_min, w_max]] | |
) | |
return image | |
def rasterize_3d( | |
coordinates_zyx: np.ndarray, | |
image_shape: Tuple[int, int, int] | |
) -> np.ndarray: | |
"""Rasterize 3D coordinates into a 3D image.""" | |
d, h, w = image_shape | |
d_min, d_max = (-0.5, d - 1 + 0.5) | |
h_min, h_max = (-0.5, h - 1 + 0.5) | |
w_min, w_max = (-0.5, w - 1 + 0.5) | |
image = histogramdd( | |
sample=coordinates_zyx, | |
bins=(d, h, w), | |
range=[[d_min, d_max], [h_min, h_max], [w_min, w_max]] | |
) | |
return image | |
if __name__ == "__main__": | |
# Constants | |
VOLUME_SIZE = (512, 512, 512) | |
IMAGE_SHAPE = (512, 512) | |
SIMULATION_PIXEL_SPACING = 15 | |
N_PARTICLES = 100 | |
N_TILTS = 41 | |
TILT_RANGE = (-60, 60) | |
TILT_AXIS_ANGLE = 85 | |
# grab dimensions | |
d, h, w = VOLUME_SIZE | |
# load [0, 0, 0] centered particle positions in px | |
atom_zyx = load_atoms_zyx( | |
"/Users/aburt/Data/4v6x.cif", | |
ca_only=True, | |
pixel_spacing=SIMULATION_PIXEL_SPACING | |
) | |
# Generate particle orientations and positions | |
particle_rotation_matrices = generate_particle_orientations(N_PARTICLES) | |
particle_positions = generate_particle_positions(N_PARTICLES, VOLUME_SIZE) | |
# rotate centered atoms and place at particle position in volume | |
atoms_zyx = rotate_and_place_particles( | |
atom_zyx, particle_rotation_matrices, particle_positions | |
) # (n_particles, n_atoms, 3) | |
atoms_zyx = einops.rearrange(atoms_zyx, 'p a zyx -> (p a) zyx') | |
# Set up projection geometry | |
tilt_min, tilt_max = TILT_RANGE | |
tilt_angles = torch.linspace(tilt_min, tilt_max, N_TILTS) | |
# Generate random translations for each tilt | |
rng = np.random.default_rng(seed=42) | |
translations = rng.normal(size=(N_TILTS, 2), scale=0.02 * w) | |
translations = torch.tensor(translations).float() | |
translations = F.pad(translations, pad=(0, 1), value=0) | |
# Build transformation matrices | |
volume_center = torch.tensor(VOLUME_SIZE) // 2 | |
image_center = torch.tensor(IMAGE_SHAPE) // 2 | |
t0 = T(-1 * volume_center) # center volume at [0, 0, 0] | |
r0 = Ry(tilt_angles, zyx=True) # tilt volume around y axis | |
r1 = Rz(TILT_AXIS_ANGLE, zyx=True) # rotate volume in plane | |
t1 = T(translations) # 2D shifts | |
t2 = T((0, *image_center)) # [0, 0] to center of tilt image | |
# Combined transformation matrix | |
M_zyxw = t2 @ t1 @ r1 @ r0 @ t0 | |
# Project particles to 2D | |
# Extract 2D projection part of transformation matrices | |
M_yx = M_zyxw[..., [1, 2], :] # (b, 2, 4) | |
# Project particles to 2D | |
atoms_zyxw = homogenise_coordinates(atoms_zyx) # (n_atoms, 4) | |
atoms_zyxw = einops.rearrange(atoms_zyxw, 'a zyxw -> a zyxw 1') | |
M_yx = einops.repeat(M_yx, 'tilts i j -> tilts atoms i j', atoms=len(atoms_zyxw)) | |
atoms_yx = M_yx @ atoms_zyxw # (n_tilts, n_atoms, 2, 1) | |
atoms_yx = einops.rearrange(atoms_yx, 'tilt atoms yx 1 -> tilt atoms yx') | |
# Render tilt series | |
tilt_series = [ | |
rasterize_2d(np.asarray(tilt_yx), image_shape=(h, w)) | |
for tilt_yx in atoms_yx | |
] | |
tilt_series = np.stack(tilt_series, axis=0) | |
# gt volume | |
volume = rasterize_3d(coordinates_zyx=np.array(atoms_zyx), image_shape=(d, h, w)) | |
# viz | |
import napari | |
viewer = napari.Viewer() | |
viewer.add_image(tilt_series) | |
viewer.add_image(volume) | |
napari.run() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment