Created
July 31, 2024 04:23
-
-
Save Matt54/d999a987e04cc85f0a2ca4f2f9156c08 to your computer and use it in GitHub Desktop.
RealityView comparing a wave plane animation of a LowLevelMesh between a pure Swift and a Metal shader implementation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import SwiftUI | |
import RealityKit | |
import Metal | |
struct MetalWavyPlaneView: View { | |
@State private var phase: Float = 0.0 | |
@State private var mesh: LowLevelMesh? | |
@State private var timer: Timer? | |
let resolution = 250 | |
let device: MTLDevice | |
let commandQueue: MTLCommandQueue | |
let computePipeline: MTLComputePipelineState | |
init() { | |
self.device = MTLCreateSystemDefaultDevice()! | |
self.commandQueue = device.makeCommandQueue()! | |
let library = device.makeDefaultLibrary()! | |
let updateFunction = library.makeFunction(name: "updateWavyPlaneFromCenter")! | |
self.computePipeline = try! device.makeComputePipelineState(function: updateFunction) | |
} | |
var body: some View { | |
RealityView { content in | |
let planeEntity = try! getPlaneEntity() | |
let lightEntity = try! getLightEntity() | |
planeEntity.addChild(lightEntity) | |
content.add(planeEntity) | |
} | |
.onAppear { startTimer() } | |
.onDisappear { stopTimer() } | |
} | |
private func startTimer() { | |
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in | |
phase += 0.1 | |
updateMesh(phase: phase) | |
} | |
} | |
private func stopTimer() { | |
timer?.invalidate() | |
timer = nil | |
} | |
func getLightEntity() throws -> Entity { | |
let entity = Entity() | |
let pointLightComponent = PointLightComponent( cgColor: .init(red: 1, green: 1, blue: 1, alpha: 1), intensity: 10000, attenuationRadius: 0.5 ) | |
entity.components.set(pointLightComponent) | |
entity.position = .init(x: 0, y: 0, z: 0.125) | |
return entity | |
} | |
func getPlaneEntity() throws -> Entity { | |
let mesh = try createPlaneMesh() | |
let resource = try! MeshResource(from: mesh) | |
var material = PhysicallyBasedMaterial() | |
material.baseColor.tint = .init(red: 0.0625, green: 0.125, blue: 1.0, alpha: 1.0) | |
material.faceCulling = .none | |
material.metallic = 0.0 | |
material.roughness = 0.0 | |
let modelComponent = ModelComponent(mesh: resource, materials: [material]) | |
let entity = Entity() | |
entity.components.set(modelComponent) | |
return entity | |
} | |
func createPlaneMesh() throws -> LowLevelMesh { | |
let vertexCount = resolution * resolution | |
let indexCount = (resolution - 1) * (resolution - 1) * 6 | |
var desc = MyVertexWithNormal.descriptor | |
desc.vertexCapacity = vertexCount | |
desc.indexCapacity = indexCount | |
let mesh = try LowLevelMesh(descriptor: desc) | |
self.mesh = mesh | |
return mesh | |
} | |
func updateMesh(amplitude: Float = 0.1, frequency: Float = 60, phase: Float = 0.0) { | |
guard let mesh = mesh, | |
let commandBuffer = commandQueue.makeCommandBuffer(), | |
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return } | |
let size: Float = 0.4 | |
let indexCount = (resolution - 1) * (resolution - 1) * 6 | |
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer) | |
let indexBuffer = mesh.replaceIndices(using: commandBuffer) | |
computeEncoder.setComputePipelineState(computePipeline) | |
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0) | |
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1) | |
var params = WavyPlaneParams(resolution: Int32(resolution), size: size, amplitude: amplitude, frequency: frequency, phase: phase) | |
computeEncoder.setBytes(¶ms, length: MemoryLayout<WavyPlaneParams>.size, index: 2) | |
let threadsPerGrid = MTLSize(width: resolution * resolution, height: 1, depth: 1) | |
let threadsPerThreadgroup = MTLSize(width: 64, height: 1, depth: 1) | |
computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup) | |
computeEncoder.endEncoding() | |
commandBuffer.commit() | |
let maxZ = amplitude | |
mesh.parts.replaceAll([ | |
LowLevelMesh.Part( | |
indexCount: indexCount, | |
topology: .triangle, | |
bounds: BoundingBox(min: [-size/2, -size/2, -maxZ], max: [size/2, size/2, maxZ]) | |
) | |
]) | |
} | |
} | |
struct WavyPlaneParams { | |
var resolution: Int32 | |
var size: Float | |
var amplitude: Float | |
var frequency: Float | |
var phase: Float | |
} | |
#Preview { | |
MetalWavyPlaneView() | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
struct MyVertexWithNormal { | |
var position: SIMD3<Float> = .zero | |
var normal: SIMD3<Float> = .zero | |
static var vertexAttributes: [LowLevelMesh.Attribute] = [ | |
.init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!), | |
.init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!), | |
] | |
static var vertexLayouts: [LowLevelMesh.Layout] = [ | |
.init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride) | |
] | |
static var descriptor: LowLevelMesh.Descriptor { | |
var desc = LowLevelMesh.Descriptor() | |
desc.vertexAttributes = MyVertexWithNormal.vertexAttributes | |
desc.vertexLayouts = MyVertexWithNormal.vertexLayouts | |
desc.indexType = .uint32 | |
return desc | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import SwiftUI | |
import RealityKit | |
struct PureSwiftWavyPlaneView: View { | |
@State private var phase: Float = 0.0 | |
@State private var mesh: LowLevelMesh? | |
@State private var timer: Timer? | |
let resolution = 250 | |
var body: some View { | |
RealityView { content in | |
let planeEntity = try! getPlaneEntity() | |
let lightEntity = try! getLightEntity() | |
planeEntity.addChild(lightEntity) | |
content.add(planeEntity) | |
} | |
.onAppear { startTimer() } | |
.onDisappear { stopTimer() } | |
} | |
private func startTimer() { | |
timer = Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in | |
phase += 0.1 | |
updateMesh(phase: phase) | |
} | |
} | |
private func stopTimer() { | |
timer?.invalidate() | |
timer = nil | |
} | |
func getLightEntity() throws -> Entity { | |
let entity = Entity() | |
let pointLightComponent = PointLightComponent( cgColor: .init(red: 1, green: 1, blue: 1, alpha: 1), intensity: 10000, attenuationRadius: 0.25 ) | |
entity.components.set(pointLightComponent) | |
entity.position = .init(x: 0, y: 0, z: 0.125) | |
return entity | |
} | |
func getPlaneEntity() throws -> Entity { | |
let mesh = try createPlaneMesh() | |
let resource = try! MeshResource(from: mesh) | |
var material = PhysicallyBasedMaterial() | |
material.baseColor.tint = .init(red: 0.0625, green: 0.125, blue: 1.0, alpha: 1.0) | |
material.faceCulling = .none | |
material.metallic = 0.0 | |
material.roughness = 0.0 | |
let modelComponent = ModelComponent(mesh: resource, materials: [material]) | |
let entity = Entity() | |
entity.components.set(modelComponent) | |
return entity | |
} | |
func createPlaneMesh() throws -> LowLevelMesh { | |
let vertexCount = resolution * resolution | |
let indexCount = (resolution - 1) * (resolution - 1) * 6 | |
var desc = MyVertexWithNormal.descriptor | |
desc.vertexCapacity = vertexCount | |
desc.indexCapacity = indexCount | |
let mesh = try LowLevelMesh(descriptor: desc) | |
self.mesh = mesh | |
return mesh | |
} | |
func updateMesh(amplitude: Float = 0.1, frequency: Float = 60.0, phase: Float = 0.0) { | |
guard let mesh = mesh else { return } | |
let size: Float = 0.4 | |
let indexCount = (resolution - 1) * (resolution - 1) * 6 | |
mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in | |
let vertices = rawBytes.bindMemory(to: MyVertexWithNormal.self) | |
for y in 0..<resolution { | |
for x in 0..<resolution { | |
let index = y * resolution + x | |
let xPos = Float(x) / Float(resolution - 1) * size - size / 2 | |
let yPos = Float(y) / Float(resolution - 1) * size - size / 2 | |
// Calculate distance from center | |
let distanceFromCenter = sqrt(xPos * xPos + yPos * yPos) | |
// Calculate z using a sine wave based on distance from center | |
let z = amplitude * sin(frequency * distanceFromCenter - phase) | |
let position = SIMD3<Float>(xPos, yPos, z) | |
// Calculate normal for the wavy surface | |
let dz_dr = amplitude * frequency * cos(frequency * distanceFromCenter - phase) | |
let nx = -dz_dr * xPos / distanceFromCenter | |
let ny = -dz_dr * yPos / distanceFromCenter | |
let normal = simd_normalize(SIMD3<Float>(nx, ny, -1)) | |
vertices[index] = MyVertexWithNormal(position: position, normal: normal) | |
} | |
} | |
} | |
mesh.withUnsafeMutableIndices { rawIndices in | |
let indices = rawIndices.bindMemory(to: UInt32.self) | |
var index = 0 | |
for y in 0..<(resolution - 1) { | |
for x in 0..<(resolution - 1) { | |
let topLeft = UInt32(y * resolution + x) | |
let topRight = topLeft + 1 | |
let bottomLeft = UInt32((y + 1) * resolution + x) | |
let bottomRight = bottomLeft + 1 | |
indices[index] = topLeft | |
indices[index + 1] = bottomLeft | |
indices[index + 2] = topRight | |
indices[index + 3] = topRight | |
indices[index + 4] = bottomLeft | |
indices[index + 5] = bottomRight | |
index += 6 | |
} | |
} | |
} | |
let maxZ = amplitude | |
mesh.parts.replaceAll([ | |
LowLevelMesh.Part( | |
indexCount: indexCount, | |
topology: .triangle, | |
bounds: BoundingBox(min: [-size/2, -size/2, -maxZ], max: [size/2, size/2, maxZ]) | |
) | |
]) | |
} | |
} | |
#Preview { | |
PureSwiftWavyPlaneView() | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment