@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202
202
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203
203
GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204
204
GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213
+ GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205
214
GGML_METAL_KERNEL_TYPE_RMS_NORM,
206
215
GGML_METAL_KERNEL_TYPE_L2_NORM,
207
216
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ @implementation GGMLMetalClass
1169
1178
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL, get_rows_iq4_nl, true );
1170
1179
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS, get_rows_iq4_xs, true );
1171
1180
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true );
1181
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32, set_rows_f32, true );
1182
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16, set_rows_f16, true );
1183
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16, set_rows_bf16, use_bfloat);
1184
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0, set_rows_q8_0, true );
1185
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0, set_rows_q4_0, true );
1186
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1, set_rows_q4_1, true );
1187
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0, set_rows_q5_0, true );
1188
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1, set_rows_q5_1, true );
1189
+ GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL, set_rows_iq4_nl, true );
1172
1190
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, has_simdgroup_reduction);
1173
1191
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM, l2_norm, has_simdgroup_reduction);
1174
1192
GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1635
1653
const bool use_bfloat = ctx_dev->use_bfloat ;
1636
1654
1637
1655
if (!use_bfloat) {
1656
+ if (op->type == GGML_TYPE_BF16) {
1657
+ return false ;
1658
+ }
1659
+
1638
1660
for (size_t i = 0 , n = 3 ; i < n; ++i) {
1639
1661
if (op->src [i] != NULL && op->src [i]->type == GGML_TYPE_BF16) {
1640
1662
return false ;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
1804
1826
{
1805
1827
return op->ne [3 ] == 1 ;
1806
1828
}
1829
+ case GGML_OP_SET_ROWS:
1830
+ {
1831
+ if (op->src [0 ]->type != GGML_TYPE_F32) {
1832
+ return false ;
1833
+ }
1834
+
1835
+ switch (op->type ) {
1836
+ case GGML_TYPE_F32:
1837
+ case GGML_TYPE_F16:
1838
+ case GGML_TYPE_BF16:
1839
+ case GGML_TYPE_Q8_0:
1840
+ case GGML_TYPE_Q4_0:
1841
+ case GGML_TYPE_Q4_1:
1842
+ case GGML_TYPE_Q5_0:
1843
+ case GGML_TYPE_Q5_1:
1844
+ case GGML_TYPE_IQ4_NL:
1845
+ return true ;
1846
+ default :
1847
+ return false ;
1848
+ };
1849
+ }
1807
1850
default :
1808
1851
return false ;
1809
1852
}
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
3777
3820
};
3778
3821
3779
3822
[encoder setComputePipelineState: pipeline];
3780
- [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
3781
- [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
3782
- [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
3783
- [encoder setBytes: &args length: sizeof (args) atIndex: 3 ];
3823
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3824
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3825
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3826
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3784
3827
3785
3828
[encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
3786
3829
} break ;
3830
+ case GGML_OP_SET_ROWS:
3831
+ {
3832
+ id <MTLComputePipelineState > pipeline = nil ;
3833
+
3834
+ switch (dst->type ) {
3835
+ case GGML_TYPE_F32: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32 ].pipeline ; break ;
3836
+ case GGML_TYPE_F16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16 ].pipeline ; break ;
3837
+ case GGML_TYPE_BF16: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16 ].pipeline ; break ;
3838
+ case GGML_TYPE_Q8_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0 ].pipeline ; break ;
3839
+ case GGML_TYPE_Q4_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0 ].pipeline ; break ;
3840
+ case GGML_TYPE_Q4_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1 ].pipeline ; break ;
3841
+ case GGML_TYPE_Q5_0: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0 ].pipeline ; break ;
3842
+ case GGML_TYPE_Q5_1: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1 ].pipeline ; break ;
3843
+ case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline ; break ;
3844
+ default : GGML_ABORT (" not implemented" );
3845
+ }
3846
+
3847
+ const int32_t nk0 = ne0/ggml_blck_size (dst->type );
3848
+
3849
+ int nth = 32 ; // SIMD width
3850
+
3851
+ while (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3852
+ nth *= 2 ;
3853
+ }
3854
+
3855
+ int nrptg = 1 ;
3856
+ if (nth > nk0) {
3857
+ nrptg = (nth + nk0 - 1 )/nk0;
3858
+ nth = nk0;
3859
+
3860
+ if (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3861
+ nrptg--;
3862
+ }
3863
+ }
3864
+
3865
+ nth = MIN (nth, nk0);
3866
+
3867
+ ggml_metal_kargs_set_rows args = {
3868
+ /* .nk0 =*/ nk0,
3869
+ /* .ne01 =*/ ne01,
3870
+ /* .nb01 =*/ nb01,
3871
+ /* .nb02 =*/ nb02,
3872
+ /* .nb03 =*/ nb03,
3873
+ /* .ne11 =*/ ne11,
3874
+ /* .ne12 =*/ ne12,
3875
+ /* .nb10 =*/ nb10,
3876
+ /* .nb11 =*/ nb11,
3877
+ /* .nb12 =*/ nb12,
3878
+ /* .nb1 =*/ nb1,
3879
+ /* .nb2 =*/ nb2,
3880
+ /* .nb3 =*/ nb3,
3881
+ };
3882
+
3883
+ [encoder setComputePipelineState: pipeline];
3884
+ [encoder setBytes: &args length: sizeof (args) atIndex: 0 ];
3885
+ [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 1 ];
3886
+ [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 2 ];
3887
+ [encoder setBuffer: id_dst offset: offs_dst atIndex: 3 ];
3888
+
3889
+ [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3890
+ } break ;
3787
3891
case GGML_OP_RMS_NORM:
3788
3892
{
3789
3893
GGML_ASSERT (ne00 % 4 == 0 );
0 commit comments