-
-
Save Hais/cefa58546aec3af0875816733691a3bd to your computer and use it in GitHub Desktop.
A swift class that implements GPU based image processing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// filter.swift | |
// fltr | |
// | |
// Created by Avinash on 18/06/19. | |
// Copyright © 2019 eightyfive. All rights reserved. | |
// | |
import Metal | |
import MetalKit | |
// UIImage -> CGImage -> MTLTexture -> COMPUTE HAPPENS | | |
// UIImage <- CGImage <- MTLTexture <-- | |
class Filter { | |
var device: MTLDevice | |
var defaultLib: MTLLibrary? | |
var grayscaleShader: MTLFunction? | |
var commandQueue: MTLCommandQueue? | |
var commandBuffer: MTLCommandBuffer? | |
var commandEncoder: MTLComputeCommandEncoder? | |
var pipelineState: MTLComputePipelineState? | |
var inputImage: UIImage | |
var height, width: Int | |
// most devices have a limit of 512 threads per group | |
let threadsPerBlock = MTLSize(width: 16, height: 16, depth: 1) | |
init(){ | |
self.device = MTLCreateSystemDefaultDevice()! | |
self.defaultLib = self.device.makeDefaultLibrary() | |
self.grayscaleShader = self.defaultLib?.makeFunction(name: "black") | |
self.commandQueue = self.device.makeCommandQueue() | |
self.commandBuffer = self.commandQueue?.makeCommandBuffer() | |
self.commandEncoder = self.commandBuffer?.makeComputeCommandEncoder() | |
if let shader = grayscaleShader { | |
self.pipelineState = try? self.device.makeComputePipelineState(function: shader) | |
} else { fatalError("unable to make compute pipeline") } | |
self.inputImage = UIImage(named: "spidey.jpg")! | |
self.height = Int(self.inputImage.size.height) | |
self.width = Int(self.inputImage.size.width) | |
} | |
func getCGImage(from uiimg: UIImage) -> CGImage? { | |
UIGraphicsBeginImageContext(uiimg.size) | |
uiimg.draw(in: CGRect(origin: .zero, size: uiimg.size)) | |
let contextImage = UIGraphicsGetImageFromCurrentImageContext() | |
UIGraphicsEndImageContext() | |
return contextImage?.cgImage | |
} | |
func getMTLTexture(from cgimg: CGImage) -> MTLTexture { | |
let textureLoader = MTKTextureLoader(device: self.device) | |
do{ | |
let texture = try textureLoader.newTexture(cgImage: cgimg, options: nil) | |
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor(pixelFormat: texture.pixelFormat, width: width, height: height, mipmapped: false) | |
textureDescriptor.usage = [.shaderRead, .shaderWrite] | |
return texture | |
} catch { | |
fatalError("Couldn't convert CGImage to MTLtexture") | |
} | |
} | |
func getCGImage(from mtlTexture: MTLTexture) -> CGImage? { | |
var data = Array<UInt8>(repeatElement(0, count: 4*width*height)) | |
mtlTexture.getBytes(&data, | |
bytesPerRow: 4*width, | |
from: MTLRegionMake2D(0, 0, width, height), | |
mipmapLevel: 0) | |
let bitmapInfo = CGBitmapInfo(rawValue: (CGBitmapInfo.byteOrder32Big.rawValue | CGImageAlphaInfo.premultipliedLast.rawValue)) | |
let colorSpace = CGColorSpaceCreateDeviceRGB() | |
let context = CGContext(data: &data, | |
width: width, | |
height: height, | |
bitsPerComponent: 8, | |
bytesPerRow: 4*width, | |
space: colorSpace, | |
bitmapInfo: bitmapInfo.rawValue) | |
return context?.makeImage() | |
} | |
func getUIImage(from cgimg: CGImage) -> UIImage? { | |
return UIImage(cgImage: cgimg) | |
} | |
func getEmptyMTLTexture() -> MTLTexture? { | |
let textureDescriptor = MTLTextureDescriptor.texture2DDescriptor( | |
pixelFormat: MTLPixelFormat.rgba8Unorm, | |
width: width, | |
height: height, | |
mipmapped: false) | |
textureDescriptor.usage = [.shaderRead, .shaderWrite] | |
return self.device.makeTexture(descriptor: textureDescriptor) | |
} | |
func getInputMTLTexture() -> MTLTexture? { | |
if let inputImage = getCGImage(from: self.inputImage) { | |
return getMTLTexture(from: inputImage) | |
} | |
else { fatalError("Unable to convert Input image to MTLTexture") } | |
} | |
func getBlockDimensions() -> MTLSize { | |
let blockWidth = width / self.threadsPerBlock.width | |
let blockHeight = height / self.threadsPerBlock.height | |
return MTLSizeMake(blockWidth, blockHeight, 1) | |
} | |
func applyFilter() -> UIImage? { | |
if let encoder = self.commandEncoder, let buffer = self.commandBuffer, | |
let outputTexture = getEmptyMTLTexture(), let inputTexture = getInputMTLTexture() { | |
encoder.setTextures([outputTexture, inputTexture], range: 0..<2) | |
encoder.dispatchThreadgroups(self.getBlockDimensions(), threadsPerThreadgroup: threadsPerBlock) | |
encoder.endEncoding() | |
buffer.commit() | |
buffer.waitUntilCompleted() | |
guard let outputImage = getCGImage(from: outputTexture) else { fatalError("Couldn't obtain CGImage from MTLTexture") } | |
return getUIImage(from: outputImage) | |
} else { fatalError("optional unwrapping failed") } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment