@@ -382,15 +382,14 @@ private enum Vision {
382
382
383
383
// Create attention mask
384
384
let attentionMask = full (
385
- ( 1 , sequenceLength, sequenceLength) ,
386
- MLXArray . finfo ( q. dtype) . min,
387
- dtype: q. dtype)
385
+ [ 1 , sequenceLength, sequenceLength] ,
386
+ values: - Float32. greatestFiniteMagnitude)
388
387
389
388
// Update mask for each sequence
390
389
for i in 1 ..< cuSeqlens. size {
391
- let start = Int ( cuSeqlens [ i - 1 ] . item ( ) )
392
- let end = Int ( cuSeqlens [ i] . item ( ) )
393
- attentionMask [ 0 ... , start ..< end, start ..< end] = 0
390
+ let start = cuSeqlens [ i - 1 ] . item ( Int . self )
391
+ let end = cuSeqlens [ i] . item ( Int . self )
392
+ attentionMask [ 0 ... , start ..< end, start ..< end] = MLXArray ( 0 )
394
393
}
395
394
396
395
q = q. reshaped ( 1 , sequenceLength, numHeads, - 1 ) . transposed ( 0 , 2 , 1 , 3 )
@@ -414,9 +413,9 @@ private enum Vision {
414
413
@ModuleInfo ( key: " down_proj " ) var down : Linear
415
414
416
415
public init ( dimensions: Int , hiddenDimensions: Int ) {
417
- self . gate = Linear ( dimensions, hiddenDimensions)
418
- self . up = Linear ( dimensions, hiddenDimensions)
419
- self . down = Linear ( hiddenDimensions, dimensions)
416
+ self . _gate . wrappedValue = Linear ( dimensions, hiddenDimensions)
417
+ self . _up . wrappedValue = Linear ( dimensions, hiddenDimensions)
418
+ self . _down . wrappedValue = Linear ( hiddenDimensions, dimensions)
420
419
}
421
420
422
421
public func callAsFunction( _ x: MLXArray ) -> MLXArray {
@@ -559,11 +558,11 @@ private enum Vision {
559
558
let numWindowsW = ( llmGridW + padW) / vitMergerWindowSize
560
559
561
560
// Pad the index
562
- let indexPadded = pad (
561
+ let indexPadded = padded (
563
562
index,
564
- paddings : [ ( 0 , 0 ) , ( 0 , padH) , ( 0 , padW) ] ,
563
+ widths : [ [ 0 , 0 ] , [ 0 , padH] , [ 0 , padW] ] ,
565
564
mode: . constant,
566
- constantValues : - 100
565
+ value : MLXArray ( - 100 )
567
566
)
568
567
569
568
// Reshape and transpose
@@ -583,7 +582,7 @@ private enum Vision {
583
582
)
584
583
585
584
// Calculate sequence lengths
586
- let seqlens = sum ( indexTransposed != - 100 , axes: [ 2 , 3 ] ) . reshaped ( - 1 )
585
+ let seqlens = sum ( indexTransposed . != - 100 , axes: [ 2 , 3 ] ) . reshaped ( - 1 )
587
586
588
587
// Get valid indices
589
588
let indexFlattened = indexTransposed. flattened ( )
@@ -671,7 +670,7 @@ private enum Vision {
671
670
hiddenStates = patchMerger ( hiddenStates)
672
671
673
672
// Reorder back to original sequence
674
- let reverseIndices = argsort ( windowIndex, axis: 0 )
673
+ let reverseIndices = argSort ( windowIndex, axis: 0 )
675
674
hiddenStates = hiddenStates [ reverseIndices, 0 ... ]
676
675
677
676
return hiddenStates
0 commit comments