-
-
Save jasmainak/b1c666773a7bd4a5f5466e3fe2852c7d to your computer and use it in GitHub Desktop.
Testing the Haufe trick
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test the Haufe trick" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import print_function\n", | |
"from sklearn.linear_model import LinearRegression\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This notebook tests the computation of patterns from weights, following equation (6) of Haufe et al. 2014 (Neuroimage).\n", | |
"\n", | |
"$ A = \\Sigma_X\\, W\\, \\Sigma_{\\,\\hat{Y}}^{-1} $\n", | |
"\n", | |
"(The original equation uses $\\hat{s}$ instead of $\\hat{Y}$, but we'll be following scikit-learn's notation here.)\n", | |
"\n", | |
"\n", | |
"We will test by generating test data with a known \"pattern\" (i.e. forward model), that satisfies all the necessary assumptions for the equation to work." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Fix random seed for consistency\n", | |
"np.random.seed(42)\n", | |
"\n", | |
"\n", | |
"def _gen_data(noise_scale=2):\n", | |
" \"\"\"Generate some testing data.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" noise_scale : float\n", | |
" The amount of noise (in standard deviations) to add to the data.\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" X : ndarray, shape (n_samples, n_features)\n", | |
" The measured data.\n", | |
" Y : ndarray, shape (n_samples, n_targets)\n", | |
" The latent variables generating the data.\n", | |
" A : ndarray, shape (n_features, n_targets)\n", | |
" The forward model, mapping the latent variables (=Y) to the measured\n", | |
" data (=X).\n", | |
" \"\"\"\n", | |
" N = 1000 # Number of samples\n", | |
" M = 5 # Number of features\n", | |
"\n", | |
" # Y has 3 targets and the following covariance:\n", | |
" cov_Y = np.array([\n", | |
" [10, 1, 2],\n", | |
" [1, 5, 1],\n", | |
" [2, 1, 3],\n", | |
" ])\n", | |
" mean_Y = np.array([1, -3, 7])\n", | |
" Y = np.random.multivariate_normal(mean_Y, cov_Y, size=N)\n", | |
" Y += [1, 4, 2] # Put an offset\n", | |
"\n", | |
" # The pattern (=forward model)\n", | |
" A = np.array([\n", | |
" [1, 10, -3],\n", | |
" [4, 1, 8],\n", | |
" [3, -2, 4],\n", | |
" [1, 1, 1],\n", | |
" [7, 6, 0],\n", | |
" ]).astype(float)\n", | |
"\n", | |
" X = Y.dot(A.T)\n", | |
" X += noise_scale * np.random.randn(N, M)\n", | |
" X += [5, 2, 6, 3, 9] # Put an offset\n", | |
"\n", | |
" return X, Y, A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Generate data without any noise, so we can perfectly reconstruct the pattern A from the measured data\n", | |
"X, Y, A = _gen_data(noise_scale=0)\n", | |
"\n", | |
"# This data is not normalized (i.e. not zero-mean and not unit variance)\n", | |
"assert (np.abs(X.mean(axis=0)) > 0.1).all()\n", | |
"assert (X.std(axis=0) > 1.1).all()\n", | |
"assert (np.abs(Y.mean(axis=0)) > 0.1).all()\n", | |
"assert (Y.std(axis=0) > 1.1).all()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We now proceed by fitting a standard linear regression model to the data and applying the equation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = LinearRegression(normalize=False, fit_intercept=True).fit(X, Y)\n", | |
"W = model.coef_\n", | |
"\n", | |
"def haufe_trick(W, X, Y):\n", | |
" \"\"\"Perform the Haufe trick.\"\"\"\n", | |
" # Computing the covariance of X and Y involves removing the mean\n", | |
" X_ = X - X.mean(axis=0)\n", | |
" Y_ = Y - Y.mean(axis=0)\n", | |
" cov_X = X_.T.dot(X_)\n", | |
" cov_Y = Y_.T.dot(Y_)\n", | |
"\n", | |
" # The Haufe trick\n", | |
" A_hat = cov_X.dot(W.T).dot(np.linalg.pinv(cov_Y))\n", | |
" return A_hat\n", | |
"\n", | |
"A_hat = haufe_trick(W, X, Y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"How did we do?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real pattern\n", | |
"[[ 1. 10. -3.]\n", | |
" [ 4. 1. 8.]\n", | |
" [ 3. -2. 4.]\n", | |
" [ 1. 1. 1.]\n", | |
" [ 7. 6. 0.]]\n", | |
"Estimated pattern\n", | |
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n", | |
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n", | |
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n", | |
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n", | |
" [ 7.00000000e+00 6.00000000e+00 -8.88178420e-15]]\n", | |
"Are they equal? True\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Real pattern')\n", | |
"print(A)\n", | |
"print('Estimated pattern')\n", | |
"print(A_hat)\n", | |
"print('Are they equal?', np.allclose(A, A_hat))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## The thing about normalization\n", | |
"\n", | |
"In the above case, the linear regression model did not normalize the data. What happens if we do?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real pattern\n", | |
"[[ 1. 10. -3.]\n", | |
" [ 4. 1. 8.]\n", | |
" [ 3. -2. 4.]\n", | |
" [ 1. 1. 1.]\n", | |
" [ 7. 6. 0.]]\n", | |
"Estimated pattern\n", | |
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n", | |
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n", | |
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n", | |
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n", | |
" [ 7.00000000e+00 6.00000000e+00 1.24344979e-14]]\n", | |
"Are they equal? True\n" | |
] | |
} | |
], | |
"source": [ | |
"model = LinearRegression(normalize=True, fit_intercept=True).fit(X, Y)\n", | |
"W = model.coef_\n", | |
"A_hat = haufe_trick(W, X, Y)\n", | |
"\n", | |
"print('Real pattern')\n", | |
"print(A)\n", | |
"print('Estimated pattern')\n", | |
"print(A_hat)\n", | |
"print('Are they equal?', np.allclose(A, A_hat))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It still works! That is because scikit-learn reverses the normalization when storing the final filter weights." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment