Skip to content

Instantly share code, notes, and snippets.

@h-mayorquin
Created August 21, 2025 02:02
Show Gist options
  • Select an option

  • Save h-mayorquin/d4292a59c54e3a6cd43b9da3374895e5 to your computer and use it in GitHub Desktop.

Select an option

Save h-mayorquin/d4292a59c54e3a6cd43b9da3374895e5 to your computer and use it in GitHub Desktop.
This is to stub an edf file with both electrode and analog data
# /// script
# requires-python = ">=3.10"
# dependencies = [
# "pyedflib>=0.1.36",
# "numpy>=1.23"
# ]
# ///
"""
One-off script to create a short "stub" EDF/EDF+ from a full file by copying
the first few data records (whole-record truncation) while keeping per-signal
sampling layout intact. Also copies annotations overlapping the kept window
and clips them.
Key EDF timing terms used below:
- record_duration_sec: duration (in seconds) of one EDF data record (aka block).
- datarecords_in_file (R_new here): number of data records written in the stub.
- Total file duration = record_duration_sec * R_new.
Why this is not general (intentionally kept simple):
- Assumes a pyedflib-compatible EDF(+) with consistent headers.
- Truncates from the start only; no resampling or complex reconciliation.
- Scales very large physical ranges to satisfy EDF+ 8-char header limits.
Usage: run directly (no CLI args). Adjust the variables below.
"""
from __future__ import annotations
from pathlib import Path
from typing import List
import math
import numpy as np
import pyedflib
def _compute_record_count_for_target(ns_per_record: List[int], available_records: int, target: int) -> int:
"""Compute how many data records to keep (R_new) under an "at least" policy.
Parameters
- ns_per_record: per-signal samples-per-record, length = number of channels.
For channel i, total samples written will be ns_per_record[i] * R_new.
- available_records: total records available in the source file.
- target: minimum samples desired per channel in the stub.
Returns
- R_new: smallest integer number of records so that ns[i] * R_new >= target
for all channels i, clamped to [1, available_records].
"""
# at_least policy: choose the smallest R such that ns[i]*R >= target for all i
need = [math.ceil(target / max(1, int(n))) for n in ns_per_record]
R = max(need) if need else 1
return max(1, min(int(R), int(available_records)))
def _make_stub(src_path: Path, out_path: Path, target_samples: int) -> dict:
"""Create a truncated EDF stub by copying whole data records from the start.
Variables used (aligned with EDF terminology):
- record_duration_sec: duration (seconds) of one EDF data record (block).
- ns_per_record[i]: samples per record for channel i (layout from source).
- R_new: number of records to write so every channel reaches target samples.
- n_samples_to_write[i] = ns_per_record[i] * R_new: samples written for channel i.
- Total stub duration = record_duration_sec * R_new.
"""
f = pyedflib.EdfReader(str(src_path))
n_sig = int(f.signals_in_file)
labels = list(f.getSignalLabels())
total_records = int(getattr(f, "datarecords_in_file", 0))
file_duration_sec = float(f.getFileDuration())
record_duration_sec = file_duration_sec / total_records
# Per-record sample counts based on declared fs
ns_per_record: List[int] = [int(round(float(f.getSampleFrequency(i)) * record_duration_sec)) for i in range(n_sig)]
# at_least rounding: smallest R with ns[i]*R >= target for all i
R_new = _compute_record_count_for_target(ns_per_record, total_records, int(target_samples))
n_samples_to_write = [int(n * R_new) for n in ns_per_record]
file_type = getattr(f, "file_type", pyedflib.FILETYPE_EDFPLUS)
writer = pyedflib.EdfWriter(str(out_path), n_sig, file_type=file_type)
# File-level header (keep what matters; keep it simple)
src_header = f.getHeader() or {}
file_header = {}
for key in (
"technician",
"recording_additional",
"patientname",
"patient_additional",
"patientcode",
"equipment",
"admincode",
"gender",
"startdate",
"birthdate",
):
if key in src_header and src_header[key] is not None:
file_header[key] = src_header[key]
if "sex" not in file_header and "gender" in src_header:
file_header["sex"] = src_header.get("gender", 0)
file_header["datarecord_duration"] = record_duration_sec
file_header["record_duration"] = record_duration_sec
file_header["duration"] = record_duration_sec
writer.setHeader(file_header)
writer.setDatarecordDuration(record_duration_sec)
# Signal headers + data
sig_headers = []
signals = []
for i in range(n_sig):
src_sh = f.getSignalHeader(i) or {}
fs = float(f.getSampleFrequency(i))
pmin = float(src_sh.get("physical_min", -100.0))
pmax = float(src_sh.get("physical_max", 100.0))
if pmin > pmax:
pmin, pmax = pmax, pmin
dmin = int(src_sh.get("digital_min", -32768))
dmax = int(src_sh.get("digital_max", 32767))
if dmin > dmax:
dmin, dmax = dmax, dmin
# Scale overly large physical ranges to satisfy EDF+ 8-char limit
m_abs = max(abs(pmin), abs(pmax))
scale_pow = 0
if m_abs and m_abs >= 1e8:
scale_pow = max(0, int(math.floor(math.log10(m_abs)) - 6))
scale_factor = float(10 ** scale_pow)
sh = {
"label": src_sh.get("label", labels[i] if i < len(labels) else f"ch{i}"),
"dimension": src_sh.get("dimension", ""),
"transducer": src_sh.get("transducer", ""),
"prefilter": src_sh.get("prefilter", ""),
"physical_min": (pmin / scale_factor) if scale_pow > 0 else pmin,
"physical_max": (pmax / scale_factor) if scale_pow > 0 else pmax,
"digital_min": dmin,
"digital_max": dmax,
"sample_frequency": fs,
"samples_per_record": int(ns_per_record[i]),
}
sig_headers.append(sh)
n_to_read = int(n_samples_to_write[i])
data_list = list(f.readSignal(i, start=0, n=n_to_read)) if n_to_read > 0 else []
if scale_pow > 0 and data_list:
data_list = [x / scale_factor for x in data_list]
pf_note = f"STUB_SCALE=1e{scale_pow}"
prev = sh.get("prefilter", "") or ""
sh["prefilter"] = (prev + (" | " if prev else "") + pf_note)
signals.append(data_list)
writer.setSignalHeaders(sig_headers)
for i in range(n_sig):
writer.setSampleFrequency(i, float(ns_per_record[i]) / float(record_duration_sec))
writer.writeSamples([np.asarray(sig, dtype=float) for sig in signals])
# Copy annotations overlapping [0, D_new]
onsets, durations, descriptions = f.readAnnotations()
D_new = float(R_new) * float(record_duration_sec)
kept_total = 0
kept_events = 0
kept_epochs = 0
clipped = 0
min_onset = None
max_onset = None
for onset, dur, desc in zip(onsets, durations, descriptions):
onset_f = float(onset)
dur_f = float(dur)
start = onset_f
end = onset_f + max(0.0, dur_f)
if (dur_f == 0.0 and 0.0 <= start < D_new) or (end > 0.0 and start < D_new):
start_clipped = max(0.0, start)
if dur_f == 0.0:
dur_clipped = 0.0
else:
dur_clipped = max(0.0, min(end, D_new) - start_clipped)
writer.writeAnnotation(start_clipped, dur_clipped, str(desc))
kept_total += 1
if dur_f == 0.0:
kept_events += 1
else:
kept_epochs += 1
if (start_clipped != start) or (dur_clipped != dur_f):
clipped += 1
if min_onset is None or start_clipped < min_onset:
min_onset = start_clipped
if max_onset is None or start_clipped > max_onset:
max_onset = start_clipped
writer.close()
f.close()
return {
"n_signals": n_sig,
"record_duration_sec": record_duration_sec,
"R_new": R_new,
"samples_per_record": ns_per_record,
"samples_written_per_signal": n_samples_to_write,
"out_path": str(out_path),
"annotation_summary": {
"source_total": int(len(onsets)),
"kept_total": kept_total,
"kept_events": kept_events,
"kept_epochs": kept_epochs,
"clipped": clipped,
"min_onset": min_onset,
"max_onset": max_onset,
"window_sec": D_new,
},
}
# --- Adjust these for your one-off run ---
full_edf_file_path = Path("X~ X_ea8077ef-0800-4c23-bfa3-d317b55be08f_0002.edf")
stub_out_path = Path("stub_edf.edf")
target_samples_per_channel = 100 # minimum target per channel (at_least)
if not full_edf_file_path.exists():
raise SystemExit(f"Source EDF not found: {full_edf_file_path}")
info = _make_stub(full_edf_file_path, stub_out_path, target_samples_per_channel)
print("Stub created:")
print({
"file": info.get("out_path"),
"signals": info.get("n_signals"),
"record_duration_sec": info.get("record_duration_sec"),
"R_new": info.get("R_new"),
"kept_annotations": (info.get("annotation_summary") or {}).get("kept_total"),
"window_sec": (info.get("annotation_summary") or {}).get("window_sec"),
})
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment