Skip to content

Instantly share code, notes, and snippets.

@rmorey
Created November 14, 2024 22:46
Show Gist options
  • Save rmorey/da08efd9371730513783abf0f7d61d66 to your computer and use it in GitHub Desktop.
Save rmorey/da08efd9371730513783abf0f7d61d66 to your computer and use it in GitHub Desktop.
from pathlib import Path
from typing import NamedTuple
import pandas as pd
import numpy as np
import cv2
import os
import boto3
from botocore.config import Config
from io import StringIO
from feabas import config
DEFAULT_SUBTILE_OVERLAP = 0.2
DEFAULT_SUPERTILE_OVERLAP = 0.2
# Generate supertile map from stage_positions.csv
# TODO: remove need for pandas
def generate_supertile_map(section_path):
csv_path = Path(section_path) / 'metadata' / 'stage_positions.csv'
csv_path = str(csv_path)
print(f"{csv_path=}")
if csv_path.startswith("matrix:/"):
matrix_client = get_matrix_client()
actual_path = csv_path.replace("matrix:/", "")
bucket_name, object_key = actual_path.split("/", 1)
csv_obj = matrix_client.get_object(Bucket=bucket_name, Key=object_key)
df = pd.read_csv(StringIO(csv_obj['Body'].read().decode('utf-8')))
elif csv_path.startswith("tigerdata:/"):
tigerdata_client = get_tigerdata_client()
actual_path = csv_path.replace("tigerdata:/", "")
bucket_name, object_key = actual_path.split("/", 1)
csv_obj = tigerdata_client.get_object(Bucket=bucket_name, Key=object_key)
df = pd.read_csv(StringIO(csv_obj['Body'].read().decode('utf-8')))
else:
df = pd.read_csv(str(csv_path))
return generate_supertile_map_from_df(df)
def generate_supertile_map_from_df(df):
df.columns = [col.strip() for col in df.columns]
# Normalize the stage_x_nm and stage_y_nm values
df['norm_stage_x'] = df['stage_x_nm'].rank(method='dense').astype(int) - 1
df['norm_stage_y'] = df['stage_y_nm'].rank(method='dense').astype(int) - 1
# Determine the dimensions of the 2D array
max_x = df['norm_stage_x'].max()
max_y = df['norm_stage_y'].max()
# Initialize an array of shape (max_y+1, max_x+1) with None values
arr = np.full((max_y+1, max_x+1), None)
# Populate the array with tile_id values using numpy's advanced indexing
arr[df['norm_stage_y'].values, df['norm_stage_x'].values] = df['tile_id'].values
# Reverse the order of the rows in the array
supertile_map = arr[::-1]
return supertile_map
def generate_tile_id_map(supertile_map):
# Cricket subtile order
SUBTILE_MAP = [[6, 7, 8], [5, 0, 1], [4, 3, 2]]
tile_id_map = []
for supertile_row in supertile_map:
for subtile_row in SUBTILE_MAP:
current_row = []
for supertile in supertile_row:
if supertile is not None:
for subtile in subtile_row:
current_row.append(f"{supertile:04}_{subtile}")
tile_id_map.append(current_row)
max_length = max(len(row) for row in tile_id_map) # Find the maximum row length
for row in tile_id_map:
if len(row) < max_length:
row.extend([None] * (max_length - len(row))) # Pad shorter rows with None or other values
return np.array(tile_id_map)
def generate_tile_id_map_from_section_path(section_path):
supertile_map = generate_supertile_map(section_path)
return generate_tile_id_map(supertile_map)
def get_subtile_pos(supertile_map, subtile_size, subtile_overlap=DEFAULT_SUBTILE_OVERLAP, supertile_overlap=DEFAULT_SUPERTILE_OVERLAP):
supertile_pos = {}
supertile_x, supertile_y = 0, 0
supertile_size = 3 * subtile_size - 2 * subtile_size * subtile_overlap
for row in supertile_map:
for supertile in row:
supertile_pos[supertile] = (supertile_x, supertile_y)
supertile_x += supertile_size * (1 - supertile_overlap)
supertile_y += supertile_size * (1 - supertile_overlap)
supertile_x = 0
SUBTILE_ID_TO_XY = {
0: (1,1),
1: (2,1),
2: (2,2),
3: (1,2),
4: (0,2),
5: (0,1),
6: (0,0),
7: (1,0),
8: (2,0)
}
tile_id_map = generate_tile_id_map(supertile_map)
subtile_pos = {}
for row in tile_id_map:
for tile in row:
if tile:
supertile, subtile_y = tile.split('_')
supertile_x, supertile_y = supertile_pos[int(supertile)]
dx, dy = SUBTILE_ID_TO_XY[int(subtile_y)]
subtile_x = int(supertile_x + dx * subtile_size * (1 - subtile_overlap))
subtile_y = int(supertile_y + dy * subtile_size * (1 - subtile_overlap))
subtile_pos[tile] = (subtile_x, subtile_y)
return subtile_pos
class StitchConfig(NamedTuple):
"""Configuration for stitching a single TEM section"""
section_dir: str
"""Path to the section directory as a string, e.g., '/scratch/tem-data/bladeseq-2024.07.02-11.01.33/s013-2024.07.02-11.01.33'"""
resolution: int
"""Resolution in nanometers, e.g., 4"""
subtile_size: int
"""Size of subtiles in pixels, e.g., 6000"""
subtile_overlap: float
"""Fraction of overlap between subtiles, e.g., 0.08"""
supertile_overlap: float
"""Fraction of overlap between supertiles, e.g., 0.06"""
file_ext: str
"""File extension for image files, e.g., 'bmp'"""
def gen_stitch_coords(config: StitchConfig, output_file: str):
"""
Generate stitch coordinates based on the given configuration and save them to a file.
Args:
config (StitchConfig): The configuration object containing the stitch parameters.
output_file (str): The path to the output file where the stitch coordinates will be saved.
Returns:
None
"""
section_dir = Path(config.section_dir)
tile_root_dir = section_dir / 'subtiles'
supertile_map = generate_supertile_map(section_dir)
tile_coordinates = get_subtile_pos(supertile_map, config.subtile_size, config.subtile_overlap, config.supertile_overlap)
tile_root_dir = str(tile_root_dir).replace("matrix:/","matrix://")
tile_root_dir = str(tile_root_dir).replace("tigerdata:/","tigerdata://")
# File content preparation
file_content = [
"{ROOT_DIR}\t" + f"{tile_root_dir}",
"{RESOLUTION}\t" + str(config.resolution),
"{TILE_SIZE}\t" + "\t".join(map(str, (config.subtile_size, config.subtile_size))),
]
file_content.extend([f"tile_{tile_id}.{config.file_ext}\t{coord_x}\t{coord_y}" for tile_id, (coord_x, coord_y) in tile_coordinates.items()])
# Joining content into a single string
file_content_str = "\n".join(file_content)
# Writing the content to the file
with open(output_file, 'w') as file:
file.write(file_content_str)
def get_image_dimension(image_path):
image = cv2.imread(image_path)
if image is not None:
height, width, _ = image.shape
assert width == height
return width
else:
raise ValueError(f"Failed to load image at path: {image_path}")
def make_stitch_coord_from_local_blade_path(blade_path: str|Path, stitch_coord_path: str|Path):
blade_path = Path(blade_path)
if not blade_path.exists():
raise ValueError(f"{blade_path} does not exist.")
if "bladeseq" in blade_path.name:
contents = os.listdir(blade_path)
if len(contents) != 1:
raise ValueError(f"Expected one directory in {blade_path}, but found {contents}.")
blade_path = blade_path / contents[0]
stage_positions_path = blade_path / 'metadata/stage_positions.csv'
if not stage_positions_path.exists():
raise ValueError(f"Path {stage_positions_path} does not exist.")
stitch_coord_path = Path(stitch_coord_path)
if os.path.exists(stitch_coord_path):
raise ValueError(f"{stitch_coord_path} already exists.")
some_tile = os.listdir(blade_path / 'subtiles')[0]
subtile_size = get_image_dimension(blade_path / 'subtiles' / some_tile)
file_ext = some_tile.split('.')[-1]
stitch_config = StitchConfig(
section_dir=blade_path,
resolution=config.data_resolution(),
subtile_size=subtile_size,
subtile_overlap=DEFAULT_SUBTILE_OVERLAP, # todo: configurable
supertile_overlap=DEFAULT_SUPERTILE_OVERLAP,
file_ext=file_ext
)
gen_stitch_coords(stitch_config, stitch_coord_path)
def make_stitch_coord_from_matrix_blade_path(blade_path: Path|str, stitch_coord_path: Path|str):
blade_path = Path(blade_path)
if "bladeseq" in blade_path.name:
raise NotImplementedError(f"Matrix blade path must be a single blade directory, not a bladeseq directory.{blade_path=}")
stitch_coord_path = Path(stitch_coord_path)
if os.path.exists(stitch_coord_path):
raise ValueError(f"{stitch_coord_path} already exists.")
stitch_config = StitchConfig(
section_dir=blade_path,
resolution=config.data_resolution(),
subtile_size=6000,
subtile_overlap=DEFAULT_SUBTILE_OVERLAP, # todo: configurable
supertile_overlap=DEFAULT_SUPERTILE_OVERLAP,
file_ext='bmp'
)
gen_stitch_coords(stitch_config, stitch_coord_path)
def get_matrix_client():
endpoint_url = os.getenv("MATRIX_ENDPOINT_URL")
access_key = os.getenv("MATRIX_ACCESS_KEY")
secret_key = os.getenv("MATRIX_SECRET_KEY")
if endpoint_url is None or access_key is None or secret_key is None:
raise ValueError("MATRIX_ENDPOINT_URL, MATRIX_ACCESS_KEY, MATRIX_SECRET_KEY must be set.")
matrix_client = boto3.client(
"s3",
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
config=Config(signature_version="s3v4"),
)
return matrix_client
def make_stitch_coord_from_tigerdata_blade_path(blade_path: Path|str, stitch_coord_path: Path|str):
blade_path = Path(blade_path)
if "bladeseq" in blade_path.name:
raise NotImplementedError(f"Tigerdata blade path must be a single blade directory, not a bladeseq directory.{blade_path=}")
stitch_coord_path = Path(stitch_coord_path)
if os.path.exists(stitch_coord_path):
raise ValueError(f"{stitch_coord_path} already exists.")
stitch_config = StitchConfig(
section_dir=blade_path,
resolution=config.data_resolution(),
subtile_size=6000,
subtile_overlap=DEFAULT_SUBTILE_OVERLAP, # todo: configurable
supertile_overlap=DEFAULT_SUPERTILE_OVERLAP,
file_ext='bmp'
)
gen_stitch_coords(stitch_config, stitch_coord_path)
def get_tigerdata_client():
endpoint_url = os.getenv("TIGERDATA_ENDPOINT_URL")
access_key = os.getenv("TIGERDATA_ACCESS_KEY")
secret_key = os.getenv("TIGERDATA_SECRET_KEY")
if endpoint_url is None or access_key is None or secret_key is None:
raise ValueError("TIGERDATA_ENDPOINT_URL, TIGERDATA_ACCESS_KEY, TIGERDATA_SECRET_KEY must be set.")
tigerdata_client = boto3.client(
"s3",
endpoint_url=endpoint_url,
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
config=Config(signature_version="s3v4"),
)
return tigerdata_client
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment