Skip to content

Instantly share code, notes, and snippets.

@ynagatomo
Created December 4, 2022 08:56
Show Gist options
  • Save ynagatomo/fd87876a6a6ab92e5a3f5816c420f982 to your computer and use it in GitHub Desktop.
Save ynagatomo/fd87876a6a6ab92e5a3f5816c420f982 to your computer and use it in GitHub Desktop.
An iOS app that generates images using Stable-Diffusion-v2 CoreML models.
//
// ContentView.swift
// coremlsd2test
//
// Created by Yasuhito Nagatomo on 2022/12/03.
//
// A sample code using Apple/ml-stable-diffusion library.
// Preparation:
// 1. convert the PyTorch Stable-Diffusion v2 model to coreml models using Apple's tools.
// 2. import the coreml models into the iOS project.
// models2/
// +-- Resources/
// +-- merges, vocab, TextEncoder, Unet, VAEDecoder
// If you use different directories, modify the code for the ResourceURL path.
// 3. add Swift Package, https://github.com/apple/ml-stable-diffusion, to the iOS project.
//
// tested with macOS 13.1 beta 4, Xcode 14.1
// for iPhone 12+/iOS 16.2, iPad Pro/M1/M2/iPadOS 16.2
import SwiftUI
import StableDiffusion
struct PromptView: View {
@Binding var parameter: ImageGenerator.GenerationParameter
var body: some View {
VStack {
TextField("Prompt:", text: $parameter.prompt)
.textFieldStyle(RoundedBorderTextFieldStyle())
Stepper(value: $parameter.imageCount, in: 1...10) {
Text("Image Count: \(parameter.imageCount)")
}
Stepper(value: $parameter.stepCount, in: 1...100) {
Text("Iteration steps: \(parameter.stepCount)")
}
Stepper(value: $parameter.seed, in: 0...10000) {
Text("Seed: \(parameter.seed)")
}
}.padding()
}
}
struct ContentView: View {
@StateObject var imageGenerator = ImageGenerator()
@State private var generationParameter = ImageGenerator.GenerationParameter(prompt: "a photo of an astronaut riding a horse on mars",
seed: 100, stepCount: 20,
imageCount: 1, disableSafety: false)
var body: some View {
ScrollView {
VStack {
Text("Stable Diffusion v2").font(.title).padding()
PromptView(parameter: $generationParameter)
.disabled(imageGenerator.generationState != .idle)
if imageGenerator.generationState == .idle {
Button(action: generate) {
Text("Generate").font(.title)
}.buttonStyle(.borderedProminent)
} else {
ProgressView()
}
if let generatedImages = imageGenerator.generatedImages {
ForEach(generatedImages.images) {
Image(uiImage: $0.uiImage)
.resizable()
.scaledToFit()
}
}
}
}
.padding()
}
func generate() {
imageGenerator.generateImages(generationParameter)
}
}
@MainActor
final class ImageGenerator: ObservableObject {
struct GenerationParameter {
var prompt: String
var seed: Int
var stepCount: Int
var imageCount: Int
var disableSafety: Bool
}
struct GeneratedImage: Identifiable {
let id: UUID = UUID()
let uiImage: UIImage
}
struct GeneratedImages {
let prompt: String
let imageCount: Int
let stepCount: Int
let seed: Int
let disableSafety: Bool
let images: [GeneratedImage]
}
enum GenerationState: Equatable {
case idle
case generating(progressStep: Int)
static func ==(lhs: Self, rhs: Self) -> Bool{
switch (lhs, rhs) {
case (.idle, idle): return true
case (.generating(let s1), .generating(let s2)):
if s1 == s2 { return true }
else { return false }
default:
return false
}
}
}
@Published var generationState: GenerationState = .idle
@Published var generatedImages: GeneratedImages?
private let sdpipeline: StableDiffusionPipeline
init() {
guard let path = Bundle.main.path(forResource: "Resources", ofType: nil, inDirectory: "models2") else {
fatalError("Fatal error: failed to find the CoreML models.")
}
let resourceURL = URL(fileURLWithPath: path)
if let pipeline = try? StableDiffusionPipeline(resourcesAt: resourceURL) {
sdpipeline = pipeline
} else {
fatalError("Fatal error: failed to create the Stable-Diffusion-Pipeline.")
}
}
func setState(_ state: GenerationState) { // for actor isolation
generationState = state
}
func setGeneratedImages(_ images: GeneratedImages) { // for actor isolation
generatedImages = images
}
func generateImages(_ parameter: GenerationParameter) {
guard generationState == .idle else { return }
Task.detached(priority: .high) {
await self.setState(.generating(progressStep: 0))
do {
// generateImages(prompt: String, imageCount: Int = 1, stepCount: Int = 50, seed: Int = 0,
// disableSafety: Bool = false,
// progressHandler: (StableDiffusionPipeline.Progress) -> Bool = { _ in true }) throws -> [CGImage?]
// TODO: use the progressHandler
let cgImages = try self.sdpipeline.generateImages(prompt: parameter.prompt,
imageCount: parameter.imageCount,
stepCount: parameter.stepCount,
seed: parameter.seed,
disableSafety: parameter.disableSafety)
print("images were generated.")
let uiImages = cgImages.compactMap { image in
if let cgImage = image { return UIImage(cgImage: cgImage) }
else { return nil }
}
await self.setGeneratedImages(GeneratedImages(prompt: parameter.prompt,
imageCount: parameter.imageCount,
stepCount: parameter.stepCount,
seed: parameter.seed,
disableSafety: parameter.disableSafety,
images: uiImages.map { uiImage in GeneratedImage(uiImage: uiImage) }))
} catch {
print("failed.")
}
await self.setState(.idle)
}
}
}
@LEOS-83
Copy link

LEOS-83 commented Dec 4, 2022

O

@ynagatomo
Copy link
Author

latest Xcode project is on GitHub: https://github.com/ynagatomo/ImgGenSD2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment