Skip to content

Instantly share code, notes, and snippets.

@Matt54
Last active January 6, 2025 17:29
Show Gist options
  • Save Matt54/835a275bca08d782ea1b9be3fe6ab72b to your computer and use it in GitHub Desktop.
Save Matt54/835a275bca08d782ea1b9be3fe6ab72b to your computer and use it in GitHub Desktop.
Cube / Sphere RealityView using LowLevelMesh with metal shader
#ifndef cubeShaderTypes_h
#define cubeShaderTypes_h
#include <simd/simd.h>
struct CubeVertex {
vector_float3 position;
vector_float3 normal;
};
struct CubeParams {
vector_float3 size;
uint32_t dimensions_x;
uint32_t dimensions_y;
float cubeSphereInterpolationRatio;
};
#endif /* cubeShaderTypes_h */
import SwiftUI
import RealityKit
import Metal
struct CubeSphereState {
var size: SIMD3<Float> = [0.3, 0.3, 0.3]
var planeResolution: SIMD2<UInt32> = [16, 16]
var cubeSphereInterpolationRatio: Float = 0.0
}
struct MetalCubeExample: View {
@State private var rootEntity: Entity?
@State private var mesh: LowLevelMesh?
@State var state: CubeSphereState = .init()
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipeline: MTLComputePipelineState
@State var isForward: Bool = true
@State var time: Double = 0.0
@State var timer: Timer?
@State var deadBandValue: Double = 0.5
@State private var rotationAngles: SIMD3<Float> = [0, 0, 0]
@State private var lastRotationUpdateTime = CACurrentMediaTime()
let deadbandStep = 0.005
let modulationStep: Float = 0.02
let maxResolution: UInt32 = 128
init() {
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let updateFunction = library.makeFunction(name: "updateCubeMesh")!
self.computePipeline = try! device.makeComputePipelineState(function: updateFunction)
}
var body: some View {
RealityView { content in
let mesh = try! createMesh()
let resource = try! MeshResource(from: mesh)
let modelComponent = ModelComponent(mesh: resource, materials: [UnlitMaterial()])
let entity = Entity()
entity.components.set(modelComponent)
content.add(entity)
self.mesh = mesh
updateMesh(with: state)
self.rootEntity = entity
} update: { content in
updateMesh(with: state)
}
.onAppear { startTimer() }
.onDisappear { stopTimer() }
}
private func startTimer() {
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
updateRotation()
if deadBandValue > 0 {
deadBandValue -= deadbandStep
return
}
var ratio = state.cubeSphereInterpolationRatio
if isForward {
ratio += modulationStep
if ratio >= 1.0 {
deadBandValue = 1.0
ratio = 1
isForward = false
}
} else {
ratio -= modulationStep
if ratio <= 0.0 {
deadBandValue = 1.0
ratio = 0.0
isForward = true
}
}
state.cubeSphereInterpolationRatio = ratio
}
}
func updateRotation() {
let currentTime = CACurrentMediaTime()
let frameDuration = currentTime - lastRotationUpdateTime
self.time += frameDuration
// Rotate along all axis at different rates for a wonky roll effect
rotationAngles.x += Float(frameDuration * 0.25)
rotationAngles.y += Float(frameDuration * 0.125)
rotationAngles.z += Float(frameDuration * 0.0675)
let rotationX = simd_quatf(angle: rotationAngles.x, axis: [1, 0, 0])
let rotationY = simd_quatf(angle: rotationAngles.y, axis: [0, 1, 0])
let rotationZ = simd_quatf(angle: rotationAngles.z, axis: [0, 0, 1])
rootEntity?.transform.rotation = rotationX * rotationY * rotationZ
lastRotationUpdateTime = currentTime
}
private func stopTimer() {
timer?.invalidate()
timer = nil
}
private var vertexCount: Int {
Int(state.planeResolution.x * state.planeResolution.y * 6) // 6 faces
}
private var vertexCapacity: Int {
Int(maxResolution * maxResolution * 6)
}
private var indexCount: Int {
Int(6 * (state.planeResolution.x - 1) * (state.planeResolution.y - 1) * 6) // 6 indices per quad, 6 faces
}
private var indexCapacity: Int {
Int(6 * (maxResolution - 1) * (maxResolution - 1) * 6) // 6 indices per quad, 6 faces
}
private func createMesh() throws -> LowLevelMesh {
let vertexAttributes = [
LowLevelMesh.Attribute(semantic: .position, format: .float3, offset: 0),
LowLevelMesh.Attribute(semantic: .normal, format: .float3, offset: MemoryLayout<SIMD3<Float>>.stride)
]
let vertexLayouts = [
LowLevelMesh.Layout(bufferIndex: 0, bufferStride: MemoryLayout<CubeVertex>.stride)
]
return try LowLevelMesh(descriptor: .init(
vertexCapacity: vertexCapacity,
vertexAttributes: vertexAttributes,
vertexLayouts: vertexLayouts,
indexCapacity: indexCapacity
))
}
private func updateMesh(with state: CubeSphereState) {
guard let mesh = mesh,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
var params = CubeParams(
size: state.size,
dimensions_x: UInt32(state.planeResolution.x),
dimensions_y: UInt32(state.planeResolution.y),
cubeSphereInterpolationRatio: state.cubeSphereInterpolationRatio
)
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
let indexBuffer = mesh.replaceIndices(using: commandBuffer)
computeEncoder.setComputePipelineState(computePipeline)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1)
computeEncoder.setBytes(&params, length: MemoryLayout<CubeParams>.stride, index: 2)
let threadgroupSize = MTLSize(width: 64, height: 1, depth: 1)
let threadgroups = MTLSize(
width: (vertexCount + threadgroupSize.width - 1) / threadgroupSize.width,
height: 1,
depth: 1
)
computeEncoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
computeEncoder.endEncoding()
commandBuffer.commit()
let halfSize = state.size * 0.5
let bounds = BoundingBox(
min: -halfSize,
max: halfSize
)
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: mesh.descriptor.indexCapacity,
topology: .line,
bounds: bounds
)
])
}
}
#Preview {
MetalCubeExample()
}
#include <metal_stdlib>
#include "cubeShaderTypes.h"
using namespace metal;
constant uint planeOrder[6] = {0, 5, 3, 1, 4, 2};
kernel void updateCubeMesh(device CubeVertex* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant CubeParams& params [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
// Calculate which face and vertex we're working with
uint verticesPerPlane = params.dimensions_x * params.dimensions_y;
uint planeIndex = id / verticesPerPlane;
uint vertexInPlane = id % verticesPerPlane;
if (planeIndex >= 6) return;
uint x = vertexInPlane % params.dimensions_x;
uint y = vertexInPlane / params.dimensions_x;
float u = float(x) / float(params.dimensions_x - 1);
float v = float(y) / float(params.dimensions_y - 1);
float3 position;
float3 normal;
// Match the exact same face ordering as the original:
// Front (Z+), Back (Z-), Right (X+), Left (X-), Top (Y+), Bottom (Y-)
switch(planeIndex) {
case 0: // Front face (Z+)
position = float3(
params.size.x * (u - 0.5),
params.size.y * (v - 0.5),
params.size.z * 0.5
);
normal = float3(0, 0, 1);
break;
case 1: // Back face (Z-)
position = float3(
params.size.x * (u - 0.5),
params.size.y * (v - 0.5),
params.size.z * -0.5
);
normal = float3(0, 0, -1);
break;
case 2: // Right face (X+)
position = float3(
params.size.x * 0.5,
params.size.y * (v - 0.5),
params.size.z * (0.5 - u)
);
normal = float3(1, 0, 0);
break;
case 3: // Left face (X-)
position = float3(
params.size.x * -0.5,
params.size.y * (v - 0.5),
params.size.z * (0.5 - u)
);
normal = float3(-1, 0, 0);
break;
case 4: // Top face (Y+)
position = float3(
params.size.x * (u - 0.5),
params.size.y * 0.5,
params.size.z * (0.5 - v)
);
normal = float3(0, 1, 0);
break;
case 5: // Bottom face (Y-)
position = float3(
params.size.x * (u - 0.5),
params.size.y * -0.5,
params.size.z * (0.5 - v)
);
normal = float3(0, -1, 0);
break;
}
// Proportionally normalize based on normalization factor (0 = cube, 1 = fully normalized)
float3 scale = params.size * 0.5;
float3 normalizedPos = normalize(position) * scale;
position = mix(position, normalizedPos, params.cubeSphereInterpolationRatio);
vertices[id].position = position;
vertices[id].normal = normalize(position);
// Update indices in this order [0,5,3,1,4,2]
if (x < params.dimensions_x - 1 && y < params.dimensions_y - 1) {
// Convert planeIndex to the desired order
uint orderedPlaneIndex = 0;
for (uint i = 0; i < 6; i++) {
if (planeIndex == planeOrder[i]) {
orderedPlaneIndex = i;
break;
}
}
uint indexBase = (orderedPlaneIndex * (params.dimensions_x - 1) * (params.dimensions_y - 1) +
y * (params.dimensions_x - 1) + x) * 6;
uint bottomLeft = vertexInPlane;
uint bottomRight = bottomLeft + 1;
uint topLeft = bottomLeft + params.dimensions_x;
uint topRight = topLeft + 1;
// Add plane offset to indices
bottomLeft += planeIndex * verticesPerPlane;
bottomRight += planeIndex * verticesPerPlane;
topLeft += planeIndex * verticesPerPlane;
topRight += planeIndex * verticesPerPlane;
// Match the winding order from the original implementation
if (planeIndex == 1 || planeIndex == 3 || planeIndex == 5) {
// Back, Left, Bottom faces need reversed winding
indices[indexBase] = bottomLeft;
indices[indexBase + 1] = topLeft;
indices[indexBase + 2] = bottomRight;
indices[indexBase + 3] = bottomRight;
indices[indexBase + 4] = topLeft;
indices[indexBase + 5] = topRight;
} else {
// Front, Right, Top faces keep original winding
indices[indexBase] = bottomLeft;
indices[indexBase + 1] = bottomRight;
indices[indexBase + 2] = topLeft;
indices[indexBase + 3] = topLeft;
indices[indexBase + 4] = bottomRight;
indices[indexBase + 5] = topRight;
}
}
}
@Matt54
Copy link
Author

Matt54 commented Jan 6, 2025

Glad you got it working and happy to hear that you're enjoying the gists. I plan to create more organized resources for these soon. At the very least, including a project file would save readers from bridging header and path issues 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment