Skip to content

Commit cf29855

Browse files
authored
implement stable diffusion example (#120)
* implement stable diffusion example - based on https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion - example code for these two models - https://huggingface.co/stabilityai/sdxl-turbo - https://huggingface.co/stabilityai/stable-diffusion-2-1 - command line tool example for text-to-image - command line tool example for image-to-image - example application for same
1 parent a4f278a commit cf29855

File tree

24 files changed

+4397
-92
lines changed

24 files changed

+4397
-92
lines changed
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"colors" : [
3+
{
4+
"idiom" : "universal"
5+
}
6+
],
7+
"info" : {
8+
"author" : "xcode",
9+
"version" : 1
10+
}
11+
}
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
{
2+
"images" : [
3+
{
4+
"idiom" : "universal",
5+
"platform" : "ios",
6+
"size" : "1024x1024"
7+
},
8+
{
9+
"idiom" : "mac",
10+
"scale" : "1x",
11+
"size" : "16x16"
12+
},
13+
{
14+
"idiom" : "mac",
15+
"scale" : "2x",
16+
"size" : "16x16"
17+
},
18+
{
19+
"idiom" : "mac",
20+
"scale" : "1x",
21+
"size" : "32x32"
22+
},
23+
{
24+
"idiom" : "mac",
25+
"scale" : "2x",
26+
"size" : "32x32"
27+
},
28+
{
29+
"idiom" : "mac",
30+
"scale" : "1x",
31+
"size" : "128x128"
32+
},
33+
{
34+
"idiom" : "mac",
35+
"scale" : "2x",
36+
"size" : "128x128"
37+
},
38+
{
39+
"idiom" : "mac",
40+
"scale" : "1x",
41+
"size" : "256x256"
42+
},
43+
{
44+
"idiom" : "mac",
45+
"scale" : "2x",
46+
"size" : "256x256"
47+
},
48+
{
49+
"idiom" : "mac",
50+
"scale" : "1x",
51+
"size" : "512x512"
52+
},
53+
{
54+
"idiom" : "mac",
55+
"scale" : "2x",
56+
"size" : "512x512"
57+
}
58+
],
59+
"info" : {
60+
"author" : "xcode",
61+
"version" : 1
62+
}
63+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"info" : {
3+
"author" : "xcode",
4+
"version" : 1
5+
}
6+
}
Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
// Copyright © 2024 Apple Inc.
2+
3+
import MLX
4+
import StableDiffusion
5+
import SwiftUI
6+
7+
struct ContentView: View {
8+
9+
@State var prompt = "dismal swamp, dense, very dark, realistic"
10+
@State var negativePrompt = ""
11+
@State var evaluator = StableDiffusionEvaluator()
12+
@State var showProgress = false
13+
14+
var body: some View {
15+
VStack {
16+
HStack {
17+
if let progress = evaluator.progress {
18+
ProgressView(progress.title, value: progress.current, total: progress.limit)
19+
}
20+
}
21+
.frame(height: 20)
22+
23+
Spacer()
24+
if let image = evaluator.image {
25+
Image(image, scale: 1.0, label: Text(""))
26+
.resizable()
27+
.aspectRatio(contentMode: .fit)
28+
.frame(minHeight: 200)
29+
}
30+
Spacer()
31+
32+
Grid {
33+
GridRow {
34+
TextField("prompt", text: $prompt)
35+
.onSubmit(generate)
36+
.disabled(evaluator.progress != nil)
37+
#if os(visionOS)
38+
.textFieldStyle(.roundedBorder)
39+
#endif
40+
41+
Button(action: { prompt = "" }) {
42+
Label("clear", systemImage: "xmark.circle.fill").font(.system(size: 10))
43+
}
44+
.labelStyle(.iconOnly)
45+
.buttonStyle(.plain)
46+
47+
Button("generate", action: generate)
48+
.disabled(evaluator.progress != nil)
49+
.keyboardShortcut("r")
50+
}
51+
if evaluator.modelFactory.canShowProgress
52+
|| evaluator.modelFactory.canUseNegativeText
53+
{
54+
GridRow {
55+
if evaluator.modelFactory.canUseNegativeText {
56+
TextField("negative prompt", text: $negativePrompt)
57+
.onSubmit(generate)
58+
.disabled(evaluator.progress != nil)
59+
#if os(visionOS)
60+
.textFieldStyle(.roundedBorder)
61+
#endif
62+
Button(action: { prompt = "" }) {
63+
Label("clear", systemImage: "xmark.circle.fill").font(
64+
.system(size: 10))
65+
}
66+
.labelStyle(.iconOnly)
67+
.buttonStyle(.plain)
68+
} else {
69+
EmptyView()
70+
EmptyView()
71+
}
72+
73+
if evaluator.modelFactory.canShowProgress {
74+
Toggle("Show Progress", isOn: $showProgress)
75+
}
76+
}
77+
}
78+
}
79+
.frame(minWidth: 300)
80+
}
81+
.padding()
82+
}
83+
84+
private func generate() {
85+
Task {
86+
await evaluator.generate(
87+
prompt: prompt, negativePrompt: negativePrompt, showProgress: showProgress)
88+
}
89+
}
90+
}
91+
92+
/// Progress reporting with a title.
93+
struct Progress: Equatable {
94+
let title: String
95+
let current: Double
96+
let limit: Double
97+
}
98+
99+
/// Async model factory
100+
actor ModelFactory {
101+
102+
enum LoadState {
103+
case idle
104+
case loading(Task<ModelContainer<TextToImageGenerator>, Error>)
105+
case loaded(ModelContainer<TextToImageGenerator>)
106+
}
107+
108+
enum SDError: Error {
109+
case unableToLoad
110+
}
111+
112+
public nonisolated let configuration = StableDiffusionConfiguration.presetSDXLTurbo
113+
114+
/// if true we show UI that lets users see the intermediate steps
115+
public nonisolated let canShowProgress: Bool
116+
117+
/// if true we show UI to give negative text
118+
public nonisolated let canUseNegativeText: Bool
119+
120+
private var loadState = LoadState.idle
121+
private var loadConfiguration = LoadConfiguration(float16: true, quantize: false)
122+
123+
public nonisolated let conserveMemory: Bool
124+
125+
init() {
126+
let defaultParameters = configuration.defaultParameters()
127+
self.canShowProgress = defaultParameters.steps > 4
128+
self.canUseNegativeText = defaultParameters.cfgWeight > 1
129+
130+
// this will be true e.g. if the computer has 8G of memory or less
131+
self.conserveMemory = MLX.GPU.memoryLimit < 8 * 1024 * 1024 * 1024
132+
133+
if conserveMemory {
134+
print("conserving memory")
135+
loadConfiguration.quantize = true
136+
MLX.GPU.set(cacheLimit: 1 * 1024 * 1024)
137+
MLX.GPU.set(memoryLimit: 3 * 1024 * 1024 * 1024)
138+
} else {
139+
MLX.GPU.set(cacheLimit: 256 * 1024 * 1024)
140+
}
141+
}
142+
143+
public func load(reportProgress: @escaping @Sendable (Progress) -> Void) async throws
144+
-> ModelContainer<TextToImageGenerator>
145+
{
146+
switch loadState {
147+
case .idle:
148+
let task = Task {
149+
try await configuration.download { progress in
150+
if progress.fractionCompleted < 0.99 {
151+
reportProgress(
152+
.init(
153+
title: "Download", current: progress.fractionCompleted * 100,
154+
limit: 100))
155+
}
156+
}
157+
158+
let container = try ModelContainer<TextToImageGenerator>.createTextToImageGenerator(
159+
configuration: configuration, loadConfiguration: loadConfiguration)
160+
161+
await container.setConserveMemory(conserveMemory)
162+
163+
try await container.perform { model in
164+
reportProgress(.init(title: "Loading weights", current: 0, limit: 1))
165+
if !conserveMemory {
166+
model.ensureLoaded()
167+
}
168+
}
169+
170+
return container
171+
}
172+
self.loadState = .loading(task)
173+
174+
let container = try await task.value
175+
176+
if conserveMemory {
177+
// if conserving memory return the model but do not keep it in memory
178+
self.loadState = .idle
179+
} else {
180+
// cache the model in memory to make it faster to run with new prompts
181+
self.loadState = .loaded(container)
182+
}
183+
184+
return container
185+
186+
case .loading(let task):
187+
let generator = try await task.value
188+
return generator
189+
190+
case .loaded(let generator):
191+
return generator
192+
}
193+
}
194+
195+
}
196+
197+
@Observable @MainActor
198+
class StableDiffusionEvaluator {
199+
200+
var progress: Progress?
201+
var message: String?
202+
var image: CGImage?
203+
204+
let modelFactory = ModelFactory()
205+
206+
@Sendable
207+
nonisolated private func updateProgress(progress: Progress?) {
208+
Task { @MainActor in
209+
self.progress = progress
210+
}
211+
}
212+
213+
@Sendable
214+
nonisolated private func updateImage(image: CGImage?) {
215+
Task { @MainActor in
216+
self.image = image
217+
}
218+
}
219+
220+
nonisolated private func display(decoded: MLXArray) {
221+
let raster = (decoded * 255).asType(.uint8).squeezed()
222+
let image = Image(raster).asCGImage()
223+
224+
Task { @MainActor in
225+
updateImage(image: image)
226+
}
227+
}
228+
229+
func generate(prompt: String, negativePrompt: String, showProgress: Bool) async {
230+
progress = .init(title: "Preparing", current: 0, limit: 1)
231+
message = nil
232+
233+
// the parameters that control the generation of the image. See
234+
// EvaluateParameters for more information. For example adjusting
235+
// the latentSize parameter will change the size of the generated
236+
// image. imageCount could be used to generate a gallery of
237+
// images at the same time.
238+
let parameters = {
239+
var p = modelFactory.configuration.defaultParameters()
240+
p.prompt = prompt
241+
p.negativePrompt = negativePrompt
242+
243+
// per measurement each step consumes memory that we want to conserve. trade
244+
// off steps (quality) for memory
245+
if modelFactory.conserveMemory {
246+
p.steps = 1
247+
}
248+
249+
return p
250+
}()
251+
252+
do {
253+
// note: the optionals are used to discard parts of the model
254+
// as it runs -- this is used to conserveMemory in devices
255+
// with less memory
256+
let container = try await modelFactory.load(reportProgress: updateProgress)
257+
258+
try await container.performTwoStage { generator in
259+
// the parameters that control the generation of the image. See
260+
// EvaluateParameters for more information. For example adjusting
261+
// the latentSize parameter will change the size of the generated
262+
// image. imageCount could be used to generate a gallery of
263+
// images at the same time.
264+
var parameters = modelFactory.configuration.defaultParameters()
265+
parameters.prompt = prompt
266+
parameters.negativePrompt = negativePrompt
267+
268+
// per measurement each step consumes memory that we want to conserve. trade
269+
// off steps (quality) for memory
270+
if modelFactory.conserveMemory {
271+
parameters.steps = 1
272+
}
273+
274+
// generate the latent images -- this is fast as it is just generating
275+
// the graphs that will be evaluated below
276+
let latents: DenoiseIterator? = generator.generateLatents(parameters: parameters)
277+
278+
// when conserveMemory is true this will discard the first part of
279+
// the model and just evaluate the decode portion
280+
return (generator.detachedDecoder(), latents)
281+
282+
} second: { decoder, latents in
283+
var lastXt: MLXArray?
284+
for (i, xt) in latents!.enumerated() {
285+
lastXt = nil
286+
eval(xt)
287+
lastXt = xt
288+
289+
if showProgress, i % 10 == 0 {
290+
display(decoded: decoder(xt))
291+
}
292+
293+
updateProgress(
294+
progress: .init(
295+
title: "Generate Latents", current: Double(i),
296+
limit: Double(parameters.steps)))
297+
}
298+
299+
if let lastXt {
300+
display(decoded: decoder(lastXt))
301+
}
302+
updateProgress(progress: nil)
303+
}
304+
305+
} catch {
306+
progress = nil
307+
message = "Failed: \(error)"
308+
}
309+
}
310+
}

0 commit comments

Comments
 (0)