Skip to content

Instantly share code, notes, and snippets.

@atadams
Created February 1, 2026 23:31
Show Gist options
  • Select an option

  • Save atadams/afa32291caa0216cad7f69492075dacc to your computer and use it in GitHub Desktop.

Select an option

Save atadams/afa32291caa0216cad7f69492075dacc to your computer and use it in GitHub Desktop.
PTC (variance) Analysis
#!/usr/bin/env python3
"""
Batch Photon Transfer Curve (PTC) Analysis for CR2 RAW files
Recursively processes directories of CR2 files and generates:
- Individual PTC plots for each file
- Summary CSV with all statistics
- Overview comparison plot
Usage:
python ptc_batch_analysis.py /path/to/cr2/directory [--output /path/to/output]
Requirements:
pip install rawpy numpy scipy matplotlib pandas
"""
import os
import sys
import argparse
from pathlib import Path
from datetime import datetime
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import csv
try:
import rawpy
except ImportError:
print("Installing rawpy...")
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "rawpy", "--break-system-packages", "-q"])
import rawpy
try:
import pandas as pd
HAS_PANDAS = True
except ImportError:
HAS_PANDAS = False
variance_limit = 20000
# variance_limit = 20536
def analyze_ptc(raw_data, block_size=8):
"""
Analyze Photon Transfer Curve from raw sensor data.
Returns:
means: array of mean brightness values per block
variances: array of variance values per block
"""
height, width = raw_data.shape
n_blocks_y = height // block_size
n_blocks_x = width // block_size
means = []
variances = []
for by in range(n_blocks_y):
for bx in range(n_blocks_x):
y_start = by * block_size
x_start = bx * block_size
block = raw_data[y_start:y_start + block_size,
x_start:x_start + block_size].astype(np.float64)
means.append(np.mean(block))
variances.append(np.var(block))
return np.array(means), np.array(variances)
def process_cr2_file(cr2_path, output_dir=None, generate_plot=True):
"""
Process a single CR2 file and return statistics.
Returns:
dict with all statistics, or None if processing failed
"""
result = {
'filepath': str(cr2_path),
'filename': os.path.basename(cr2_path),
'success': False,
'error': None,
}
try:
with rawpy.imread(str(cr2_path)) as raw:
raw_data = raw.raw_image.copy()
# raw_data = raw.raw_image_visible.astype(np.float32)
black_level = raw.black_level_per_channel[0]
white_level = raw.white_level
result['image_height'] = raw_data.shape[0]
result['image_width'] = raw_data.shape[1]
result['black_level'] = black_level
result['white_level'] = white_level
result['raw_min'] = int(raw_data.min())
result['raw_max'] = int(raw_data.max())
# raw_data is now available outside the with block
# Full image analysis
means, variances = analyze_ptc(raw_data, block_size=8)
result['total_blocks'] = len(means)
result['mean_min'] = float(means.min())
result['mean_max'] = float(means.max())
result['var_min'] = float(variances.min())
result['var_max'] = float(variances.max())
# Full correlation
corr_full, pval_full = stats.pearsonr(means, variances)
slope_full, intercept_full, r_value, _, _ = stats.linregress(means, variances)
result['corr_full'] = float(corr_full)
result['pval_full'] = float(pval_full)
result['slope_full'] = float(slope_full)
result['r_squared_full'] = float(r_value ** 2)
# 95th percentile filtered (removes high-variance outliers)
var_p95 = np.percentile(variances, 95)
mask_p95 = variances <= var_p95
if mask_p95.sum() > 10:
corr_p95, _ = stats.pearsonr(means[mask_p95], variances[mask_p95])
slope_p95, _, _, _, _ = stats.linregress(means[mask_p95], variances[mask_p95])
result['corr_p95'] = float(corr_p95)
result['slope_p95'] = float(slope_p95)
else:
result['corr_p95'] = None
result['slope_p95'] = None
# Analysis in brightness range 1600-2200 (the range claimed in original)
mask_1600_2200 = (means >= 1600) & (means <= 2200)
result['blocks_in_1600_2200'] = int(mask_1600_2200.sum())
if mask_1600_2200.sum() > 10:
m_range = means[mask_1600_2200]
v_range = variances[mask_1600_2200]
corr_range, _ = stats.pearsonr(m_range, v_range)
slope_range, _, _, _, _ = stats.linregress(m_range, v_range)
result['corr_1600_2200'] = float(corr_range)
result['slope_1600_2200'] = float(slope_range)
result['var_min_in_1600_2200'] = float(v_range.min())
result['var_max_in_1600_2200'] = float(v_range.max())
result['blocks_var_lt_20k_in_1600_2200'] = int((v_range < variance_limit).sum())
else:
result['corr_1600_2200'] = None
result['slope_1600_2200'] = None
result['var_min_in_1600_2200'] = None
result['var_max_in_1600_2200'] = None
result['blocks_var_lt_20k_in_1600_2200'] = 0
# Analysis with variance < 20000 only
mask_var_lt_20k = variances < variance_limit
result['blocks_var_lt_20k'] = int(mask_var_lt_20k.sum())
if mask_var_lt_20k.sum() > 10:
m_low = means[mask_var_lt_20k]
v_low = variances[mask_var_lt_20k]
corr_low, _ = stats.pearsonr(m_low, v_low)
slope_low, _, _, _, _ = stats.linregress(m_low, v_low)
result['corr_var_lt_20k'] = float(corr_low)
result['slope_var_lt_20k'] = float(slope_low)
result['mean_range_when_var_lt_20k'] = f"{m_low.min():.1f}-{m_low.max():.1f}"
else:
result['corr_var_lt_20k'] = None
result['slope_var_lt_20k'] = None
result['mean_range_when_var_lt_20k'] = None
# Verdict
result['shows_positive_correlation'] = corr_full > 0
result['consistent_with_physics'] = corr_full > 0
result['success'] = True
# Generate plot if requested
if generate_plot and output_dir:
plot_path = os.path.join(output_dir, f"{Path(cr2_path).stem}_ptc.png")
create_ptc_plot(means, variances, result, plot_path, raw_data=raw_data)
result['plot_path'] = plot_path
except Exception as e:
result['error'] = str(e)
result['success'] = False
return result
def create_ptc_plot(means, variances, stats_dict, output_path, raw_data=None):
"""Create a PTC plot for a single file."""
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
filename = stats_dict['filename']
corr_full = stats_dict['corr_full']
slope_full = stats_dict['slope_full']
# Top-left: Image preview
ax = axes[0, 0]
if raw_data is not None:
# Downsample for display
display_img = raw_data[::4, ::4].astype(np.float32)
# Simple normalization for display
black_level = stats_dict.get('black_level', 0)
p_low, p_high = np.percentile(display_img, [1, 99])
display_img = np.clip((display_img - p_low) / (p_high - p_low), 0, 1)
ax.imshow(display_img, cmap='gray', aspect='auto')
ax.set_title(f'{filename}\nImage Preview', fontweight='bold')
ax.axis('off')
else:
ax.text(0.5, 0.5, 'Image preview\nnot available', ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.set_title(f'{filename}', fontweight='bold')
ax.axis('off')
# Top-right: Full data analysis
ax = axes[0, 1]
color = 'green' if corr_full > 0 else 'red'
ax.scatter(means, variances, c=color, s=3, alpha=0.3, label='8x8 blocks')
x_fit = np.linspace(means.min(), means.max(), 100)
y_fit = slope_full * x_fit + (np.mean(variances) - slope_full * np.mean(means))
ax.plot(x_fit, y_fit, 'k--', lw=2, label=f'slope={slope_full:.1f}')
ax.set_xlabel('Mean Brightness (Signal)')
ax.set_ylabel('Variance (Noise²)')
ax.set_title(f'Full Analysis\nCorrelation: {corr_full:+.3f}',
color=color, fontweight='bold')
ax.legend(loc='upper left')
# Cap Y axis for visibility
y_cap = np.percentile(variances, 98)
ax.set_ylim(0, y_cap * 1.1)
# Bottom-left: Low variance regions spatial map
ax = axes[1, 0]
mask_v20k = variances < variance_limit
if raw_data is not None:
# Create a map showing low-variance block locations
block_size = 8
n_by = raw_data.shape[0] // block_size
n_bx = raw_data.shape[1] // block_size
var_map = np.zeros((n_by, n_bx))
idx = 0
for by in range(n_by):
for bx in range(n_bx):
if mask_v20k[idx]:
var_map[by, bx] = 1
idx += 1
ax.imshow(var_map, cmap='Blues', aspect='auto')
ax.set_title(f'Low variance regions (<{variance_limit})\n{mask_v20k.sum()} blocks', fontweight='bold')
ax.set_xlabel('Block X')
ax.set_ylabel('Block Y')
else:
ax.text(0.5, 0.5, 'Spatial map\nnot available', ha='center', va='center',
transform=ax.transAxes, fontsize=12)
ax.set_title(f'Low variance regions (<{variance_limit})', fontweight='bold')
# Bottom-right: Variance < variance_limit (matching original claim's Y-axis)
ax = axes[1, 1]
mask_v20k = variances < variance_limit
if mask_v20k.sum() > 10:
m_v20k = means[mask_v20k]
v_v20k = variances[mask_v20k]
corr_v20k = stats_dict.get('corr_var_lt_20k', 0) or 0
color = 'green' if corr_v20k > 0 else 'red'
ax.scatter(m_v20k, v_v20k, c=color, s=10, alpha=0.5, label='8x8 blocks')
slope_v20k, intercept, _, _, _ = stats.linregress(m_v20k, v_v20k)
x_fit = np.linspace(m_v20k.min(), m_v20k.max(), 100)
ax.plot(x_fit, slope_v20k * x_fit + intercept, 'k--', lw=2, label=f'Fit: slope={slope_v20k:.2f}')
ax.set_xlabel('Mean Brightness (Signal)')
ax.set_ylabel('Variance (Noise²)')
ax.set_title(f'Variance < {variance_limit} (Original Style)\nCorrelation: {corr_v20k:+.3f}',
color=color, fontweight='bold')
ax.legend(loc='upper right')
# Add note about what this region actually is (moved to bottom left, reformatted)
mean_range = stats_dict.get('mean_range_when_var_lt_20k', 'N/A')
black_level = stats_dict.get('black_level', 'N/A')
ax.text(0.02, 0.02, f"Mean range: {mean_range}\nBlack level: {black_level}\nBlocks: {mask_v20k.sum()}",
transform=ax.transAxes, fontsize=8, verticalalignment='bottom',
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
else:
ax.text(0.5, 0.5, f'No blocks with\nvariance < {variance_limit}\n({mask_v20k.sum()} blocks)',
ha='center', va='center', transform=ax.transAxes, fontsize=12)
ax.set_xlabel('Mean Brightness (Signal)')
ax.set_ylabel('Variance (Noise²)')
ax.set_title(f'Variance < {variance_limit} (Original Style)\nInsufficient data', fontweight='bold')
fig.suptitle('Photon Transfer Curve Analysis\nShot noise physics: Brighter pixels MUST have more noise',
fontsize=12, fontweight='bold', y=0.98)
plt.tight_layout()
plt.savefig(output_path, dpi=120, bbox_inches='tight')
plt.close(fig)
def find_cr2_files(directory):
"""Recursively find all CR2 files in a directory."""
cr2_files = []
for root, dirs, files in os.walk(directory):
for f in files:
if f.lower().endswith('.cr2'):
cr2_files.append(os.path.join(root, f))
return sorted(cr2_files)
def create_summary_plot(results, output_path):
"""Create a summary plot comparing all files."""
# Filter to successful results with valid correlation
valid = [r for r in results if r['success'] and r.get('corr_full') is not None]
if not valid:
print("No valid results to plot.")
return
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
# Extract data
filenames = [r['filename'] for r in valid]
corr_full = [r['corr_full'] for r in valid]
corr_p95 = [r.get('corr_p95') or 0 for r in valid]
# Plot 1: Correlation distribution
ax = axes[0, 0]
colors = ['green' if c > 0 else 'red' for c in corr_full]
bars = ax.barh(range(len(filenames)), corr_full, color=colors, alpha=0.7)
ax.axvline(0, color='black', linestyle='-', linewidth=1)
ax.set_yticks(range(len(filenames)))
ax.set_yticklabels(filenames, fontsize=8)
ax.set_xlabel('Correlation (Full)')
ax.set_title('Full Image Correlation by File')
ax.set_xlim(-1, 1)
# Plot 2: Correlation comparison (full vs 95th percentile)
ax = axes[0, 1]
x = np.arange(len(filenames))
width = 0.35
ax.bar(x - width / 2, corr_full, width, label='Full', color='steelblue', alpha=0.7)
ax.bar(x + width / 2, corr_p95, width, label='95th %ile', color='darkorange', alpha=0.7)
ax.axhline(0, color='black', linestyle='-', linewidth=1)
ax.set_xticks(x)
ax.set_xticklabels(filenames, rotation=45, ha='right', fontsize=8)
ax.set_ylabel('Correlation')
ax.set_title('Correlation: Full vs 95th Percentile Filtered')
ax.legend()
ax.set_ylim(-1, 1)
# Plot 3: Scatter of correlations
ax = axes[1, 0]
ax.scatter(corr_full, corr_p95, c='steelblue', s=50, alpha=0.7)
ax.axhline(0, color='gray', linestyle='--', linewidth=0.5)
ax.axvline(0, color='gray', linestyle='--', linewidth=0.5)
ax.plot([-1, 1], [-1, 1], 'k--', alpha=0.3, label='y=x')
ax.set_xlabel('Full Correlation')
ax.set_ylabel('95th Percentile Correlation')
ax.set_title('Correlation Comparison')
ax.set_xlim(-1, 1)
ax.set_ylim(-1, 1)
ax.legend()
# Add filename labels
for i, fn in enumerate(filenames):
ax.annotate(fn, (corr_full[i], corr_p95[i]), fontsize=6, alpha=0.7)
# Plot 4: Summary statistics
ax = axes[1, 1]
ax.axis('off')
n_positive = sum(1 for c in corr_full if c > 0)
n_negative = sum(1 for c in corr_full if c < 0)
n_strong_pos = sum(1 for c in corr_full if c > 0.5)
summary_text = f"""
BATCH ANALYSIS SUMMARY
══════════════════════════════════════
Total files analyzed: {len(valid)}
Full correlation results:
• Positive correlation: {n_positive} ({100 * n_positive / len(valid):.1f}%)
• Negative correlation: {n_negative} ({100 * n_negative / len(valid):.1f}%)
• Strongly positive (>0.5): {n_strong_pos}
Mean correlation (full): {np.mean(corr_full):+.3f}
Mean correlation (95th): {np.mean(corr_p95):+.3f}
══════════════════════════════════════
PHYSICS EXPECTATION:
Real camera sensors should show POSITIVE
correlation due to Poisson shot noise
(variance proportional to signal).
"""
ax.text(0.1, 0.9, summary_text, transform=ax.transAxes, fontsize=11,
verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='white', edgecolor='gray'))
fig.suptitle('Photon Transfer Curve Batch Analysis Summary', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(output_path, dpi=150, bbox_inches='tight')
plt.close(fig)
print(f"Summary plot saved to: {output_path}")
def main():
parser = argparse.ArgumentParser(
description='Batch PTC analysis for CR2 files',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
python ptc_batch_analysis.py /path/to/photos
python ptc_batch_analysis.py /path/to/photos --output ./results
python ptc_batch_analysis.py /path/to/photos --no-plots
"""
)
parser.add_argument('directory', help='Directory containing CR2 files (searched recursively)')
parser.add_argument('--output', '-o', default=None,
help='Output directory for results (default: ./ptc_results_TIMESTAMP)')
parser.add_argument('--no-plots', action='store_true', help='Skip generating individual plots')
parser.add_argument('--csv', default='ptc_results.csv', help='CSV filename for results')
args = parser.parse_args()
# Validate input directory
if not os.path.isdir(args.directory):
print(f"Error: '{args.directory}' is not a valid directory")
sys.exit(1)
# Setup output directory
if args.output is None:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
args.output = f"./ptc_results_{timestamp}"
os.makedirs(args.output, exist_ok=True)
# Find CR2 files
print(f"Searching for CR2 files in: {args.directory}")
cr2_files = find_cr2_files(args.directory)
if not cr2_files:
print("No CR2 files found!")
sys.exit(1)
print(f"Found {len(cr2_files)} CR2 file(s)")
print(f"Output directory: {args.output}")
print("=" * 60)
# Process each file
results = []
for i, cr2_path in enumerate(cr2_files, 1):
print(f"[{i}/{len(cr2_files)}] Processing: {os.path.basename(cr2_path)}...", end=' ')
result = process_cr2_file(
cr2_path,
output_dir=args.output if not args.no_plots else None,
generate_plot=not args.no_plots
)
results.append(result)
if result['success']:
corr = result['corr_full']
symbol = '✓' if corr > 0 else '⚠'
print(f"{symbol} Correlation: {corr:+.3f}")
else:
print(f"✗ Error: {result['error']}")
# Write CSV results
csv_path = os.path.join(args.output, args.csv)
# Determine all unique keys
all_keys = set()
for r in results:
all_keys.update(r.keys())
all_keys = sorted(all_keys)
with open(csv_path, 'w', newline='') as f:
writer = csv.DictWriter(f, fieldnames=all_keys)
writer.writeheader()
writer.writerows(results)
print("=" * 60)
print(f"Results saved to: {csv_path}")
# Create summary plot
summary_path = os.path.join(args.output, 'ptc_summary.png')
create_summary_plot(results, summary_path)
# Print summary
successful = [r for r in results if r['success']]
failed = [r for r in results if not r['success']]
print("=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total files: {len(results)}")
print(f"Successful: {len(successful)}")
print(f"Failed: {len(failed)}")
if successful:
correlations = [r['corr_full'] for r in successful]
positive = sum(1 for c in correlations if c > 0)
negative = sum(1 for c in correlations if c < 0)
print(f"\nCorrelation results:")
print(f" Positive (expected): {positive}")
print(f" Negative (anomalous): {negative}")
print(f" Mean correlation: {np.mean(correlations):+.3f}")
if negative > 0:
print(f"\n⚠ Files with negative correlation:")
for r in successful:
if r['corr_full'] < 0:
print(f" {r['filename']}: {r['corr_full']:+.3f}")
if failed:
print(f"\n✗ Failed files:")
for r in failed:
print(f" {r['filename']}: {r['error']}")
print("=" * 60)
print(f"Full results: {csv_path}")
print(f"Summary plot: {summary_path}")
if not args.no_plots:
print(f"Individual plots: {args.output}/*.png")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment