Skip to content

Instantly share code, notes, and snippets.

@vlondon
Created March 1, 2021 20:51
Show Gist options
  • Save vlondon/491c2e7829d60e835d53a1f6810a34ed to your computer and use it in GitHub Desktop.
Save vlondon/491c2e7829d60e835d53a1f6810a34ed to your computer and use it in GitHub Desktop.
// Don't forget to add to the project:
// 1. DeepLabV3 - https://developer.apple.com/machine-learning/models/
// 2. CoreMLHelpers - https://github.com/hollance/CoreMLHelpers
enum RemoveBackroundResult {
case background
case finalImage
}
extension UIImage {
func removeBackground(returnResult: RemoveBackroundResult) -> UIImage? {
guard let model = getDeepLabV3Model() else { return nil }
let width: CGFloat = 513
let height: CGFloat = 513
let resizedImage = resized(to: CGSize(width: height, height: height), scale: 1)
guard let pixelBuffer = resizedImage.pixelBuffer(width: Int(width), height: Int(height)),
let outputPredictionImage = try? model.prediction(image: pixelBuffer),
let outputImage = outputPredictionImage.semanticPredictions.image(min: 0, max: 1, axes: (0, 0, 1)),
let outputCIImage = CIImage(image: outputImage),
let maskImage = outputCIImage.removeWhitePixels(),
let maskBlurImage = maskImage.applyBlurEffect() else { return nil }
switch returnResult {
case .finalImage:
guard let resizedCIImage = CIImage(image: resizedImage),
let compositedImage = resizedCIImage.composite(with: maskBlurImage) else { return nil }
let finalImage = UIImage(ciImage: compositedImage)
.resized(to: CGSize(width: size.width, height: size.height))
return finalImage
case .background:
let finalImage = UIImage(
ciImage: maskBlurImage,
scale: scale,
orientation: self.imageOrientation
).resized(to: CGSize(width: size.width, height: size.height))
return finalImage
}
}
private func getDeepLabV3Model() -> DeepLabV3? {
do {
let config = MLModelConfiguration()
return try DeepLabV3(configuration: config)
} catch {
log("Error loading model: \(error)")
return nil
}
}
}
extension CIImage {
func removeWhitePixels() -> CIImage? {
let chromaCIFilter = chromaKeyFilter()
chromaCIFilter?.setValue(self, forKey: kCIInputImageKey)
return chromaCIFilter?.outputImage
}
func composite(with mask: CIImage) -> CIImage? {
return CIFilter(
name: "CISourceOutCompositing",
parameters: [
kCIInputImageKey: self,
kCIInputBackgroundImageKey: mask
]
)?.outputImage
}
func applyBlurEffect() -> CIImage? {
let context = CIContext(options: nil)
let clampFilter = CIFilter(name: "CIAffineClamp")!
clampFilter.setDefaults()
clampFilter.setValue(self, forKey: kCIInputImageKey)
guard let currentFilter = CIFilter(name: "CIGaussianBlur") else { return nil }
currentFilter.setValue(clampFilter.outputImage, forKey: kCIInputImageKey)
currentFilter.setValue(2, forKey: "inputRadius")
guard let output = currentFilter.outputImage,
let cgimg = context.createCGImage(output, from: extent) else { return nil }
return CIImage(cgImage: cgimg)
}
// modified from https://developer.apple.com/documentation/coreimage/applying_a_chroma_key_effect
private func chromaKeyFilter() -> CIFilter? {
let size = 64
var cubeRGB = [Float]()
for z in 0 ..< size {
let blue = CGFloat(z) / CGFloat(size - 1)
for y in 0 ..< size {
let green = CGFloat(y) / CGFloat(size - 1)
for x in 0 ..< size {
let red = CGFloat(x) / CGFloat(size - 1)
let brightness = getBrightness(red: red, green: green, blue: blue)
let alpha: CGFloat = brightness == 1 ? 0 : 1
cubeRGB.append(Float(red * alpha))
cubeRGB.append(Float(green * alpha))
cubeRGB.append(Float(blue * alpha))
cubeRGB.append(Float(alpha))
}
}
}
let data = Data(buffer: UnsafeBufferPointer(start: &cubeRGB, count: cubeRGB.count))
let colorCubeFilter = CIFilter(
name: "CIColorCube",
parameters: [
"inputCubeDimension": size,
"inputCubeData": data
]
)
return colorCubeFilter
}
// modified from https://developer.apple.com/documentation/coreimage/applying_a_chroma_key_effect
private func getBrightness(red: CGFloat, green: CGFloat, blue: CGFloat) -> CGFloat {
let color = UIColor(red: red, green: green, blue: blue, alpha: 1)
var brightness: CGFloat = 0
color.getHue(nil, saturation: nil, brightness: &brightness, alpha: nil)
return brightness
}
}
@DanielZanchi
Copy link

Nice one!
What are these indicating?
.semanticPredictions.image(min: 0, max: 1, axes: (0, 0, 1)),

Is there a way to convert the black mask created on outputImage with another color?

@JonathanWMorris
Copy link

Hi, thanks for making this! It was very helpful, however, there is one issue I found. Line 107 produces a warning in Xcode and it causes the app to crash in Release. To fix it, you can replace that line with the following:

var data = Data()
cubeRGB.withUnsafeBufferPointer { ptr in data = Data(buffer: ptr) }

@carissaboo1212
Copy link

Cannot find type 'UIImage' in scope

@Didami
Copy link

Didami commented Sep 8, 2022

import UIKit

@marcozabo
Copy link

:D

@bhavik-nexios
Copy link

bhavik-nexios commented Aug 13, 2025

import UIKit
import CoreML


extension MLMultiArray {
    func image(min: Float, max: Float, axes: (Int, Int, Int)) -> UIImage? {
        let width = Int(self.shape[axes.0].intValue)
        let height = Int(self.shape[axes.1].intValue)
        let channels = Int(self.shape[axes.2].intValue)
        
        guard channels == 3 else {
            print("Unsupported number of channels: \(channels). Only RGB images are supported.")
            return nil
        }
        
        var pixelData = [UInt8](repeating: 0, count: width * height * 4)
        
        for y in 0..<height {
            for x in 0..<width {
                let index = (y * width + x) * channels
                let r = UInt8((self[index]).floatValue * (255 / (max - min)) + min)
                let g = UInt8((self[index + 1]).floatValue * (255 / (max - min)) + min)
                let b = UInt8((self[index + 2]).floatValue * (255 / (max - min)) + min)
                
                pixelData[(y * width + x) * 4] = r
                pixelData[(y * width + x) * 4 + 1] = g
                pixelData[(y * width + x) * 4 + 2] = b
                pixelData[(y * width + x) * 4 + 3] = 255 // Alpha channel
            }
        }
        
        let colorSpace = CGColorSpaceCreateDeviceRGB()
        guard let context = CGContext(
            data: &pixelData,
            width: width,
            height: height,
            bitsPerComponent: 8,
            bytesPerRow: width * 4,
            space: colorSpace,
            bitmapInfo: CGImageAlphaInfo.premultipliedLast.rawValue
        ) else {
            print("Failed to create CGContext")
            return nil }
        
        guard let cgImage = context.makeImage() else {
            print("Failed to create CGImage")
            return nil }
        
        return UIImage(cgImage: cgImage, scale: UIScreen.main.scale, orientation: .up)
    }
}

extension UIImage {
    func resized(to size: CGSize, scale: CGFloat = 1) -> UIImage {
        let renderer = UIGraphicsImageRenderer(size: size)
        return renderer.image { _ in
            self.draw(in: CGRect(origin: .zero, size: size))
        }
    }
    
    func pixelBuffer(width: Int, height: Int) -> CVPixelBuffer? {
        let attrs = [
            kCVPixelBufferCGImageCompatibilityKey: true,
            kCVPixelBufferCGBitmapContextCompatibilityKey: true
        ] as CFDictionary
        var pixelBuffer: CVPixelBuffer?
        let status = CVPixelBufferCreate(
            kCFAllocatorDefault, width, height, kCVPixelFormatType_32ARGB, attrs, &pixelBuffer
        )
        guard status == kCVReturnSuccess, let buffer = pixelBuffer else { return nil }
        CVPixelBufferLockBaseAddress(buffer, [])
        let rgbColorSpace = CGColorSpaceCreateDeviceRGB()
        guard let context = CGContext(
            data: CVPixelBufferGetBaseAddress(buffer),
            width: width, height: height, bitsPerComponent: 8,
            bytesPerRow: CVPixelBufferGetBytesPerRow(buffer),
            space: rgbColorSpace, bitmapInfo: CGImageAlphaInfo.premultipliedFirst.rawValue
        ) else {
            CVPixelBufferUnlockBaseAddress(buffer, [])
            return nil
        }
        UIGraphicsPushContext(context)
        UIGraphicsGetCurrentContext()?.scaleBy(x: 1, y: -1)
        UIGraphicsGetCurrentContext()?.translateBy(x: 0, y: -CGFloat(height))
        self.draw(in: CGRect(x: 0, y: 0, width: CGFloat(width), height: CGFloat(height)))
        UIGraphicsPopContext()
        CVPixelBufferUnlockBaseAddress(buffer, [])
        return buffer
    }
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment