Skip to content

Commit cb10fc0

Browse files
committed
fix build issues
1 parent 9a1e473 commit cb10fc0

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

Libraries/MLXVLM/Models/Qwen25VL.swift

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,15 +382,14 @@ private enum Vision {
382382

383383
// Create attention mask
384384
let attentionMask = full(
385-
(1, sequenceLength, sequenceLength),
386-
MLXArray.finfo(q.dtype).min,
387-
dtype: q.dtype)
385+
[1, sequenceLength, sequenceLength],
386+
values: -Float32.greatestFiniteMagnitude)
388387

389388
// Update mask for each sequence
390389
for i in 1 ..< cuSeqlens.size {
391-
let start = Int(cuSeqlens[i - 1].item())
392-
let end = Int(cuSeqlens[i].item())
393-
attentionMask[0..., start ..< end, start ..< end] = 0
390+
let start = cuSeqlens[i - 1].item(Int.self)
391+
let end = cuSeqlens[i].item(Int.self)
392+
attentionMask[0..., start ..< end, start ..< end] = MLXArray(0)
394393
}
395394

396395
q = q.reshaped(1, sequenceLength, numHeads, -1).transposed(0, 2, 1, 3)
@@ -414,9 +413,9 @@ private enum Vision {
414413
@ModuleInfo(key: "down_proj") var down: Linear
415414

416415
public init(dimensions: Int, hiddenDimensions: Int) {
417-
self.gate = Linear(dimensions, hiddenDimensions)
418-
self.up = Linear(dimensions, hiddenDimensions)
419-
self.down = Linear(hiddenDimensions, dimensions)
416+
self._gate.wrappedValue = Linear(dimensions, hiddenDimensions)
417+
self._up.wrappedValue = Linear(dimensions, hiddenDimensions)
418+
self._down.wrappedValue = Linear(hiddenDimensions, dimensions)
420419
}
421420

422421
public func callAsFunction(_ x: MLXArray) -> MLXArray {
@@ -559,11 +558,11 @@ private enum Vision {
559558
let numWindowsW = (llmGridW + padW) / vitMergerWindowSize
560559

561560
// Pad the index
562-
let indexPadded = pad(
561+
let indexPadded = padded(
563562
index,
564-
paddings: [(0, 0), (0, padH), (0, padW)],
563+
widths: [[0, 0], [0, padH], [0, padW]],
565564
mode: .constant,
566-
constantValues: -100
565+
value: MLXArray(-100)
567566
)
568567

569568
// Reshape and transpose
@@ -583,7 +582,7 @@ private enum Vision {
583582
)
584583

585584
// Calculate sequence lengths
586-
let seqlens = sum(indexTransposed != -100, axes: [2, 3]).reshaped(-1)
585+
let seqlens = sum(indexTransposed .!= -100, axes: [2, 3]).reshaped(-1)
587586

588587
// Get valid indices
589588
let indexFlattened = indexTransposed.flattened()
@@ -671,7 +670,7 @@ private enum Vision {
671670
hiddenStates = patchMerger(hiddenStates)
672671

673672
// Reorder back to original sequence
674-
let reverseIndices = argsort(windowIndex, axis: 0)
673+
let reverseIndices = argSort(windowIndex, axis: 0)
675674
hiddenStates = hiddenStates[reverseIndices, 0...]
676675

677676
return hiddenStates

0 commit comments

Comments
 (0)