Created
January 19, 2025 05:32
-
-
Save N8python/c2bd73b29577cab4727fcd638845a0bd to your computer and use it in GitHub Desktop.
OPTIMIZE with me
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
# Define the differentiable orthonormal linear layer | |
class OrthonormalLayer(nn.Module): | |
def __init__(self, n): | |
""" | |
Initializes a learnable layer with an orthonormal weight matrix. | |
:param n: Dimension of the square weight matrix. | |
""" | |
super(OrthonormalLayer, self).__init__() | |
self.n = n | |
# Create an unconstrained parameter B; using small random values. | |
# B is a full n x n matrix. | |
self.B = nn.Parameter(torch.randn(n, n) * 0.1) | |
def forward(self, X): | |
""" | |
Forward pass using orthogonal weight transformation. | |
:param X: Input tensor of shape (n, *) where the first dimension is n. | |
:return: Transformed tensor Y = W @ X, with W orthogonal. | |
""" | |
# Ensure skew-symmetry: A = (B - B^T) / 2 | |
A = (self.B - self.B.t()) / 2.0 | |
# Compute the matrix exponential: W = exp(A) | |
# torch.matrix_exp is differentiable | |
W = torch.matrix_exp(A) | |
# Apply the orthogonal transform; assuming X is of shape (n, *). | |
return W @ X | |
def compute_W(self): | |
""" | |
Computes the orthogonal matrix W from the parameter B. | |
:return: The orthogonal matrix W. | |
""" | |
A = (self.B - self.B.t()) / 2.0 | |
return torch.matrix_exp(A) | |
# Demonstration code | |
def main(): | |
# Set dimension | |
n = 32 | |
# Create an instance of the layer | |
layer = OrthonormalLayer(n) | |
# Create an optimizer for demonstration purposes | |
optimizer = optim.Adam(layer.parameters(), lr=1e-3) | |
# Generate a random 32x32 input tensor | |
X = torch.randn(n, n) | |
# Forward pass: apply the orthonormal layer | |
Y = layer(X) | |
# Compute the Frobenius norm (energy) of the original and transformed inputs | |
energy_initial = torch.norm(X, p='fro').item() | |
energy_transformed = torch.norm(Y, p='fro').item() | |
print(f'Initial energy: {energy_initial:.4f}') | |
print(f'Energy after one application of orthonormal layer: {energy_transformed:.4f}') | |
# Verify that W is orthonormal by checking W^T W ~ I. | |
# Compute W explicitly. | |
with torch.no_grad(): | |
A = (layer.B - layer.B.t()) / 2.0 | |
W = torch.matrix_exp(A) | |
identity = torch.eye(n) | |
orthogonality_error = torch.norm(W.t() @ W - identity, p='fro').item() | |
print(f'Orthogonality error (should be close to 0): {orthogonality_error:.2e}') | |
# Demonstrate differentiability: create a dummy loss and backpropagate. | |
# For example, let our loss be the sum of the output Y. | |
loss = Y.sum() | |
loss.backward() | |
# Check that the gradient flows back to B. | |
print("Gradient norm for parameter B:", | |
layer.B.grad.norm().item()) | |
# Optionally, run a few steps of training demonstrating that parameters update. | |
num_steps = 50000 | |
# Make T an orthogonal rotation of X. | |
def random_orthogonal_qr(n): | |
# Start with random normal matrix | |
A = torch.randn(n, n) | |
# QR decomposition | |
Q, R = torch.linalg.qr(A) | |
# Make Q uniform by handling the signs of diagonal entries in R | |
# This is because Q is only unique up to sign of each column | |
d = torch.diag(R) | |
Q = Q * torch.sign(d).unsqueeze(0) | |
return Q | |
W_0 = random_orthogonal_qr(n) | |
T = W_0 @ X | |
for step in range(num_steps): | |
optimizer.zero_grad() | |
X2 = layer(X) | |
X3 = layer(X2) | |
Y = layer(X3) | |
# Define a dummy loss function: difference between Y and a tensor T. | |
# MSE loss between Y and T. | |
loss = nn.MSELoss()(Y, T) | |
loss.backward() | |
optimizer.step() | |
# Compute new orthogonality error and energy | |
with torch.no_grad(): | |
A = (layer.B - layer.B.t()) / 2.0 | |
W = torch.matrix_exp(A) | |
orthogonality_error = torch.norm(W.t() @ W - torch.eye(n), p='fro').item() | |
energy_transformed = torch.norm(Y, p='fro').item() | |
print(f"Step {step+1}: Loss {loss.item():.3e}, Orthogonality error: {orthogonality_error:.2e}, Energy: {energy_transformed:.4f}") | |
# Now, the W of the layer cubed should be close to W_0. | |
# after 5000 steps. | |
W_from_layer = layer.compute_W() | |
W_power = torch.matrix_power(W_from_layer, 3) | |
# Compute absolute mean error. | |
error = torch.abs(W_power - W_0).mean().item() | |
print(f"Error between W^3 and W_0: {error:.2e}") | |
print(W_power, W_0) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment