Skip to content

Commit db579f3

Browse files
authored
pick up qwen2-vl fixes (#163)
- see awni/mlx-vlm@85cdc44
1 parent 197f241 commit db579f3

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

Libraries/MLXVLM/Models/Qwen2VL.swift

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,8 +373,8 @@ private enum Vision {
373373
let B = gridThw[0].t
374374
let L = sequenceLength / B
375375

376-
let qkv = qkv(x).reshaped(sequenceLength, 3, -1)
377-
let s = split(qkv, parts: 3, axis: 1)
376+
let qkv = qkv(x)
377+
let s = split(qkv, parts: 3, axis: -1)
378378
var (q, k, v) = (s[0], s[1], s[2])
379379

380380
q = q.reshaped(sequenceLength, numHeads, -1)
@@ -512,7 +512,7 @@ private enum Vision {
512512
.flattened()
513513

514514
let stackedPosIds = stacked([hposIds, wposIds], axis: -1)
515-
positionIds.append(repeated(stackedPosIds, count: t, axis: 0))
515+
positionIds.append(tiled(stackedPosIds, repetitions: [t, 1]))
516516
}
517517

518518
let indices = concatenated(positionIds, axis: 0)

0 commit comments

Comments
 (0)