Created
June 3, 2024 17:14
-
-
Save cloneofsimo/0ac4b4aa5549b24799697310ff4fe1e4 to your computer and use it in GitHub Desktop.
Is your backprop secretly linear solver?
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| # motivated by https://x.com/yaroslavvb/status/1797662470859071892 | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| def cosine_similarity(v1, v2): | |
| return np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2)) | |
| def stable_rank(matrix): | |
| s = np.linalg.svd(matrix, compute_uv=False) | |
| return (np.sum(s) ** 2) / np.sum(s ** 2) | |
| def compute_values_varying_rank(n, num_samples, ranks, epsilon): | |
| mean_cos_values = [] | |
| std_cos_values = [] | |
| stable_ranks = [] | |
| for rank in ranks: | |
| cos_values = [] | |
| sr_values = [] | |
| for _ in range(num_samples): | |
| X = np.random.randn(n, rank) | |
| Y = np.random.randn(rank, n) | |
| A = np.dot(X, Y) + epsilon * np.eye(n) | |
| x = np.random.randn(n) | |
| y = np.dot(A, x) | |
| A_inv_y = np.linalg.solve(A, y) | |
| cos_sim = cosine_similarity(np.dot(A.T, y), A_inv_y) | |
| cos_values.append(cos_sim) | |
| sr_values.append(stable_rank(A)) | |
| mean_cos_values.append(np.mean(cos_values)) | |
| std_cos_values.append(np.std(cos_values)) | |
| stable_ranks.append(np.mean(sr_values)) | |
| return mean_cos_values, std_cos_values, stable_ranks | |
| ranks = np.arange(1, 100, 5) | |
| mean_cos_values, std_cos_values, stable_ranks = compute_values_varying_rank(100, 20, ranks, 1e-4) | |
| import plotly.graph_objects as go | |
| # Create the plot | |
| fig = go.Figure() | |
| # Add scatter plot with error bars | |
| fig.add_trace(go.Scatter( | |
| x=stable_ranks, | |
| y=mean_cos_values, | |
| error_y=dict(type='data', array=std_cos_values, visible=True), | |
| mode='markers', | |
| marker=dict(size=10, color='blue'), | |
| name='Mean Cosine Similarity' | |
| )) | |
| # Add axis titles and plot title | |
| fig.update_layout( | |
| title='Cosine Similarity vs. Stable Rank for Low-Rank Matrices', | |
| xaxis_title='Stable Rank', | |
| yaxis_title='Mean Cosine Similarity', | |
| template='plotly_white' | |
| ) | |
| # save plot | |
| fig.write_html('cosine_similarity_vs_stable_rank.html') |
cloneofsimo
commented
Jun 3, 2024
Author

Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment