Skip to content

Instantly share code, notes, and snippets.

@patcon
Created March 4, 2025 18:58
Show Gist options
  • Save patcon/d3f096483156dd63f7836291a8ba8097 to your computer and use it in GitHub Desktop.
Save patcon/d3f096483156dd63f7836291a8ba8097 to your computer and use it in GitHub Desktop.
from tabpfn_extensions import interpretability
from tabpfn import TabPFNClassifier
import numpy as np
# Initialize and load data
client = PolisClient()
client.load_data(report_id="r8xhmkwp6shm9yfermteh")
# Generate and process the vote matrix
vote_matrix = client.get_matrix(is_filtered=True)
client.run_pca()
client.scale_projected_data()
# Get feature names (columns from vote matrix)
feature_names = [f"Vote_{i}" for i in range(vote_matrix.shape[1])]
n_samples = 50
# Split data - using vote_matrix and creating a simple binary target
# based on first PCA component as an example
X = vote_matrix.values
y = (client.projected_data[:, 0] > client.projected_data[:, 0].mean()).astype(int)
# Manual train-test split (50-50)
n_train = len(X) // 2
X_train, X_test = X[:n_train], X[n_train:]
y_train, y_test = y[:n_train], y[n_train:]
# Initialize and train model
clf = TabPFNClassifier()
clf.fit(X_train, y_train)
# Calculate SHAP values
shap_values = interpretability.shap.get_shap_values(
estimator=clf,
test_x=X_test[:n_samples],
attribute_names=feature_names,
algorithm="permutation",
)
# Create visualization
fig = interpretability.shap.plot_shap(shap_values)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment