Skip to content

Instantly share code, notes, and snippets.

@yeiichi
Created September 11, 2025 00:03
Show Gist options
  • Save yeiichi/ae9da2b175a8fdf74cf4bb12d7a6166e to your computer and use it in GitHub Desktop.
Save yeiichi/ae9da2b175a8fdf74cf4bb12d7a6166e to your computer and use it in GitHub Desktop.
Analyze and infer SQL types for columns in CSV files
#!/usr/bin/env python3
"""
This module provides a class and utility functions to analyze and infer SQL
types for columns in CSV files via statistical and datatype examination.
It handles various datatypes like integers, floats, booleans, dates, and
timestamps while accommodating dialect-specific SQL type mappings. Additional
functionality includes updating statistics across data chunks, robustness
against missing or invalid data, and generation of NULL/NOT NULL constraints.
"""
from __future__ import annotations
import argparse
import csv
import logging
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional
import pandas as pd
import numpy as np
# ----------------------------- Logging Setup ----------------------------- #
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
)
logger = logging.getLogger(__name__)
# Quiet pandas' engine warning just in case; we also avoid it by detecting delimiter
warnings.filterwarnings(
"ignore",
message=r"^Falling back to the 'python' engine.*",
category=pd.errors.ParserWarning,
)
# Silence pandas' per-element datetime parse warning when no format is provided
warnings.filterwarnings(
"ignore",
message=r"^Could not infer format, so each element will be parsed individually, falling back to `dateutil`.*",
category=UserWarning,
)
# ----------------------------- Helpers & Types ----------------------------- #
BOOL_TRUE = {"true", "t", "1", "yes", "y"}
BOOL_FALSE = {"false", "f", "0", "no", "n"}
INT32_MIN, INT32_MAX = -2147483648, 2147483647
# Common NA/NULL spellings for CSVs
NA_STRINGS = ["", "na", "n/a", "null", "none", "NaN", "NA"]
# ... existing code ...
@dataclass
class ColStats:
"""
Per-column statistics accumulated across CSV chunks to infer SQL types.
"""
name: str
non_null_count: int = 0
null_count: int = 0
could_be_int: bool = True
could_be_float: bool = True
could_be_bool: bool = True
could_be_date: bool = True
could_be_timestamp: bool = True
int_min: Optional[int] = None
int_max: Optional[int] = None
max_len: int = 0
any_time_component: bool = False
def update_from_series(self, s: pd.Series, *, date_format: Optional[str] = None, **_kw: object) -> None:
"""
Update stats based on a pandas Series chunk for this column.
"""
logger.debug(f"Updating stats for column '{self.name}'")
nn_mask = s.notna()
self.non_null_count += int(nn_mask.sum())
self.null_count += int((~nn_mask).sum())
if self.non_null_count == 0:
return
non_null_values = s[nn_mask]
self._update_string_length(non_null_values)
if self.could_be_bool:
self._update_bool_candidate(non_null_values)
if self.could_be_int or self.could_be_float:
self._update_numeric_candidate(non_null_values)
if self.could_be_date or self.could_be_timestamp:
self._update_datetime_candidate(non_null_values, date_format=date_format)
def _update_string_length(self, non_null_values: pd.Series) -> None:
max_len_this = int(non_null_values.astype(str).str.len().max()) if not non_null_values.empty else 0
self.max_len = max(self.max_len, max_len_this)
def _update_bool_candidate(self, non_null_values: pd.Series) -> None:
lowered = non_null_values.astype(str).str.strip().str.lower()
is_bool_like = lowered.isin(BOOL_TRUE | BOOL_FALSE)
if not bool(is_bool_like.all()):
self.could_be_bool = False
def _update_numeric_candidate(self, non_null_values: pd.Series) -> None:
coerced = pd.to_numeric(non_null_values, errors="coerce")
if coerced.isna().any():
self.could_be_int = False
self.could_be_float = False
return
# Robust integer-like check with tolerance to floating precision
if self.could_be_int:
values = coerced.to_numpy(dtype=float, copy=False)
finite_mask = np.isfinite(values)
if not finite_mask.all():
# Shouldn't happen after isna() check, but be defensive
self.could_be_int = False
else:
is_integer_like = np.isclose(values, np.round(values))
if not bool(is_integer_like.all()):
self.could_be_int = False
if self.could_be_int:
# Update min/max using rounded integer-like values without narrowing dtype
ints = np.round(coerced.to_numpy(dtype=float, copy=False)).astype(np.int64, copy=False)
if ints.size > 0:
mn = int(ints.min())
mx = int(ints.max())
self.int_min = mn if self.int_min is None else min(self.int_min, mn)
self.int_max = mx if self.int_max is None else max(self.int_max, mx)
if not self.could_be_int:
# If int is ruled out but numeric parsing succeeded, float remains possible
self.could_be_float = True
def _update_datetime_candidate(self, non_null_values: pd.Series, *, date_format: Optional[str]) -> None:
parsed = pd.to_datetime(
non_null_values,
errors="coerce",
utc=False,
format=date_format if date_format else None,
)
if parsed.isna().any():
self.could_be_date = False
self.could_be_timestamp = False
return
hours = parsed.dt.hour
mins = parsed.dt.minute
secs = parsed.dt.second
self.any_time_component = bool(((hours != 0) | (mins != 0) | (secs != 0)).any())
def inferred_sql_type(self, dialect: str = "postgres") -> str:
if self.non_null_count == 0:
return _sql_type_text(self.max_len, dialect)
if self.could_be_bool and not (self.could_be_int or self.could_be_float):
return _sql_type_bool(dialect)
if self.could_be_int:
if self.int_min is not None and self.int_max is not None:
if INT32_MIN <= self.int_min <= self.int_max <= INT32_MAX:
return _sql_type_int32(dialect)
else:
return _sql_type_int64(dialect)
else:
return _sql_type_int64(dialect)
if self.could_be_float:
return _sql_type_float(dialect)
if self.could_be_date or self.could_be_timestamp:
return _sql_type_timestamp(dialect) if self.any_time_component else _sql_type_date(dialect)
return _sql_type_text(self.max_len, dialect)
def nullability_sql(self) -> str:
return "NULL" if self.null_count > 0 else "NOT NULL"
# ----------------------------- SQL Mappings ----------------------------- #
def _sql_type_bool(dialect: str) -> str:
return {"postgres": "BOOLEAN", "sqlite": "INTEGER", "mysql": "BOOLEAN"}.get(dialect, "BOOLEAN")
def _sql_type_int32(dialect: str) -> str:
return {"postgres": "INTEGER", "sqlite": "INTEGER", "mysql": "INT"}.get(dialect, "INTEGER")
def _sql_type_int64(dialect: str) -> str:
return {"postgres": "BIGINT", "sqlite": "INTEGER", "mysql": "BIGINT"}.get(dialect, "BIGINT")
def _sql_type_float(dialect: str) -> str:
return {"postgres": "DOUBLE PRECISION", "sqlite": "REAL", "mysql": "DOUBLE"}.get(dialect, "DOUBLE PRECISION")
def _sql_type_date(dialect: str) -> str:
return {"postgres": "DATE", "sqlite": "TEXT", "mysql": "DATE"}.get(dialect, "DATE")
def _sql_type_timestamp(dialect: str) -> str:
return {"postgres": "TIMESTAMP", "sqlite": "TEXT", "mysql": "DATETIME"}.get(dialect, "TIMESTAMP")
def _sql_type_text(max_len: int, dialect: str) -> str:
"""
Return a textual SQL type with consideration for dialect and length.
- For Postgres/MySQL, prefer VARCHAR(n) for moderate lengths, else TEXT.
- For SQLite, TEXT is standard regardless of length.
"""
if dialect == "sqlite":
return "TEXT"
# Use VARCHAR for smaller columns, TEXT for large/unbounded
if max_len and max_len <= 65535:
return f"VARCHAR({max_len})"
if dialect == "mysql":
if max_len <= 65535:
return "TEXT"
elif max_len <= 16777215:
return "MEDIUMTEXT"
else:
return "LONGTEXT"
return "TEXT"
# ----------------------------- Core Logic ----------------------------- #
def infer_schema(
csv_path: Path,
table_name: str,
*,
schema: Optional[str] = None,
dialect: str = "postgres",
chunksize: int = 100_000,
encoding: Optional[str] = None,
delimiter: Optional[str] = None,
date_format: Optional[str] = None,
) -> Dict[str, ColStats]:
logger.info(f"Inferring schema from CSV: {csv_path}")
col_stats: Dict[str, ColStats] = {}
# Auto-detect delimiter if requested/unspecified and choose fastest engine
delimiter_resolved = delimiter
if delimiter_resolved in (None, "auto"):
delimiter_resolved = _detect_delimiter(csv_path, encoding=encoding)
logger.info(f"Detected delimiter: '{delimiter_resolved}'")
read_kwargs = dict(
chunksize=chunksize,
encoding=encoding,
sep=delimiter_resolved,
engine="c",
dtype=str,
keep_default_na=True,
na_values=NA_STRINGS,
)
first_chunk = True
for chunk in pd.read_csv(csv_path, **read_kwargs):
logger.debug(f"Processing new chunk of size {len(chunk)}")
if first_chunk:
for col in chunk.columns:
col_stats[col] = ColStats(name=col)
first_chunk = False
for col in chunk.columns:
try:
col_stats[col].update_from_series(chunk[col], date_format=date_format)
except Exception as e:
logger.warning(f"Column '{col}': inference error ({e}); treating as TEXT.")
cs = col_stats[col]
cs.could_be_bool = cs.could_be_int = cs.could_be_float = cs.could_be_date = cs.could_be_timestamp = False
nn = chunk[col].notna()
cs.non_null_count += int(nn.sum())
cs.null_count += int((~nn).sum())
non_null_values = chunk[col][nn]
max_len_this = int(non_null_values.astype(str).str.len().max()) if not non_null_values.empty else 0
cs.max_len = max(cs.max_len, max_len_this)
if first_chunk:
logger.error("CSV appears empty or has no header.")
raise ValueError("CSV appears empty or has no header.")
logger.info("Schema inference complete.")
return col_stats
def build_create_table_sql(
table_name: str,
col_stats: Dict[str, ColStats],
*,
schema: Optional[str] = None,
dialect: str = "postgres",
) -> str:
fq_table = f"{schema}.{table_name}" if schema else table_name
logger.info(f"Building CREATE TABLE for {fq_table}")
lines: List[str] = [f"CREATE TABLE {fq_table} ("]
col_defs: List[str] = []
for col_name, stats in col_stats.items():
col_ident = quote_ident(col_name, dialect)
col_type = stats.inferred_sql_type(dialect)
null_sql = stats.nullability_sql()
logger.debug(f"Column {col_name}: {col_type} {null_sql}")
col_defs.append(f" {col_ident} {col_type} {null_sql}")
lines.append(",\n".join(col_defs))
lines.append(");")
return "\n".join(lines)
def quote_ident(ident: str, dialect: str) -> str:
needs_quote = not ident.isidentifier() or ident[0].isdigit() or any(c in ident for c in ' -./:;"\'`')
if dialect == "mysql":
return f"`{ident}`" if needs_quote else ident
return f'"{ident}"' if needs_quote else ident
# ----------------------------- CLI ----------------------------- #
def parse_args(argv: Optional[List[str]] = None) -> argparse.Namespace:
p = argparse.ArgumentParser(description="Infer a SQL CREATE TABLE schema from a CSV with header")
p.add_argument("csv", type=Path, help="Path to input CSV file (with header)")
p.add_argument("--table", required=True, help="Target table name")
p.add_argument("--schema", default=None, help="SQL schema/namespace (e.g., public)")
p.add_argument("--dialect", choices=["postgres", "sqlite", "mysql"], default="postgres")
p.add_argument("--chunksize", type=int, default=100_000, help="Rows per chunk when streaming CSV")
p.add_argument("--encoding", default=None, help="File encoding, e.g., utf-8, cp932")
p.add_argument("--delimiter", default="auto", help="CSV delimiter: ',', '\t', ';', or 'auto'")
p.add_argument("--output", type=Path, default=None, help="Optional path to write .sql as file")
p.add_argument("--date-format", default=None,
help="Optional strptime format (e.g., '%Y-%m-%d' or '%Y-%m-%d %H:%M:%S') to parse dates/timestamps deterministically and silence parser warnings")
p.add_argument("--loglevel", default="INFO", help="Logging level (DEBUG, INFO, WARNING, ERROR)")
return p.parse_args(argv)
def main(argv: Optional[List[str]] = None) -> None:
args = parse_args(argv)
logger.setLevel(args.loglevel.upper())
stats = infer_schema(
args.csv,
args.table,
schema=args.schema,
dialect=args.dialect,
chunksize=args.chunksize,
encoding=args.encoding,
delimiter=args.delimiter,
date_format=args.date_format,
)
sql = build_create_table_sql(args.table, stats, schema=args.schema, dialect=args.dialect)
print(sql)
if args.output:
logger.info(f"Writing CREATE TABLE statement to {args.output}")
args.output.write_text(sql, encoding="utf-8")
# ----------------------------- Utilities ----------------------------- #
def _detect_delimiter(csv_path: Path, *, encoding: Optional[str] = None) -> str:
"""Detect a likely delimiter using csv.Sniffer; default to comma on failure."""
try:
with open(csv_path, "r", encoding=encoding or "utf-8", errors="replace") as f:
sample = f.read(64 * 1024)
dialect = csv.Sniffer().sniff(sample, delimiters=[",", "\t", ";", "|", ":"])
delim = dialect.delimiter
if delim:
return delim
except Exception as e:
logger.debug(f"Delimiter sniff failed, defaulting to comma: {e}")
return ","
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment