Created
June 9, 2022 14:57
-
-
Save Micky774/12f83783e578fb4fb4e30209f72bd251 to your computer and use it in GitHub Desktop.
Benchmark script for logistic regression
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # %% | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.compose import ColumnTransformer | |
| from sklearn.datasets import fetch_openml | |
| from sklearn.impute import SimpleImputer | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.pipeline import make_pipeline | |
| from sklearn.preprocessing import OneHotEncoder | |
| from sklearn.preprocessing import StandardScaler | |
| df = fetch_openml(data_id=41162, as_frame=True, parser="auto").frame | |
| linear_model_preprocessor = ColumnTransformer( | |
| [ | |
| ( | |
| "passthrough_numeric", | |
| make_pipeline(SimpleImputer(), StandardScaler()), | |
| [ | |
| "MMRAcquisitionAuctionAveragePrice", | |
| "MMRAcquisitionAuctionCleanPrice", | |
| "MMRCurrentAuctionAveragePrice", | |
| "MMRCurrentAuctionCleanPrice", | |
| "MMRCurrentRetailAveragePrice", | |
| "MMRCurrentRetailCleanPrice", | |
| "MMRCurrentRetailAveragePrice", | |
| "MMRCurrentRetailCleanPrice", | |
| "VehBCost", | |
| "VehYear", | |
| "VehicleAge", | |
| "WarrantyCost", | |
| ], | |
| ), | |
| ( | |
| "onehot_categorical", | |
| OneHotEncoder(min_frequency=10), | |
| [ | |
| "Auction", | |
| "Color", | |
| "IsOnlineSale", | |
| "Make", | |
| "Model", | |
| "Nationality", | |
| "Size", | |
| "SubModel", | |
| "Transmission", | |
| "Trim", | |
| "WheelType", | |
| ], | |
| ), | |
| ], | |
| remainder="drop", | |
| ) | |
| y = np.asarray(df["IsBadBuy"] == "1", dtype=float) | |
| X = linear_model_preprocessor.fit_transform(df) | |
| # %% | |
| from functools import partial | |
| from time import perf_counter | |
| from statistics import mean, stdev | |
| from itertools import product | |
| import csv | |
| from sklearn.linear_model import LogisticRegression | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.metrics import log_loss | |
| import warnings | |
| results_path = 'local_artifacts/benchmarks/trust-ncg/' | |
| branch="kicks" | |
| alpha = 1e-4 | |
| benchmark_config = [ | |
| ( | |
| LogisticRegression(C=1 / alpha / X.shape[0], tol=1e-4, max_iter=1000), | |
| lambda *args, **kwargs: (X, y), | |
| product( | |
| ["ovr", "multinomial"], | |
| ["l2", "none"], | |
| ["lbfgs", "trust-ncg"], | |
| ), | |
| ), | |
| ] | |
| N_REPEATS = 7 | |
| with open(f'{results_path}{branch}.csv', 'w', newline='') as csvfile: | |
| writer = csv.DictWriter( | |
| csvfile, | |
| fieldnames=[ | |
| "multi_class", | |
| "penalty", | |
| "solver", | |
| "duration", | |
| "NLL", | |
| ], | |
| ) | |
| writer.writeheader() | |
| for est, make_data, items in benchmark_config: | |
| for multi_class, penalty, solver in items: | |
| time_results = [] | |
| est.set_params(multi_class=multi_class, solver=solver, penalty=penalty) | |
| for n_repeat in range(N_REPEATS): | |
| X, y = make_data(random_state=n_repeat) | |
| X_c = StandardScaler(with_mean=False).fit_transform(X) | |
| start = perf_counter() | |
| est.fit(X_c, y) | |
| duration = perf_counter() - start | |
| time_results.append(duration) | |
| row = { | |
| "multi_class": multi_class, | |
| "penalty": penalty, | |
| "solver": solver, | |
| "duration": duration, | |
| "NLL": log_loss(y, est.predict_proba(X)), | |
| } | |
| writer.writerow(row) | |
| print(f"{row}") | |
| results_mean, results_stdev = mean(time_results), stdev(time_results) | |
| print( | |
| f" {multi_class=} {penalty=} {solver=} |" | |
| f" {results_mean:.3f} +/- {results_stdev:.3f}" | |
| ) | |
| # %% | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import seaborn as sns | |
| plt.rc('font', size=12) | |
| results_path = 'local_artifacts/benchmarks/trust-ncg/' | |
| _branches = ("kicks",) | |
| branches = {br:pd.read_csv(f'{results_path}{br}.csv') for br in _branches} | |
| df = pd.concat([branches[br].assign(branch=br) for br in _branches]) | |
| group_by_attrs = [ | |
| "penalty", | |
| "multi_class", | |
| ] | |
| grouped = list(df.groupby(group_by_attrs)) | |
| fig, axis = plt.subplots(2, 2, figsize=(9, 8), constrained_layout=True) | |
| fig.patch.set_facecolor('white') | |
| for (grouped_attrs, subset), ax in zip(grouped, axis.reshape(-1)): | |
| sns.violinplot(data=subset, y="duration", x="solver", ax=ax) | |
| ax.set_title("|".join( [f"{k}={v}" for k, v in zip(group_by_attrs,grouped_attrs)] )) | |
| ax.set_xlabel("") | |
| if axis.ndim > 1: | |
| for ax in axis[:, 1:].ravel(): | |
| ax.set_ylabel("") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment