Created
          October 16, 2024 23:24 
        
      - 
      
 - 
        
Save cat-state/709eee029dd70c635e8e0a8deddbbdbd to your computer and use it in GitHub Desktop.  
    Morton Code Attention
  
        
  
    
      This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
      Learn more about bidirectional Unicode characters
    
  
  
    
  | import torch | |
| import matplotlib.pyplot as plt | |
| from mpl_toolkits.mplot3d import Axes3D | |
| import numpy as np | |
| def quantize_coords(coords: torch.Tensor, bits: int = 21): | |
| max_int = (1 << bits) - 1 | |
| coords = coords.clamp(0, 1) | |
| return (coords * max_int).long() | |
| def split_by_3_bits_21(x: torch.Tensor): | |
| x = (x | (x << 32)) & 0x1f00000000ffff | |
| x = (x | (x << 16)) & 0x1f0000ff0000ff | |
| x = (x | (x << 8)) & 0x100f00f00f00f00f | |
| x = (x | (x << 4)) & 0x10c30c30c30c30c3 | |
| x = (x | (x << 2)) & 0x1249249249249249 | |
| return x | |
| # @torch.compile | |
| def morton_encode(coords: torch.Tensor, bits: int = 21): | |
| coords = quantize_coords(coords, bits) | |
| x = split_by_3_bits_21(coords[..., 0]) | |
| y = split_by_3_bits_21(coords[..., 1]) << 1 | |
| z = split_by_3_bits_21(coords[..., 2]) << 2 | |
| morton_code = x | y | z | |
| return morton_code | |
| def create_cube_points(n=2): | |
| # Create a grid of points in a cube shape | |
| x = np.linspace(0, 1, n) | |
| y = np.linspace(0, 1, n) | |
| z = np.linspace(0, 1, n) | |
| points = np.array(np.meshgrid(x, y, z)).T.reshape(-1, 3) | |
| return torch.tensor(points, dtype=torch.float32).cuda() | |
| # Create cube points | |
| n = 1024 | |
| cube_points = torch.rand(n**2, 3).cuda() | |
| # Scale points to match the range used in morton_encode | |
| scaled_points = (cube_points) #/ cube_points.norm(dim=1, keepdim=True) | |
| scaled_points = (scaled_points - scaled_points.min(dim=0).values) / (scaled_points.max(dim=0).values - scaled_points.min(dim=0).values) | |
| # Encode points using Morton encoding | |
| morton_codes = morton_encode(scaled_points) | |
| # Sort points based on Morton codes | |
| sorted_indices = torch.argsort(morton_codes) | |
| sorted_points = cube_points[sorted_indices] | |
| import time | |
| # start = time.time() | |
| # for _ in range(1): | |
| # morton_codes = morton_encode(scaled_points) | |
| # sorted_indices = torch.argsort(morton_codes) | |
| # sorted_points = cube_points[sorted_indices] | |
| # end = time.time() | |
| # print(f"Time taken: {(end - start) / 100} seconds") | |
| # Convert to numpy for plotting | |
| sorted_points_np = sorted_points.cpu().numpy() | |
| # Plot the results | |
| def plot_3d_arrows(points, order): | |
| fig = plt.figure(figsize=(10, 10)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| # Plot points | |
| ax.scatter(points[:, 0], points[:, 1], points[:, 2]) | |
| # Plot arrows connecting points in order | |
| for i in range(len(order) - 1): | |
| start = points[order[i]] | |
| end = points[order[i+1]] | |
| ax.quiver(start[0], start[1], start[2], | |
| end[0] - start[0], end[1] - start[1], end[2] - start[2], | |
| color='r', arrow_length_ratio=0.1) | |
| ax.set_xlabel('X') | |
| ax.set_ylabel('Y') | |
| ax.set_zlabel('Z') | |
| plt.title('Morton Order Visualization') | |
| plt.savefig("morton_order.png") | |
| plt.show() | |
| # plot_3d_arrows(sorted_points_np, range(len(sorted_points_np))) | |
| print("Morton codes:", morton_codes) | |
| print("Sorted indices:", sorted_indices) | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import torchdr | |
| def compute_adjacency_matrix(points, threshold=1.1, k=4): | |
| """Compute the adjacency matrix for the given points.""" | |
| k_min_vals, k_closest = torchdr.pairwise_distances(points, points, metric="euclidean", keops=True).Kmin_argKmin(K=k+1, dim=1) | |
| print(k_closest.device) | |
| # Remove self-connections | |
| mask = k_closest != torch.arange(points.shape[0], device=points.device).unsqueeze(1) | |
| k_closest = k_closest[mask].view(points.shape[0], k) | |
| k_min_vals = k_min_vals[mask].view(points.shape[0], k) | |
| # Filter out connections that are too far | |
| valid_connections = k_min_vals <= threshold | |
| # Create flattened indices and values for valid connections | |
| row_indices = torch.arange(points.shape[0], device=points.device).unsqueeze(1).expand(-1, k)[valid_connections] | |
| col_indices = k_closest[valid_connections] | |
| indices = torch.stack([row_indices, col_indices]) | |
| values = torch.ones(indices.shape[1], dtype=torch.bool, device=points.device) | |
| adj_matrix = torch.sparse_coo_tensor(indices, values, (points.shape[0], points.shape[0]), device=points.device) | |
| return adj_matrix | |
| def plot_adjacency_matrices(original_adj_matrix, sorted_adj_matrix): | |
| """Plot the original and sorted adjacency matrices side by side.""" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10)) | |
| im1 = ax1.imshow(original_adj_matrix, cmap='binary') | |
| ax1.set_title("Original Adjacency Matrix") | |
| fig.colorbar(im1, ax=ax1) | |
| im2 = ax2.imshow(sorted_adj_matrix, cmap='binary') | |
| ax2.set_title("Sorted Adjacency Matrix (Morton Order)") | |
| fig.colorbar(im2, ax=ax2) | |
| plt.tight_layout() | |
| plt.savefig("adjacency_matrices_comparison.png") | |
| plt.show() | |
| plt.close() | |
| # Compute adjacency matrix for original points | |
| threshold = 1000 # 1.5 * 1.0 / 100 | |
| original_adj_matrix = compute_adjacency_matrix(scaled_points, threshold) | |
| # Compute adjacency matrix for sorted points | |
| sorted_adj_matrix = compute_adjacency_matrix(sorted_points, threshold) | |
| # Plot adjacency matrices side by side | |
| # plot_adjacency_matrices(original_adj_matrix.dense().cpu().numpy(), sorted_adj_matrix.dense().cpu().numpy()) | |
| # Print some statistics | |
| print(f"Original adjacency matrix sparsity: {original_adj_matrix.sum() / original_adj_matrix.numel():.4f}") | |
| print(f"Sorted adjacency matrix sparsity: {sorted_adj_matrix.sum() / sorted_adj_matrix.numel():.4f}") | |
| def compute_block_sparsity(adj_matrix, block_sizes=[8, 16, 32, 64, 512, 1024, 2048, 4096]): | |
| """Compute the block sparsity for given block sizes.""" | |
| results = {} | |
| n = adj_matrix.shape[0] | |
| # Get nonzero indices from the sparse COO tensor | |
| nonzero_indices = adj_matrix._indices() | |
| for k in block_sizes: | |
| num_blocks = (n + k - 1) // k # Ceiling division | |
| total_blocks = num_blocks * num_blocks | |
| # Compute block indices for each nonzero element | |
| block_indices = (nonzero_indices // k).unique(dim=1) | |
| # Count unique block indices | |
| nonempty_blocks = block_indices.shape[1] | |
| results[k] = nonempty_blocks / total_blocks | |
| return results | |
| # Compute block sparsity for original and sorted adjacency matrices | |
| original_block_sparsity = compute_block_sparsity(original_adj_matrix) | |
| sorted_block_sparsity = compute_block_sparsity(sorted_adj_matrix) | |
| # Print block sparsity results | |
| print("\nBlock Sparsity (number of nonempty blocks):") | |
| print("Block Size | Original | Sorted") | |
| print("-" * 35) | |
| for k in [8, 16, 32, 64, 512, 1024, 2048, 4096]: | |
| print(f"{k:10d} | {original_block_sparsity[k]:8f} | {sorted_block_sparsity[k]:6f}") | |
| def compute_block_diagonal_sparsity(adj_matrix, block_sizes=[8, 16, 32, 64, 512, 1024, 2048, 4096]): | |
| """Compute the sparsity on the block diagonal for given block sizes.""" | |
| results = {} | |
| n = adj_matrix.shape[0] | |
| # Get nonzero indices from the sparse COO tensor | |
| nonzero_indices = adj_matrix._indices() | |
| for k in block_sizes: | |
| num_blocks = n // k | |
| # Compute block indices for each nonzero element | |
| block_row_indices = nonzero_indices[0] // k | |
| block_col_indices = nonzero_indices[1] // k | |
| # Count nonzero elements on the block diagonal | |
| block_diagonal_mask = block_row_indices == block_col_indices | |
| nonzero_on_block_diagonal = block_diagonal_mask.sum().item() | |
| # Calculate total nonzero elements | |
| total_nonzero = nonzero_indices.shape[1] | |
| # Calculate the ratio of nonzero elements on the block diagonal | |
| results[k] = nonzero_on_block_diagonal / total_nonzero | |
| return results | |
| # Compute block diagonal sparsity for original and sorted adjacency matrices | |
| original_block_diagonal_sparsity = compute_block_diagonal_sparsity(original_adj_matrix) | |
| sorted_block_diagonal_sparsity = compute_block_diagonal_sparsity(sorted_adj_matrix) | |
| # Print block diagonal sparsity results | |
| print("\nBlock Diagonal Sparsity (ratio of nonzero elements on block diagonal):") | |
| print("Block Size | Original | Sorted") | |
| print("-" * 35) | |
| for k in [8, 16, 32, 64, 512, 1024, 2048, 4096]: | |
| print(f"{k:10d} | {original_block_diagonal_sparsity[k]:8f} | {sorted_block_diagonal_sparsity[k]:6f}") | |
| original_adj_matrix = original_adj_matrix.float().to_sparse_csr() | |
| v = torch.randn_like(cube_points).cuda() | |
| t = time.time() | |
| for _ in range(100): | |
| v = original_adj_matrix @ v | |
| print(f"Time taken: {(time.time() - t) / 100} seconds") | |
| sorted_adj_matrix = sorted_adj_matrix.float().to_sparse_csr() | |
| v = torch.randn_like(sorted_points).cuda() | |
| t = time.time() | |
| for _ in range(100): | |
| v = sorted_adj_matrix @ v | |
| print(f"Time taken: {(time.time() - t) / 100} seconds") | |
| def plot_3d_connections(points, adj_matrix, title): | |
| """Plot the 3D points and their connections based on the adjacency matrix.""" | |
| fig = plt.figure(figsize=(10, 10)) | |
| ax = fig.add_subplot(111, projection='3d') | |
| # Plot points | |
| ax.scatter(points[:, 0], points[:, 1], points[:, 2], c='b', s=20) | |
| # Plot connections | |
| for i in range(len(points)): | |
| for j in range(i+1, len(points)): | |
| if adj_matrix[i, j]: | |
| ax.plot([points[i, 0], points[j, 0]], | |
| [points[i, 1], points[j, 1]], | |
| [points[i, 2], points[j, 2]], 'r-', linewidth=0.5, alpha=0.3) | |
| ax.set_xlabel('X') | |
| ax.set_ylabel('Y') | |
| ax.set_zlabel('Z') | |
| ax.set_title(title) | |
| plt.tight_layout() | |
| plt.savefig(f"{title.lower().replace(' ', '_')}_3d.png") | |
| plt.show() | |
| plt.close() | |
| # Plot 3D connections for original points | |
| # plot_3d_connections(cube_points.cpu().numpy(), original_adj_matrix.cpu().numpy(), "Original 3D Connections") | |
| # Plot 3D connections for sorted points | |
| # plot_3d_connections(sorted_points.cpu().numpy(), sorted_adj_matrix.cpu().numpy(), "Sorted 3D Connections (Morton Order)") | 
  
    Sign up for free
    to join this conversation on GitHub.
    Already have an account?
    Sign in to comment