Skip to content

Instantly share code, notes, and snippets.

@Hais
Forked from avinashselvam/filter.swift
Created December 13, 2023 19:18
Show Gist options
  • Save Hais/cefa58546aec3af0875816733691a3bd to your computer and use it in GitHub Desktop.
Save Hais/cefa58546aec3af0875816733691a3bd to your computer and use it in GitHub Desktop.
A swift class that implements GPU based image processing
//
// filter.swift
// fltr
//
// Created by Avinash on 18/06/19.
// Copyright © 2019 eightyfive. All rights reserved.
//
import Metal
import MetalKit
// UIImage -> CGImage -> MTLTexture -> COMPUTE HAPPENS |
// UIImage <- CGImage <- MTLTexture <--
class Filter {
var device: MTLDevice
var defaultLib: MTLLibrary?
var grayscaleShader: MTLFunction?
var commandQueue: MTLCommandQueue?
var commandBuffer: MTLCommandBuffer?
var commandEncoder: MTLComputeCommandEncoder?
var pipelineState: MTLComputePipelineState?
var inputImage: UIImage
var height, width: Int
// most devices have a limit of 512 threads per group
let threadsPerBlock = MTLSize(width: 16, height: 16, depth: 1)
init(){
self.device = MTLCreateSystemDefaultDevice()!
self.defaultLib = self.device.makeDefaultLibrary()
self.grayscaleShader = self.defaultLib?.makeFunction(name: "black")
self.commandQueue = self.device.makeCommandQueue()
self.commandBuffer = self.commandQueue?.makeCommandBuffer()
self.commandEncoder = self.commandBuffer?.makeComputeCommandEncoder()
if let shader = grayscaleShader {
self.pipelineState = try? self.device.makeComputePipelineState(function: shader)
} else { fatalError("unable to make compute pipeline") }
self.inputImage = UIImage(named: "spidey.jpg")!
self.height = Int(self.inputImage.size.height)
self.width = Int(self.inputImage.size.width)
}
func getCGImage(from uiimg: UIImage) -> CGImage? {
UIGraphicsBeginImageContext(uiimg.size)
uiimg.draw(in: CGRect(origin: .zero, size: uiimg.size))
let contextImage = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return contextImage?.cgImage
}
func getMTLTexture(from cgimg: CGImage) -> MTLTexture {
let textureLoader = MTKTextureLoader(device: self.device)
do{
let texture = try textureLoader.newTexture(cgImage: cgimg, options: nil)
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: texture.pixelFormat, width: width, height: height, mipmapped: false)
textureDescriptor.usage = [.shaderRead, .shaderWrite]
return texture
} catch {
fatalError("Couldn't convert CGImage to MTLtexture")
}
}
func getCGImage(from mtlTexture: MTLTexture) -> CGImage? {
var data = Array<UInt8>(repeatElement(0, count: 4*width*height))
mtlTexture.getBytes(&data,
bytesPerRow: 4*width,
from: MTLRegionMake2D(0, 0, width, height),
mipmapLevel: 0)
let bitmapInfo = CGBitmapInfo(rawValue: (CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.premultipliedLast.rawValue))
let colorSpace = CGColorSpaceCreateDeviceRGB()
let context = CGContext(data: &data,
width: width,
height: height,
bitsPerComponent: 8,
bytesPerRow: 4*width,
space: colorSpace,
bitmapInfo: bitmapInfo.rawValue)
return context?.makeImage()
}
func getUIImage(from cgimg: CGImage) -> UIImage? {
return UIImage(cgImage: cgimg)
}
func getEmptyMTLTexture() -> MTLTexture? {
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(
pixelFormat: MTLPixelFormat.rgba8Unorm,
width: width,
height: height,
mipmapped: false)
textureDescriptor.usage = [.shaderRead, .shaderWrite]
return self.device.makeTexture(descriptor: textureDescriptor)
}
func getInputMTLTexture() -> MTLTexture? {
if let inputImage = getCGImage(from: self.inputImage) {
return getMTLTexture(from: inputImage)
}
else { fatalError("Unable to convert Input image to MTLTexture") }
}
func getBlockDimensions() -> MTLSize {
let blockWidth = width / self.threadsPerBlock.width
let blockHeight = height / self.threadsPerBlock.height
return MTLSizeMake(blockWidth, blockHeight, 1)
}
func applyFilter() -> UIImage? {
if let encoder = self.commandEncoder, let buffer = self.commandBuffer,
let outputTexture = getEmptyMTLTexture(), let inputTexture = getInputMTLTexture() {
encoder.setTextures([outputTexture, inputTexture], range: 0..<2)
encoder.dispatchThreadgroups(self.getBlockDimensions(), threadsPerThreadgroup: threadsPerBlock)
encoder.endEncoding()
buffer.commit()
buffer.waitUntilCompleted()
guard let outputImage = getCGImage(from: outputTexture) else { fatalError("Couldn't obtain CGImage from MTLTexture") }
return getUIImage(from: outputImage)
} else { fatalError("optional unwrapping failed") }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment