Skip to content

Instantly share code, notes, and snippets.

@Matt54
Created December 5, 2024 04:50
CPU and GPU implementation for Apple's LowLevelMesh plane example
import Foundation
import RealityKit
import SwiftUI
struct ApplePlaneExample: View {
var body: some View {
RealityView { content in
// Create a plane mesh.
if let planeMesh = try? PlaneMesh(size: [0.2, 0.2], dimensions: [16, 16]), let mesh = try? MeshResource(from: planeMesh.mesh) {
// Create an entity with the plane mesh.
let planeEntity = Entity()
let planeModel = ModelComponent(mesh: mesh, materials: [SimpleMaterial()])
planeEntity.components.set(planeModel)
// Add the plane entity to the scene.
content.add(planeEntity)
}
}
}
@MainActor struct PlaneMesh {
/// The plane low-level mesh.
var mesh: LowLevelMesh!
/// The size of the plane mesh.
let size: SIMD2<Float>
/// The number of vertices in each dimension of the plane mesh.
let dimensions: SIMD2<UInt32>
/// The maximum offset depth for the vertices of the plane mesh.
///
/// Use this to ensure the bounds of the plane encompass its vertices, even if they are offset.
let maxVertexDepth: Float
/// Initializes the plane mesh by creating a low-level mesh and filling its vertex and index buffers
/// to form a plane with given size and dimensions.
init(size: SIMD2<Float>, dimensions: SIMD2<UInt32>, maxVertexDepth: Float = 1) throws {
self.size = size
self.dimensions = dimensions
self.maxVertexDepth = maxVertexDepth
// Create the low-level mesh.
self.mesh = try createMesh()
// Fill the mesh's vertex buffer with data.
initializeVertexData()
// Fill the mesh's index buffer with data.
initializeIndexData()
// Initialize the mesh parts.
initializeMeshParts()
}
/// Creates a low-level mesh with `PlaneVertex` vertices.
private func createMesh() throws -> LowLevelMesh {
// Define the vertex attributes of `PlaneVertex`.
let positionAttributeOffset = MemoryLayout.offset(of: \PlaneVertex.position) ?? 0
let normalAttributeOffset = MemoryLayout.offset(of: \PlaneVertex.normal) ?? 16
let positionAttribute = LowLevelMesh.Attribute(semantic: .position, format: .float3, offset: positionAttributeOffset)
let normalAttribute = LowLevelMesh.Attribute(semantic: .normal, format: .float3, offset: normalAttributeOffset)
let vertexAttributes = [positionAttribute, normalAttribute]
// Define the vertex layouts of `PlaneVertex`.
let vertexLayouts = [LowLevelMesh.Layout(bufferIndex: 0, bufferStride: MemoryLayout<PlaneVertex>.stride)]
let vertexCount = Int(dimensions.x * dimensions.y)
let indicesPerTriangle = 3
let trianglesPerCell = 2
let cellCount = Int((dimensions.x - 1) * (dimensions.y - 1))
let indexCount = indicesPerTriangle * trianglesPerCell * cellCount
// Create a low-level mesh with the necessary `PlaneVertex` capacity.
let meshDescriptor = LowLevelMesh.Descriptor(vertexCapacity: vertexCount,
vertexAttributes: vertexAttributes,
vertexLayouts: vertexLayouts,
indexCapacity: indexCount)
return try LowLevelMesh(descriptor: meshDescriptor)
}
/// Initialize the vertices of the mesh, positioning them to form an xy-plane with the given size.
private func initializeVertexData() {
// Initialize mesh vertex positions and normals.
mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
// Convert `rawBytes` into a `PlaneVertex` buffer pointer.
let vertices = rawBytes.bindMemory(to: PlaneVertex.self)
// Define the normal direction for the vertices.
let normalDirection: SIMD3<Float> = [0, 0, 1]
// Iterate through each vertex.
for xCoord in 0..<dimensions.x {
for yCoord in 0..<dimensions.y {
// Remap the x and y vertex coordinates to the range [0, 1].
let xCoord01 = Float(xCoord) / Float(dimensions.x - 1)
let yCoord01 = Float(yCoord) / Float(dimensions.y - 1)
// Derive the vertex position from the remapped vertex coordinates and the size.
let xPosition = size.x * xCoord01 - size.x / 2
let yPosition = size.y * yCoord01 - size.y / 2
let zPosition = Float(0)
// Get the current vertex from the vertex coordinates and set its position and normal.
let vertexIndex = Int(vertexIndex(xCoord, yCoord))
vertices[vertexIndex].position = [xPosition, yPosition, zPosition]
vertices[vertexIndex].normal = normalDirection
}
}
}
}
// The winding order of the vertices in a triangle determine which side of the triangle is the front. RealityKit considers a counterclockwise winding order to be front-facing.
/// Initializes the indices of the mesh two triangles at a time for each cell in the mesh.
private func initializeIndexData() {
mesh.withUnsafeMutableIndices { rawIndices in
// Convert `rawIndices` into a UInt32 pointer.
guard var indices = rawIndices.baseAddress?.assumingMemoryBound(to: UInt32.self) else { return }
// Iterate through each cell.
for xCoord in 0..<(dimensions.x - 1) {
for yCoord in 0..<(dimensions.y - 1) {
/*
Each cell in the plane mesh consists of two triangles:
topLeft topRight
|\ ̅ ̅|
1st Triangle--> | \ | <-- 2nd Triangle
| ̲ ̲\|
+y bottomLeft bottomRight
^
|
*---> +x
*/
let bottomLeft = vertexIndex(xCoord, yCoord)
let bottomRight = vertexIndex(xCoord + 1, yCoord)
let topLeft = vertexIndex(xCoord, yCoord + 1)
let topRight = vertexIndex(xCoord + 1, yCoord + 1)
// Create the 1st triangle with a counterclockwise winding order.
indices[0] = bottomLeft
indices[1] = bottomRight
indices[2] = topLeft
// Create the 2nd triangle with a counterclockwise winding order.
indices[3] = topLeft
indices[4] = bottomRight
indices[5] = topRight
indices += 6
}
}
}
}
/// Initializes mesh parts, indicating topology and bounds.
func initializeMeshParts() {
// Create a bounding box that encompasses the plane's size and max vertex depth.
let bounds = BoundingBox(min: [-size.x / 2, -size.y / 2, 0],
max: [size.x / 2, size.y / 2, maxVertexDepth])
mesh.parts.replaceAll([LowLevelMesh.Part(indexCount: mesh.descriptor.indexCapacity,
topology: .lineStrip,
bounds: bounds)])
}
/// Converts a 2D vertex coordinate to a 1D vertex buffer index.
private func vertexIndex(_ xCoord: UInt32, _ yCoord: UInt32) -> UInt32 {
xCoord + dimensions.x * yCoord
}
}
}
#Preview {
ApplePlaneExample()
}
import Foundation
import SwiftUI
import RealityKit
import Metal
struct ApplePlaneMetalExample: View {
@State private var mesh: LowLevelMesh?
let size: SIMD2<Float>
let dimensions: SIMD2<UInt32>
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipeline: MTLComputePipelineState
init(size: SIMD2<Float> = [0.2, 0.2], dimensions: SIMD2<UInt32> = [16, 16]) {
self.size = size
self.dimensions = dimensions
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let updateFunction = library.makeFunction(name: "updatePlaneMesh")!
self.computePipeline = try! device.makeComputePipelineState(function: updateFunction)
}
var body: some View {
RealityView { content in
let mesh = try! createMesh()
let resource = try! MeshResource(from: mesh)
let modelComponent = ModelComponent(mesh: resource, materials: [SimpleMaterial()])
let entity = Entity()
entity.components.set(modelComponent)
content.add(entity)
self.mesh = mesh
updateMesh()
}
}
var vertexCount: Int {
Int(dimensions.x * dimensions.y)
}
var indexCount: Int {
Int(6 * (dimensions.x - 1) * (dimensions.y - 1))
}
func createMesh() throws -> LowLevelMesh {
let positionAttribute = LowLevelMesh.Attribute(semantic: .position, format: .float3, offset: 0)
let normalAttribute = LowLevelMesh.Attribute(semantic: .normal, format: .float3, offset: MemoryLayout<SIMD3<Float>>.stride)
let vertexAttributes = [positionAttribute, normalAttribute]
let vertexLayouts = [LowLevelMesh.Layout(bufferIndex: 0, bufferStride: MemoryLayout<PlaneVertex>.stride)]
let meshDescriptor = LowLevelMesh.Descriptor(
vertexCapacity: vertexCount,
vertexAttributes: vertexAttributes,
vertexLayouts: vertexLayouts,
indexCapacity: indexCount
)
return try LowLevelMesh(descriptor: meshDescriptor)
}
func updateMesh() {
guard let mesh = mesh,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
var params = PlaneParams(size: size, dimensions: dimensions)
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
let indexBuffer = mesh.replaceIndices(using: commandBuffer)
computeEncoder.setComputePipelineState(computePipeline)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1)
computeEncoder.setBytes(&params, length: MemoryLayout<PlaneParams>.stride, index: 2)
let threadgroupSize = MTLSize(width: 8, height: 8, depth: 1)
let threadgroups = MTLSize(
width: (Int(dimensions.x) + threadgroupSize.width - 1) / threadgroupSize.width,
height: (Int(dimensions.y) + threadgroupSize.height - 1) / threadgroupSize.height,
depth: 1
)
computeEncoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
computeEncoder.endEncoding()
commandBuffer.commit()
let bounds = BoundingBox(
min: [-size.x/2, -size.y/2, 0],
max: [size.x/2, size.y/2, 0]
)
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: mesh.descriptor.indexCapacity,
topology: .lineStrip,
bounds: bounds
)
])
}
}
#Preview {
ApplePlaneMetalExample()
}
#ifndef PlaneParams_h
#define PlaneParams_h
struct PlaneParams {
simd_float2 size;
simd_uint2 dimensions;
};
#endif /* PlaneParams_h */
#ifndef PlaneVertex_h
#define PlaneVertex_h
struct PlaneVertex {
simd_float3 position;
simd_float3 normal;
};
#endif /* PlaneVertex_h */
#include <metal_stdlib>
using namespace metal;
#include "PlaneVertex.h"
#include "PlaneParams.h"
kernel void updatePlaneMesh(device PlaneVertex* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant PlaneParams& params [[buffer(2)]],
uint2 id [[thread_position_in_grid]])
{
uint x = id.x;
uint y = id.y;
if (x >= params.dimensions.x || y >= params.dimensions.y) return;
// Calculate vertex index
uint vertexIndex = y * params.dimensions.x + x;
// Calculate normalized coordinates (0 to 1)
float xCoord01 = float(x) / float(params.dimensions.x - 1);
float yCoord01 = float(y) / float(params.dimensions.y - 1);
// Calculate actual position
float xPosition = params.size.x * xCoord01 - params.size.x / 2;
float yPosition = params.size.y * yCoord01 - params.size.y / 2;
float zPosition = 0.0;
// Set vertex position and normal
vertices[vertexIndex].position = float3(xPosition, yPosition, zPosition);
vertices[vertexIndex].normal = float3(0, 0, 1);
// Create indices for triangles
if (x < params.dimensions.x - 1 && y < params.dimensions.y - 1) {
uint indexBase = 6 * (y * (params.dimensions.x - 1) + x);
uint bottomLeft = vertexIndex;
uint bottomRight = bottomLeft + 1;
uint topLeft = bottomLeft + params.dimensions.x;
uint topRight = topLeft + 1;
// First triangle
indices[indexBase] = bottomLeft;
indices[indexBase + 1] = bottomRight;
indices[indexBase + 2] = topLeft;
// Second triangle
indices[indexBase + 3] = topLeft;
indices[indexBase + 4] = bottomRight;
indices[indexBase + 5] = topRight;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment