Created
March 4, 2025 18:58
-
-
Save patcon/d3f096483156dd63f7836291a8ba8097 to your computer and use it in GitHub Desktop.
Using TabPFN for interpretability of Polis data. c/o @ThenWho in https://discord.com/channels/815450304421691412/1336804926629875713/1344408335532687380
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from tabpfn_extensions import interpretability | |
from tabpfn import TabPFNClassifier | |
import numpy as np | |
# Initialize and load data | |
client = PolisClient() | |
client.load_data(report_id="r8xhmkwp6shm9yfermteh") | |
# Generate and process the vote matrix | |
vote_matrix = client.get_matrix(is_filtered=True) | |
client.run_pca() | |
client.scale_projected_data() | |
# Get feature names (columns from vote matrix) | |
feature_names = [f"Vote_{i}" for i in range(vote_matrix.shape[1])] | |
n_samples = 50 | |
# Split data - using vote_matrix and creating a simple binary target | |
# based on first PCA component as an example | |
X = vote_matrix.values | |
y = (client.projected_data[:, 0] > client.projected_data[:, 0].mean()).astype(int) | |
# Manual train-test split (50-50) | |
n_train = len(X) // 2 | |
X_train, X_test = X[:n_train], X[n_train:] | |
y_train, y_test = y[:n_train], y[n_train:] | |
# Initialize and train model | |
clf = TabPFNClassifier() | |
clf.fit(X_train, y_train) | |
# Calculate SHAP values | |
shap_values = interpretability.shap.get_shap_values( | |
estimator=clf, | |
test_x=X_test[:n_samples], | |
attribute_names=feature_names, | |
algorithm="permutation", | |
) | |
# Create visualization | |
fig = interpretability.shap.plot_shap(shap_values) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment