Skip to content

Instantly share code, notes, and snippets.

@alisterburt
Created May 28, 2025 17:48
Show Gist options
  • Save alisterburt/1b2ae72a5b4607d4c365c0a81edd6baf to your computer and use it in GitHub Desktop.
Save alisterburt/1b2ae72a5b4607d4c365c0a81edd6baf to your computer and use it in GitHub Desktop.
ts simulation for marten
from typing import Tuple
import mmdf
import numpy as np
import torch
from fast_histogram import histogramdd
import einops
from scipy.stats import special_ortho_group
from torch_affine_utils.transforms_3d import Rz, Ry, T
from torch_affine_utils import homogenise_coordinates
import torch.nn.functional as F
def load_atoms_zyx(filepath: str, pixel_spacing: float, ca_only: bool = True) -> np.ndarray:
"""Load centered atomic coordinates, optionally filter for CA atoms only."""
df = mmdf.read(filepath)
if ca_only is True:
idx_ca = df['atom'] == 'CA'
df = df[idx_ca]
zyx = df[['z', 'y', 'x']].to_numpy()
# convert angstroms to pixels
zyx = zyx / pixel_spacing
# center atoms around [0, 0, 0]
zyx -= einops.reduce(zyx, 'b zyx -> zyx', reduction='mean')
return zyx
def generate_particle_orientations(n_particles: int) -> np.ndarray:
"""Generate random rotation matrices for particle orientations."""
return special_ortho_group.rvs(dim=3, size=n_particles)
def generate_particle_positions(
n_particles: int,
volume_size: Tuple[int, int, int],
seed: int = 42
) -> np.ndarray:
"""Generate random positions for particles within the volume."""
d, h, w = volume_size
rng = np.random.default_rng(seed=seed)
# Generate positions for particles central 1/3 of depth dim
dl, du = d // 3, 2 * d // 3
particle_positions = rng.uniform(
size=(n_particles, 3),
low=(dl, 0, 0),
high=(du, h, w)
)
return particle_positions
def rotate_and_place_particles(
atom_zyx: np.ndarray, # (n_atoms, 3)
rotation_matrices: np.ndarray, # (n_particles, 3, 3)
positions: np.ndarray # (n_particles, 3, 3)
) -> torch.Tensor:
"""Rotate and place all atoms for each particle orientation and position."""
# setup for broadcasting
# target: (n_particles, n_atoms, 3, 1)
n_particles, n_atoms = len(rotation_matrices), len(atom_zyx)
rotation_matrices = einops.repeat(rotation_matrices, "p i j -> p a i j", a=n_atoms)
atom_zyx = einops.repeat(atom_zyx, "a zyx -> p a zyx 1", p=n_particles)
# rotate atoms and remove trailing dim with length 1
atom_zyx = rotation_matrices @ atom_zyx
atom_zyx = einops.rearrange(atom_zyx, 'p a zyx 1 -> p a zyx')
# Translate rotated atoms to particle positions
atom_zyx += einops.repeat(positions, 'p zyx -> p a zyx', a=n_atoms)
return torch.tensor(atom_zyx).float()
def rasterize_2d(
coordinates_yx: np.ndarray,
image_shape: Tuple[int, int]
) -> np.ndarray:
"""Rasterize 2D coordinates into a 2D image."""
h, w = image_shape
h_min, h_max = (-0.5, h - 1 + 0.5)
w_min, w_max = (-0.5, w - 1 + 0.5)
image = histogramdd(
sample=coordinates_yx,
bins=(h, w),
range=[[h_min, h_max], [w_min, w_max]]
)
return image
def rasterize_3d(
coordinates_zyx: np.ndarray,
image_shape: Tuple[int, int, int]
) -> np.ndarray:
"""Rasterize 3D coordinates into a 3D image."""
d, h, w = image_shape
d_min, d_max = (-0.5, d - 1 + 0.5)
h_min, h_max = (-0.5, h - 1 + 0.5)
w_min, w_max = (-0.5, w - 1 + 0.5)
image = histogramdd(
sample=coordinates_zyx,
bins=(d, h, w),
range=[[d_min, d_max], [h_min, h_max], [w_min, w_max]]
)
return image
if __name__ == "__main__":
# Constants
VOLUME_SIZE = (512, 512, 512)
IMAGE_SHAPE = (512, 512)
SIMULATION_PIXEL_SPACING = 15
N_PARTICLES = 100
N_TILTS = 41
TILT_RANGE = (-60, 60)
TILT_AXIS_ANGLE = 85
# grab dimensions
d, h, w = VOLUME_SIZE
# load [0, 0, 0] centered particle positions in px
atom_zyx = load_atoms_zyx(
"/Users/aburt/Data/4v6x.cif",
ca_only=True,
pixel_spacing=SIMULATION_PIXEL_SPACING
)
# Generate particle orientations and positions
particle_rotation_matrices = generate_particle_orientations(N_PARTICLES)
particle_positions = generate_particle_positions(N_PARTICLES, VOLUME_SIZE)
# rotate centered atoms and place at particle position in volume
atoms_zyx = rotate_and_place_particles(
atom_zyx, particle_rotation_matrices, particle_positions
) # (n_particles, n_atoms, 3)
atoms_zyx = einops.rearrange(atoms_zyx, 'p a zyx -> (p a) zyx')
# Set up projection geometry
tilt_min, tilt_max = TILT_RANGE
tilt_angles = torch.linspace(tilt_min, tilt_max, N_TILTS)
# Generate random translations for each tilt
rng = np.random.default_rng(seed=42)
translations = rng.normal(size=(N_TILTS, 2), scale=0.02 * w)
translations = torch.tensor(translations).float()
translations = F.pad(translations, pad=(0, 1), value=0)
# Build transformation matrices
volume_center = torch.tensor(VOLUME_SIZE) // 2
image_center = torch.tensor(IMAGE_SHAPE) // 2
t0 = T(-1 * volume_center) # center volume at [0, 0, 0]
r0 = Ry(tilt_angles, zyx=True) # tilt volume around y axis
r1 = Rz(TILT_AXIS_ANGLE, zyx=True) # rotate volume in plane
t1 = T(translations) # 2D shifts
t2 = T((0, *image_center)) # [0, 0] to center of tilt image
# Combined transformation matrix
M_zyxw = t2 @ t1 @ r1 @ r0 @ t0
# Project particles to 2D
# Extract 2D projection part of transformation matrices
M_yx = M_zyxw[..., [1, 2], :] # (b, 2, 4)
# Project particles to 2D
atoms_zyxw = homogenise_coordinates(atoms_zyx) # (n_atoms, 4)
atoms_zyxw = einops.rearrange(atoms_zyxw, 'a zyxw -> a zyxw 1')
M_yx = einops.repeat(M_yx, 'tilts i j -> tilts atoms i j', atoms=len(atoms_zyxw))
atoms_yx = M_yx @ atoms_zyxw # (n_tilts, n_atoms, 2, 1)
atoms_yx = einops.rearrange(atoms_yx, 'tilt atoms yx 1 -> tilt atoms yx')
# Render tilt series
tilt_series = [
rasterize_2d(np.asarray(tilt_yx), image_shape=(h, w))
for tilt_yx in atoms_yx
]
tilt_series = np.stack(tilt_series, axis=0)
# gt volume
volume = rasterize_3d(coordinates_zyx=np.array(atoms_zyx), image_shape=(d, h, w))
# viz
import napari
viewer = napari.Viewer()
viewer.add_image(tilt_series)
viewer.add_image(volume)
napari.run()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment