Skip to content

Instantly share code, notes, and snippets.

@Diyago
Last active August 19, 2025 19:06
Show Gist options
  • Select an option

  • Save Diyago/961eba6ff68c43874d7d83de04458cfe to your computer and use it in GitHub Desktop.

Select an option

Save Diyago/961eba6ff68c43874d7d83de04458cfe to your computer and use it in GitHub Desktop.
Tabm benchmark
"""
LightGBM vs TabM binary-classification benchmark with optimized training
Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm
!git clone https://github.com/Diyago/Tabular-data-generation.git
!mv Tabular-data-generation/Research/data/* data/
"""
"""
LightGBM vs TabM vs RealMLP binary-classification benchmark with optimized training
Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm "pytabkit[models]"
"""
import gc
import warnings
import time
from pathlib import Path
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OrdinalEncoder
from sklearn.metrics import roc_auc_score
import lightgbm as lgb
from tabm import TabM
from pytabkit import RealMLP_TD_Classifier
warnings.filterwarnings("ignore")
# Set CUDA optimizations
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# Configuration
LGB_PARAMS = {
'objective': 'binary',
'metric': 'auc',
'learning_rate': 0.1,
'bagging_fraction': 0.8,
'bagging_freq': 5,
'verbose': -1,
'early_stopping_rounds': 50,
'verbosity': -1,
'max_depth': 3,
}
DATASETS = [
"telecom",
"poverty_A"
#"employee", "credit", "adult", "mortgages",
]
EPOCHS = 110
TRAIN_PROP = 0.2
# Helper Functions
def cast_cats(df):
"""Convert categorical columns to ordered categories."""
for col in df.filter(like="cat"):
df[col] = df[col].astype("category").cat.as_ordered()
return df
def fillna(df):
"""Fill missing values with mean for numeric, mode for categorical."""
out = df.copy()
for col in out.columns:
if pd.api.types.is_numeric_dtype(out[col]):
out[col] = out[col].fillna(out[col].mean())
else:
mode = out[col].mode()
if not mode.empty:
out[col] = out[col].fillna(mode.iloc[0])
return out
def ensure_series(y):
"""Ensure target is a pandas Series."""
if isinstance(y, pd.DataFrame):
y = y.iloc[:, 0]
return y if isinstance(y, pd.Series) else pd.Series(y)
def prepare_data(dataset_name):
"""Load and preprocess dataset."""
print(f"\nProcessing dataset: {dataset_name}")
# Assuming data is in a subfolder, e.g., './data/telecom/telecom.gz'
# You might need to adjust this path based on your directory structure.
dataset_path = Path(f"./data/{dataset_name}/{dataset_name}.gz")
if not dataset_path.exists():
raise FileNotFoundError(f"Dataset not found at {dataset_path}. Please ensure it's downloaded and in the correct directory.")
data = pd.read_csv(dataset_path)
data = fillna(cast_cats(data))
target_col = ensure_series(data["target"])
X = data.drop(columns=["target"])
cat_cols = [col for col in X.columns if col.startswith("cat")]
num_cols = [col for col in X.columns if col not in cat_cols]
print(f"Categorical columns: {len(cat_cols)}, Numerical columns: {len(num_cols)}")
return data, X, target_col, cat_cols, num_cols
def create_splits(X, y, train_prop=TRAIN_PROP):
"""Create train/validation/test splits."""
# Create train/val/test splits
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.4, shuffle=False, random_state=42
)
X_train, X_val, y_train, y_val = train_test_split(
X_train, y_train, test_size=0.4, random_state=1
)
# Ensure all targets are Series and reset indices
y_train = ensure_series(y_train).reset_index(drop=True)
y_val = ensure_series(y_val).reset_index(drop=True)
y_test = ensure_series(y_test).reset_index(drop=True)
X_train = X_train.reset_index(drop=True)
X_val = X_val.reset_index(drop=True)
X_test = X_test.reset_index(drop=True)
# Reduce training size
cut = int(len(X_train) * train_prop)
X_train_small = X_train.iloc[:cut].reset_index(drop=True)
y_train_small = ensure_series(y_train.iloc[:cut]).reset_index(drop=True)
return {
'X_train': X_train, 'X_val': X_val, 'X_test': X_test,
'y_train': y_train, 'y_val': y_val, 'y_test': y_test,
'X_train_small': X_train_small, 'y_train_small': y_train_small
}
def train_evaluate_lgb(splits, params=LGB_PARAMS):
"""Train and evaluate LightGBM model."""
print("Training LightGBM...")
start_time = time.time()
lgb_train = lgb.Dataset(splits['X_train_small'], label=splits['y_train_small'], params={"verbosity": -1})
lgb_val = lgb.Dataset(splits['X_val'], label=splits['y_val'], params={"verbosity": -1})
model = lgb.train(params, lgb_train, valid_sets=[lgb_val])
training_time = time.time() - start_time
# Predictions and AUC scores
pred_train = model.predict(splits['X_train_small'])
pred_val = model.predict(splits['X_val'])
pred_test = model.predict(splits['X_test'])
auc_train = roc_auc_score(splits['y_train_small'], pred_train)
auc_val = roc_auc_score(splits['y_val'], pred_val)
auc_test = roc_auc_score(splits['y_test'], pred_test)
return {
'model': model,
'training_time': training_time,
'auc_train': auc_train,
'auc_val': auc_val,
'auc_test': auc_test
}
def train_evaluate_realmlp(splits, cat_cols):
"""Train and evaluate RealMLP_TD_Classifier from pytabkit."""
print("Training RealMLP...")
start_time = time.time()
# RealMLP_TD_Classifier uses tuned defaults.
# It automatically detects and uses GPU if available.
# It also uses an internal validation split for early stopping if one is not provided.
model = RealMLP_TD_Classifier(
val_metric_name='1-auc_ovr', # Optimize for AUC (minimize 1-AUC)
verbosity=0 # Suppress verbose output
)
# The fit method is scikit-learn compatible. We provide X_val/y_val for early stopping.
# The cast_cats function prepares the dataframe correctly for auto-detection,
# but passing cat_col_names is a good practice.
model.fit(
splits['X_train_small'],
splits['y_train_small'],
X_val=splits['X_val'],
y_val=splits['y_val'],
cat_col_names=cat_cols
)
training_time = time.time() - start_time
# Predictions (get probability for the positive class)
pred_train = model.predict_proba(splits['X_train_small'])[:, 1]
pred_val = model.predict_proba(splits['X_val'])[:, 1]
pred_test = model.predict_proba(splits['X_test'])[:, 1]
# AUC scores
auc_train = roc_auc_score(splits['y_train_small'], pred_train)
auc_val = roc_auc_score(splits['y_val'], pred_val)
auc_test = roc_auc_score(splits['y_test'], pred_test)
return {
'model': model,
'training_time': training_time,
'auc_train': auc_train,
'auc_val': auc_val,
'auc_test': auc_test
}
def prepare_tabm_data(data, splits, cat_cols, num_cols, train_prop=0.2):
"""
Prepare tensors for TabM with proper handling of unknown categorical values.
"""
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split
# 1) Copy splits
Xo_train = splits['X_train'].copy().reset_index(drop=True)
Xo_val = splits['X_val'].copy().reset_index(drop=True)
Xo_test = splits['X_test'].copy().reset_index(drop=True)
# 2) Handle categorical encoding properly
if len(cat_cols) > 0:
# Convert to numpy arrays to avoid feature name issues
cat_train_np = Xo_train[cat_cols].to_numpy()
cat_val_np = Xo_val[cat_cols].to_numpy()
cat_test_np = Xo_test[cat_cols].to_numpy()
# First pass: determine max values for each column
temp_enc = OrdinalEncoder()
temp_enc.fit(cat_train_np)
encoded_train = temp_enc.transform(cat_train_np)
# Calculate safe unknown values (max + 1 for each column)
max_values = encoded_train.max(axis=0).astype(int)
unknown_values = max_values + 1
# Create final encoder with safe unknown values
# Use the max of all unknown values to avoid conflicts
safe_unknown_value = int(unknown_values.max()) + 1
final_enc = OrdinalEncoder(
handle_unknown='use_encoded_value',
unknown_value=safe_unknown_value
)
final_enc.fit(cat_train_np)
# Transform all datasets
cat_train_encoded = final_enc.transform(cat_train_np)
cat_val_encoded = final_enc.transform(cat_val_np)
cat_test_encoded = final_enc.transform(cat_test_np)
# Put back into dataframes
Xo_train[cat_cols] = cat_train_encoded
Xo_val[cat_cols] = cat_val_encoded
Xo_test[cat_cols] = cat_test_encoded
# 3) Targets (align indices, no transform)
yo_train = ensure_series(splits['y_train']).reset_index(drop=True)
yo_val = ensure_series(splits['y_val']).reset_index(drop=True)
yo_test = ensure_series(splits['y_test']).reset_index(drop=True)
# 4) Reduce training size
X_train_small, _, y_train_small, _ = train_test_split(
Xo_train, yo_train,
train_size=train_prop,
stratify=yo_train if yo_train.nunique() > 1 else None,
random_state=42
)
# 5) Convert to tensors with validation
def to_tensors(df, cat_cols, num_cols):
if len(num_cols) > 0:
x_num = torch.tensor(df[num_cols].to_numpy(copy=True), dtype=torch.float32).contiguous()
else:
x_num = torch.empty((len(df), 0), dtype=torch.float32)
if len(cat_cols) > 0:
cat_data = df[cat_cols].to_numpy(copy=True).astype(np.int64)
# Ensure all values are non-negative
if (cat_data < 0).any():
print(f"WARNING: Found negative values, min: {cat_data.min()}")
cat_data = np.clip(cat_data, 0, None)
x_cat = torch.tensor(cat_data, dtype=torch.long).contiguous()
else:
x_cat = torch.empty((len(df), 0), dtype=torch.long)
return x_num, x_cat
def make_y_tensor(y):
return torch.tensor(y.to_numpy(copy=True), dtype=torch.float32)
# Create tensors
xnum_train, xcat_train = to_tensors(X_train_small, cat_cols, num_cols)
y_train_tensor = make_y_tensor(y_train_small)
xnum_val, xcat_val = to_tensors(Xo_val, cat_cols, num_cols)
y_val_tensor = make_y_tensor(yo_val)
xnum_test, xcat_test = to_tensors(Xo_test, cat_cols, num_cols)
y_test_tensor = make_y_tensor(yo_test)
# 6) Calculate cardinalities (account for unknown values)
if len(cat_cols) > 0:
cat_cards = []
for col in cat_cols:
max_val = int(max(Xo_train[col].max(), Xo_val[col].max(), Xo_test[col].max()))
cat_cards.append(max_val + 1)
else:
cat_cards = []
print(f"Categorical cardinalities: {cat_cards}")
print(f"Categorical columns processed: {len(cat_cols)}")
return {
'tensors': {
'xnum_train': xnum_train, 'xcat_train': xcat_train,
'xnum_val': xnum_val, 'xcat_val': xcat_val,
'xnum_test': xnum_test, 'xcat_test': xcat_test,
'y_train': y_train_tensor,
'y_val': y_val_tensor,
'y_test': y_test_tensor,
},
'cat_cardinalities': cat_cards
}
# ... [The rest of the TabM functions (prepare_tensor_safe, train_tabm_model, etc.) remain unchanged] ...
def prepare_tensor_safe(tensor, device):
"""Safely prepare tensor for device, handling empty feature dimensions."""
if tensor is None:
return None
num_features = int(tensor.shape[1]) # Convert torch.Size to regular Python int
return tensor.to(device) if num_features > 0 else None
def train_tabm_model(xnum_tr, xcat_tr, y_tr, xnum_val, xcat_val, y_val, cat_cards,
d_out=1, epochs=100, bs=512, lr=0.002, wd=0.0003, patience=10):
"""Train TabM model with proper k-ensemble loss calculation"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device_str = f"GPU ({torch.cuda.get_device_name()})" if device.type == "cuda" else "CPU"
print(f"Training TabM on: {device_str}")
# Handle empty feature tensors
if xnum_tr.shape[1] == 0:
xnum_tr = None
xnum_val = None
n_num_features = 0
else:
xnum_tr = xnum_tr.to(device)
xnum_val = xnum_val.to(device)
n_num_features = xnum_tr.shape[1]
if xcat_tr.shape[1] == 0:
xcat_tr = None
xcat_val = None
cat_cardinalities = None
else:
xcat_tr = xcat_tr.to(device)
xcat_val = xcat_val.to(device)
cat_cardinalities = cat_cards if cat_cards else None
# Move targets to device
y_tr = y_tr.to(device)
y_val = y_val.to(device)
# Create TabM model - ADD DROPOUT!
model = TabM.make(
n_num_features=n_num_features,
cat_cardinalities=cat_cardinalities,
d_out=d_out,
k=32, # Number of ensemble members
n_blocks=4, # Number of MLP blocks
d_block=512, # Width of each MLP block
dropout=0.1, # CRITICAL: Add dropout for regularization
).to(device)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=lr,
weight_decay=wd,
)
# Base loss function for binary classification
base_loss_fn = nn.BCEWithLogitsLoss()
# TabM-specific loss function - CRITICAL FIX
def tabm_loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
TabM produces k predictions. Each must be trained separately.
Args:
y_pred: Shape (batch_size, k)
y_true: Shape (batch_size,)
"""
# Flatten predictions: (batch_size, k) -> (batch_size * k,)
y_pred_flat = y_pred.flatten(0, 1)
# Repeat targets k times: (batch_size,) -> (batch_size * k,)
y_true_repeated = y_true.repeat_interleave(model.backbone.k)
return base_loss_fn(y_pred_flat, y_true_repeated)
best_auc, wait = 0.0, 0
best_state = None
n = int(y_tr.size(0))
start_time = time.time()
for epoch in range(epochs):
model.train()
epoch_loss = 0.0
n_batches = 0
idx = torch.randperm(n, device=device)
for i in range(0, n, bs):
batch_idx = idx[i:i + bs]
optimizer.zero_grad()
if xnum_tr is None and xcat_tr is not None:
logits = model(x_cat=xcat_tr[batch_idx]) # Keep all k predictions
elif xcat_tr is None and xnum_tr is not None:
logits = model(x_num=xnum_tr[batch_idx]) # Keep all k predictions
else:
logits = model(xnum_tr[batch_idx], xcat_tr[batch_idx]) # Keep all k predictions
# Remove extra dimension if present
if len(logits.shape) == 3:
logits = logits.squeeze(-1) # (batch_size, k, 1) -> (batch_size, k)
# Use TabM-specific loss - CRITICAL FIX
loss = tabm_loss_fn(logits, y_tr[batch_idx])
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
epoch_loss += loss.item()
n_batches += 1
avg_loss = epoch_loss / n_batches
current_lr = optimizer.param_groups[0]['lr']
if epoch % 5 == 0:
val_probs = tabm_predict(model, xnum_val, xcat_val, device)
val_auc = roc_auc_score(y_val.cpu().numpy(), val_probs)
print(f"TabM epoch {epoch+1:02d} | loss: {avg_loss:.4f} | val_auc: {val_auc:.4f} | lr: {current_lr:.6f}")
# Proper early stopping
if val_auc > best_auc:
best_auc, best_state, wait = val_auc, model.state_dict(), 0
training_time = time.time() - start_time
# Load best model
if best_state is not None:
model.load_state_dict(best_state)
model.eval()
return model, training_time, device_str, device
def tabm_predict(model, x_num, x_cat, device):
"""Get predictions from TabM model."""
model.eval()
with torch.no_grad():
if x_num is None and x_cat is None:
raise ValueError("Both x_num and x_cat cannot be None")
elif x_num is None:
logits = model(x_cat=x_cat).mean(1).squeeze(1)
elif x_cat is None:
logits = model(x_num=x_num).mean(1).squeeze(1)
else:
logits = model(x_num, x_cat).mean(1).squeeze(1)
return torch.sigmoid(logits).cpu().numpy()
def evaluate_tabm(model, tensors, splits, device):
"""Evaluate TabM model and return AUC scores."""
# Move tensors to device using safe preparation
xnum_test = prepare_tensor_safe(tensors['xnum_test'], device)
xcat_test = prepare_tensor_safe(tensors['xcat_test'], device)
xnum_train = prepare_tensor_safe(tensors['xnum_train'], device)
xcat_train = prepare_tensor_safe(tensors['xcat_train'], device)
xnum_val = prepare_tensor_safe(tensors['xnum_val'], device)
xcat_val = prepare_tensor_safe(tensors['xcat_val'], device)
# Predictions
pred_train = tabm_predict(model, xnum_train, xcat_train, device)
pred_val = tabm_predict(model, xnum_val, xcat_val, device)
pred_test = tabm_predict(model, xnum_test, xcat_test, device)
# Correct ground truth
y_train_np = tensors['y_train'].cpu().numpy() # torch.Tensor
y_val_np = tensors['y_val'].cpu().numpy() # torch.Tensor
y_test_np = ensure_series(splits['y_test']).to_numpy() # Series -> np
# AUC
auc_train = roc_auc_score(y_train_np, pred_train)
auc_val = roc_auc_score(y_val_np, pred_val)
auc_test = roc_auc_score(y_test_np, pred_test)
return auc_train, auc_val, auc_test
# --- MAIN BENCHMARK LOGIC ---
def run_single_benchmark(dataset_name):
"""Run benchmark for a single dataset."""
# Prepare data
data, X, target_col, cat_cols, num_cols = prepare_data(dataset_name)
splits = create_splits(X, target_col)
# Train and evaluate LightGBM
lgb_results = train_evaluate_lgb(splits)
# Train and evaluate RealMLP
realmlp_results = train_evaluate_realmlp(splits, cat_cols)
# Train and evaluate TabM
tabm_data = prepare_tabm_data(data, splits, cat_cols, num_cols)
t = tabm_data['tensors']
tabm_model, tabm_training_time, tabm_device, device = train_tabm_model(
xnum_tr=t['xnum_train'],
xcat_tr=t['xcat_train'],
y_tr=t['y_train'],
xnum_val=t['xnum_val'],
xcat_val=t['xcat_val'],
y_val=t['y_val'],
cat_cards=tabm_data['cat_cardinalities'], epochs = EPOCHS
)
tabm_auc_train, tabm_auc_val, tabm_auc_test = evaluate_tabm(
tabm_model, tabm_data['tensors'], splits, device
)
# Compile results
results = {
"dataset": dataset_name,
"lgb_train": lgb_results['auc_train'],
"lgb_val": lgb_results['auc_val'],
"lgb_test": lgb_results['auc_test'],
"lgb_time_sec": lgb_results['training_time'],
"realmlp_train": realmlp_results['auc_train'],
"realmlp_val": realmlp_results['auc_val'],
"realmlp_test": realmlp_results['auc_test'],
"realmlp_time_sec": realmlp_results['training_time'],
"tabm_train": tabm_auc_train,
"tabm_val": tabm_auc_val,
"tabm_test": tabm_auc_test,
"tabm_time_sec": tabm_training_time,
"tabm_device": tabm_device,
}
# Print results
print(f"\n--- Results for {dataset_name} ---")
print(f"LGB - Train: {lgb_results['auc_train']:.4f}, Val: {lgb_results['auc_val']:.4f}, Test: {lgb_results['auc_test']:.4f}, Time: {lgb_results['training_time']:.2f}s")
print(f"RealMLP - Train: {realmlp_results['auc_train']:.4f}, Val: {realmlp_results['auc_val']:.4f}, Test: {realmlp_results['auc_test']:.4f}, Time: {realmlp_results['training_time']:.2f}s")
print(f"TabM - Train: {tabm_auc_train:.4f}, Val: {tabm_auc_val:.4f}, Test: {tabm_auc_test:.4f}, Time: {tabm_training_time:.2f}s, Device: {tabm_device}")
# Clean up GPU memory
del lgb_results, realmlp_results, tabm_model, tabm_data
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
return results
def display_results(results):
"""Display final benchmark results."""
results_df = pd.DataFrame(results)
# Define columns to display for clarity
display_cols = [
"dataset",
"lgb_test", "realmlp_test", "tabm_test",
"lgb_time_sec", "realmlp_time_sec", "tabm_time_sec",
"tabm_device"
]
# Ensure all columns exist before selecting
display_cols = [col for col in display_cols if col in results_df.columns]
print("\n" + "="*120)
print("FINAL RESULTS COMPARISON WITH TIMING AND DEVICE INFO")
print("="*120)
print(results_df[display_cols].to_string())
# Summary statistics
print("\nSUMMARY (Test Set):")
print(f"Average LGB Test AUC: {results_df['lgb_test'].mean():.4f} ± {results_df['lgb_test'].std():.4f}")
print(f"Average RealMLP Test AUC: {results_df['realmlp_test'].mean():.4f} ± {results_df['realmlp_test'].std():.4f}")
print(f"Average TabM Test AUC: {results_df['tabm_test'].mean():.4f} ± {results_df['tabm_test'].std():.4f}")
print(f"\nAverage LGB Time: {results_df['lgb_time_sec'].mean():.2f}s")
print(f"Average RealMLP Time: {results_df['realmlp_time_sec'].mean():.2f}s")
print(f"Average TabM Time: {results_df['tabm_time_sec'].mean():.2f}s")
def main():
"""Main benchmark execution."""
# Print device info
if torch.cuda.is_available():
print(f"CUDA available: {torch.cuda.get_device_name()}")
print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
print("CUDA not available, using CPU")
all_results = []
for dataset_name in tqdm(DATASETS, desc="Overall Progress"):
try:
result = run_single_benchmark(dataset_name)
all_results.append(result)
except Exception as e:
print(f"\nERROR processing dataset {dataset_name}: {e}")
continue
if all_results:
display_results(all_results)
else:
print("\nNo benchmarks were successfully completed.")
return all_results
if __name__ == "__main__":
results = main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment