Last active
July 14, 2017 09:25
-
-
Save wmvanvliet/08f9b4b8c5e4f1bb0a25d5eceae1b419 to your computer and use it in GitHub Desktop.
Testing the Haufe trick
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# Test the Haufe trick" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"from __future__ import print_function\n", | |
"from sklearn.linear_model import LinearRegression\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"This notebook tests the computation of patterns from weights, following equation (6) of Haufe et al. 2014 (Neuroimage).\n", | |
"\n", | |
"$ A = \\Sigma_X\\, W\\, \\Sigma_{\\,\\hat{Y}}^{-1} $\n", | |
"\n", | |
"(The original equation uses $\\hat{s}$ instead of $\\hat{Y}$, but we'll be following scikit-learn's notation here.)\n", | |
"\n", | |
"\n", | |
"We will test by generating test data with a known \"pattern\" (i.e. forward model), that satisfies all the necessary assumptions for the equation to work." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"# Fix random seed for consistency\n", | |
"np.random.seed(42)\n", | |
"\n", | |
"\n", | |
"def _gen_data(noise_scale=2):\n", | |
" \"\"\"Generate some testing data.\n", | |
"\n", | |
" Parameters\n", | |
" ----------\n", | |
" noise_scale : float\n", | |
" The amount of noise (in standard deviations) to add to the data.\n", | |
"\n", | |
" Returns\n", | |
" -------\n", | |
" X : ndarray, shape (n_samples, n_features)\n", | |
" The measured data.\n", | |
" Y : ndarray, shape (n_samples, n_targets)\n", | |
" The latent variables generating the data.\n", | |
" A : ndarray, shape (n_features, n_targets)\n", | |
" The forward model, mapping the latent variables (=Y) to the measured\n", | |
" data (=X).\n", | |
" \"\"\"\n", | |
" N = 1000 # Number of samples\n", | |
" M = 5 # Number of features\n", | |
"\n", | |
" # Y has 3 targets and the following covariance:\n", | |
" cov_Y = np.array([\n", | |
" [10, 1, 2],\n", | |
" [1, 5, 1],\n", | |
" [2, 1, 3],\n", | |
" ])\n", | |
" mean_Y = np.array([1, -3, 7])\n", | |
" Y = np.random.multivariate_normal(mean_Y, cov_Y, size=N)\n", | |
" Y += [1, 4, 2] # Put an offset\n", | |
"\n", | |
" # The pattern (=forward model)\n", | |
" A = np.array([\n", | |
" [1, 10, -3],\n", | |
" [4, 1, 8],\n", | |
" [3, -2, 4],\n", | |
" [1, 1, 1],\n", | |
" [7, 6, 0],\n", | |
" ]).astype(float)\n", | |
"\n", | |
" X = Y.dot(A.T)\n", | |
" X += noise_scale * np.random.randn(N, M)\n", | |
" X += [5, 2, 6, 3, 9] # Put an offset\n", | |
"\n", | |
" return X, Y, A" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"# Generate data without any noise, so we can perfectly reconstruct the pattern A from the measured data\n", | |
"X, Y, A = _gen_data(noise_scale=0)\n", | |
"\n", | |
"# This data is not normalized (i.e. not zero-mean and not unit variance)\n", | |
"assert (np.abs(X.mean(axis=0)) > 0.1).all()\n", | |
"assert (X.std(axis=0) > 1.1).all()\n", | |
"assert (np.abs(Y.mean(axis=0)) > 0.1).all()\n", | |
"assert (Y.std(axis=0) > 1.1).all()" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"We now proceed by fitting a standard linear regression model to the data and applying the equation." | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": true | |
}, | |
"outputs": [], | |
"source": [ | |
"model = LinearRegression(normalize=False, fit_intercept=True).fit(X, Y)\n", | |
"W = model.coef_\n", | |
"\n", | |
"def haufe_trick(W, X, Y):\n", | |
" \"\"\"Perform the Haufe trick.\"\"\"\n", | |
" # Computing the covariance of X and Y involves removing the mean\n", | |
" X_ = X - X.mean(axis=0)\n", | |
" Y_ = Y - Y.mean(axis=0)\n", | |
" cov_X = X_.T.dot(X_)\n", | |
" cov_Y = Y_.T.dot(Y_)\n", | |
"\n", | |
" # The Haufe trick\n", | |
" A_hat = cov_X.dot(W.T).dot(np.linalg.pinv(cov_Y))\n", | |
" return A_hat\n", | |
"\n", | |
"A_hat = haufe_trick(W, X, Y)" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"How did we do?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real pattern\n", | |
"[[ 1. 10. -3.]\n", | |
" [ 4. 1. 8.]\n", | |
" [ 3. -2. 4.]\n", | |
" [ 1. 1. 1.]\n", | |
" [ 7. 6. 0.]]\n", | |
"Estimated pattern\n", | |
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n", | |
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n", | |
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n", | |
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n", | |
" [ 7.00000000e+00 6.00000000e+00 -8.88178420e-15]]\n", | |
"Are they equal? True\n" | |
] | |
} | |
], | |
"source": [ | |
"print('Real pattern')\n", | |
"print(A)\n", | |
"print('Estimated pattern')\n", | |
"print(A_hat)\n", | |
"print('Are they equal?', np.allclose(A, A_hat))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"## The thing about normalization\n", | |
"\n", | |
"In the above case, the linear regression model did not normalize the data. What happens if we do?" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Real pattern\n", | |
"[[ 1. 10. -3.]\n", | |
" [ 4. 1. 8.]\n", | |
" [ 3. -2. 4.]\n", | |
" [ 1. 1. 1.]\n", | |
" [ 7. 6. 0.]]\n", | |
"Estimated pattern\n", | |
"[[ 1.00000000e+00 1.00000000e+01 -3.00000000e+00]\n", | |
" [ 4.00000000e+00 1.00000000e+00 8.00000000e+00]\n", | |
" [ 3.00000000e+00 -2.00000000e+00 4.00000000e+00]\n", | |
" [ 1.00000000e+00 1.00000000e+00 1.00000000e+00]\n", | |
" [ 7.00000000e+00 6.00000000e+00 1.24344979e-14]]\n", | |
"Are they equal? True\n" | |
] | |
} | |
], | |
"source": [ | |
"model = LinearRegression(normalize=True, fit_intercept=True).fit(X, Y)\n", | |
"W = model.coef_\n", | |
"A_hat = haufe_trick(W, X, Y)\n", | |
"\n", | |
"print('Real pattern')\n", | |
"print(A)\n", | |
"print('Estimated pattern')\n", | |
"print(A_hat)\n", | |
"print('Are they equal?', np.allclose(A, A_hat))" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"It still works! That is because scikit-learn reverses the normalization when storing the final filter weights." | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 2", | |
"language": "python", | |
"name": "python2" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 2 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython2", | |
"version": "2.7.13" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment