import Foundation
import RealityKit
import SwiftUI

struct CubeMeshExample: View {
    @State private var rotationAngles: SIMD3<Float> = [0, 0, 0]
    @State private var lastRotationUpdateTime = CACurrentMediaTime()
    @State private var time: Double = 0.0
    @State private var rootEntity: Entity?
    
    var body: some View {
        RealityView { content in
            if let cubeMesh = try? CubeMesh(size: [0.2, 0.2, 0.2], dimensions: [12, 12]),
               let mesh = try? MeshResource(from: cubeMesh.mesh) {
                let cubeEntity = Entity()
//                let material = SimpleMaterial(color: .white, isMetallic: true)
                let material = UnlitMaterial(color: .white)
                let cubeModel = ModelComponent(mesh: mesh, materials: [material])
                cubeEntity.components.set(cubeModel)
                content.add(cubeEntity)
                self.rootEntity = cubeEntity
            }
        }
        .onAppear {
            startTimer()
        }
    }
    
    private func startTimer() {
        Timer.scheduledTimer(withTimeInterval: 1/120.0, repeats: true) { _ in
            
            let currentTime = CACurrentMediaTime()
            let frameDuration = currentTime - lastRotationUpdateTime
            self.time += frameDuration
            
            // Rotate along all axis at different rates for a wonky roll effect
            rotationAngles.x += Float(frameDuration * 0.25)
            rotationAngles.y += Float(frameDuration * 0.125)
            rotationAngles.z += Float(frameDuration * 0.0675)
            
            let rotationX = simd_quatf(angle: rotationAngles.x, axis: [1, 0, 0])
            let rotationY = simd_quatf(angle: rotationAngles.y, axis: [0, 1, 0])
            let rotationZ = simd_quatf(angle: rotationAngles.z, axis: [0, 0, 1])
            rootEntity?.transform.rotation = rotationX * rotationY * rotationZ
            
            lastRotationUpdateTime = currentTime
        }
    }
    
    @MainActor struct CubeMesh {
        var mesh: LowLevelMesh!
        let size: SIMD3<Float>
        let dimensions: SIMD2<UInt32>
        let verticesPerPlane: UInt32

        init(size: SIMD3<Float>, dimensions: SIMD2<UInt32>) throws {
            self.size = size
            self.dimensions = dimensions
            self.verticesPerPlane = dimensions.x * dimensions.y
            
            self.mesh = try createMesh()
            initializeVertexData()
            initializeIndexData()
            initializeMeshParts()
        }
        
        private func createMesh() throws -> LowLevelMesh {
            let vertexAttributes = [
                LowLevelMesh.Attribute(semantic: .position, format: .float3, offset: 0),
                LowLevelMesh.Attribute(semantic: .normal, format: .float3, offset: 16)
            ]
            let vertexLayouts = [LowLevelMesh.Layout(bufferIndex: 0, bufferStride: MemoryLayout<PlaneVertex>.stride)]
            
            // Six planes worth of vertices and indices
            let vertexCount = Int(verticesPerPlane * 6)
            let indicesPerTriangle = 3
            let trianglesPerCell = 2
            let cellsPerPlane = Int((dimensions.x - 1) * (dimensions.y - 1))
            let indexCount = indicesPerTriangle * trianglesPerCell * cellsPerPlane * 6
            
            return try LowLevelMesh(descriptor: .init(
                vertexCapacity: vertexCount,
                vertexAttributes: vertexAttributes,
                vertexLayouts: vertexLayouts,
                indexCapacity: indexCount
            ))
        }
        
        private func initializeVertexData() {
            mesh.withUnsafeMutableBytes(bufferIndex: 0) { rawBytes in
                let vertices = rawBytes.bindMemory(to: PlaneVertex.self)
                
                // Back, Left, Bottom faces need reversed winding
                
                // Front face (Z+)
                createPlane(vertices: vertices, planeIndex: 0, normal: [0, 0, 1]) { x, y in
                    [
                        size.x * (x - 0.5),
                        size.y * (y - 0.5),
                        size.z * 0.5
                    ]
                }
                
                // Back face (Z-)
                createPlane(vertices: vertices, planeIndex: 1, normal: [0, 0, -1]) { x, y in
                    [
                        size.x * (x - 0.5),
                        size.y * (y - 0.5),
                        size.z * -0.5
                    ]
                }
                
                // Right face (X+)
                createPlane(vertices: vertices, planeIndex: 2, normal: [1, 0, 0]) { z, y in
                    [
                        size.x * 0.5,
                        size.y * (y - 0.5),
                        size.z * (0.5 - z)
                    ]
                }
                
                // Left face (X-)
                createPlane(vertices: vertices, planeIndex: 3, normal: [-1, 0, 0]) { z, y in
                    [
                        size.x * -0.5,
                        size.y * (y - 0.5),
                        size.z * (0.5 - z)
                    ]
                }
                
                // Top face (Y+)
                createPlane(vertices: vertices, planeIndex: 4, normal: [0, 1, 0]) { x, z in
                    [
                        size.x * (x - 0.5),
                        size.y * 0.5,
                        size.z * (0.5 - z)
                    ]
                }
                
                // Bottom face (Y-)
                createPlane(vertices: vertices, planeIndex: 5, normal: [0, -1, 0]) { x, z in
                    [
                        size.x * (x-0.5),
                        size.y * -0.5,
                        size.z * (0.5-z)
                    ]
                }
            }
        }
        
        private func createPlane(
            vertices: UnsafeMutableBufferPointer<PlaneVertex>,
            planeIndex: UInt32,
            normal: SIMD3<Float>,
            positionFor: (Float, Float) -> SIMD3<Float>
        ) {
            for u in 0..<dimensions.x {
                for v in 0..<dimensions.y {
                    let u01 = Float(u) / Float(dimensions.x - 1)
                    let v01 = Float(v) / Float(dimensions.y - 1)
                    
                    let vertexIndex = Int(vertexIndex(u, v, planeIndex: planeIndex))
                    vertices[vertexIndex].position = positionFor(u01, v01)
                    vertices[vertexIndex].normal = normal
                }
            }
        }
        
        private func initializeIndexData() {
            mesh.withUnsafeMutableIndices { rawIndices in
                guard let indices = rawIndices.baseAddress?.assumingMemoryBound(to: UInt32.self) else { return }
                var currentIndex = 0
                
                // Generate indices for all six faces
                // ordering for lineStrip to not cross center
                for planeIndex in [0,5,3,1,4,2] {
                    for y in 0..<(dimensions.y - 1) {
                    for x in 0..<(dimensions.x - 1) {
                        
                            let bottomLeft = vertexIndex(x, y, planeIndex: UInt32(planeIndex))
                            let bottomRight = vertexIndex(x + 1, y, planeIndex: UInt32(planeIndex))
                            let topLeft = vertexIndex(x, y + 1, planeIndex: UInt32(planeIndex))
                            let topRight = vertexIndex(x + 1, y + 1, planeIndex: UInt32(planeIndex))
                            
                            // 0 = front
                            // 1 = back
                            // 2 = right
                            // 3 = left
                            // 4 = top
                            // 5 = bottom

                            // Adjust winding order based on face orientation
                            if planeIndex == 1 || planeIndex == 3 || planeIndex == 5 {
                                // Back, Left, Bottom faces need reversed winding
                                indices[currentIndex] = bottomLeft
                                indices[currentIndex + 1] = topLeft
                                indices[currentIndex + 2] = bottomRight
                                indices[currentIndex + 3] = bottomRight
                                indices[currentIndex + 4] = topLeft
                                indices[currentIndex + 5] = topRight
                            } else {
                                // Front, Right, Top faces keep original winding
                                indices[currentIndex] = bottomLeft
                                indices[currentIndex + 1] = bottomRight
                                indices[currentIndex + 2] = topLeft
                                
                                indices[currentIndex + 3] = topLeft
                                indices[currentIndex + 4] = bottomRight
                                indices[currentIndex + 5] = topRight
                            }
                            currentIndex += 6
                        }
                    }
                }
            }
        }
        
        private func initializeMeshParts() {
            let halfSize = size * 0.5
            let bounds = BoundingBox(
                min: -halfSize,
                max: halfSize
            )
            
            mesh.parts.replaceAll([
                LowLevelMesh.Part(
                    indexCount: mesh.descriptor.indexCapacity,
                    topology: .line,
                    bounds: bounds
                )
            ])
        }
        
        private func vertexIndex(_ x: UInt32, _ y: UInt32, planeIndex: UInt32) -> UInt32 {
            (x + dimensions.x * y) + (verticesPerPlane * planeIndex)
        }
    }
}

#Preview {
    CubeMeshExample()
}