|
#!/usr/bin/env python3 |
|
import sys |
|
import argparse |
|
import os |
|
import sqlite3 |
|
try: |
|
from lxml import etree as ET |
|
except ImportError: |
|
import xml.etree.ElementTree as ET |
|
from collections import defaultdict |
|
from datetime import datetime, timedelta |
|
|
|
SLEEP_TYPE = "HKCategoryTypeIdentifierSleepAnalysis" |
|
MIN_REM_CYCLE_MIN = 3 # minimum total REM to count as a cycle boundary |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("export", metavar="export.xml") |
|
parser.add_argument("--from", dest="date_from", metavar="DATE") |
|
parser.add_argument("--to", dest="date_to", metavar="DATE") |
|
parser.add_argument("--source", help="Filter by source name (e.g. \"Vlad's Apple Watch\")") |
|
args = parser.parse_args() |
|
|
|
records_by_date = load_all_sleep_records(args.export, args.source) |
|
|
|
if args.date_from: |
|
date_to = args.date_to or args.date_from |
|
dates = [d for d in sorted(records_by_date) if args.date_from <= d <= date_to] |
|
else: |
|
dates = sorted(records_by_date) |
|
|
|
all_cycle_durations = [] |
|
for date in dates: |
|
records = records_by_date.get(date) |
|
if not records: |
|
if args.date_from: |
|
print(f"No sleep data for {date}") |
|
continue |
|
cycles = print_cycles(date, records) |
|
all_cycle_durations.extend(dur for _, _, dur, _ in cycles) |
|
|
|
if all_cycle_durations: |
|
all_cycle_durations.sort() |
|
n = len(all_cycle_durations) |
|
p90 = all_cycle_durations[min(int(n * 0.9), n - 1)] |
|
median = all_cycle_durations[n // 2] |
|
def fmt(m): h, m = divmod(m, 60); return f"{h}h {m:02d}m" |
|
print(f"Cycle duration ({n} cycles): median {fmt(median)}, P90 {fmt(p90)}") |
|
|
|
|
|
def db_path(xml_path): |
|
return os.path.splitext(xml_path)[0] + ".db" |
|
|
|
|
|
def build_db(xml_path, path): |
|
print(f"Building cache {path} ...", file=sys.stderr) |
|
root = ET.parse(xml_path).getroot() |
|
con = sqlite3.connect(path) |
|
con.execute(""" |
|
CREATE TABLE sleep ( |
|
startDate TEXT, |
|
endDate TEXT, |
|
stage TEXT, |
|
source TEXT |
|
) |
|
""") |
|
rows = [] |
|
for r in root.findall("Record"): |
|
if r.get("type") != SLEEP_TYPE: |
|
continue |
|
if "InBed" in r.get("value", ""): |
|
continue |
|
stage = ( |
|
r.get("value", "") |
|
.replace("HKCategoryValueSleepAnalysisAsleep", "") |
|
.replace("HKCategoryValueSleepAnalysis", "") |
|
) |
|
rows.append((r.get("startDate"), r.get("endDate"), stage, r.get("sourceName", ""))) |
|
con.executemany("INSERT INTO sleep VALUES (?,?,?,?)", rows) |
|
con.execute("CREATE INDEX idx_end ON sleep (endDate)") |
|
con.commit() |
|
con.close() |
|
|
|
|
|
def load_all_sleep_records(xml_path, source=None): |
|
path = db_path(xml_path) |
|
|
|
# Build DB if missing or older than the XML |
|
if not os.path.exists(path) or os.path.getmtime(path) < os.path.getmtime(xml_path): |
|
build_db(xml_path, path) |
|
|
|
con = sqlite3.connect(path) |
|
query = "SELECT startDate, endDate, stage FROM sleep" |
|
params = [] |
|
if source: |
|
query += " WHERE source = ?" |
|
params.append(source) |
|
rows = con.execute(query, params).fetchall() |
|
con.close() |
|
|
|
by_date = defaultdict(list) |
|
for start_str, end_str, stage in rows: |
|
s = parse_dt(start_str) |
|
e = parse_dt(end_str) |
|
# Records ending before noon belong to that morning's night; |
|
# records ending at noon or later are evening/onset records for the next night. |
|
if e.hour >= 12: |
|
night = (e + timedelta(days=1)).strftime("%Y-%m-%d") |
|
else: |
|
night = e.strftime("%Y-%m-%d") |
|
by_date[night].append((s, e, stage)) |
|
for records in by_date.values(): |
|
records.sort() |
|
return by_date |
|
|
|
|
|
def print_cycles(date, records): |
|
cycles = find_cycles(records) |
|
labels = [f"Cycle {i+1}" for i in range(len(cycles))] |
|
|
|
print(f"Sleep cycles for {date}") |
|
print() |
|
print(f"{'':8} {'Start':>5} {'End':>5} {'Duration':>10} Composition") |
|
print("-" * 70) |
|
for label, (start, end, dur, recs) in zip(labels, cycles): |
|
composition = summarize_stages(recs) |
|
h, m = divmod(dur, 60) |
|
dur_str = f"{h}h {m:02d}m" if h else f"{m}m" |
|
print(f"{label:8} {start.strftime('%H:%M'):>5} {end.strftime('%H:%M'):>5} {dur_str:>10} {composition}") |
|
|
|
total = sum(dur for _, _, dur, _ in cycles) |
|
h, m = divmod(total, 60) |
|
print(f"\n{'Total sleep':>32} {h}h {m:02d}m") |
|
print() |
|
return cycles |
|
|
|
|
|
def find_cycles(records): |
|
"""Split records into cycles using REM blocks as cycle boundaries.""" |
|
cycles = [] |
|
current_start = records[0][0] if records else None |
|
|
|
i = 0 |
|
while i < len(records): |
|
s, e, stage = records[i] |
|
if stage == "REM": |
|
# Merge consecutive REM blocks (possibly separated by tiny Core/Awake) |
|
rem_end = e |
|
rem_total = (e - s).total_seconds() / 60 |
|
j = i + 1 |
|
while j < len(records): |
|
ns, ne, nstage = records[j] |
|
gap = (ns - rem_end).total_seconds() / 60 |
|
if nstage == "REM" or (gap <= 5 and nstage in ("Core", "Awake")): |
|
if nstage == "REM": |
|
rem_end = ne |
|
rem_total += (ne - ns).total_seconds() / 60 |
|
j += 1 |
|
else: |
|
break |
|
|
|
if rem_total < MIN_REM_CYCLE_MIN: |
|
i = j |
|
continue |
|
|
|
# Collect all records up to and including this REM block |
|
cycle_records = [r for r in records if current_start <= r[0] < rem_end] |
|
dur = int((rem_end - current_start).total_seconds() / 60) |
|
cycles.append((current_start, rem_end, dur, cycle_records)) |
|
# Find the next record starting at or after rem_end (don't skip |
|
# records that the merge loop may have advanced j past). |
|
j = next((k for k in range(i + 1, len(records)) if records[k][0] >= rem_end), len(records)) |
|
current_start = records[j][0] if j < len(records) else None |
|
i = j |
|
else: |
|
i += 1 |
|
|
|
# Tail: remaining records after last REM |
|
if current_start is not None: |
|
tail = [r for r in records if r[0] >= current_start] |
|
if tail: |
|
tail_end = tail[-1][1] |
|
dur = int((tail_end - current_start).total_seconds() / 60) |
|
cycles.append((current_start, tail_end, dur, tail)) |
|
|
|
return cycles |
|
|
|
|
|
def summarize_stages(cycle_records): |
|
# Resolve overlaps: when records overlap, the later-starting one takes precedence. |
|
# Build a list of non-overlapping segments by clipping earlier records. |
|
segments = [] |
|
for s, e, stage in sorted(cycle_records): |
|
if segments and s < segments[-1][1]: |
|
ps, pe, pstage = segments[-1] |
|
segments[-1] = (ps, s, pstage) |
|
segments.append((s, e, stage)) |
|
|
|
totals = {} |
|
for s, e, stage in segments: |
|
dur = (e - s).total_seconds() / 60 |
|
if dur >= 1: |
|
totals[stage] = totals.get(stage, 0) + dur |
|
|
|
parts = [] |
|
for stage in ("Deep", "Core", "REM", "Awake"): |
|
if stage in totals: |
|
parts.append(f"{stage} {int(totals[stage])}m") |
|
return " → ".join(parts) |
|
|
|
|
|
def parse_dt(s): |
|
return datetime.strptime(s, "%Y-%m-%d %H:%M:%S %z") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |