Skip to content

Instantly share code, notes, and snippets.

@Matt54
Created August 10, 2024 23:00
Show Gist options
  • Save Matt54/8a483b8e59835d6b89575637cafb65b1 to your computer and use it in GitHub Desktop.
Save Matt54/8a483b8e59835d6b89575637cafb65b1 to your computer and use it in GitHub Desktop.
RealityView comparing built-in generateSphere with LowLevelMesh implementation - both having and excluding uv data in vertex attributes
import SwiftUI
import RealityKit
struct SphereMeshTextureComparisonView: View {
let latitudeBands = 128
let longitudeBands = 80
let radius: Float = 0.1
// Metal-related properties
let device: MTLDevice
let commandQueue: MTLCommandQueue
let computePipeline: MTLComputePipelineState
let computePipelineExcludingUV: MTLComputePipelineState
let textureComputePipeline: MTLComputePipelineState
@State private var texture: LowLevelTexture?
init() {
self.device = MTLCreateSystemDefaultDevice()!
self.commandQueue = device.makeCommandQueue()!
let library = device.makeDefaultLibrary()!
let updateVerticesFunction = library.makeFunction(name: "sphereVertices")!
let updateVerticesExcludingUVFunction = library.makeFunction(name: "sphereVerticesExcludingUV")!
let updateTextureFunction = library.makeFunction(name: "sphereTexture")!
self.computePipeline = try! device.makeComputePipelineState(function: updateVerticesFunction)
self.computePipelineExcludingUV = try! device.makeComputePipelineState(function: updateVerticesExcludingUVFunction)
self.textureComputePipeline = try! device.makeComputePipelineState(function: updateTextureFunction)
}
var body: some View {
RealityView { content in
// left (no uv)
let meshNoUV = try! getMesh(includeUV: false)
let meshResourceLowLevelNoUV = try! MeshResource(from: meshNoUV)
// middle
let mesh = try! getMesh()
let meshResourceLowLevel = try! MeshResource(from: mesh)
// right - standard sphere
let meshResourceGenerateSphere = MeshResource.generateSphere(radius: radius)
let texture = try! LowLevelTexture(descriptor: textureDescriptor)
let resource = try! TextureResource(from: texture)
var material = UnlitMaterial()
material.color.texture = .init(resource)
material.blending = .transparent(opacity: 1.0)
var redMaterial = SimpleMaterial(color: .red, isMetallic: true)
var materials: [any RealityKit.Material] = [material]
// uncomment to switch to simple reflective red
// materials = [redMaterial]
let entityLowLevelMeshNoUV = ModelEntity(mesh: meshResourceLowLevelNoUV, materials: materials)
entityLowLevelMeshNoUV.transform.translation.x = -0.25
content.add(entityLowLevelMeshNoUV)
let entityLowLevelMesh = ModelEntity(mesh: meshResourceLowLevel, materials: materials)
content.add(entityLowLevelMesh)
let entityGenerateSphere = ModelEntity(mesh: meshResourceGenerateSphere, materials: materials)
entityGenerateSphere.transform.translation.x = 0.25
content.add(entityGenerateSphere)
self.texture = texture
updateTexture()
updateMesh(mesh)
updateMesh(meshNoUV, includeUV: false)
}
}
struct VertexData {
var position: SIMD3<Float> = .zero
var normal: SIMD3<Float> = .zero
var uv: SIMD2<Float> = .zero
static var vertexAttributes: [LowLevelMesh.Attribute] = [
.init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
.init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!),
.init(semantic: .uv0, format: .float2, offset: MemoryLayout<Self>.offset(of: \.uv)!)
]
static var vertexLayouts: [LowLevelMesh.Layout] = [
.init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
]
static var descriptor: LowLevelMesh.Descriptor {
var desc = LowLevelMesh.Descriptor()
desc.vertexAttributes = VertexData.vertexAttributes
desc.vertexLayouts = VertexData.vertexLayouts
desc.indexType = .uint32
return desc
}
}
struct VertexDataExcludingUV {
var position: SIMD3<Float> = .zero
var normal: SIMD3<Float> = .zero
static var vertexAttributes: [LowLevelMesh.Attribute] = [
.init(semantic: .position, format: .float3, offset: MemoryLayout<Self>.offset(of: \.position)!),
.init(semantic: .normal, format: .float3, offset: MemoryLayout<Self>.offset(of: \.normal)!),
]
static var vertexLayouts: [LowLevelMesh.Layout] = [
.init(bufferIndex: 0, bufferStride: MemoryLayout<Self>.stride)
]
static var descriptor: LowLevelMesh.Descriptor {
var desc = LowLevelMesh.Descriptor()
desc.vertexAttributes = VertexDataExcludingUV.vertexAttributes
desc.vertexLayouts = VertexDataExcludingUV.vertexLayouts
desc.indexType = .uint32
return desc
}
}
var textureDescriptor: LowLevelTexture.Descriptor {
var desc = LowLevelTexture.Descriptor()
desc.textureType = .type2D
desc.arrayLength = 1
desc.width = 2048
desc.height = 2048
desc.depth = 1
desc.mipmapLevelCount = 1
desc.pixelFormat = .bgra8Unorm
desc.textureUsage = [.shaderRead, .shaderWrite]
desc.swizzle = .init(red: .red, green: .green, blue: .blue, alpha: .alpha)
return desc
}
func getMesh(includeUV: Bool = true) throws -> LowLevelMesh {
let vertexCount = (latitudeBands + 1) * (longitudeBands + 1)
let indexCount = latitudeBands * longitudeBands * 6
var desc = includeUV ? VertexData.descriptor : VertexDataExcludingUV.descriptor
desc.vertexCapacity = vertexCount
desc.indexCapacity = indexCount
return try LowLevelMesh(descriptor: desc)
}
func updateMesh(_ mesh: LowLevelMesh, includeUV: Bool = true) {
guard let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
let vertexBuffer = mesh.replace(bufferIndex: 0, using: commandBuffer)
let indexBuffer = mesh.replaceIndices(using: commandBuffer)
computeEncoder.setComputePipelineState(includeUV ? computePipeline : computePipelineExcludingUV)
computeEncoder.setBuffer(vertexBuffer, offset: 0, index: 0)
computeEncoder.setBuffer(indexBuffer, offset: 0, index: 1)
var params = SphereParams(
latitudeBands: Int32(latitudeBands),
longitudeBands: Int32(longitudeBands),
radius: radius
)
computeEncoder.setBytes(&params, length: MemoryLayout<SphereParams>.size, index: 2)
let threadsPerGrid = MTLSize(width: (latitudeBands + 1) * (longitudeBands + 1), height: 1, depth: 1)
let threadsPerThreadgroup = MTLSize(width: 64, height: 1, depth: 1)
computeEncoder.dispatchThreads(threadsPerGrid, threadsPerThreadgroup: threadsPerThreadgroup)
computeEncoder.endEncoding()
commandBuffer.commit()
let meshBounds = BoundingBox(min: [-radius, -radius, -radius], max: [radius, radius, radius])
mesh.parts.replaceAll([
LowLevelMesh.Part(
indexCount: latitudeBands * longitudeBands * 6,
topology: .triangle,
bounds: meshBounds
)
])
}
func updateTexture() {
guard let texture = self.texture,
let commandBuffer = commandQueue.makeCommandBuffer(),
let computeEncoder = commandBuffer.makeComputeCommandEncoder() else { return }
computeEncoder.setComputePipelineState(textureComputePipeline)
let outTexture: MTLTexture = texture.replace(using: commandBuffer)
computeEncoder.setTexture(outTexture, index: 0)
let w = textureComputePipeline.threadExecutionWidth
let h = textureComputePipeline.maxTotalThreadsPerThreadgroup / w
let threadGroupSize = MTLSizeMake(w, h, 1)
let threadGroupCount = MTLSizeMake(
(textureDescriptor.width + threadGroupSize.width - 1) / threadGroupSize.width,
(textureDescriptor.height + threadGroupSize.height - 1) / threadGroupSize.height,
1)
computeEncoder.dispatchThreadgroups(threadGroupCount, threadsPerThreadgroup: threadGroupSize)
computeEncoder.endEncoding()
commandBuffer.commit()
}
struct SphereParams {
var latitudeBands: Int32
var longitudeBands: Int32
var radius: Float
}
}
#Preview {
SphereMeshTextureComparisonView()
}
#include <metal_stdlib>
using namespace metal;
struct VertexData {
float3 position;
float3 normal;
float2 uv;
};
struct SphereParams {
int32_t latitudeBands;
int32_t longitudeBands;
float radius;
};
kernel void sphereVertices(device VertexData* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant SphereParams& params [[buffer(2)]],
uint2 id [[thread_position_in_grid]])
{
int vertexIndex = id.x;
int latNumber = vertexIndex / (params.longitudeBands + 1);
int longNumber = vertexIndex % (params.longitudeBands + 1);
if (latNumber > params.latitudeBands || longNumber > params.longitudeBands) return;
float lat = float(latNumber) / float(params.latitudeBands);
float lon = float(longNumber) / float(params.longitudeBands);
float theta = (1.0 - lat) * M_PI_F;
float phi = lon * 2 * M_PI_F;
float sinTheta = sin(theta);
float cosTheta = cos(theta);
float sinPhi = sin(phi);
float cosPhi = cos(phi);
float x = sinTheta * cosPhi;
float y = cosTheta;
float z = sinTheta * sinPhi;
float3 position = float3(x, y, z) * params.radius;
float3 normal = normalize(float3(x, y, z));
float2 uv = float2(lon, lat);
vertices[vertexIndex].position = position;
vertices[vertexIndex].normal = normal;
vertices[vertexIndex].uv = uv;
// Update indices
if (latNumber < params.latitudeBands && longNumber < params.longitudeBands) {
int indexBase = (latNumber * params.longitudeBands + longNumber) * 6;
uint32_t current = latNumber * (params.longitudeBands + 1) + longNumber;
uint32_t next = current + 1;
uint32_t below = current + params.longitudeBands + 1;
uint32_t belowNext = below + 1;
indices[indexBase] = current;
indices[indexBase + 1] = below;
indices[indexBase + 2] = next;
indices[indexBase + 3] = next;
indices[indexBase + 4] = below;
indices[indexBase + 5] = belowNext;
}
}
kernel void sphereTexture(
texture2d<half, access::write> outTexture [[texture(0)]],
uint2 gid [[thread_position_in_grid]])
{
// Compute texture coordinate ranging from 0 to 1 along each axis.
half2 texCoord {
half(gid[0]) / outTexture.get_width(),
half(gid[1]) / outTexture.get_height()
};
// Compute the color as a linear gradient
half3 color = mix(
half3 { 0.8, 0.0, 0.8 },
half3 { 0.7, 0.7, 0.0 },
texCoord.y);
half interpolationValue = texCoord.y;
interpolationValue = pow(interpolationValue, half(4)) * 1.5;
half alpha = mix(half { 1.0 }, half { 0.0 }, interpolationValue);
outTexture.write(half4(color, alpha), gid);
}
struct VertexDataExcludingUV {
float3 position;
float3 normal;
};
kernel void sphereVerticesExcludingUV(device VertexDataExcludingUV* vertices [[buffer(0)]],
device uint* indices [[buffer(1)]],
constant SphereParams& params [[buffer(2)]],
uint2 id [[thread_position_in_grid]])
{
int vertexIndex = id.x;
int latNumber = vertexIndex / (params.longitudeBands + 1);
int longNumber = vertexIndex % (params.longitudeBands + 1);
if (latNumber > params.latitudeBands || longNumber > params.longitudeBands) return;
float lat = float(latNumber) / float(params.latitudeBands);
float lon = float(longNumber) / float(params.longitudeBands);
float theta = (1.0 - lat) * M_PI_F;
float phi = lon * 2 * M_PI_F;
float sinTheta = sin(theta);
float cosTheta = cos(theta);
float sinPhi = sin(phi);
float cosPhi = cos(phi);
float x = sinTheta * cosPhi;
float y = cosTheta;
float z = sinTheta * sinPhi;
float3 position = float3(x, y, z) * params.radius;
float3 normal = normalize(float3(x, y, z));
vertices[vertexIndex].position = position;
vertices[vertexIndex].normal = normal;
// Update indices
if (latNumber < params.latitudeBands && longNumber < params.longitudeBands) {
int indexBase = (latNumber * params.longitudeBands + longNumber) * 6;
uint32_t current = latNumber * (params.longitudeBands + 1) + longNumber;
uint32_t next = current + 1;
uint32_t below = current + params.longitudeBands + 1;
uint32_t belowNext = below + 1;
indices[indexBase] = current;
indices[indexBase + 1] = below;
indices[indexBase + 2] = next;
indices[indexBase + 3] = next;
indices[indexBase + 4] = below;
indices[indexBase + 5] = belowNext;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment