Last active
February 11, 2025 06:14
-
-
Save mathematicalmichael/fe4ad7234c5ce3a58830e1457e8dd923 to your computer and use it in GitHub Desktop.
Semi-Automated Perspective Transform App in Streamlit
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Streamlit web interface for perspective warp correction. | |
Uses core functionality from correct.py. | |
""" | |
import pathlib | |
import tempfile | |
from io import BytesIO | |
import cv2 | |
import kornia | |
import kornia.geometry.transform as kgt | |
import numpy as np | |
import streamlit as st | |
import tifffile | |
import torch | |
from correct import ( | |
DEFAULT_PARAMS, | |
RAW_EXTENSIONS, | |
crop_image, | |
detect_document_corners, | |
warp_perspective_transform, | |
) | |
try: | |
import rawpy | |
except ImportError: | |
rawpy = None | |
# ----------------------------------------------------------- | |
# Helper: Load image from Streamlit file uploader. | |
# Supports RAW files (using rawpy) and standard formats. | |
# For raw files, we call postprocess() with custom parameters | |
# to get a correctly color-balanced view. | |
# ----------------------------------------------------------- | |
@st.cache_data(show_spinner=False) | |
def load_image_streamlit( | |
uploaded_file, | |
gamma: float = 1.0, | |
curve: float = 1.0, | |
use_auto_wb: bool = False, | |
bright: float = 1.0, | |
): | |
"""Load image from Streamlit's UploadedFile object""" | |
extension = pathlib.Path(uploaded_file.name).suffix.lower() | |
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8) | |
if extension in RAW_EXTENSIONS: | |
if rawpy is None: | |
st.error( | |
"rawpy module not found. Please install it via 'pip install rawpy'" | |
) | |
return None, None | |
# Write to temp file since rawpy needs file access | |
with tempfile.NamedTemporaryFile(delete=False, suffix=extension) as tmp_file: | |
tmp_file.write(file_bytes) | |
tmp_filename = tmp_file.name | |
with rawpy.imread(tmp_filename) as raw: | |
image = raw.postprocess( | |
gamma=(gamma, curve), | |
use_auto_wb=use_auto_wb, | |
bright=bright, | |
output_color=rawpy.ColorSpace.sRGB, | |
) | |
color_order = "RGB" | |
else: | |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
if image is None: | |
st.error("Error loading image.") | |
return None, None | |
color_order = "BGR" | |
return image, color_order | |
# ----------------------------------------------------------- | |
# Helper: Order 4 points as [top-left, top-right, bottom-right, bottom-left] | |
# ----------------------------------------------------------- | |
def order_points(pts: np.ndarray) -> np.ndarray: | |
s = pts.sum(axis=1) | |
tl = pts[np.argmin(s)] | |
br = pts[np.argmax(s)] | |
diff = np.diff(pts, axis=1) | |
tr = pts[np.argmin(diff)] | |
bl = pts[np.argmax(diff)] | |
return np.array([tl, tr, br, bl], dtype="float32") | |
# ----------------------------------------------------------- | |
# Detect document/frame corners using either Canny or Adaptive thresholding. | |
# | |
# Returns: | |
# - corners: a 4x2 array if detected (or None) | |
# - debug_info: a dictionary of intermediate images and info. | |
# | |
# This version downsizes the grayscale for processing and then scales | |
# the detected corners back up. | |
# ----------------------------------------------------------- | |
@st.cache_data(show_spinner=False) | |
def detect_document_corners_custom(image, color_order, detection_method, params): | |
debug_info = {} | |
# Convert to grayscale. | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
if color_order.upper() == "BGR": | |
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) | |
else: | |
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) | |
else: | |
gray = image.copy() | |
debug_info["Grayscale"] = gray | |
# Downscale for processing if required. | |
downscale_factor = params.get("downscale_factor", 1) | |
if downscale_factor > 1: | |
h, w = gray.shape | |
gray = cv2.resize(gray, (w // downscale_factor, h // downscale_factor)) | |
# Blur the image. | |
blur_kernel = params.get("blur_kernel_size", 5) | |
# Ensure the kernel size is odd. | |
if blur_kernel % 2 == 0: | |
blur_kernel += 1 | |
blurred = cv2.GaussianBlur(gray, (blur_kernel, blur_kernel), 0) | |
debug_info["Blurred"] = blurred | |
corners = None | |
method_used = None | |
# --- Adaptive threshold approach (default and primary) --- | |
if detection_method in ["Adaptive", "Auto"]: | |
block_size = params.get("adaptive_block_size", 11) | |
adaptive_C = params.get("adaptive_C", 2) | |
# Ensure block_size is odd. | |
if block_size % 2 == 0: | |
block_size += 1 | |
thresh = cv2.adaptiveThreshold( | |
blurred, | |
255, | |
cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
cv2.THRESH_BINARY, | |
block_size, | |
adaptive_C, | |
) | |
debug_info["Adaptive Threshold"] = thresh | |
cnts, _ = cv2.findContours( | |
thresh.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE | |
) | |
cnts = sorted(cnts, key=cv2.contourArea, reverse=True) | |
approx_factor_adaptive = params.get("adaptive_approx", 0.02) | |
for c in cnts: | |
peri = cv2.arcLength(c, True) | |
approx = cv2.approxPolyDP(c, approx_factor_adaptive * peri, True) | |
if len(approx) == 4: | |
pts = approx.reshape(4, 2) | |
corners = order_points(pts) | |
method_used = "Adaptive" | |
break | |
# --- Canny-based approach (if forced) --- | |
if corners is None and detection_method in ["Canny"]: | |
canny_low = params.get("canny_low", 50) | |
canny_high = params.get("canny_high", 150) | |
edges = cv2.Canny(blurred, canny_low, canny_high) | |
debug_info["Canny Edges"] = edges | |
cnts, _ = cv2.findContours(edges.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | |
cnts = sorted(cnts, key=cv2.contourArea, reverse=True) | |
approx_factor = params.get("canny_approx", 0.1) | |
for c in cnts: | |
peri = cv2.arcLength(c, True) | |
approx = cv2.approxPolyDP(c, approx_factor * peri, True) | |
if len(approx) == 4: | |
pts = approx.reshape(4, 2) | |
corners = order_points(pts) | |
method_used = "Canny" | |
break | |
# Create an overlay to draw the detected corners. | |
corners_overlay = image.copy() | |
if corners is not None: | |
# If we downscaled before processing, scale corners back. | |
if downscale_factor > 1: | |
corners = corners * downscale_factor | |
for x, y in corners: | |
cv2.circle(corners_overlay, (int(x), int(y)), 10, (0, 255, 0), -1) | |
debug_info["Detected Corners Overlay"] = corners_overlay | |
debug_info["Method Used"] = method_used | |
return corners, debug_info | |
# ----------------------------------------------------------- | |
# Warp perspective using Kornia. | |
# ----------------------------------------------------------- | |
@st.cache_data(show_spinner=False) | |
def warp_perspective_transform(image, src_pts): | |
src = order_points(src_pts) | |
widthA = np.linalg.norm(src[2] - src[3]) | |
widthB = np.linalg.norm(src[1] - src[0]) | |
maxWidth = int(max(widthA, widthB)) | |
heightA = np.linalg.norm(src[1] - src[2]) | |
heightB = np.linalg.norm(src[0] - src[3]) | |
maxHeight = int(max(heightA, heightB)) | |
if maxWidth <= 0 or maxHeight <= 0: | |
return None | |
dst = np.array( | |
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], | |
dtype="float32", | |
) | |
src_tensor = torch.from_numpy(src).unsqueeze(0) | |
dst_tensor = torch.from_numpy(dst).unsqueeze(0) | |
M = kgt.get_perspective_transform(src_tensor, dst_tensor) | |
if len(image.shape) == 2: | |
image = image[:, :, np.newaxis] | |
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() | |
warped_tensor = kgt.warp_perspective(image_tensor, M, dsize=(maxHeight, maxWidth)) | |
warped = warped_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
warped = np.clip(warped, 0, 255).astype(np.uint8) | |
return warped | |
def main(): | |
"""Main Streamlit app""" | |
st.title("Perspective Warp Debugger") | |
# File upload | |
uploaded_file = st.file_uploader( | |
"Upload an image", | |
type=["jpg", "jpeg", "png", "nef", "cr2", "arw", "raf", "raw", "dng", "rw2"], | |
) | |
if not uploaded_file: | |
return | |
# Load image | |
with st.sidebar.expander("Image Processing Parameters", expanded=True): | |
gamma = st.slider( | |
"Gamma", | |
0.1, | |
5.0, | |
DEFAULT_PARAMS["gamma"], | |
0.1, | |
help="Gamma correction for the image.", | |
) | |
curve = st.slider( | |
"Curve", | |
0.1, | |
5.0, | |
DEFAULT_PARAMS["curve"], | |
0.1, | |
help="Curve correction for the image.", | |
) | |
use_auto_wb = st.checkbox( | |
"Use Auto White Balance", | |
value=DEFAULT_PARAMS["use_auto_wb"], | |
help="Use auto white balance for the image.", | |
) | |
bright = st.slider("Brightness", 0.1, 5.0, DEFAULT_PARAMS["bright"], 0.1) | |
st.markdown("---") | |
st.markdown("### Crop Settings") | |
# Use columns for a more compact layout | |
col1, col2 = st.columns(2) | |
with col1: | |
crop_top = st.slider( | |
"Crop Top", | |
0.0, | |
0.4, | |
DEFAULT_PARAMS["crop_top"], | |
0.01, | |
help="Percentage to crop from top edge", | |
) | |
crop_bottom = st.slider( | |
"Crop Bottom", | |
0.0, | |
0.4, | |
DEFAULT_PARAMS["crop_bottom"], | |
0.01, | |
help="Percentage to crop from bottom edge", | |
) | |
with col2: | |
crop_left = st.slider( | |
"Crop Left", | |
0.0, | |
0.4, | |
DEFAULT_PARAMS["crop_left"], | |
0.01, | |
help="Percentage to crop from left edge", | |
) | |
crop_right = st.slider( | |
"Crop Right", | |
0.0, | |
0.4, | |
DEFAULT_PARAMS["crop_right"], | |
0.01, | |
help="Percentage to crop from right edge", | |
) | |
image, color_order = load_image_streamlit( | |
uploaded_file, gamma=gamma, curve=curve, use_auto_wb=use_auto_wb, bright=bright | |
) | |
if image is None: | |
return | |
# Apply cropping | |
image = crop_image( | |
image, | |
crop_top=crop_top, | |
crop_bottom=crop_bottom, | |
crop_left=crop_left, | |
crop_right=crop_right, | |
) | |
# Display original | |
disp_image = ( | |
cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
if color_order.upper() == "BGR" | |
else image.copy() | |
) | |
st.image( | |
disp_image, caption="Original Image (After Cropping)", use_container_width=True | |
) | |
# Parameter controls | |
st.sidebar.header("Processing Parameters") | |
params = {} | |
params["downscale_factor"] = st.sidebar.number_input( | |
"Downscale Factor (for processing)", | |
min_value=1, | |
value=DEFAULT_PARAMS["downscale_factor"], | |
step=1, | |
help="Downscale factor to speed up processing on very high-res images.", | |
) | |
params["blur_kernel_size"] = st.sidebar.slider( | |
"Gaussian Blur Kernel Size (odd)", | |
3, | |
21, | |
DEFAULT_PARAMS["blur_kernel_size"], | |
step=2, | |
) | |
params["canny_low"] = st.sidebar.slider( | |
"Canny Low Threshold", 0, 255, DEFAULT_PARAMS["canny_low"] | |
) | |
params["canny_high"] = st.sidebar.slider( | |
"Canny High Threshold", 0, 255, DEFAULT_PARAMS["canny_high"] | |
) | |
params["canny_approx"] = st.sidebar.slider( | |
"Contour Approx Factor (Canny)", | |
0.01, | |
0.2, | |
DEFAULT_PARAMS["canny_approx"], | |
step=0.01, | |
) | |
params["adaptive_block_size"] = st.sidebar.slider( | |
"Adaptive Threshold Block Size (odd)", | |
3, | |
25, | |
DEFAULT_PARAMS["adaptive_block_size"], | |
step=2, | |
) | |
params["adaptive_C"] = st.sidebar.slider( | |
"Adaptive Threshold Constant", 0, 10, DEFAULT_PARAMS["adaptive_C"] | |
) | |
params["adaptive_approx"] = st.sidebar.slider( | |
"Contour Approx Factor (Adaptive)", | |
0.01, | |
0.2, | |
DEFAULT_PARAMS["adaptive_approx"], | |
step=0.01, | |
) | |
detection_method = st.sidebar.radio( | |
"Detection Method", ["Adaptive", "Canny", "Auto"], index=0 | |
).lower() | |
# Process image | |
corners, debug_info = detect_document_corners( | |
image, color_order, method=detection_method, params=params | |
) | |
# Show results | |
method_used = debug_info.get("Method Used") | |
st.write(f"Detection method used: **{method_used}**") | |
if corners is None: | |
st.error("Could not detect a 4-corner contour. Try adjusting parameters.") | |
return | |
st.success("Detected corners!") | |
# Show debug images | |
for key in ["Grayscale", "Blurred", "Canny Edges", "Adaptive Threshold"]: | |
if key not in debug_info: | |
continue | |
img = debug_info[key] | |
if len(img.shape) == 2: | |
st.image(img, caption=key, use_container_width=True) | |
else: | |
img_disp = ( | |
cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
if color_order.upper() == "BGR" | |
else img.copy() | |
) | |
st.image(img_disp, caption=key, use_container_width=True) | |
# Apply warp | |
warped = warp_perspective_transform(image, corners) | |
if warped is None: | |
st.error("Warping failed due to invalid dimensions.") | |
return | |
# Show warped image | |
warped_disp = ( | |
cv2.cvtColor(warped, cv2.COLOR_BGR2RGB) | |
if color_order.upper() == "BGR" | |
else warped.copy() | |
) | |
st.image( | |
warped_disp, | |
caption="Warped (Perspective Corrected) Image", | |
use_container_width=True, | |
) | |
# Save/download functionality | |
base_name = pathlib.Path(uploaded_file.name).stem | |
output_filename = f"{base_name}.tiff" | |
buf = BytesIO() | |
tifffile.imwrite(buf, warped) | |
buf.seek(0) | |
st.sidebar.download_button( | |
label="Save Output Image", | |
data=buf, | |
file_name=output_filename, | |
mime="image/tiff", | |
key="download", | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
Core functionality for perspective warp correction using Kornia. | |
Provides both programmatic API and CLI interface. | |
Supports RAW files (using rawpy) and standard image formats. | |
Usage: | |
python correct.py /path/to/your/image.jpg [--method {adaptive,canny,auto}] [--debug] [--params key=value...] | |
""" | |
import argparse | |
import ast | |
import pathlib | |
from typing import Dict, Optional, Tuple, Union | |
import cv2 | |
import kornia.geometry.transform as kgt | |
import numpy as np | |
import torch | |
from tifffile import TiffWriter | |
# List of common RAW file extensions (lowercase) | |
RAW_EXTENSIONS = {".nef", ".cr2", ".arw", ".raf", ".raw", ".dng", ".rw2"} | |
try: | |
import rawpy | |
except ImportError: | |
rawpy = None | |
# Default processing parameters | |
DEFAULT_PARAMS = { | |
"downscale_factor": 5, | |
"blur_kernel_size": 5, | |
"canny_low": 50, | |
"canny_high": 150, | |
"canny_approx": 0.1, | |
"adaptive_block_size": 11, | |
"adaptive_C": 2, | |
"adaptive_approx": 0.02, | |
"gamma": 1.0, | |
"curve": 1.0, | |
"use_auto_wb": True, | |
"bright": 1.0, | |
"crop_top": 0.0, | |
"crop_bottom": 0.0, | |
"crop_left": 0.0, | |
"crop_right": 0.0, | |
} | |
def load_image( | |
image_path: Union[str, pathlib.Path], | |
gamma: float = 1.0, | |
curve: float = 1.0, | |
use_auto_wb: bool = True, | |
bright: float = 1.0, | |
output_color: Optional[str] = "sRGB", | |
) -> Tuple[np.ndarray, str]: | |
""" | |
Load an image from a file path, supporting both RAW and standard formats. | |
Args: | |
image_path: Path to the image file | |
gamma: Gamma correction value (default: 2.2) | |
curve: Toe slope for dark areas (default: 4.5) | |
use_auto_wb: Whether to use auto white balance (default: True) | |
bright: Brightness adjustment (default: 1.0) | |
output_color: Color space for RAW output (default: "sRGB") | |
Returns: | |
Tuple of (image array, color order string ['RGB' or 'BGR']) | |
Raises: | |
ImportError: If RAW file support is needed but rawpy isn't installed | |
ValueError: If the image cannot be loaded | |
""" | |
image_path = pathlib.Path(image_path) | |
ext = image_path.suffix.lower() | |
if ext in RAW_EXTENSIONS: | |
if rawpy is None: | |
raise ImportError("rawpy module required for RAW file support") | |
with rawpy.imread(str(image_path)) as raw: | |
# Convert string color space to rawpy enum | |
color_space = ( | |
getattr(rawpy.ColorSpace, output_color) | |
if output_color | |
else rawpy.ColorSpace.sRGB | |
) | |
image = raw.postprocess( | |
gamma=(gamma, curve), | |
use_auto_wb=use_auto_wb, | |
bright=bright, | |
output_color=color_space, | |
) | |
color_order = "RGB" | |
else: | |
image = cv2.imread(str(image_path)) | |
if image is None: | |
raise ValueError(f"Failed to load image: {image_path}") | |
color_order = "BGR" | |
return image, color_order | |
def order_points(pts: np.ndarray) -> np.ndarray: | |
""" | |
Order 4 points as [top-left, top-right, bottom-right, bottom-left]. | |
Args: | |
pts: Array of 4 points (4x2) | |
Returns: | |
Ordered array of points as float32 | |
""" | |
s = pts.sum(axis=1) | |
tl = pts[np.argmin(s)] | |
br = pts[np.argmax(s)] | |
diff = np.diff(pts, axis=1) | |
tr = pts[np.argmin(diff)] | |
bl = pts[np.argmax(diff)] | |
return np.array([tl, tr, br, bl], dtype="float32") | |
def detect_document_corners( | |
image: np.ndarray, | |
color_order: str, | |
method: str = "adaptive", | |
params: Optional[Dict] = None, | |
) -> Tuple[Optional[np.ndarray], Dict]: | |
""" | |
Detect document corners using either adaptive thresholding or Canny edge detection. | |
Args: | |
image: Input image array | |
color_order: Color channel order ('RGB' or 'BGR') | |
method: Detection method ('adaptive', 'canny', or 'auto') | |
params: Optional dictionary of processing parameters | |
Returns: | |
Tuple of (corners array or None, debug info dictionary) | |
""" | |
params = {**DEFAULT_PARAMS, **(params or {})} | |
debug_info = {} | |
# Convert to grayscale | |
if len(image.shape) == 3 and image.shape[2] == 3: | |
gray = cv2.cvtColor( | |
image, | |
cv2.COLOR_BGR2GRAY if color_order.upper() == "BGR" else cv2.COLOR_RGB2GRAY, | |
) | |
else: | |
gray = image.copy() | |
debug_info["Grayscale"] = gray | |
# Downscale for processing if needed | |
downscale_factor = params["downscale_factor"] | |
if downscale_factor > 1: | |
h, w = gray.shape | |
gray = cv2.resize(gray, (w // downscale_factor, h // downscale_factor)) | |
# Apply blur | |
blur_kernel = params["blur_kernel_size"] | |
if blur_kernel % 2 == 0: | |
blur_kernel += 1 | |
blurred = cv2.GaussianBlur(gray, (blur_kernel, blur_kernel), 0) | |
debug_info["Blurred"] = blurred | |
corners = None | |
method_used = None | |
# Try adaptive threshold method | |
if method.lower() in ["adaptive", "auto"]: | |
block_size = params["adaptive_block_size"] | |
if block_size % 2 == 0: | |
block_size += 1 | |
thresh = cv2.adaptiveThreshold( | |
blurred, | |
255, | |
cv2.ADAPTIVE_THRESH_GAUSSIAN_C, | |
cv2.THRESH_BINARY, | |
block_size, | |
params["adaptive_C"], | |
) | |
debug_info["Adaptive Threshold"] = thresh | |
cnts, _ = cv2.findContours( | |
thresh.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE | |
) | |
cnts = sorted(cnts, key=cv2.contourArea, reverse=True) | |
for c in cnts: | |
peri = cv2.arcLength(c, True) | |
approx = cv2.approxPolyDP(c, params["adaptive_approx"] * peri, True) | |
if len(approx) == 4: | |
corners = order_points(approx.reshape(4, 2)) | |
method_used = "Adaptive" | |
break | |
# Try Canny method if needed | |
if corners is None and method.lower() in ["canny", "auto"]: | |
edges = cv2.Canny(blurred, params["canny_low"], params["canny_high"]) | |
debug_info["Canny Edges"] = edges | |
cnts, _ = cv2.findContours(edges.copy(), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) | |
cnts = sorted(cnts, key=cv2.contourArea, reverse=True) | |
for c in cnts: | |
peri = cv2.arcLength(c, True) | |
approx = cv2.approxPolyDP(c, params["canny_approx"] * peri, True) | |
if len(approx) == 4: | |
corners = order_points(approx.reshape(4, 2)) | |
method_used = "Canny" | |
break | |
# Scale corners back up if needed | |
if corners is not None and downscale_factor > 1: | |
corners = corners * downscale_factor | |
debug_info["Method Used"] = method_used | |
return corners, debug_info | |
def warp_perspective_transform( | |
image: np.ndarray, src_pts: np.ndarray | |
) -> Optional[np.ndarray]: | |
""" | |
Apply perspective warp transformation using Kornia. | |
Args: | |
image: Input image array | |
src_pts: Source points (4x2 array) | |
Returns: | |
Warped image array or None if warping fails | |
""" | |
src = order_points(src_pts) | |
# Calculate output dimensions | |
widthA = np.linalg.norm(src[2] - src[3]) | |
widthB = np.linalg.norm(src[1] - src[0]) | |
maxWidth = int(max(widthA, widthB)) | |
heightA = np.linalg.norm(src[1] - src[2]) | |
heightB = np.linalg.norm(src[0] - src[3]) | |
maxHeight = int(max(heightA, heightB)) | |
if maxWidth <= 0 or maxHeight <= 0: | |
return None | |
# Define destination points | |
dst = np.array( | |
[[0, 0], [maxWidth - 1, 0], [maxWidth - 1, maxHeight - 1], [0, maxHeight - 1]], | |
dtype="float32", | |
) | |
# Convert to torch tensors | |
src_tensor = torch.from_numpy(src).unsqueeze(0) | |
dst_tensor = torch.from_numpy(dst).unsqueeze(0) | |
M = kgt.get_perspective_transform(src_tensor, dst_tensor) | |
# Ensure image has channels dimension | |
if len(image.shape) == 2: | |
image = image[:, :, np.newaxis] | |
# Apply transformation | |
image_tensor = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).float() | |
warped_tensor = kgt.warp_perspective(image_tensor, M, dsize=(maxHeight, maxWidth)) | |
warped = warped_tensor.squeeze(0).permute(1, 2, 0).cpu().numpy() | |
return np.clip(warped, 0, 255).astype(np.uint8) | |
def save_image( | |
image: np.ndarray, | |
output_path: str | pathlib.Path, | |
suffix: str = "", | |
format: str = "tiff", | |
) -> pathlib.Path: | |
""" | |
Save an image using TiffWriter with consistent naming convention. | |
Args: | |
image: numpy array of the image | |
output_path: original image path or desired output path | |
suffix: suffix to append to filename (default: "-w") | |
format: output format ('dng' or 'tiff', default: 'tiff') | |
Returns: | |
Path to saved file | |
""" | |
output_path = pathlib.Path(output_path) | |
extension = ".tiff" | |
save_path = output_path.parent / f"{output_path.stem}{suffix}{extension}" | |
with TiffWriter(str(save_path)) as tif: | |
tif.write( | |
image, | |
photometric="rgb", | |
software="correct.py", | |
datetime=True, | |
compression="lzw", | |
) | |
return save_path | |
def crop_image( | |
image: np.ndarray, | |
crop_top: float = 0.0, | |
crop_bottom: float = 0.0, | |
crop_left: float = 0.0, | |
crop_right: float = 0.0, | |
) -> np.ndarray: | |
""" | |
Crop image by percentage from each edge. | |
Args: | |
image: Input image array | |
crop_top: Percentage to crop from top (0.0-1.0) | |
crop_bottom: Percentage to crop from bottom (0.0-1.0) | |
crop_left: Percentage to crop from left (0.0-1.0) | |
crop_right: Percentage to crop from right (0.0-1.0) | |
Returns: | |
Cropped image array | |
""" | |
h, w = image.shape[:2] | |
# Calculate pixel values from percentages | |
top = int(h * crop_top) | |
bottom = int(h * (1 - crop_bottom)) | |
left = int(w * crop_left) | |
right = int(w * (1 - crop_right)) | |
# Ensure we don't crop everything | |
if bottom <= top or right <= left: | |
return image | |
return image[top:bottom, left:right] | |
def main(): | |
"""CLI entry point""" | |
parser = argparse.ArgumentParser(description=__doc__) | |
# Input path | |
parser.add_argument("image_path", help="Path to input image") | |
# Image loading parameters | |
parser.add_argument( | |
"-g", | |
"--gamma", | |
type=float, | |
default=1.0, | |
help="Gamma correction value (default: 1.0)", | |
) | |
parser.add_argument( | |
"-c", | |
"--curve", | |
type=float, | |
default=1.0, | |
help="Toe slope for dark areas (default: 1.0)", | |
) | |
parser.add_argument( | |
"-b", | |
"--bright", | |
type=float, | |
default=1.0, | |
help="Brightness adjustment (default: 1.0)", | |
) | |
parser.add_argument( | |
"-w", | |
"--use-auto-wb", | |
action="store_true", | |
default=True, | |
help="Use auto white balance (default: True)", | |
) | |
parser.add_argument( | |
"--output-color", | |
type=str, | |
default="sRGB", | |
choices=["sRGB", "Adobe", "Wide", "ProPhoto", "XYZ"], | |
help="Color space for RAW output (default: sRGB)", | |
) | |
# Add cropping parameters | |
parser.add_argument( | |
"-ct", | |
"--crop-top", | |
type=float, | |
default=0.0, | |
help="Percentage to crop from top (0.0-1.0)", | |
) | |
parser.add_argument( | |
"-cb", | |
"--crop-bottom", | |
type=float, | |
default=0.0, | |
help="Percentage to crop from bottom (0.0-1.0)", | |
) | |
parser.add_argument( | |
"-cl", | |
"--crop-left", | |
type=float, | |
default=0.0, | |
help="Percentage to crop from left (0.0-1.0)", | |
) | |
parser.add_argument( | |
"-cr", | |
"--crop-right", | |
type=float, | |
default=0.0, | |
help="Percentage to crop from right (0.0-1.0)", | |
) | |
# Detection parameters | |
parser.add_argument( | |
"-m", | |
"--method", | |
choices=["adaptive", "canny", "auto"], | |
default="adaptive", | |
help="Corner detection method", | |
) | |
parser.add_argument("--debug", action="store_true", help="Show debug information") | |
parser.add_argument( | |
"-p", | |
"--params", | |
nargs="*", | |
help="Override default parameters (e.g. --params blur_kernel_size=7 adaptive_C=3)", | |
) | |
# Add format argument | |
parser.add_argument( | |
"-f", | |
"--format", | |
choices=["tiff"], | |
default="tiff", | |
help="Output format (default: tiff)", | |
) | |
args = parser.parse_args() | |
# Parse custom parameters | |
custom_params = {} | |
if args.params: | |
for param in args.params: | |
key, value = param.split("=") | |
try: | |
custom_params[key] = ast.literal_eval(value) | |
except: | |
custom_params[key] = value | |
# Load image with all parameters | |
try: | |
image, color_order = load_image( | |
args.image_path, | |
gamma=args.gamma, | |
curve=args.curve, | |
use_auto_wb=args.use_auto_wb, | |
bright=args.bright, | |
output_color=args.output_color, | |
) | |
# Apply cropping | |
image = crop_image( | |
image, | |
crop_top=args.crop_top, | |
crop_bottom=args.crop_bottom, | |
crop_left=args.crop_left, | |
crop_right=args.crop_right, | |
) | |
except Exception as e: | |
print(f"Error loading/processing image: {e}") | |
return 1 | |
# Detect corners | |
corners, debug_info = detect_document_corners( | |
image, color_order, method=args.method, params=custom_params | |
) | |
if corners is None: | |
print("No document corners detected") | |
return 1 | |
# Apply warp | |
warped = warp_perspective_transform(image, corners) | |
if warped is None: | |
print("Warping failed") | |
return 1 | |
# Convert if needed for saving | |
if color_order == "RGB": | |
# Convert RGB to BGR for consistent color handling | |
warped = cv2.cvtColor(warped, cv2.COLOR_RGB2BGR) | |
# Convert BGR to RGB for saving | |
warped = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB) | |
# Save output | |
save_path = save_image(warped, args.image_path, format=args.format) | |
print(f"Saved image to: {save_path}") | |
return 0 # Return 0 for success | |
if __name__ == "__main__": | |
exit(main()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/bin/bash | |
uv run --isolated \ | |
--with streamlit==1.42.0 \ | |
--with kornia==0.7.0 \ | |
--with torch==2.1.0 \ | |
--with rawpy==0.21.0 \ | |
--with opencv-python-headless==4.8.1.78 \ | |
--with tifffile==2025.1.10 \ | |
streamlit run \ | |
--server.headless true \ | |
--browser.gatherUsageStats false \ | |
--browser.serverAddress 0.0.0.0 \ | |
--server.port 8501 \ | |
app.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment