-
-
Save KittenYang/fbddca9878e093cd5361069e91f2b583 to your computer and use it in GitHub Desktop.
An iOS app that generates images using Stable-Diffusion-v2 CoreML models.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// | |
// ContentView.swift | |
// coremlsd2test | |
// | |
// Created by Yasuhito Nagatomo on 2022/12/03. | |
// | |
// A sample code using Apple/ml-stable-diffusion library. | |
// Preparation: | |
// 1. convert the PyTorch Stable-Diffusion v2 model to coreml models using Apple's tools. | |
// 2. import the coreml models into the iOS project. | |
// models2/ | |
// +-- Resources/ | |
// +-- merges, vocab, TextEncoder, Unet, VAEDecoder | |
// If you use different directories, modify the code for the ResourceURL path. | |
// 3. add Swift Package, https://github.com/apple/ml-stable-diffusion, to the iOS project. | |
// | |
// tested with macOS 13.1 beta 4, Xcode 14.1 | |
// for iPhone 12+/iOS 16.2, iPad Pro/M1/M2/iPadOS 16.2 | |
import SwiftUI | |
import StableDiffusion | |
struct PromptView: View { | |
@Binding var parameter: ImageGenerator.GenerationParameter | |
var body: some View { | |
VStack { | |
TextField("Prompt:", text: $parameter.prompt) | |
.textFieldStyle(RoundedBorderTextFieldStyle()) | |
Stepper(value: $parameter.imageCount, in: 1...10) { | |
Text("Image Count: \(parameter.imageCount)") | |
} | |
Stepper(value: $parameter.stepCount, in: 1...100) { | |
Text("Iteration steps: \(parameter.stepCount)") | |
} | |
Stepper(value: $parameter.seed, in: 0...10000) { | |
Text("Seed: \(parameter.seed)") | |
} | |
}.padding() | |
} | |
} | |
struct ContentView: View { | |
@StateObject var imageGenerator = ImageGenerator() | |
@State private var generationParameter = ImageGenerator.GenerationParameter(prompt: "a photo of an astronaut riding a horse on mars", | |
seed: 100, stepCount: 20, | |
imageCount: 1, disableSafety: false) | |
var body: some View { | |
ScrollView { | |
VStack { | |
Text("Stable Diffusion v2").font(.title).padding() | |
PromptView(parameter: $generationParameter) | |
.disabled(imageGenerator.generationState != .idle) | |
if imageGenerator.generationState == .idle { | |
Button(action: generate) { | |
Text("Generate").font(.title) | |
}.buttonStyle(.borderedProminent) | |
} else { | |
ProgressView() | |
} | |
if let generatedImages = imageGenerator.generatedImages { | |
ForEach(generatedImages.images) { | |
Image(uiImage: $0.uiImage) | |
.resizable() | |
.scaledToFit() | |
} | |
} | |
} | |
} | |
.padding() | |
} | |
func generate() { | |
imageGenerator.generateImages(generationParameter) | |
} | |
} | |
@MainActor | |
final class ImageGenerator: ObservableObject { | |
struct GenerationParameter { | |
var prompt: String | |
var seed: Int | |
var stepCount: Int | |
var imageCount: Int | |
var disableSafety: Bool | |
} | |
struct GeneratedImage: Identifiable { | |
let id: UUID = UUID() | |
let uiImage: UIImage | |
} | |
struct GeneratedImages { | |
let prompt: String | |
let imageCount: Int | |
let stepCount: Int | |
let seed: Int | |
let disableSafety: Bool | |
let images: [GeneratedImage] | |
} | |
enum GenerationState: Equatable { | |
case idle | |
case generating(progressStep: Int) | |
static func ==(lhs: Self, rhs: Self) -> Bool{ | |
switch (lhs, rhs) { | |
case (.idle, idle): return true | |
case (.generating(let s1), .generating(let s2)): | |
if s1 == s2 { return true } | |
else { return false } | |
default: | |
return false | |
} | |
} | |
} | |
@Published var generationState: GenerationState = .idle | |
@Published var generatedImages: GeneratedImages? | |
private let sdpipeline: StableDiffusionPipeline | |
init() { | |
guard let path = Bundle.main.path(forResource: "Resources", ofType: nil, inDirectory: "models2") else { | |
fatalError("Fatal error: failed to find the CoreML models.") | |
} | |
let resourceURL = URL(fileURLWithPath: path) | |
if let pipeline = try? StableDiffusionPipeline(resourcesAt: resourceURL) { | |
sdpipeline = pipeline | |
} else { | |
fatalError("Fatal error: failed to create the Stable-Diffusion-Pipeline.") | |
} | |
} | |
func setState(_ state: GenerationState) { // for actor isolation | |
generationState = state | |
} | |
func setGeneratedImages(_ images: GeneratedImages) { // for actor isolation | |
generatedImages = images | |
} | |
func generateImages(_ parameter: GenerationParameter) { | |
guard generationState == .idle else { return } | |
Task.detached(priority: .high) { | |
await self.setState(.generating(progressStep: 0)) | |
do { | |
// generateImages(prompt: String, imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0, | |
// disableSafety: Bool = false, | |
// progressHandler: (StableDiffusionPipeline.Progress) -> Bool = { _ in true }) throws -> [CGImage?] | |
// TODO: use the progressHandler | |
let cgImages = try self.sdpipeline.generateImages(prompt: parameter.prompt, | |
imageCount: parameter.imageCount, | |
stepCount: parameter.stepCount, | |
seed: parameter.seed, | |
disableSafety: parameter.disableSafety) | |
print("images were generated.") | |
let uiImages = cgImages.compactMap { image in | |
if let cgImage = image { return UIImage(cgImage: cgImage) } | |
else { return nil } | |
} | |
await self.setGeneratedImages(GeneratedImages(prompt: parameter.prompt, | |
imageCount: parameter.imageCount, | |
stepCount: parameter.stepCount, | |
seed: parameter.seed, | |
disableSafety: parameter.disableSafety, | |
images: uiImages.map { uiImage in GeneratedImage(uiImage: uiImage) })) | |
} catch { | |
print("failed.") | |
} | |
await self.setState(.idle) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment