Last active
August 19, 2025 19:06
-
-
Save Diyago/961eba6ff68c43874d7d83de04458cfe to your computer and use it in GitHub Desktop.
Tabm benchmark
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| LightGBM vs TabM binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm | |
| !git clone https://github.com/Diyago/Tabular-data-generation.git | |
| !mv Tabular-data-generation/Research/data/* data/ | |
| """ | |
| """ | |
| LightGBM vs TabM vs RealMLP binary-classification benchmark with optimized training | |
| Requirements: pip install lightgbm tabm torch pandas scikit-learn tqdm "pytabkit[models]" | |
| """ | |
| import gc | |
| import warnings | |
| import time | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| import pandas as pd | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.preprocessing import OrdinalEncoder | |
| from sklearn.metrics import roc_auc_score | |
| import lightgbm as lgb | |
| from tabm import TabM | |
| from pytabkit import RealMLP_TD_Classifier | |
| warnings.filterwarnings("ignore") | |
| # Set CUDA optimizations | |
| if torch.cuda.is_available(): | |
| torch.backends.cudnn.benchmark = True | |
| torch.backends.cudnn.deterministic = False | |
| # Configuration | |
| LGB_PARAMS = { | |
| 'objective': 'binary', | |
| 'metric': 'auc', | |
| 'learning_rate': 0.1, | |
| 'bagging_fraction': 0.8, | |
| 'bagging_freq': 5, | |
| 'verbose': -1, | |
| 'early_stopping_rounds': 50, | |
| 'verbosity': -1, | |
| 'max_depth': 3, | |
| } | |
| DATASETS = [ | |
| "telecom", | |
| "poverty_A" | |
| #"employee", "credit", "adult", "mortgages", | |
| ] | |
| EPOCHS = 110 | |
| TRAIN_PROP = 0.2 | |
| # Helper Functions | |
| def cast_cats(df): | |
| """Convert categorical columns to ordered categories.""" | |
| for col in df.filter(like="cat"): | |
| df[col] = df[col].astype("category").cat.as_ordered() | |
| return df | |
| def fillna(df): | |
| """Fill missing values with mean for numeric, mode for categorical.""" | |
| out = df.copy() | |
| for col in out.columns: | |
| if pd.api.types.is_numeric_dtype(out[col]): | |
| out[col] = out[col].fillna(out[col].mean()) | |
| else: | |
| mode = out[col].mode() | |
| if not mode.empty: | |
| out[col] = out[col].fillna(mode.iloc[0]) | |
| return out | |
| def ensure_series(y): | |
| """Ensure target is a pandas Series.""" | |
| if isinstance(y, pd.DataFrame): | |
| y = y.iloc[:, 0] | |
| return y if isinstance(y, pd.Series) else pd.Series(y) | |
| def prepare_data(dataset_name): | |
| """Load and preprocess dataset.""" | |
| print(f"\nProcessing dataset: {dataset_name}") | |
| # Assuming data is in a subfolder, e.g., './data/telecom/telecom.gz' | |
| # You might need to adjust this path based on your directory structure. | |
| dataset_path = Path(f"./data/{dataset_name}/{dataset_name}.gz") | |
| if not dataset_path.exists(): | |
| raise FileNotFoundError(f"Dataset not found at {dataset_path}. Please ensure it's downloaded and in the correct directory.") | |
| data = pd.read_csv(dataset_path) | |
| data = fillna(cast_cats(data)) | |
| target_col = ensure_series(data["target"]) | |
| X = data.drop(columns=["target"]) | |
| cat_cols = [col for col in X.columns if col.startswith("cat")] | |
| num_cols = [col for col in X.columns if col not in cat_cols] | |
| print(f"Categorical columns: {len(cat_cols)}, Numerical columns: {len(num_cols)}") | |
| return data, X, target_col, cat_cols, num_cols | |
| def create_splits(X, y, train_prop=TRAIN_PROP): | |
| """Create train/validation/test splits.""" | |
| # Create train/val/test splits | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.4, shuffle=False, random_state=42 | |
| ) | |
| X_train, X_val, y_train, y_val = train_test_split( | |
| X_train, y_train, test_size=0.4, random_state=1 | |
| ) | |
| # Ensure all targets are Series and reset indices | |
| y_train = ensure_series(y_train).reset_index(drop=True) | |
| y_val = ensure_series(y_val).reset_index(drop=True) | |
| y_test = ensure_series(y_test).reset_index(drop=True) | |
| X_train = X_train.reset_index(drop=True) | |
| X_val = X_val.reset_index(drop=True) | |
| X_test = X_test.reset_index(drop=True) | |
| # Reduce training size | |
| cut = int(len(X_train) * train_prop) | |
| X_train_small = X_train.iloc[:cut].reset_index(drop=True) | |
| y_train_small = ensure_series(y_train.iloc[:cut]).reset_index(drop=True) | |
| return { | |
| 'X_train': X_train, 'X_val': X_val, 'X_test': X_test, | |
| 'y_train': y_train, 'y_val': y_val, 'y_test': y_test, | |
| 'X_train_small': X_train_small, 'y_train_small': y_train_small | |
| } | |
| def train_evaluate_lgb(splits, params=LGB_PARAMS): | |
| """Train and evaluate LightGBM model.""" | |
| print("Training LightGBM...") | |
| start_time = time.time() | |
| lgb_train = lgb.Dataset(splits['X_train_small'], label=splits['y_train_small'], params={"verbosity": -1}) | |
| lgb_val = lgb.Dataset(splits['X_val'], label=splits['y_val'], params={"verbosity": -1}) | |
| model = lgb.train(params, lgb_train, valid_sets=[lgb_val]) | |
| training_time = time.time() - start_time | |
| # Predictions and AUC scores | |
| pred_train = model.predict(splits['X_train_small']) | |
| pred_val = model.predict(splits['X_val']) | |
| pred_test = model.predict(splits['X_test']) | |
| auc_train = roc_auc_score(splits['y_train_small'], pred_train) | |
| auc_val = roc_auc_score(splits['y_val'], pred_val) | |
| auc_test = roc_auc_score(splits['y_test'], pred_test) | |
| return { | |
| 'model': model, | |
| 'training_time': training_time, | |
| 'auc_train': auc_train, | |
| 'auc_val': auc_val, | |
| 'auc_test': auc_test | |
| } | |
| def train_evaluate_realmlp(splits, cat_cols): | |
| """Train and evaluate RealMLP_TD_Classifier from pytabkit.""" | |
| print("Training RealMLP...") | |
| start_time = time.time() | |
| # RealMLP_TD_Classifier uses tuned defaults. | |
| # It automatically detects and uses GPU if available. | |
| # It also uses an internal validation split for early stopping if one is not provided. | |
| model = RealMLP_TD_Classifier( | |
| val_metric_name='1-auc_ovr', # Optimize for AUC (minimize 1-AUC) | |
| verbosity=0 # Suppress verbose output | |
| ) | |
| # The fit method is scikit-learn compatible. We provide X_val/y_val for early stopping. | |
| # The cast_cats function prepares the dataframe correctly for auto-detection, | |
| # but passing cat_col_names is a good practice. | |
| model.fit( | |
| splits['X_train_small'], | |
| splits['y_train_small'], | |
| X_val=splits['X_val'], | |
| y_val=splits['y_val'], | |
| cat_col_names=cat_cols | |
| ) | |
| training_time = time.time() - start_time | |
| # Predictions (get probability for the positive class) | |
| pred_train = model.predict_proba(splits['X_train_small'])[:, 1] | |
| pred_val = model.predict_proba(splits['X_val'])[:, 1] | |
| pred_test = model.predict_proba(splits['X_test'])[:, 1] | |
| # AUC scores | |
| auc_train = roc_auc_score(splits['y_train_small'], pred_train) | |
| auc_val = roc_auc_score(splits['y_val'], pred_val) | |
| auc_test = roc_auc_score(splits['y_test'], pred_test) | |
| return { | |
| 'model': model, | |
| 'training_time': training_time, | |
| 'auc_train': auc_train, | |
| 'auc_val': auc_val, | |
| 'auc_test': auc_test | |
| } | |
| def prepare_tabm_data(data, splits, cat_cols, num_cols, train_prop=0.2): | |
| """ | |
| Prepare tensors for TabM with proper handling of unknown categorical values. | |
| """ | |
| from sklearn.preprocessing import OrdinalEncoder | |
| from sklearn.model_selection import train_test_split | |
| # 1) Copy splits | |
| Xo_train = splits['X_train'].copy().reset_index(drop=True) | |
| Xo_val = splits['X_val'].copy().reset_index(drop=True) | |
| Xo_test = splits['X_test'].copy().reset_index(drop=True) | |
| # 2) Handle categorical encoding properly | |
| if len(cat_cols) > 0: | |
| # Convert to numpy arrays to avoid feature name issues | |
| cat_train_np = Xo_train[cat_cols].to_numpy() | |
| cat_val_np = Xo_val[cat_cols].to_numpy() | |
| cat_test_np = Xo_test[cat_cols].to_numpy() | |
| # First pass: determine max values for each column | |
| temp_enc = OrdinalEncoder() | |
| temp_enc.fit(cat_train_np) | |
| encoded_train = temp_enc.transform(cat_train_np) | |
| # Calculate safe unknown values (max + 1 for each column) | |
| max_values = encoded_train.max(axis=0).astype(int) | |
| unknown_values = max_values + 1 | |
| # Create final encoder with safe unknown values | |
| # Use the max of all unknown values to avoid conflicts | |
| safe_unknown_value = int(unknown_values.max()) + 1 | |
| final_enc = OrdinalEncoder( | |
| handle_unknown='use_encoded_value', | |
| unknown_value=safe_unknown_value | |
| ) | |
| final_enc.fit(cat_train_np) | |
| # Transform all datasets | |
| cat_train_encoded = final_enc.transform(cat_train_np) | |
| cat_val_encoded = final_enc.transform(cat_val_np) | |
| cat_test_encoded = final_enc.transform(cat_test_np) | |
| # Put back into dataframes | |
| Xo_train[cat_cols] = cat_train_encoded | |
| Xo_val[cat_cols] = cat_val_encoded | |
| Xo_test[cat_cols] = cat_test_encoded | |
| # 3) Targets (align indices, no transform) | |
| yo_train = ensure_series(splits['y_train']).reset_index(drop=True) | |
| yo_val = ensure_series(splits['y_val']).reset_index(drop=True) | |
| yo_test = ensure_series(splits['y_test']).reset_index(drop=True) | |
| # 4) Reduce training size | |
| X_train_small, _, y_train_small, _ = train_test_split( | |
| Xo_train, yo_train, | |
| train_size=train_prop, | |
| stratify=yo_train if yo_train.nunique() > 1 else None, | |
| random_state=42 | |
| ) | |
| # 5) Convert to tensors with validation | |
| def to_tensors(df, cat_cols, num_cols): | |
| if len(num_cols) > 0: | |
| x_num = torch.tensor(df[num_cols].to_numpy(copy=True), dtype=torch.float32).contiguous() | |
| else: | |
| x_num = torch.empty((len(df), 0), dtype=torch.float32) | |
| if len(cat_cols) > 0: | |
| cat_data = df[cat_cols].to_numpy(copy=True).astype(np.int64) | |
| # Ensure all values are non-negative | |
| if (cat_data < 0).any(): | |
| print(f"WARNING: Found negative values, min: {cat_data.min()}") | |
| cat_data = np.clip(cat_data, 0, None) | |
| x_cat = torch.tensor(cat_data, dtype=torch.long).contiguous() | |
| else: | |
| x_cat = torch.empty((len(df), 0), dtype=torch.long) | |
| return x_num, x_cat | |
| def make_y_tensor(y): | |
| return torch.tensor(y.to_numpy(copy=True), dtype=torch.float32) | |
| # Create tensors | |
| xnum_train, xcat_train = to_tensors(X_train_small, cat_cols, num_cols) | |
| y_train_tensor = make_y_tensor(y_train_small) | |
| xnum_val, xcat_val = to_tensors(Xo_val, cat_cols, num_cols) | |
| y_val_tensor = make_y_tensor(yo_val) | |
| xnum_test, xcat_test = to_tensors(Xo_test, cat_cols, num_cols) | |
| y_test_tensor = make_y_tensor(yo_test) | |
| # 6) Calculate cardinalities (account for unknown values) | |
| if len(cat_cols) > 0: | |
| cat_cards = [] | |
| for col in cat_cols: | |
| max_val = int(max(Xo_train[col].max(), Xo_val[col].max(), Xo_test[col].max())) | |
| cat_cards.append(max_val + 1) | |
| else: | |
| cat_cards = [] | |
| print(f"Categorical cardinalities: {cat_cards}") | |
| print(f"Categorical columns processed: {len(cat_cols)}") | |
| return { | |
| 'tensors': { | |
| 'xnum_train': xnum_train, 'xcat_train': xcat_train, | |
| 'xnum_val': xnum_val, 'xcat_val': xcat_val, | |
| 'xnum_test': xnum_test, 'xcat_test': xcat_test, | |
| 'y_train': y_train_tensor, | |
| 'y_val': y_val_tensor, | |
| 'y_test': y_test_tensor, | |
| }, | |
| 'cat_cardinalities': cat_cards | |
| } | |
| # ... [The rest of the TabM functions (prepare_tensor_safe, train_tabm_model, etc.) remain unchanged] ... | |
| def prepare_tensor_safe(tensor, device): | |
| """Safely prepare tensor for device, handling empty feature dimensions.""" | |
| if tensor is None: | |
| return None | |
| num_features = int(tensor.shape[1]) # Convert torch.Size to regular Python int | |
| return tensor.to(device) if num_features > 0 else None | |
| def train_tabm_model(xnum_tr, xcat_tr, y_tr, xnum_val, xcat_val, y_val, cat_cards, | |
| d_out=1, epochs=100, bs=512, lr=0.002, wd=0.0003, patience=10): | |
| """Train TabM model with proper k-ensemble loss calculation""" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| device_str = f"GPU ({torch.cuda.get_device_name()})" if device.type == "cuda" else "CPU" | |
| print(f"Training TabM on: {device_str}") | |
| # Handle empty feature tensors | |
| if xnum_tr.shape[1] == 0: | |
| xnum_tr = None | |
| xnum_val = None | |
| n_num_features = 0 | |
| else: | |
| xnum_tr = xnum_tr.to(device) | |
| xnum_val = xnum_val.to(device) | |
| n_num_features = xnum_tr.shape[1] | |
| if xcat_tr.shape[1] == 0: | |
| xcat_tr = None | |
| xcat_val = None | |
| cat_cardinalities = None | |
| else: | |
| xcat_tr = xcat_tr.to(device) | |
| xcat_val = xcat_val.to(device) | |
| cat_cardinalities = cat_cards if cat_cards else None | |
| # Move targets to device | |
| y_tr = y_tr.to(device) | |
| y_val = y_val.to(device) | |
| # Create TabM model - ADD DROPOUT! | |
| model = TabM.make( | |
| n_num_features=n_num_features, | |
| cat_cardinalities=cat_cardinalities, | |
| d_out=d_out, | |
| k=32, # Number of ensemble members | |
| n_blocks=4, # Number of MLP blocks | |
| d_block=512, # Width of each MLP block | |
| dropout=0.1, # CRITICAL: Add dropout for regularization | |
| ).to(device) | |
| optimizer = torch.optim.AdamW( | |
| model.parameters(), | |
| lr=lr, | |
| weight_decay=wd, | |
| ) | |
| # Base loss function for binary classification | |
| base_loss_fn = nn.BCEWithLogitsLoss() | |
| # TabM-specific loss function - CRITICAL FIX | |
| def tabm_loss_fn(y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| """ | |
| TabM produces k predictions. Each must be trained separately. | |
| Args: | |
| y_pred: Shape (batch_size, k) | |
| y_true: Shape (batch_size,) | |
| """ | |
| # Flatten predictions: (batch_size, k) -> (batch_size * k,) | |
| y_pred_flat = y_pred.flatten(0, 1) | |
| # Repeat targets k times: (batch_size,) -> (batch_size * k,) | |
| y_true_repeated = y_true.repeat_interleave(model.backbone.k) | |
| return base_loss_fn(y_pred_flat, y_true_repeated) | |
| best_auc, wait = 0.0, 0 | |
| best_state = None | |
| n = int(y_tr.size(0)) | |
| start_time = time.time() | |
| for epoch in range(epochs): | |
| model.train() | |
| epoch_loss = 0.0 | |
| n_batches = 0 | |
| idx = torch.randperm(n, device=device) | |
| for i in range(0, n, bs): | |
| batch_idx = idx[i:i + bs] | |
| optimizer.zero_grad() | |
| if xnum_tr is None and xcat_tr is not None: | |
| logits = model(x_cat=xcat_tr[batch_idx]) # Keep all k predictions | |
| elif xcat_tr is None and xnum_tr is not None: | |
| logits = model(x_num=xnum_tr[batch_idx]) # Keep all k predictions | |
| else: | |
| logits = model(xnum_tr[batch_idx], xcat_tr[batch_idx]) # Keep all k predictions | |
| # Remove extra dimension if present | |
| if len(logits.shape) == 3: | |
| logits = logits.squeeze(-1) # (batch_size, k, 1) -> (batch_size, k) | |
| # Use TabM-specific loss - CRITICAL FIX | |
| loss = tabm_loss_fn(logits, y_tr[batch_idx]) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) | |
| optimizer.step() | |
| epoch_loss += loss.item() | |
| n_batches += 1 | |
| avg_loss = epoch_loss / n_batches | |
| current_lr = optimizer.param_groups[0]['lr'] | |
| if epoch % 5 == 0: | |
| val_probs = tabm_predict(model, xnum_val, xcat_val, device) | |
| val_auc = roc_auc_score(y_val.cpu().numpy(), val_probs) | |
| print(f"TabM epoch {epoch+1:02d} | loss: {avg_loss:.4f} | val_auc: {val_auc:.4f} | lr: {current_lr:.6f}") | |
| # Proper early stopping | |
| if val_auc > best_auc: | |
| best_auc, best_state, wait = val_auc, model.state_dict(), 0 | |
| training_time = time.time() - start_time | |
| # Load best model | |
| if best_state is not None: | |
| model.load_state_dict(best_state) | |
| model.eval() | |
| return model, training_time, device_str, device | |
| def tabm_predict(model, x_num, x_cat, device): | |
| """Get predictions from TabM model.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| if x_num is None and x_cat is None: | |
| raise ValueError("Both x_num and x_cat cannot be None") | |
| elif x_num is None: | |
| logits = model(x_cat=x_cat).mean(1).squeeze(1) | |
| elif x_cat is None: | |
| logits = model(x_num=x_num).mean(1).squeeze(1) | |
| else: | |
| logits = model(x_num, x_cat).mean(1).squeeze(1) | |
| return torch.sigmoid(logits).cpu().numpy() | |
| def evaluate_tabm(model, tensors, splits, device): | |
| """Evaluate TabM model and return AUC scores.""" | |
| # Move tensors to device using safe preparation | |
| xnum_test = prepare_tensor_safe(tensors['xnum_test'], device) | |
| xcat_test = prepare_tensor_safe(tensors['xcat_test'], device) | |
| xnum_train = prepare_tensor_safe(tensors['xnum_train'], device) | |
| xcat_train = prepare_tensor_safe(tensors['xcat_train'], device) | |
| xnum_val = prepare_tensor_safe(tensors['xnum_val'], device) | |
| xcat_val = prepare_tensor_safe(tensors['xcat_val'], device) | |
| # Predictions | |
| pred_train = tabm_predict(model, xnum_train, xcat_train, device) | |
| pred_val = tabm_predict(model, xnum_val, xcat_val, device) | |
| pred_test = tabm_predict(model, xnum_test, xcat_test, device) | |
| # Correct ground truth | |
| y_train_np = tensors['y_train'].cpu().numpy() # torch.Tensor | |
| y_val_np = tensors['y_val'].cpu().numpy() # torch.Tensor | |
| y_test_np = ensure_series(splits['y_test']).to_numpy() # Series -> np | |
| # AUC | |
| auc_train = roc_auc_score(y_train_np, pred_train) | |
| auc_val = roc_auc_score(y_val_np, pred_val) | |
| auc_test = roc_auc_score(y_test_np, pred_test) | |
| return auc_train, auc_val, auc_test | |
| # --- MAIN BENCHMARK LOGIC --- | |
| def run_single_benchmark(dataset_name): | |
| """Run benchmark for a single dataset.""" | |
| # Prepare data | |
| data, X, target_col, cat_cols, num_cols = prepare_data(dataset_name) | |
| splits = create_splits(X, target_col) | |
| # Train and evaluate LightGBM | |
| lgb_results = train_evaluate_lgb(splits) | |
| # Train and evaluate RealMLP | |
| realmlp_results = train_evaluate_realmlp(splits, cat_cols) | |
| # Train and evaluate TabM | |
| tabm_data = prepare_tabm_data(data, splits, cat_cols, num_cols) | |
| t = tabm_data['tensors'] | |
| tabm_model, tabm_training_time, tabm_device, device = train_tabm_model( | |
| xnum_tr=t['xnum_train'], | |
| xcat_tr=t['xcat_train'], | |
| y_tr=t['y_train'], | |
| xnum_val=t['xnum_val'], | |
| xcat_val=t['xcat_val'], | |
| y_val=t['y_val'], | |
| cat_cards=tabm_data['cat_cardinalities'], epochs = EPOCHS | |
| ) | |
| tabm_auc_train, tabm_auc_val, tabm_auc_test = evaluate_tabm( | |
| tabm_model, tabm_data['tensors'], splits, device | |
| ) | |
| # Compile results | |
| results = { | |
| "dataset": dataset_name, | |
| "lgb_train": lgb_results['auc_train'], | |
| "lgb_val": lgb_results['auc_val'], | |
| "lgb_test": lgb_results['auc_test'], | |
| "lgb_time_sec": lgb_results['training_time'], | |
| "realmlp_train": realmlp_results['auc_train'], | |
| "realmlp_val": realmlp_results['auc_val'], | |
| "realmlp_test": realmlp_results['auc_test'], | |
| "realmlp_time_sec": realmlp_results['training_time'], | |
| "tabm_train": tabm_auc_train, | |
| "tabm_val": tabm_auc_val, | |
| "tabm_test": tabm_auc_test, | |
| "tabm_time_sec": tabm_training_time, | |
| "tabm_device": tabm_device, | |
| } | |
| # Print results | |
| print(f"\n--- Results for {dataset_name} ---") | |
| print(f"LGB - Train: {lgb_results['auc_train']:.4f}, Val: {lgb_results['auc_val']:.4f}, Test: {lgb_results['auc_test']:.4f}, Time: {lgb_results['training_time']:.2f}s") | |
| print(f"RealMLP - Train: {realmlp_results['auc_train']:.4f}, Val: {realmlp_results['auc_val']:.4f}, Test: {realmlp_results['auc_test']:.4f}, Time: {realmlp_results['training_time']:.2f}s") | |
| print(f"TabM - Train: {tabm_auc_train:.4f}, Val: {tabm_auc_val:.4f}, Test: {tabm_auc_test:.4f}, Time: {tabm_training_time:.2f}s, Device: {tabm_device}") | |
| # Clean up GPU memory | |
| del lgb_results, realmlp_results, tabm_model, tabm_data | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| return results | |
| def display_results(results): | |
| """Display final benchmark results.""" | |
| results_df = pd.DataFrame(results) | |
| # Define columns to display for clarity | |
| display_cols = [ | |
| "dataset", | |
| "lgb_test", "realmlp_test", "tabm_test", | |
| "lgb_time_sec", "realmlp_time_sec", "tabm_time_sec", | |
| "tabm_device" | |
| ] | |
| # Ensure all columns exist before selecting | |
| display_cols = [col for col in display_cols if col in results_df.columns] | |
| print("\n" + "="*120) | |
| print("FINAL RESULTS COMPARISON WITH TIMING AND DEVICE INFO") | |
| print("="*120) | |
| print(results_df[display_cols].to_string()) | |
| # Summary statistics | |
| print("\nSUMMARY (Test Set):") | |
| print(f"Average LGB Test AUC: {results_df['lgb_test'].mean():.4f} ± {results_df['lgb_test'].std():.4f}") | |
| print(f"Average RealMLP Test AUC: {results_df['realmlp_test'].mean():.4f} ± {results_df['realmlp_test'].std():.4f}") | |
| print(f"Average TabM Test AUC: {results_df['tabm_test'].mean():.4f} ± {results_df['tabm_test'].std():.4f}") | |
| print(f"\nAverage LGB Time: {results_df['lgb_time_sec'].mean():.2f}s") | |
| print(f"Average RealMLP Time: {results_df['realmlp_time_sec'].mean():.2f}s") | |
| print(f"Average TabM Time: {results_df['tabm_time_sec'].mean():.2f}s") | |
| def main(): | |
| """Main benchmark execution.""" | |
| # Print device info | |
| if torch.cuda.is_available(): | |
| print(f"CUDA available: {torch.cuda.get_device_name()}") | |
| print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") | |
| else: | |
| print("CUDA not available, using CPU") | |
| all_results = [] | |
| for dataset_name in tqdm(DATASETS, desc="Overall Progress"): | |
| try: | |
| result = run_single_benchmark(dataset_name) | |
| all_results.append(result) | |
| except Exception as e: | |
| print(f"\nERROR processing dataset {dataset_name}: {e}") | |
| continue | |
| if all_results: | |
| display_results(all_results) | |
| else: | |
| print("\nNo benchmarks were successfully completed.") | |
| return all_results | |
| if __name__ == "__main__": | |
| results = main() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment