Skip to content

Instantly share code, notes, and snippets.

@cat-state
Created October 16, 2024 23:24
Show Gist options
  • Save cat-state/709eee029dd70c635e8e0a8deddbbdbd to your computer and use it in GitHub Desktop.
Save cat-state/709eee029dd70c635e8e0a8deddbbdbd to your computer and use it in GitHub Desktop.
Morton Code Attention
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
def quantize_coords(coords: torch.Tensor, bits: int = 21):
max_int = (1 << bits) - 1
coords = coords.clamp(0, 1)
return (coords * max_int).long()
def split_by_3_bits_21(x: torch.Tensor):
x = (x | (x << 32)) & 0x1f00000000ffff
x = (x | (x << 16)) & 0x1f0000ff0000ff
x = (x | (x << 8)) & 0x100f00f00f00f00f
x = (x | (x << 4)) & 0x10c30c30c30c30c3
x = (x | (x << 2)) & 0x1249249249249249
return x
# @torch.compile
def morton_encode(coords: torch.Tensor, bits: int = 21):
coords = quantize_coords(coords, bits)
x = split_by_3_bits_21(coords[..., 0])
y = split_by_3_bits_21(coords[..., 1]) << 1
z = split_by_3_bits_21(coords[..., 2]) << 2
morton_code = x | y | z
return morton_code
def create_cube_points(n=2):
# Create a grid of points in a cube shape
x = np.linspace(0, 1, n)
y = np.linspace(0, 1, n)
z = np.linspace(0, 1, n)
points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3)
return torch.tensor(points, dtype=torch.float32).cuda()
# Create cube points
n = 1024
cube_points = torch.rand(n**2, 3).cuda()
# Scale points to match the range used in morton_encode
scaled_points = (cube_points) #/ cube_points.norm(dim=1, keepdim=True)
scaled_points = (scaled_points - scaled_points.min(dim=0).values) / (scaled_points.max(dim=0).values - scaled_points.min(dim=0).values)
# Encode points using Morton encoding
morton_codes = morton_encode(scaled_points)
# Sort points based on Morton codes
sorted_indices = torch.argsort(morton_codes)
sorted_points = cube_points[sorted_indices]
import time
# start = time.time()
# for _ in range(1):
# morton_codes = morton_encode(scaled_points)
# sorted_indices = torch.argsort(morton_codes)
# sorted_points = cube_points[sorted_indices]
# end = time.time()
# print(f"Time taken: {(end - start) / 100} seconds")
# Convert to numpy for plotting
sorted_points_np = sorted_points.cpu().numpy()
# Plot the results
def plot_3d_arrows(points, order):
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Plot points
ax.scatter(points[:, 0], points[:, 1], points[:, 2])
# Plot arrows connecting points in order
for i in range(len(order) - 1):
start = points[order[i]]
end = points[order[i+1]]
ax.quiver(start[0], start[1], start[2],
end[0] - start[0], end[1] - start[1], end[2] - start[2],
color='r', arrow_length_ratio=0.1)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
plt.title('Morton Order Visualization')
plt.savefig("morton_order.png")
plt.show()
# plot_3d_arrows(sorted_points_np, range(len(sorted_points_np)))
print("Morton codes:", morton_codes)
print("Sorted indices:", sorted_indices)
import numpy as np
import matplotlib.pyplot as plt
import torchdr
def compute_adjacency_matrix(points, threshold=1.1, k=4):
"""Compute the adjacency matrix for the given points."""
k_min_vals, k_closest = torchdr.pairwise_distances(points, points, metric="euclidean", keops=True).Kmin_argKmin(K=k+1, dim=1)
print(k_closest.device)
# Remove self-connections
mask = k_closest != torch.arange(points.shape[0], device=points.device).unsqueeze(1)
k_closest = k_closest[mask].view(points.shape[0], k)
k_min_vals = k_min_vals[mask].view(points.shape[0], k)
# Filter out connections that are too far
valid_connections = k_min_vals <= threshold
# Create flattened indices and values for valid connections
row_indices = torch.arange(points.shape[0], device=points.device).unsqueeze(1).expand(-1, k)[valid_connections]
col_indices = k_closest[valid_connections]
indices = torch.stack([row_indices, col_indices])
values = torch.ones(indices.shape[1], dtype=torch.bool, device=points.device)
adj_matrix = torch.sparse_coo_tensor(indices, values, (points.shape[0], points.shape[0]), device=points.device)
return adj_matrix
def plot_adjacency_matrices(original_adj_matrix, sorted_adj_matrix):
"""Plot the original and sorted adjacency matrices side by side."""
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
im1 = ax1.imshow(original_adj_matrix, cmap='binary')
ax1.set_title("Original Adjacency Matrix")
fig.colorbar(im1, ax=ax1)
im2 = ax2.imshow(sorted_adj_matrix, cmap='binary')
ax2.set_title("Sorted Adjacency Matrix (Morton Order)")
fig.colorbar(im2, ax=ax2)
plt.tight_layout()
plt.savefig("adjacency_matrices_comparison.png")
plt.show()
plt.close()
# Compute adjacency matrix for original points
threshold = 1000 # 1.5 * 1.0 / 100
original_adj_matrix = compute_adjacency_matrix(scaled_points, threshold)
# Compute adjacency matrix for sorted points
sorted_adj_matrix = compute_adjacency_matrix(sorted_points, threshold)
# Plot adjacency matrices side by side
# plot_adjacency_matrices(original_adj_matrix.dense().cpu().numpy(), sorted_adj_matrix.dense().cpu().numpy())
# Print some statistics
print(f"Original adjacency matrix sparsity: {original_adj_matrix.sum() / original_adj_matrix.numel():.4f}")
print(f"Sorted adjacency matrix sparsity: {sorted_adj_matrix.sum() / sorted_adj_matrix.numel():.4f}")
def compute_block_sparsity(adj_matrix, block_sizes=[8, 16, 32, 64, 512, 1024, 2048, 4096]):
"""Compute the block sparsity for given block sizes."""
results = {}
n = adj_matrix.shape[0]
# Get nonzero indices from the sparse COO tensor
nonzero_indices = adj_matrix._indices()
for k in block_sizes:
num_blocks = (n + k - 1) // k # Ceiling division
total_blocks = num_blocks * num_blocks
# Compute block indices for each nonzero element
block_indices = (nonzero_indices // k).unique(dim=1)
# Count unique block indices
nonempty_blocks = block_indices.shape[1]
results[k] = nonempty_blocks / total_blocks
return results
# Compute block sparsity for original and sorted adjacency matrices
original_block_sparsity = compute_block_sparsity(original_adj_matrix)
sorted_block_sparsity = compute_block_sparsity(sorted_adj_matrix)
# Print block sparsity results
print("\nBlock Sparsity (number of nonempty blocks):")
print("Block Size | Original | Sorted")
print("-" * 35)
for k in [8, 16, 32, 64, 512, 1024, 2048, 4096]:
print(f"{k:10d} | {original_block_sparsity[k]:8f} | {sorted_block_sparsity[k]:6f}")
def compute_block_diagonal_sparsity(adj_matrix, block_sizes=[8, 16, 32, 64, 512, 1024, 2048, 4096]):
"""Compute the sparsity on the block diagonal for given block sizes."""
results = {}
n = adj_matrix.shape[0]
# Get nonzero indices from the sparse COO tensor
nonzero_indices = adj_matrix._indices()
for k in block_sizes:
num_blocks = n // k
# Compute block indices for each nonzero element
block_row_indices = nonzero_indices[0] // k
block_col_indices = nonzero_indices[1] // k
# Count nonzero elements on the block diagonal
block_diagonal_mask = block_row_indices == block_col_indices
nonzero_on_block_diagonal = block_diagonal_mask.sum().item()
# Calculate total nonzero elements
total_nonzero = nonzero_indices.shape[1]
# Calculate the ratio of nonzero elements on the block diagonal
results[k] = nonzero_on_block_diagonal / total_nonzero
return results
# Compute block diagonal sparsity for original and sorted adjacency matrices
original_block_diagonal_sparsity = compute_block_diagonal_sparsity(original_adj_matrix)
sorted_block_diagonal_sparsity = compute_block_diagonal_sparsity(sorted_adj_matrix)
# Print block diagonal sparsity results
print("\nBlock Diagonal Sparsity (ratio of nonzero elements on block diagonal):")
print("Block Size | Original | Sorted")
print("-" * 35)
for k in [8, 16, 32, 64, 512, 1024, 2048, 4096]:
print(f"{k:10d} | {original_block_diagonal_sparsity[k]:8f} | {sorted_block_diagonal_sparsity[k]:6f}")
original_adj_matrix = original_adj_matrix.float().to_sparse_csr()
v = torch.randn_like(cube_points).cuda()
t = time.time()
for _ in range(100):
v = original_adj_matrix @ v
print(f"Time taken: {(time.time() - t) / 100} seconds")
sorted_adj_matrix = sorted_adj_matrix.float().to_sparse_csr()
v = torch.randn_like(sorted_points).cuda()
t = time.time()
for _ in range(100):
v = sorted_adj_matrix @ v
print(f"Time taken: {(time.time() - t) / 100} seconds")
def plot_3d_connections(points, adj_matrix, title):
"""Plot the 3D points and their connections based on the adjacency matrix."""
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
# Plot points
ax.scatter(points[:, 0], points[:, 1], points[:, 2], c='b', s=20)
# Plot connections
for i in range(len(points)):
for j in range(i+1, len(points)):
if adj_matrix[i, j]:
ax.plot([points[i, 0], points[j, 0]],
[points[i, 1], points[j, 1]],
[points[i, 2], points[j, 2]], 'r-', linewidth=0.5, alpha=0.3)
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.set_title(title)
plt.tight_layout()
plt.savefig(f"{title.lower().replace(' ', '_')}_3d.png")
plt.show()
plt.close()
# Plot 3D connections for original points
# plot_3d_connections(cube_points.cpu().numpy(), original_adj_matrix.cpu().numpy(), "Original 3D Connections")
# Plot 3D connections for sorted points
# plot_3d_connections(sorted_points.cpu().numpy(), sorted_adj_matrix.cpu().numpy(), "Sorted 3D Connections (Morton Order)")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment