Skip to content

Commit a2ebad7

Browse files
authored
fixes from python side (#158)
1 parent 566c8ec commit a2ebad7

File tree

2 files changed

+6
-5
lines changed

2 files changed

+6
-5
lines changed

Libraries/StableDiffusion/StableDiffusion.swift

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,10 +372,11 @@ open class StableDiffusionXL: StableDiffusion, TextToImageGenerator, ImageToImag
372372
conditioning2.hiddenStates.dropLast().last!,
373373
],
374374
axis: -1)
375-
let pooledConditionng = conditioning2.pooledOutput
375+
var pooledConditionng = conditioning2.pooledOutput
376376

377377
if imageCount > 1 {
378378
conditioning = repeated(conditioning, count: imageCount, axis: 0)
379+
pooledConditionng = repeated(pooledConditionng, count: imageCount, axis: 0)
379380
}
380381

381382
return (conditioning, pooledConditionng)

Libraries/StableDiffusion/UNet.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class Transformer2D: Module {
136136

137137
// Perform the input norm and projection
138138
let (B, H, W, C) = x.shape4
139-
x = norm(x.asType(.float32)).asType(dtype).reshaped(B, -1, C)
139+
x = norm(x).reshaped(B, -1, C)
140140
x = projectIn(x)
141141

142142
// apply the transformer
@@ -195,7 +195,7 @@ class ResnetBlock2D: Module {
195195
func callAsFunction(_ x: MLXArray, timeEmbedding: MLXArray? = nil) -> MLXArray {
196196
let dtype = x.dtype
197197

198-
var y = norm1(x.asType(.float32)).asType(dtype)
198+
var y = norm1(x)
199199
y = silu(y)
200200
y = conv1(y)
201201

@@ -204,7 +204,7 @@ class ResnetBlock2D: Module {
204204
y = y + timeEmbedding[0..., .newAxis, .newAxis, 0...]
205205
}
206206

207-
y = norm2(y.asType(.float32)).asType(dtype)
207+
y = norm2(y)
208208
y = silu(y)
209209
y = conv2(y)
210210

@@ -501,7 +501,7 @@ class UNetModel: Module {
501501

502502
// postprocess the output
503503
let dtype = x.dtype
504-
x = convNormOut(x.asType(.float32)).asType(dtype)
504+
x = convNormOut(x)
505505
x = silu(x)
506506
x = convOut(x)
507507

0 commit comments

Comments
 (0)