Last active
June 9, 2019 18:43
-
-
Save wdevazelhes/f7f619bebc3afc3f3f99331beb773b0f to your computer and use it in GitHub Desktop.
trying to use kernel approximations by explicit feature maps for softmax self-attention
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
from sklearn.utils.extmath import softmax | |
from sklearn.kernel_approximation import RBFSampler | |
from sklearn_extra.kernel_approximation import Fastfood | |
seed = 42 | |
rng = np.random.RandomState(seed) | |
D = 20 | |
# It seems that it does not work in every case: it seems to work better if the | |
# sample are positive and we divide by the sum. See this example from https:// | |
# github.com/scikit-learn/scikit-learn/commit/3f7cec39997e28b4056bdea4fd04572 | |
# cfaad0080#diff-364d5b0b1ecfc510277a8f8072d884e7 | |
X = rng.random_sample(size=(100, D)) | |
X /= X.sum(axis=1)[:, np.newaxis] | |
# Let's take the mean as the key for instance | |
key = np.mean(X, axis=0, keepdims=True) | |
class TweakedRBFSampler(RBFSampler): | |
def transform(self, X): | |
tweak = np.exp((np.linalg.norm(X, axis=1, keepdims=True)**2) / 2) | |
return super(TweakedRBFSampler, self).transform(X) * tweak | |
class TweakedFastfoodSampler(Fastfood): | |
def transform(self, X): | |
tweak = np.exp( | |
(np.linalg.norm(X, axis=1, keepdims=True)**2) / 2) | |
return super(TweakedFastfoodSampler, self).transform(X) * tweak | |
def attention_projection(X, key): | |
return softmax(X.dot(key.T).T)[0].dot(X) | |
def attention_projection_approx(sampler, X, key): | |
X_f = sampler.fit_transform(X) | |
key_f = sampler.transform(key) | |
A = X.T.dot(X_f) | |
Z = X_f.sum(axis=0) | |
return A.dot(key_f.T) / Z.dot(key_f.T) | |
# The two kernels have not the same notation for the coefficient, for | |
# RBFSampler it's gamma, and for FastFood it's sigma | |
sampler_1 = TweakedRBFSampler(n_components=10000, gamma=0.5) | |
sampler_2 = TweakedFastfoodSampler(n_components=10000, sigma=1.) | |
print('True value of self-attention:') | |
print(attention_projection(X, key)) | |
print("Approximation of self-attention by a modified version of scikit-learn's" | |
"RBFSampler:") | |
print(attention_projection_approx(sampler_1, X, key)) | |
print('Approximation of self-attention by a modified FastFood Sampler:') | |
print(attention_projection_approx(sampler_2, X, key)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment