Skip to content

Instantly share code, notes, and snippets.

@N8python
Created January 19, 2025 05:32
Show Gist options
  • Save N8python/c2bd73b29577cab4727fcd638845a0bd to your computer and use it in GitHub Desktop.
Save N8python/c2bd73b29577cab4727fcd638845a0bd to your computer and use it in GitHub Desktop.
OPTIMIZE with me
import torch
import torch.nn as nn
import torch.optim as optim
# Define the differentiable orthonormal linear layer
class OrthonormalLayer(nn.Module):
def __init__(self, n):
"""
Initializes a learnable layer with an orthonormal weight matrix.
:param n: Dimension of the square weight matrix.
"""
super(OrthonormalLayer, self).__init__()
self.n = n
# Create an unconstrained parameter B; using small random values.
# B is a full n x n matrix.
self.B = nn.Parameter(torch.randn(n, n) * 0.1)
def forward(self, X):
"""
Forward pass using orthogonal weight transformation.
:param X: Input tensor of shape (n, *) where the first dimension is n.
:return: Transformed tensor Y = W @ X, with W orthogonal.
"""
# Ensure skew-symmetry: A = (B - B^T) / 2
A = (self.B - self.B.t()) / 2.0
# Compute the matrix exponential: W = exp(A)
# torch.matrix_exp is differentiable
W = torch.matrix_exp(A)
# Apply the orthogonal transform; assuming X is of shape (n, *).
return W @ X
def compute_W(self):
"""
Computes the orthogonal matrix W from the parameter B.
:return: The orthogonal matrix W.
"""
A = (self.B - self.B.t()) / 2.0
return torch.matrix_exp(A)
# Demonstration code
def main():
# Set dimension
n = 32
# Create an instance of the layer
layer = OrthonormalLayer(n)
# Create an optimizer for demonstration purposes
optimizer = optim.Adam(layer.parameters(), lr=1e-3)
# Generate a random 32x32 input tensor
X = torch.randn(n, n)
# Forward pass: apply the orthonormal layer
Y = layer(X)
# Compute the Frobenius norm (energy) of the original and transformed inputs
energy_initial = torch.norm(X, p='fro').item()
energy_transformed = torch.norm(Y, p='fro').item()
print(f'Initial energy: {energy_initial:.4f}')
print(f'Energy after one application of orthonormal layer: {energy_transformed:.4f}')
# Verify that W is orthonormal by checking W^T W ~ I.
# Compute W explicitly.
with torch.no_grad():
A = (layer.B - layer.B.t()) / 2.0
W = torch.matrix_exp(A)
identity = torch.eye(n)
orthogonality_error = torch.norm(W.t() @ W - identity, p='fro').item()
print(f'Orthogonality error (should be close to 0): {orthogonality_error:.2e}')
# Demonstrate differentiability: create a dummy loss and backpropagate.
# For example, let our loss be the sum of the output Y.
loss = Y.sum()
loss.backward()
# Check that the gradient flows back to B.
print("Gradient norm for parameter B:",
layer.B.grad.norm().item())
# Optionally, run a few steps of training demonstrating that parameters update.
num_steps = 50000
# Make T an orthogonal rotation of X.
def random_orthogonal_qr(n):
# Start with random normal matrix
A = torch.randn(n, n)
# QR decomposition
Q, R = torch.linalg.qr(A)
# Make Q uniform by handling the signs of diagonal entries in R
# This is because Q is only unique up to sign of each column
d = torch.diag(R)
Q = Q * torch.sign(d).unsqueeze(0)
return Q
W_0 = random_orthogonal_qr(n)
T = W_0 @ X
for step in range(num_steps):
optimizer.zero_grad()
X2 = layer(X)
X3 = layer(X2)
Y = layer(X3)
# Define a dummy loss function: difference between Y and a tensor T.
# MSE loss between Y and T.
loss = nn.MSELoss()(Y, T)
loss.backward()
optimizer.step()
# Compute new orthogonality error and energy
with torch.no_grad():
A = (layer.B - layer.B.t()) / 2.0
W = torch.matrix_exp(A)
orthogonality_error = torch.norm(W.t() @ W - torch.eye(n), p='fro').item()
energy_transformed = torch.norm(Y, p='fro').item()
print(f"Step {step+1}: Loss {loss.item():.3e}, Orthogonality error: {orthogonality_error:.2e}, Energy: {energy_transformed:.4f}")
# Now, the W of the layer cubed should be close to W_0.
# after 5000 steps.
W_from_layer = layer.compute_W()
W_power = torch.matrix_power(W_from_layer, 3)
# Compute absolute mean error.
error = torch.abs(W_power - W_0).mean().item()
print(f"Error between W^3 and W_0: {error:.2e}")
print(W_power, W_0)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment