Created
September 11, 2025 00:03
-
-
Save yeiichi/ae9da2b175a8fdf74cf4bb12d7a6166e to your computer and use it in GitHub Desktop.
Analyze and infer SQL types for columns in CSV files
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
""" | |
This module provides a class and utility functions to analyze and infer SQL | |
types for columns in CSV files via statistical and datatype examination. | |
It handles various datatypes like integers, floats, booleans, dates, and | |
timestamps while accommodating dialect-specific SQL type mappings. Additional | |
functionality includes updating statistics across data chunks, robustness | |
against missing or invalid data, and generation of NULL/NOT NULL constraints. | |
""" | |
from __future__ import annotations | |
import argparse | |
import csv | |
import logging | |
import warnings | |
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Dict, List, Optional | |
import pandas as pd | |
import numpy as np | |
# ----------------------------- Logging Setup ----------------------------- # | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s [%(levelname)s] %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# Quiet pandas' engine warning just in case; we also avoid it by detecting delimiter | |
warnings.filterwarnings( | |
"ignore", | |
message=r"^Falling back to the 'python' engine.*", | |
category=pd.errors.ParserWarning, | |
) | |
# Silence pandas' per-element datetime parse warning when no format is provided | |
warnings.filterwarnings( | |
"ignore", | |
message=r"^Could not infer format, so each element will be parsed individually, falling back to `dateutil`.*", | |
category=UserWarning, | |
) | |
# ----------------------------- Helpers & Types ----------------------------- # | |
BOOL_TRUE = {"true", "t", "1", "yes", "y"} | |
BOOL_FALSE = {"false", "f", "0", "no", "n"} | |
INT32_MIN, INT32_MAX = -2147483648, 2147483647 | |
# Common NA/NULL spellings for CSVs | |
NA_STRINGS = ["", "na", "n/a", "null", "none", "NaN", "NA"] | |
# ... existing code ... | |
@dataclass | |
class ColStats: | |
""" | |
Per-column statistics accumulated across CSV chunks to infer SQL types. | |
""" | |
name: str | |
non_null_count: int = 0 | |
null_count: int = 0 | |
could_be_int: bool = True | |
could_be_float: bool = True | |
could_be_bool: bool = True | |
could_be_date: bool = True | |
could_be_timestamp: bool = True | |
int_min: Optional[int] = None | |
int_max: Optional[int] = None | |
max_len: int = 0 | |
any_time_component: bool = False | |
def update_from_series(self, s: pd.Series, *, date_format: Optional[str] = None, **_kw: object) -> None: | |
""" | |
Update stats based on a pandas Series chunk for this column. | |
""" | |
logger.debug(f"Updating stats for column '{self.name}'") | |
nn_mask = s.notna() | |
self.non_null_count += int(nn_mask.sum()) | |
self.null_count += int((~nn_mask).sum()) | |
if self.non_null_count == 0: | |
return | |
non_null_values = s[nn_mask] | |
self._update_string_length(non_null_values) | |
if self.could_be_bool: | |
self._update_bool_candidate(non_null_values) | |
if self.could_be_int or self.could_be_float: | |
self._update_numeric_candidate(non_null_values) | |
if self.could_be_date or self.could_be_timestamp: | |
self._update_datetime_candidate(non_null_values, date_format=date_format) | |
def _update_string_length(self, non_null_values: pd.Series) -> None: | |
max_len_this = int(non_null_values.astype(str).str.len().max()) if not non_null_values.empty else 0 | |
self.max_len = max(self.max_len, max_len_this) | |
def _update_bool_candidate(self, non_null_values: pd.Series) -> None: | |
lowered = non_null_values.astype(str).str.strip().str.lower() | |
is_bool_like = lowered.isin(BOOL_TRUE | BOOL_FALSE) | |
if not bool(is_bool_like.all()): | |
self.could_be_bool = False | |
def _update_numeric_candidate(self, non_null_values: pd.Series) -> None: | |
coerced = pd.to_numeric(non_null_values, errors="coerce") | |
if coerced.isna().any(): | |
self.could_be_int = False | |
self.could_be_float = False | |
return | |
# Robust integer-like check with tolerance to floating precision | |
if self.could_be_int: | |
values = coerced.to_numpy(dtype=float, copy=False) | |
finite_mask = np.isfinite(values) | |
if not finite_mask.all(): | |
# Shouldn't happen after isna() check, but be defensive | |
self.could_be_int = False | |
else: | |
is_integer_like = np.isclose(values, np.round(values)) | |
if not bool(is_integer_like.all()): | |
self.could_be_int = False | |
if self.could_be_int: | |
# Update min/max using rounded integer-like values without narrowing dtype | |
ints = np.round(coerced.to_numpy(dtype=float, copy=False)).astype(np.int64, copy=False) | |
if ints.size > 0: | |
mn = int(ints.min()) | |
mx = int(ints.max()) | |
self.int_min = mn if self.int_min is None else min(self.int_min, mn) | |
self.int_max = mx if self.int_max is None else max(self.int_max, mx) | |
if not self.could_be_int: | |
# If int is ruled out but numeric parsing succeeded, float remains possible | |
self.could_be_float = True | |
def _update_datetime_candidate(self, non_null_values: pd.Series, *, date_format: Optional[str]) -> None: | |
parsed = pd.to_datetime( | |
non_null_values, | |
errors="coerce", | |
utc=False, | |
format=date_format if date_format else None, | |
) | |
if parsed.isna().any(): | |
self.could_be_date = False | |
self.could_be_timestamp = False | |
return | |
hours = parsed.dt.hour | |
mins = parsed.dt.minute | |
secs = parsed.dt.second | |
self.any_time_component = bool(((hours != 0) | (mins != 0) | (secs != 0)).any()) | |
def inferred_sql_type(self, dialect: str = "postgres") -> str: | |
if self.non_null_count == 0: | |
return _sql_type_text(self.max_len, dialect) | |
if self.could_be_bool and not (self.could_be_int or self.could_be_float): | |
return _sql_type_bool(dialect) | |
if self.could_be_int: | |
if self.int_min is not None and self.int_max is not None: | |
if INT32_MIN <= self.int_min <= self.int_max <= INT32_MAX: | |
return _sql_type_int32(dialect) | |
else: | |
return _sql_type_int64(dialect) | |
else: | |
return _sql_type_int64(dialect) | |
if self.could_be_float: | |
return _sql_type_float(dialect) | |
if self.could_be_date or self.could_be_timestamp: | |
return _sql_type_timestamp(dialect) if self.any_time_component else _sql_type_date(dialect) | |
return _sql_type_text(self.max_len, dialect) | |
def nullability_sql(self) -> str: | |
return "NULL" if self.null_count > 0 else "NOT NULL" | |
# ----------------------------- SQL Mappings ----------------------------- # | |
def _sql_type_bool(dialect: str) -> str: | |
return {"postgres": "BOOLEAN", "sqlite": "INTEGER", "mysql": "BOOLEAN"}.get(dialect, "BOOLEAN") | |
def _sql_type_int32(dialect: str) -> str: | |
return {"postgres": "INTEGER", "sqlite": "INTEGER", "mysql": "INT"}.get(dialect, "INTEGER") | |
def _sql_type_int64(dialect: str) -> str: | |
return {"postgres": "BIGINT", "sqlite": "INTEGER", "mysql": "BIGINT"}.get(dialect, "BIGINT") | |
def _sql_type_float(dialect: str) -> str: | |
return {"postgres": "DOUBLE PRECISION", "sqlite": "REAL", "mysql": "DOUBLE"}.get(dialect, "DOUBLE PRECISION") | |
def _sql_type_date(dialect: str) -> str: | |
return {"postgres": "DATE", "sqlite": "TEXT", "mysql": "DATE"}.get(dialect, "DATE") | |
def _sql_type_timestamp(dialect: str) -> str: | |
return {"postgres": "TIMESTAMP", "sqlite": "TEXT", "mysql": "DATETIME"}.get(dialect, "TIMESTAMP") | |
def _sql_type_text(max_len: int, dialect: str) -> str: | |
""" | |
Return a textual SQL type with consideration for dialect and length. | |
- For Postgres/MySQL, prefer VARCHAR(n) for moderate lengths, else TEXT. | |
- For SQLite, TEXT is standard regardless of length. | |
""" | |
if dialect == "sqlite": | |
return "TEXT" | |
# Use VARCHAR for smaller columns, TEXT for large/unbounded | |
if max_len and max_len <= 65535: | |
return f"VARCHAR({max_len})" | |
if dialect == "mysql": | |
if max_len <= 65535: | |
return "TEXT" | |
elif max_len <= 16777215: | |
return "MEDIUMTEXT" | |
else: | |
return "LONGTEXT" | |
return "TEXT" | |
# ----------------------------- Core Logic ----------------------------- # | |
def infer_schema( | |
csv_path: Path, | |
table_name: str, | |
*, | |
schema: Optional[str] = None, | |
dialect: str = "postgres", | |
chunksize: int = 100_000, | |
encoding: Optional[str] = None, | |
delimiter: Optional[str] = None, | |
date_format: Optional[str] = None, | |
) -> Dict[str, ColStats]: | |
logger.info(f"Inferring schema from CSV: {csv_path}") | |
col_stats: Dict[str, ColStats] = {} | |
# Auto-detect delimiter if requested/unspecified and choose fastest engine | |
delimiter_resolved = delimiter | |
if delimiter_resolved in (None, "auto"): | |
delimiter_resolved = _detect_delimiter(csv_path, encoding=encoding) | |
logger.info(f"Detected delimiter: '{delimiter_resolved}'") | |
read_kwargs = dict( | |
chunksize=chunksize, | |
encoding=encoding, | |
sep=delimiter_resolved, | |
engine="c", | |
dtype=str, | |
keep_default_na=True, | |
na_values=NA_STRINGS, | |
) | |
first_chunk = True | |
for chunk in pd.read_csv(csv_path, **read_kwargs): | |
logger.debug(f"Processing new chunk of size {len(chunk)}") | |
if first_chunk: | |
for col in chunk.columns: | |
col_stats[col] = ColStats(name=col) | |
first_chunk = False | |
for col in chunk.columns: | |
try: | |
col_stats[col].update_from_series(chunk[col], date_format=date_format) | |
except Exception as e: | |
logger.warning(f"Column '{col}': inference error ({e}); treating as TEXT.") | |
cs = col_stats[col] | |
cs.could_be_bool = cs.could_be_int = cs.could_be_float = cs.could_be_date = cs.could_be_timestamp = False | |
nn = chunk[col].notna() | |
cs.non_null_count += int(nn.sum()) | |
cs.null_count += int((~nn).sum()) | |
non_null_values = chunk[col][nn] | |
max_len_this = int(non_null_values.astype(str).str.len().max()) if not non_null_values.empty else 0 | |
cs.max_len = max(cs.max_len, max_len_this) | |
if first_chunk: | |
logger.error("CSV appears empty or has no header.") | |
raise ValueError("CSV appears empty or has no header.") | |
logger.info("Schema inference complete.") | |
return col_stats | |
def build_create_table_sql( | |
table_name: str, | |
col_stats: Dict[str, ColStats], | |
*, | |
schema: Optional[str] = None, | |
dialect: str = "postgres", | |
) -> str: | |
fq_table = f"{schema}.{table_name}" if schema else table_name | |
logger.info(f"Building CREATE TABLE for {fq_table}") | |
lines: List[str] = [f"CREATE TABLE {fq_table} ("] | |
col_defs: List[str] = [] | |
for col_name, stats in col_stats.items(): | |
col_ident = quote_ident(col_name, dialect) | |
col_type = stats.inferred_sql_type(dialect) | |
null_sql = stats.nullability_sql() | |
logger.debug(f"Column {col_name}: {col_type} {null_sql}") | |
col_defs.append(f" {col_ident} {col_type} {null_sql}") | |
lines.append(",\n".join(col_defs)) | |
lines.append(");") | |
return "\n".join(lines) | |
def quote_ident(ident: str, dialect: str) -> str: | |
needs_quote = not ident.isidentifier() or ident[0].isdigit() or any(c in ident for c in ' -./:;"\'`') | |
if dialect == "mysql": | |
return f"`{ident}`" if needs_quote else ident | |
return f'"{ident}"' if needs_quote else ident | |
# ----------------------------- CLI ----------------------------- # | |
def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace: | |
p = argparse.ArgumentParser(description="Infer a SQL CREATE TABLE schema from a CSV with header") | |
p.add_argument("csv", type=Path, help="Path to input CSV file (with header)") | |
p.add_argument("--table", required=True, help="Target table name") | |
p.add_argument("--schema", default=None, help="SQL schema/namespace (e.g., public)") | |
p.add_argument("--dialect", choices=["postgres", "sqlite", "mysql"], default="postgres") | |
p.add_argument("--chunksize", type=int, default=100_000, help="Rows per chunk when streaming CSV") | |
p.add_argument("--encoding", default=None, help="File encoding, e.g., utf-8, cp932") | |
p.add_argument("--delimiter", default="auto", help="CSV delimiter: ',', '\t', ';', or 'auto'") | |
p.add_argument("--output", type=Path, default=None, help="Optional path to write .sql as file") | |
p.add_argument("--date-format", default=None, | |
help="Optional strptime format (e.g., '%Y-%m-%d' or '%Y-%m-%d %H:%M:%S') to parse dates/timestamps deterministically and silence parser warnings") | |
p.add_argument("--loglevel", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)") | |
return p.parse_args(argv) | |
def main(argv: Optional[List[str]] = None) -> None: | |
args = parse_args(argv) | |
logger.setLevel(args.loglevel.upper()) | |
stats = infer_schema( | |
args.csv, | |
args.table, | |
schema=args.schema, | |
dialect=args.dialect, | |
chunksize=args.chunksize, | |
encoding=args.encoding, | |
delimiter=args.delimiter, | |
date_format=args.date_format, | |
) | |
sql = build_create_table_sql(args.table, stats, schema=args.schema, dialect=args.dialect) | |
print(sql) | |
if args.output: | |
logger.info(f"Writing CREATE TABLE statement to {args.output}") | |
args.output.write_text(sql, encoding="utf-8") | |
# ----------------------------- Utilities ----------------------------- # | |
def _detect_delimiter(csv_path: Path, *, encoding: Optional[str] = None) -> str: | |
"""Detect a likely delimiter using csv.Sniffer; default to comma on failure.""" | |
try: | |
with open(csv_path, "r", encoding=encoding or "utf-8", errors="replace") as f: | |
sample = f.read(64 * 1024) | |
dialect = csv.Sniffer().sniff(sample, delimiters=[",", "\t", ";", "|", ":"]) | |
delim = dialect.delimiter | |
if delim: | |
return delim | |
except Exception as e: | |
logger.debug(f"Delimiter sniff failed, defaulting to comma: {e}") | |
return "," | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment