Skip to content

Instantly share code, notes, and snippets.

@Matt54
Last active January 6, 2025 17:29
Show Gist options
  • Save Matt54/835a275bca08d782ea1b9be3fe6ab72b to your computer and use it in GitHub Desktop.
Save Matt54/835a275bca08d782ea1b9be3fe6ab72b to your computer and use it in GitHub Desktop.
Cube / Sphere RealityView using LowLevelMesh with metal shader
#ifndef cubeShaderTypes_h
#define cubeShaderTypes_h
#include <simd/simd.h>
struct CubeVertex {
vector_float3 position;
vector_float3 normal;
};
struct CubeParams {
vector_float3 size;
uint32_t dimensions_x;
uint32_t dimensions_y;
float cubeSphereInterpolationRatio;
};
#endif /* cubeShaderTypes_h */
import SwiftUI
import RealityKit
import Metal
struct CubeSphereState {
var size: SIMD3<Float> = [0.3, 0.3, 0.3]
var planeResolution: SIMD2<UInt32> = [16, 16]
var cubeSphereInterpolationRatio: Float = 0.0
}
struct MetalCubeExample: View {
@State private var rootEntity: Entity?
@State private var mesh: LowLevelMesh?
@State var state: CubeSphereState = .init()
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipeline: MTLComputePipelineState
@State var isForward: Bool = true
@State var time: Double = 0.0
@State var timer: Timer?
@State var deadBandValue: Double = 0.5
@State private var rotationAngles: SIMD3<Float> = [0, 0, 0]
@State private var lastRotationUpdateTime = CACurrentMediaTime()
let deadbandStep = 0.005
let modulationStep: Float = 0.02
let maxResolution: UInt32 = 128
init() {
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let updateFunction = library.makeFunction(name: "updateCubeMesh")!
self.computePipeline = try! device.makeComputePipelineState(function: updateFunction)
}
var body: some View {
RealityView { content in
let mesh = try! createMesh()
let resource = try! MeshResource(from: mesh)
let modelComponent = ModelComponent(mesh: resource, materials: [UnlitMaterial()])
let entity = Entity()
entity.components.set(modelComponent)
content.add(entity)
self.mesh = mesh
updateMesh(with: state)
self.rootEntity = entity
} update: { content in
updateMesh(with: state)
}
.onAppear { startTimer() }
.onDisappear { stopTimer() }
}
private func startTimer() {
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
updateRotation()
if deadBandValue > 0 {
deadBandValue -= deadbandStep
return
}
var ratio = state.cubeSphereInterpolationRatio
if isForward {
ratio += modulationStep
if ratio >= 1.0 {
deadBandValue = 1.0
ratio = 1
isForward = false
}
} else {
ratio -= modulationStep
if ratio <= 0.0 {
deadBandValue = 1.0
ratio = 0.0
isForward = true
}
}
state.cubeSphereInterpolationRatio = ratio
}
}
func updateRotation() {
let currentTime = CACurrentMediaTime()
let frameDuration = currentTime - lastRotationUpdateTime
self.time += frameDuration
// Rotate along all axis at different rates for a wonky roll effect
rotationAngles.x += Float(frameDuration * 0.25)
rotationAngles.y += Float(frameDuration * 0.125)
rotationAngles.z += Float(frameDuration * 0.0675)
let rotationX = simd_quatf(angle: rotationAngles.x, axis: [1, 0, 0])
let rotationY = simd_quatf(angle: rotationAngles.y, axis: [0, 1, 0])
let rotationZ = simd_quatf(angle: rotationAngles.z, axis: [0, 0, 1])
rootEntity?.transform.rotation = rotationX * rotationY * rotationZ
lastRotationUpdateTime = currentTime
}
private func stopTimer() {
timer?.invalidate()
timer = nil
}
private var vertexCount: Int {
Int(state.planeResolution.x * state.planeResolution.y * 6) // 6 faces
}
private var vertexCapacity: Int {
Int(maxResolution * maxResolution * 6)
}
private var indexCount: Int {
Int(6 * (state.planeResolution.x - 1) * (state.planeResolution.y - 1) * 6) // 6 indices per quad, 6 faces
}
private var indexCapacity: Int {
Int(6 * (maxResolution - 1) * (maxResolution - 1) * 6) // 6 indices per quad, 6 faces
}
private func createMesh() throws -> LowLevelMesh {
let vertexAttributes = [
LowLevelMesh.Attribute(semantic: .position, format: .float3, offset: 0),
LowLevelMesh.Attribute(semantic: .normal, format: .float3, offset: MemoryLayout<SIMD3<Float>>.stride)
]
let vertexLayouts = [
LowLevelMesh.Layout(bufferIndex: 0, bufferStride: MemoryLayout<CubeVertex>.stride)
]
return try LowLevelMesh(descriptor: .init(
vertexCapacity: vertexCapacity,
vertexAttributes: vertexAttributes,
vertexLayouts: vertexLayouts,
indexCapacity: indexCapacity
))
}
private func updateMesh(with state: CubeSphereState) {
guard let mesh = mesh,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
var params = CubeParams(
size: state.size,
dimensions_x: UInt32(state.planeResolution.x),
dimensions_y: UInt32(state.planeResolution.y),
cubeSphereInterpolationRatio: state.cubeSphereInterpolationRatio
)
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
let indexBuffer = mesh.replaceIndices(using: commandBuffer)
computeEncoder.setComputePipelineState(computePipeline)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1)
computeEncoder.setBytes(&params, length: MemoryLayout<CubeParams>.stride, index: 2)
let threadgroupSize = MTLSize(width: 64, height: 1, depth: 1)
let threadgroups = MTLSize(
width: (vertexCount + threadgroupSize.width - 1) / threadgroupSize.width,
height: 1,
depth: 1
)
computeEncoder.dispatchThreadgroups(threadgroups, threadsPerThreadgroup: threadgroupSize)
computeEncoder.endEncoding()
commandBuffer.commit()
let halfSize = state.size * 0.5
let bounds = BoundingBox(
min: -halfSize,
max: halfSize
)
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: mesh.descriptor.indexCapacity,
topology: .line,
bounds: bounds
)
])
}
}
#Preview {
MetalCubeExample()
}
#include <metal_stdlib>
#include "cubeShaderTypes.h"
using namespace metal;
constant uint planeOrder[6] = {0, 5, 3, 1, 4, 2};
kernel void updateCubeMesh(device CubeVertex* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant CubeParams& params [[buffer(2)]],
uint id [[thread_position_in_grid]])
{
// Calculate which face and vertex we're working with
uint verticesPerPlane = params.dimensions_x * params.dimensions_y;
uint planeIndex = id / verticesPerPlane;
uint vertexInPlane = id % verticesPerPlane;
if (planeIndex >= 6) return;
uint x = vertexInPlane % params.dimensions_x;
uint y = vertexInPlane / params.dimensions_x;
float u = float(x) / float(params.dimensions_x - 1);
float v = float(y) / float(params.dimensions_y - 1);
float3 position;
float3 normal;
// Match the exact same face ordering as the original:
// Front (Z+), Back (Z-), Right (X+), Left (X-), Top (Y+), Bottom (Y-)
switch(planeIndex) {
case 0: // Front face (Z+)
position = float3(
params.size.x * (u - 0.5),
params.size.y * (v - 0.5),
params.size.z * 0.5
);
normal = float3(0, 0, 1);
break;
case 1: // Back face (Z-)
position = float3(
params.size.x * (u - 0.5),
params.size.y * (v - 0.5),
params.size.z * -0.5
);
normal = float3(0, 0, -1);
break;
case 2: // Right face (X+)
position = float3(
params.size.x * 0.5,
params.size.y * (v - 0.5),
params.size.z * (0.5 - u)
);
normal = float3(1, 0, 0);
break;
case 3: // Left face (X-)
position = float3(
params.size.x * -0.5,
params.size.y * (v - 0.5),
params.size.z * (0.5 - u)
);
normal = float3(-1, 0, 0);
break;
case 4: // Top face (Y+)
position = float3(
params.size.x * (u - 0.5),
params.size.y * 0.5,
params.size.z * (0.5 - v)
);
normal = float3(0, 1, 0);
break;
case 5: // Bottom face (Y-)
position = float3(
params.size.x * (u - 0.5),
params.size.y * -0.5,
params.size.z * (0.5 - v)
);
normal = float3(0, -1, 0);
break;
}
// Proportionally normalize based on normalization factor (0 = cube, 1 = fully normalized)
float3 scale = params.size * 0.5;
float3 normalizedPos = normalize(position) * scale;
position = mix(position, normalizedPos, params.cubeSphereInterpolationRatio);
vertices[id].position = position;
vertices[id].normal = normalize(position);
// Update indices in this order [0,5,3,1,4,2]
if (x < params.dimensions_x - 1 && y < params.dimensions_y - 1) {
// Convert planeIndex to the desired order
uint orderedPlaneIndex = 0;
for (uint i = 0; i < 6; i++) {
if (planeIndex == planeOrder[i]) {
orderedPlaneIndex = i;
break;
}
}
uint indexBase = (orderedPlaneIndex * (params.dimensions_x - 1) * (params.dimensions_y - 1) +
y * (params.dimensions_x - 1) + x) * 6;
uint bottomLeft = vertexInPlane;
uint bottomRight = bottomLeft + 1;
uint topLeft = bottomLeft + params.dimensions_x;
uint topRight = topLeft + 1;
// Add plane offset to indices
bottomLeft += planeIndex * verticesPerPlane;
bottomRight += planeIndex * verticesPerPlane;
topLeft += planeIndex * verticesPerPlane;
topRight += planeIndex * verticesPerPlane;
// Match the winding order from the original implementation
if (planeIndex == 1 || planeIndex == 3 || planeIndex == 5) {
// Back, Left, Bottom faces need reversed winding
indices[indexBase] = bottomLeft;
indices[indexBase + 1] = topLeft;
indices[indexBase + 2] = bottomRight;
indices[indexBase + 3] = bottomRight;
indices[indexBase + 4] = topLeft;
indices[indexBase + 5] = topRight;
} else {
// Front, Right, Top faces keep original winding
indices[indexBase] = bottomLeft;
indices[indexBase + 1] = bottomRight;
indices[indexBase + 2] = topLeft;
indices[indexBase + 3] = topLeft;
indices[indexBase + 4] = bottomRight;
indices[indexBase + 5] = topRight;
}
}
}
@rth238
Copy link

rth238 commented Dec 31, 2024

Thanks for providing these code examples; they have been very insightful! With this one, for some reason, although I have the .h file in the project, I get 3 error messages: Cannot find type 'CubeVertex' in scope and 2 of the same: Cannot find type 'CubeParams' in scope. I've removed the .h file, reinserted, etc., but it seems a setting may be preventing the use of the header file? Any suggestions would be welcomed!

@Matt54
Copy link
Author

Matt54 commented Jan 5, 2025

Sorry for the late reply here @rth238! Do you have a bridging header that includes the path to your header file that defines these types?

The file should include something like:
#include ".../cubeShaderTypes.h"

This bridging header file makes sure that Swift can see the C types - it's the bridge between C and Swift. You'll want to make sure that this is setup correctly in your project. Check "Objective-C Bridging Header" in Build Settings

@rth238
Copy link

rth238 commented Jan 6, 2025

OK, thanks, I've not had to use the bridge header or modify the build setting before, so this got me pointed in the right direction. (But instead of #include, it was #import) So thanks for this suggestion, and than thanks again for opening up all of your remarkable work by posting these examples!

@Matt54
Copy link
Author

Matt54 commented Jan 6, 2025

Glad you got it working and happy to hear that you're enjoying the gists. I plan to create more organized resources for these soon. At the very least, including a project file would save readers from bridging header and path issues 😆

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment