Skip to content

Instantly share code, notes, and snippets.

@Matt54
Created July 31, 2024 04:23
Show Gist options
  • Save Matt54/d999a987e04cc85f0a2ca4f2f9156c08 to your computer and use it in GitHub Desktop.
Save Matt54/d999a987e04cc85f0a2ca4f2f9156c08 to your computer and use it in GitHub Desktop.
RealityView comparing a wave plane animation of a LowLevelMesh between a pure Swift and a Metal shader implementation
import SwiftUI
import RealityKit
import Metal
struct MetalWavyPlaneView: View {
@State private var phase: Float = 0.0
@State private var mesh: LowLevelMesh?
@State private var timer: Timer?
let resolution = 250
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipeline: MTLComputePipelineState
init() {
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let updateFunction = library.makeFunction(name: "updateWavyPlaneFromCenter")!
self.computePipeline = try! device.makeComputePipelineState(function: updateFunction)
}
var body: some View {
RealityView { content in
let planeEntity = try! getPlaneEntity()
let lightEntity = try! getLightEntity()
planeEntity.addChild(lightEntity)
content.add(planeEntity)
}
.onAppear { startTimer() }
.onDisappear { stopTimer() }
}
private func startTimer() {
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
phase += 0.1
updateMesh(phase: phase)
}
}
private func stopTimer() {
timer?.invalidate()
timer = nil
}
func getLightEntity() throws -> Entity {
let entity = Entity()
let pointLightComponent = PointLightComponent( cgColor: .init(red: 1, green: 1, blue: 1, alpha: 1), intensity: 10000, attenuationRadius: 0.5 )
entity.components.set(pointLightComponent)
entity.position = .init(x: 0, y: 0, z: 0.125)
return entity
}
func getPlaneEntity() throws -> Entity {
let mesh = try createPlaneMesh()
let resource = try! MeshResource(from: mesh)
var material = PhysicallyBasedMaterial()
material.baseColor.tint = .init(red: 0.0625, green: 0.125, blue: 1.0, alpha: 1.0)
material.faceCulling = .none
material.metallic = 0.0
material.roughness = 0.0
let modelComponent = ModelComponent(mesh: resource, materials: [material])
let entity = Entity()
entity.components.set(modelComponent)
return entity
}
func createPlaneMesh() throws -> LowLevelMesh {
let vertexCount = resolution * resolution
let indexCount = (resolution - 1) * (resolution - 1) * 6
var desc = MyVertexWithNormal.descriptor
desc.vertexCapacity = vertexCount
desc.indexCapacity = indexCount
let mesh = try LowLevelMesh(descriptor: desc)
self.mesh = mesh
return mesh
}
func updateMesh(amplitude: Float = 0.1, frequency: Float = 60, phase: Float = 0.0) {
guard let mesh = mesh,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
let size: Float = 0.4
let indexCount = (resolution - 1) * (resolution - 1) * 6
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
let indexBuffer = mesh.replaceIndices(using: commandBuffer)
computeEncoder.setComputePipelineState(computePipeline)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1)
var params = WavyPlaneParams(resolution: Int32(resolution), size: size, amplitude: amplitude, frequency: frequency, phase: phase)
computeEncoder.setBytes(&params, length: MemoryLayout<WavyPlaneParams>.size, index: 2)
let threadsPerGrid = MTLSize(width: resolution * resolution, height: 1, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 64, height: 1, depth: 1)
computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
computeEncoder.endEncoding()
commandBuffer.commit()
let maxZ = amplitude
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: indexCount,
topology: .triangle,
bounds: BoundingBox(min: [-size/2, -size/2, -maxZ], max: [size/2, size/2, maxZ])
)
])
}
}
struct WavyPlaneParams {
var resolution: Int32
var size: Float
var amplitude: Float
var frequency: Float
var phase: Float
}
#Preview {
MetalWavyPlaneView()
}
struct MyVertexWithNormal {
var position: SIMD3<Float> = .zero
var normal: SIMD3<Float> = .zero
static var vertexAttributes: [LowLevelMesh.Attribute] = [
.init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
.init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!),
]
static var vertexLayouts: [LowLevelMesh.Layout] = [
.init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
]
static var descriptor: LowLevelMesh.Descriptor {
var desc = LowLevelMesh.Descriptor()
desc.vertexAttributes = MyVertexWithNormal.vertexAttributes
desc.vertexLayouts = MyVertexWithNormal.vertexLayouts
desc.indexType = .uint32
return desc
}
}
import SwiftUI
import RealityKit
struct PureSwiftWavyPlaneView: View {
@State private var phase: Float = 0.0
@State private var mesh: LowLevelMesh?
@State private var timer: Timer?
let resolution = 250
var body: some View {
RealityView { content in
let planeEntity = try! getPlaneEntity()
let lightEntity = try! getLightEntity()
planeEntity.addChild(lightEntity)
content.add(planeEntity)
}
.onAppear { startTimer() }
.onDisappear { stopTimer() }
}
private func startTimer() {
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
phase += 0.1
updateMesh(phase: phase)
}
}
private func stopTimer() {
timer?.invalidate()
timer = nil
}
func getLightEntity() throws -> Entity {
let entity = Entity()
let pointLightComponent = PointLightComponent( cgColor: .init(red: 1, green: 1, blue: 1, alpha: 1), intensity: 10000, attenuationRadius: 0.25 )
entity.components.set(pointLightComponent)
entity.position = .init(x: 0, y: 0, z: 0.125)
return entity
}
func getPlaneEntity() throws -> Entity {
let mesh = try createPlaneMesh()
let resource = try! MeshResource(from: mesh)
var material = PhysicallyBasedMaterial()
material.baseColor.tint = .init(red: 0.0625, green: 0.125, blue: 1.0, alpha: 1.0)
material.faceCulling = .none
material.metallic = 0.0
material.roughness = 0.0
let modelComponent = ModelComponent(mesh: resource, materials: [material])
let entity = Entity()
entity.components.set(modelComponent)
return entity
}
func createPlaneMesh() throws -> LowLevelMesh {
let vertexCount = resolution * resolution
let indexCount = (resolution - 1) * (resolution - 1) * 6
var desc = MyVertexWithNormal.descriptor
desc.vertexCapacity = vertexCount
desc.indexCapacity = indexCount
let mesh = try LowLevelMesh(descriptor: desc)
self.mesh = mesh
return mesh
}
func updateMesh(amplitude: Float = 0.1, frequency: Float = 60.0, phase: Float = 0.0) {
guard let mesh = mesh else { return }
let size: Float = 0.4
let indexCount = (resolution - 1) * (resolution - 1) * 6
mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
let vertices = rawBytes.bindMemory(to: MyVertexWithNormal.self)
for y in 0..<resolution {
for x in 0..<resolution {
let index = y * resolution + x
let xPos = Float(x) / Float(resolution - 1) * size - size / 2
let yPos = Float(y) / Float(resolution - 1) * size - size / 2
// Calculate distance from center
let distanceFromCenter = sqrt(xPos * xPos + yPos * yPos)
// Calculate z using a sine wave based on distance from center
let z = amplitude * sin(frequency * distanceFromCenter - phase)
let position = SIMD3<Float>(xPos, yPos, z)
// Calculate normal for the wavy surface
let dz_dr = amplitude * frequency * cos(frequency * distanceFromCenter - phase)
let nx = -dz_dr * xPos / distanceFromCenter
let ny = -dz_dr * yPos / distanceFromCenter
let normal = simd_normalize(SIMD3<Float>(nx, ny, -1))
vertices[index] = MyVertexWithNormal(position: position, normal: normal)
}
}
}
mesh.withUnsafeMutableIndices { rawIndices in
let indices = rawIndices.bindMemory(to: UInt32.self)
var index = 0
for y in 0..<(resolution - 1) {
for x in 0..<(resolution - 1) {
let topLeft = UInt32(y * resolution + x)
let topRight = topLeft + 1
let bottomLeft = UInt32((y + 1) * resolution + x)
let bottomRight = bottomLeft + 1
indices[index] = topLeft
indices[index + 1] = bottomLeft
indices[index + 2] = topRight
indices[index + 3] = topRight
indices[index + 4] = bottomLeft
indices[index + 5] = bottomRight
index += 6
}
}
}
let maxZ = amplitude
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: indexCount,
topology: .triangle,
bounds: BoundingBox(min: [-size/2, -size/2, -maxZ], max: [size/2, size/2, maxZ])
)
])
}
}
#Preview {
PureSwiftWavyPlaneView()
}
#include <metal_stdlib>
using namespace metal;
struct MyVertexWithNormal {
float3 position;
float3 normal;
};
struct WavyPlaneParams {
int32_t resolution;
float size;
float amplitude;
float frequency;
float phase;
};
kernel void updateWavyPlaneFromCenter(device MyVertexWithNormal* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant WavyPlaneParams& params [[buffer(2)]],
uint2 id [[thread_position_in_grid]])
{
int x = id.x % params.resolution;
int y = id.x / params.resolution;
if (x >= params.resolution || y >= params.resolution) return;
int index = y * params.resolution + x;
float xPos = float(x) / float(params.resolution - 1) * params.size - params.size / 2;
float yPos = float(y) / float(params.resolution - 1) * params.size - params.size / 2;
// Calculate distance from center
float distanceFromCenter = sqrt(xPos * xPos + yPos * yPos);
// Calculate z using a sine wave based on distance from center
float z = params.amplitude * sin(params.frequency * distanceFromCenter - params.phase);
float3 position = float3(xPos, yPos, z);
// Calculate normal for the wavy surface
float dz_dr = params.amplitude * params.frequency * cos(params.frequency * distanceFromCenter - params.phase);
float nx = -dz_dr * xPos / distanceFromCenter;
float ny = -dz_dr * yPos / distanceFromCenter;
float3 normal = normalize(float3(nx, ny, -1));
vertices[index].position = position;
vertices[index].normal = normal;
// Update indices
if (x < params.resolution - 1 && y < params.resolution - 1) {
int indexBase = (y * (params.resolution - 1) + x) * 6;
uint32_t topLeft = uint32_t(y * params.resolution + x);
uint32_t topRight = topLeft + 1;
uint32_t bottomLeft = uint32_t((y + 1) * params.resolution + x);
uint32_t bottomRight = bottomLeft + 1;
indices[indexBase] = topLeft;
indices[indexBase + 1] = bottomLeft;
indices[indexBase + 2] = topRight;
indices[indexBase + 3] = topRight;
indices[indexBase + 4] = bottomLeft;
indices[indexBase + 5] = bottomRight;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment