81
81
GGML_METAL_DECL_KERNEL (get_rows_q6_K);
82
82
GGML_METAL_DECL_KERNEL (rms_norm);
83
83
GGML_METAL_DECL_KERNEL (norm);
84
- GGML_METAL_DECL_KERNEL (mul_mat_f32_f32 );
85
- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32 );
86
- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_1row );
87
- GGML_METAL_DECL_KERNEL (mul_mat_f16_f32_l4 );
88
- GGML_METAL_DECL_KERNEL (mul_mat_q4_0_f32 );
89
- GGML_METAL_DECL_KERNEL (mul_mat_q4_1_f32 );
90
- GGML_METAL_DECL_KERNEL (mul_mat_q8_0_f32 );
91
- GGML_METAL_DECL_KERNEL (mul_mat_q2_K_f32 );
92
- GGML_METAL_DECL_KERNEL (mul_mat_q3_K_f32 );
93
- GGML_METAL_DECL_KERNEL (mul_mat_q4_K_f32 );
94
- GGML_METAL_DECL_KERNEL (mul_mat_q5_K_f32 );
95
- GGML_METAL_DECL_KERNEL (mul_mat_q6_K_f32 );
84
+ GGML_METAL_DECL_KERNEL (mul_mv_f32_f32 );
85
+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32 );
86
+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32_1row );
87
+ GGML_METAL_DECL_KERNEL (mul_mv_f16_f32_l4 );
88
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_0_f32 );
89
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_1_f32 );
90
+ GGML_METAL_DECL_KERNEL (mul_mv_q8_0_f32 );
91
+ GGML_METAL_DECL_KERNEL (mul_mv_q2_K_f32 );
92
+ GGML_METAL_DECL_KERNEL (mul_mv_q3_K_f32 );
93
+ GGML_METAL_DECL_KERNEL (mul_mv_q4_K_f32 );
94
+ GGML_METAL_DECL_KERNEL (mul_mv_q5_K_f32 );
95
+ GGML_METAL_DECL_KERNEL (mul_mv_q6_K_f32 );
96
96
GGML_METAL_DECL_KERNEL (mul_mm_f32_f32);
97
97
GGML_METAL_DECL_KERNEL (mul_mm_f16_f32);
98
98
GGML_METAL_DECL_KERNEL (mul_mm_q4_0_f32);
@@ -262,28 +262,30 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
262
262
GGML_METAL_ADD_KERNEL (get_rows_q6_K);
263
263
GGML_METAL_ADD_KERNEL (rms_norm);
264
264
GGML_METAL_ADD_KERNEL (norm);
265
- GGML_METAL_ADD_KERNEL (mul_mat_f32_f32);
266
- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32);
267
- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_1row);
268
- GGML_METAL_ADD_KERNEL (mul_mat_f16_f32_l4);
269
- GGML_METAL_ADD_KERNEL (mul_mat_q4_0_f32);
270
- GGML_METAL_ADD_KERNEL (mul_mat_q4_1_f32);
271
- GGML_METAL_ADD_KERNEL (mul_mat_q8_0_f32);
272
- GGML_METAL_ADD_KERNEL (mul_mat_q2_K_f32);
273
- GGML_METAL_ADD_KERNEL (mul_mat_q3_K_f32);
274
- GGML_METAL_ADD_KERNEL (mul_mat_q4_K_f32);
275
- GGML_METAL_ADD_KERNEL (mul_mat_q5_K_f32);
276
- GGML_METAL_ADD_KERNEL (mul_mat_q6_K_f32);
277
- GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
278
- GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
279
- GGML_METAL_ADD_KERNEL (mul_mm_q4_0_f32);
280
- GGML_METAL_ADD_KERNEL (mul_mm_q8_0_f32);
281
- GGML_METAL_ADD_KERNEL (mul_mm_q4_1_f32);
282
- GGML_METAL_ADD_KERNEL (mul_mm_q2_K_f32);
283
- GGML_METAL_ADD_KERNEL (mul_mm_q3_K_f32);
284
- GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
285
- GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
286
- GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
265
+ GGML_METAL_ADD_KERNEL (mul_mv_f32_f32);
266
+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32);
267
+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32_1row);
268
+ GGML_METAL_ADD_KERNEL (mul_mv_f16_f32_l4);
269
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_0_f32);
270
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_1_f32);
271
+ GGML_METAL_ADD_KERNEL (mul_mv_q8_0_f32);
272
+ GGML_METAL_ADD_KERNEL (mul_mv_q2_K_f32);
273
+ GGML_METAL_ADD_KERNEL (mul_mv_q3_K_f32);
274
+ GGML_METAL_ADD_KERNEL (mul_mv_q4_K_f32);
275
+ GGML_METAL_ADD_KERNEL (mul_mv_q5_K_f32);
276
+ GGML_METAL_ADD_KERNEL (mul_mv_q6_K_f32);
277
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
278
+ GGML_METAL_ADD_KERNEL (mul_mm_f32_f32);
279
+ GGML_METAL_ADD_KERNEL (mul_mm_f16_f32);
280
+ GGML_METAL_ADD_KERNEL (mul_mm_q4_0_f32);
281
+ GGML_METAL_ADD_KERNEL (mul_mm_q8_0_f32);
282
+ GGML_METAL_ADD_KERNEL (mul_mm_q4_1_f32);
283
+ GGML_METAL_ADD_KERNEL (mul_mm_q2_K_f32);
284
+ GGML_METAL_ADD_KERNEL (mul_mm_q3_K_f32);
285
+ GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
286
+ GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
287
+ GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
288
+ }
287
289
GGML_METAL_ADD_KERNEL (rope_f32);
288
290
GGML_METAL_ADD_KERNEL (rope_f16);
289
291
GGML_METAL_ADD_KERNEL (alibi_f32);
@@ -296,8 +298,22 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
296
298
#undef GGML_METAL_ADD_KERNEL
297
299
}
298
300
299
- GGML_METAL_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
300
301
#if TARGET_OS_OSX
302
+ // print MTL GPU family:
303
+ GGML_METAL_LOG_INFO (" %s : GPU name: %s \n " , __func__, [[ctx->device name ] UTF8String ]);
304
+ GGML_METAL_LOG_INFO (" %s : GPU arch: %s \n " , __func__, [[ctx->device architecture ].name UTF8String ]);
305
+
306
+ // determine max supported GPU family
307
+ // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
308
+ // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
309
+ for (int i = MTLGPUFamilyApple9 + 10 ; i >= MTLGPUFamilyApple1 ; --i) {
310
+ if ([ctx->device supportsFamily: i]) {
311
+ GGML_METAL_LOG_INFO (" %s : GPU family: MTLGPUFamilyApple%d (%d )\n " , __func__, i - MTLGPUFamilyApple1 + 1 , i);
312
+ break ;
313
+ }
314
+ }
315
+
316
+ GGML_METAL_LOG_INFO (" %s : hasUnifiedMemory = %s \n " , __func__, ctx->device .hasUnifiedMemory ? " true" : " false" );
301
317
GGML_METAL_LOG_INFO (" %s : recommendedMaxWorkingSetSize = %8.2f MB\n " , __func__, ctx->device .recommendedMaxWorkingSetSize / 1024.0 / 1024.0 );
302
318
if (ctx->device .maxTransferRate != 0 ) {
303
319
GGML_METAL_LOG_INFO (" %s : maxTransferRate = %8.2f MB/s\n " , __func__, ctx->device .maxTransferRate / 1024.0 / 1024.0 );
@@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339
355
GGML_METAL_DEL_KERNEL (get_rows_q6_K);
340
356
GGML_METAL_DEL_KERNEL (rms_norm);
341
357
GGML_METAL_DEL_KERNEL (norm);
342
- GGML_METAL_DEL_KERNEL (mul_mat_f32_f32);
343
- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32);
344
- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_1row);
345
- GGML_METAL_DEL_KERNEL (mul_mat_f16_f32_l4);
346
- GGML_METAL_DEL_KERNEL (mul_mat_q4_0_f32);
347
- GGML_METAL_DEL_KERNEL (mul_mat_q4_1_f32);
348
- GGML_METAL_DEL_KERNEL (mul_mat_q8_0_f32);
349
- GGML_METAL_DEL_KERNEL (mul_mat_q2_K_f32);
350
- GGML_METAL_DEL_KERNEL (mul_mat_q3_K_f32);
351
- GGML_METAL_DEL_KERNEL (mul_mat_q4_K_f32);
352
- GGML_METAL_DEL_KERNEL (mul_mat_q5_K_f32);
353
- GGML_METAL_DEL_KERNEL (mul_mat_q6_K_f32);
354
- GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
355
- GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
356
- GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
357
- GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
358
- GGML_METAL_DEL_KERNEL (mul_mm_q4_1_f32);
359
- GGML_METAL_DEL_KERNEL (mul_mm_q2_K_f32);
360
- GGML_METAL_DEL_KERNEL (mul_mm_q3_K_f32);
361
- GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
362
- GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
363
- GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
358
+ GGML_METAL_DEL_KERNEL (mul_mv_f32_f32);
359
+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32);
360
+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32_1row);
361
+ GGML_METAL_DEL_KERNEL (mul_mv_f16_f32_l4);
362
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_0_f32);
363
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_1_f32);
364
+ GGML_METAL_DEL_KERNEL (mul_mv_q8_0_f32);
365
+ GGML_METAL_DEL_KERNEL (mul_mv_q2_K_f32);
366
+ GGML_METAL_DEL_KERNEL (mul_mv_q3_K_f32);
367
+ GGML_METAL_DEL_KERNEL (mul_mv_q4_K_f32);
368
+ GGML_METAL_DEL_KERNEL (mul_mv_q5_K_f32);
369
+ GGML_METAL_DEL_KERNEL (mul_mv_q6_K_f32);
370
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7]) {
371
+ GGML_METAL_DEL_KERNEL (mul_mm_f32_f32);
372
+ GGML_METAL_DEL_KERNEL (mul_mm_f16_f32);
373
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_0_f32);
374
+ GGML_METAL_DEL_KERNEL (mul_mm_q8_0_f32);
375
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_1_f32);
376
+ GGML_METAL_DEL_KERNEL (mul_mm_q2_K_f32);
377
+ GGML_METAL_DEL_KERNEL (mul_mm_q3_K_f32);
378
+ GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
379
+ GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
380
+ GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
381
+ }
364
382
GGML_METAL_DEL_KERNEL (rope_f32);
365
383
GGML_METAL_DEL_KERNEL (rope_f16);
366
384
GGML_METAL_DEL_KERNEL (alibi_f32);
@@ -986,21 +1004,46 @@ void ggml_metal_graph_compute(
986
1004
} break ;
987
1005
case GGML_OP_MUL_MAT:
988
1006
{
989
- // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
990
-
991
1007
GGML_ASSERT (ne00 == ne10);
992
- // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
993
- uint gqa = ne12/ne02;
994
1008
GGML_ASSERT (ne03 == ne13);
995
1009
1010
+ const uint gqa = ne12/ne02;
1011
+
1012
+ // find the break-even point where the matrix-matrix kernel becomes more efficient compared
1013
+ // to the matrix-vector kernel
1014
+ int ne11_mm_min = 1 ;
1015
+
1016
+ #if 0
1017
+ // the numbers below are measured on M2 Ultra for 7B and 13B models
1018
+ // these numbers do not translate to other devices or model sizes
1019
+ // TODO: need to find a better approach
1020
+ if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1021
+ switch (src0t) {
1022
+ case GGML_TYPE_F16: ne11_mm_min = 2; break;
1023
+ case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1024
+ case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1025
+ case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1026
+ case GGML_TYPE_Q4_0:
1027
+ case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1028
+ case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1029
+ case GGML_TYPE_Q5_0: // not tested yet
1030
+ case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1031
+ case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1032
+ case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1033
+ default: ne11_mm_min = 1; break;
1034
+ }
1035
+ }
1036
+ #endif
1037
+
996
1038
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
997
1039
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
998
- if (!ggml_is_transposed (src0) &&
1040
+ if ([ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1041
+ !ggml_is_transposed (src0) &&
999
1042
!ggml_is_transposed (src1) &&
1000
1043
src1t == GGML_TYPE_F32 &&
1001
- [ctx->device supportsFamily: MTLGPUFamilyApple7] &&
1002
- ne00% 32 == 0 &&
1003
- ne11 > 2 ) {
1044
+ ne00 % 32 == 0 &&
1045
+ ne11 > ne11_mm_min) {
1046
+ // printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1004
1047
switch (src0->type ) {
1005
1048
case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f32_f32]; break ;
1006
1049
case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_mul_mm_f16_f32]; break ;
@@ -1029,30 +1072,31 @@ void ggml_metal_graph_compute(
1029
1072
[encoder setBytes: &ne1 length: sizeof (ne1) atIndex: 12 ];
1030
1073
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 13 ];
1031
1074
[encoder setThreadgroupMemoryLength: 8192 atIndex: 0 ];
1032
- [encoder dispatchThreadgroups: MTLSizeMake ( (ne11+ 31 )/32 , (ne01+ 63 ) / 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1075
+ [encoder dispatchThreadgroups: MTLSizeMake ( (ne11 + 31 )/32 , (ne01 + 63 )/ 64 , ne12) threadsPerThreadgroup: MTLSizeMake (128 , 1 , 1 )];
1033
1076
} else {
1034
1077
int nth0 = 32 ;
1035
1078
int nth1 = 1 ;
1036
1079
int nrows = 1 ;
1080
+ // printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
1037
1081
1038
1082
// use custom matrix x vector kernel
1039
1083
switch (src0t) {
1040
1084
case GGML_TYPE_F32:
1041
1085
{
1042
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f32_f32 ];
1086
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f32_f32 ];
1043
1087
nrows = 4 ;
1044
1088
} break ;
1045
1089
case GGML_TYPE_F16:
1046
1090
{
1047
1091
nth0 = 32 ;
1048
1092
nth1 = 1 ;
1049
1093
if (ne11 * ne12 < 4 ) {
1050
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_1row ];
1094
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32_1row ];
1051
1095
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0 ) {
1052
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32_l4 ];
1096
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32_l4 ];
1053
1097
nrows = ne11;
1054
1098
} else {
1055
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_f16_f32 ];
1099
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_f16_f32 ];
1056
1100
nrows = 4 ;
1057
1101
}
1058
1102
} break ;
@@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute(
1063
1107
1064
1108
nth0 = 8 ;
1065
1109
nth1 = 8 ;
1066
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_0_f32 ];
1110
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_0_f32 ];
1067
1111
} break ;
1068
1112
case GGML_TYPE_Q4_1:
1069
1113
{
@@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute(
1072
1116
1073
1117
nth0 = 8 ;
1074
1118
nth1 = 8 ;
1075
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_1_f32 ];
1119
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_1_f32 ];
1076
1120
} break ;
1077
1121
case GGML_TYPE_Q8_0:
1078
1122
{
@@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute(
1081
1125
1082
1126
nth0 = 8 ;
1083
1127
nth1 = 8 ;
1084
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q8_0_f32 ];
1128
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q8_0_f32 ];
1085
1129
} break ;
1086
1130
case GGML_TYPE_Q2_K:
1087
1131
{
@@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute(
1090
1134
1091
1135
nth0 = 2 ;
1092
1136
nth1 = 32 ;
1093
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q2_K_f32 ];
1137
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q2_K_f32 ];
1094
1138
} break ;
1095
1139
case GGML_TYPE_Q3_K:
1096
1140
{
@@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute(
1099
1143
1100
1144
nth0 = 2 ;
1101
1145
nth1 = 32 ;
1102
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q3_K_f32 ];
1146
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q3_K_f32 ];
1103
1147
} break ;
1104
1148
case GGML_TYPE_Q4_K:
1105
1149
{
@@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute(
1108
1152
1109
1153
nth0 = 4 ; // 1;
1110
1154
nth1 = 8 ; // 32;
1111
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q4_K_f32 ];
1155
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q4_K_f32 ];
1112
1156
} break ;
1113
1157
case GGML_TYPE_Q5_K:
1114
1158
{
@@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute(
1117
1161
1118
1162
nth0 = 2 ;
1119
1163
nth1 = 32 ;
1120
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q5_K_f32 ];
1164
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q5_K_f32 ];
1121
1165
} break ;
1122
1166
case GGML_TYPE_Q6_K:
1123
1167
{
@@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute(
1126
1170
1127
1171
nth0 = 2 ;
1128
1172
nth1 = 32 ;
1129
- [encoder setComputePipelineState: ctx->pipeline_mul_mat_q6_K_f32 ];
1173
+ [encoder setComputePipelineState: ctx->pipeline_mul_mv_q6_K_f32 ];
1130
1174
} break ;
1131
1175
default :
1132
1176
{
@@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute(
1155
1199
[encoder setBytes: &gqa length: sizeof (gqa) atIndex: 17 ];
1156
1200
1157
1201
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1158
- src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1202
+ src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
1159
1203
[encoder dispatchThreadgroups: MTLSizeMake ((ne01 + 7 )/8 , ne11, ne12) threadsPerThreadgroup: MTLSizeMake (nth0, nth1, 1 )];
1160
1204
}
1161
1205
else if (src0t == GGML_TYPE_Q4_K) {
0 commit comments