8000 metal : support MTLGPUFamily < Apple7, formatting, style (#3524) · CodeLinaro/llama.cpp@b0ec521 · GitHub
[go: up one dir, main page]

Skip to content

Commit b0ec521

Browse files
authored
metal : support MTLGPUFamily < Apple7, formatting, style (ggml-org#3524)
* metal : improve decoding speed for batches of 2-16 * metal : rename kernels mul_mat_ to mul_mv_ * metal : indentations * minor * metal : print more GPU info + disable mul_mm for MTLGPUFamiliy < Apple7
1 parent 63d3b06 commit b0ec521

File tree

2 files changed

+176
-118
lines changed

2 files changed

+176
-118
lines changed

ggml-metal.m

Lines changed: 123 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,18 @@
8181
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
8282
GGML_METAL_DECL_KERNEL(rms_norm);
8383
GGML_METAL_DECL_KERNEL(norm);
84-
GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
85-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
86-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
87-
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
88-
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
89-
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
90-
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
91-
GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
92-
GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
93-
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
94-
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
95-
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
84+
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
85+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
86+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
87+
GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
88+
GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
89+
GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
90+
GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
91+
GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
92+
GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
93+
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
94+
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
95+
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
9696
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
9797
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
9898
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -262,28 +262,30 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
262262
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
263263
GGML_METAL_ADD_KERNEL(rms_norm);
264264
GGML_METAL_ADD_KERNEL(norm);
265-
GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
266-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
267-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
268-
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
269-
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
270-
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
271-
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
272-
GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
273-
GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
274-
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
275-
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
276-
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
277-
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
278-
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
279-
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
280-
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
281-
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
282-
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
283-
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
284-
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
285-
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
286-
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
265+
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
266+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
267+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
268+
GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
269+
GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
270+
GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
271+
GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
272+
GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
273+
GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
274+
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
275+
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
276+
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
277+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
278+
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
279+
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
280+
GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
281+
GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
282+
GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
283+
GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
284+
GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
285+
GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
286+
GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
287+
GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
288+
}
287289
GGML_METAL_ADD_KERNEL(rope_f32);
288290
GGML_METAL_ADD_KERNEL(rope_f16);
289291
GGML_METAL_ADD_KERNEL(alibi_f32);
@@ -296,8 +298,22 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
296298
#undef GGML_METAL_ADD_KERNEL
297299
}
298300

299-
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
300301
#if TARGET_OS_OSX
302+
// print MTL GPU family:
303+
GGML_METAL_LOG_INFO("%s: GPU name: %s\n", __func__, [[ctx->device name] UTF8String]);
304+
GGML_METAL_LOG_INFO("%s: GPU arch: %s\n", __func__, [[ctx->device architecture].name UTF8String]);
305+
306+
// determine max supported GPU family
307+
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
308+
// https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
309+
for (int i = MTLGPUFamilyApple9 + 10; i >= MTLGPUFamilyApple1; --i) {
310+
if ([ctx->device supportsFamily:i]) {
311+
GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
312+
break;
313+
}
314+
}
315+
316+
GGML_METAL_LOG_INFO("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
301317
GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
302318
if (ctx->device.maxTransferRate != 0) {
303319
GGML_METAL_LOG_INFO("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -339,28 +355,30 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
339355
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
340356
GGML_METAL_DEL_KERNEL(rms_norm);
341357
GGML_METAL_DEL_KERNEL(norm);
342-
GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
343-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
344-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
345-
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
346-
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
347-
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
348-
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
349-
GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
350-
GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
351-
GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
352-
GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
353-
GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
354-
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
355-
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
356-
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
357-
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
358-
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
359-
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
360-
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
361-
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
362-
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
363-
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
358+
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
359+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
360+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
361+
GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
362+
GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
363+
GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
364+
GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
365+
GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
366+
GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
367+
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
368+
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
369+
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
370+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
371+
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
372+
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
373+
GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
374+
GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
375+
GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
376+
GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
377+
GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
378+
GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
379+
GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
380+
GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
381+
}
364382
GGML_METAL_DEL_KERNEL(rope_f32);
365383
GGML_METAL_DEL_KERNEL(rope_f16);
366384
GGML_METAL_DEL_KERNEL(alibi_f32);
@@ -986,21 +1004,46 @@ void ggml_metal_graph_compute(
9861004
} break;
9871005
case GGML_OP_MUL_MAT:
9881006
{
989-
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
990-
9911007
GGML_ASSERT(ne00 == ne10);
992-
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
993-
uint gqa = ne12/ne02;
9941008
GGML_ASSERT(ne03 == ne13);
9951009

1010+
const uint gqa = ne12/ne02;
1011+
1012+
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
1013+
// to the matrix-vector kernel
1014+
int ne11_mm_min = 1;
1015+
1016+
#if 0
1017+
// the numbers below are measured on M2 Ultra for 7B and 13B models
1018+
// these numbers do not translate to other devices or model sizes
1019+
// TODO: need to find a better approach
1020+
if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
1021+
switch (src0t) {
1022+
case GGML_TYPE_F16: ne11_mm_min = 2; break;
1023+
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
1024+
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
1025+
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
1026+
case GGML_TYPE_Q4_0:
1027+
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
1028+
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
1029+
case GGML_TYPE_Q5_0: // not tested yet
1030+
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
1031+
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
1032+
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
1033+
default: ne11_mm_min = 1; break;
1034+
}
1035+
}
1036+
#endif
1037+
9961038
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
9971039
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
998-
if (!ggml_is_transposed(src0) &&
1040+
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1041+
!ggml_is_transposed(src0) &&
9991042
!ggml_is_transposed(src1) &&
10001043
src1t == GGML_TYPE_F32 &&
1001-
[ctx->device supportsFamily:MTLGPUFamilyApple7] &&
1002-
ne00%32 == 0 &&
1003-
ne11 > 2) {
1044+
ne00 % 32 == 0 &&
1045+
ne11 > ne11_mm_min) {
1046+
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10041047
switch (src0->type) {
10051048
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32]; break;
10061049
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
@@ -1029,30 +1072,31 @@ void ggml_metal_graph_compute(
10291072
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
10301073
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:13];
10311074
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
1032-
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
1075+
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
10331076
} else {
10341077
int nth0 = 32;
10351078
int nth1 = 1;
10361079
int nrows = 1;
1080+
//printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
10371081

10381082
// use custom matrix x vector kernel
10391083
switch (src0t) {
10401084
case GGML_TYPE_F32:
10411085
{
1042-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
1086+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
10431087
nrows = 4;
10441088
} break;
10451089
case GGML_TYPE_F16:
10461090
{
10471091
nth0 = 32;
10481092
nth1 = 1;
10491093
if (ne11 * ne12 < 4) {
1050-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
1094+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
10511095
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
1052-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
1096+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
10531097
nrows = ne11;
10541098
} else {
1055-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
1099+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
10561100
nrows = 4;
10571101
}
10581102
} break;
@@ -1063,7 +1107,7 @@ void ggml_metal_graph_compute(
10631107

10641108
nth0 = 8;
10651109
nth1 = 8;
1066-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
1110+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
10671111
} break;
10681112
case GGML_TYPE_Q4_1:
10691113
{
@@ -1072,7 +1116,7 @@ void ggml_metal_graph_compute(
10721116

10731117
nth0 = 8;
10741118
nth1 = 8;
1075-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
1119+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
10761120
} break;
10771121
case GGML_TYPE_Q8_0:
10781122
{
@@ -1081,7 +1125,7 @@ void ggml_metal_graph_compute(
10811125

10821126
nth0 = 8;
10831127
nth1 = 8;
1084-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
1128+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
10851129
} break;
10861130
case GGML_TYPE_Q2_K:
10871131
{
@@ -1090,7 +1134,7 @@ void ggml_metal_graph_compute(
10901134

10911135
nth0 = 2;
10921136
nth1 = 32;
1093-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
1137+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
10941138
} break;
10951139
case GGML_TYPE_Q3_K:
10961140
{
@@ -1099,7 +1143,7 @@ void ggml_metal_graph_compute(
10991143

11001144
nth0 = 2;
11011145
nth1 = 32;
1102-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
1146+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
11031147
} break;
11041148
case GGML_TYPE_Q4_K:
11051149
{
@@ -1108,7 +1152,7 @@ void ggml_metal_graph_compute(
11081152

11091153
nth0 = 4; //1;
11101154
nth1 = 8; //32;
1111-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
1155+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
11121156
} break;
11131157
case GGML_TYPE_Q5_K:
11141158
{
@@ -1117,7 +1161,7 @@ void ggml_metal_graph_compute(
11171161

11181162
nth0 = 2;
11191163
nth1 = 32;
1120-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
1164+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
11211165
} break;
11221166
case GGML_TYPE_Q6_K:
11231167
{
@@ -1126,7 +1170,7 @@ void ggml_metal_graph_compute(
11261170

11271171
nth0 = 2;
11281172
nth1 = 32;
1129-
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
1173+
[encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
11301174
} break;
11311175
default:
11321176
{
@@ -1155,7 +1199,7 @@ void ggml_metal_graph_compute(
11551199
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
11561200

11571201
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
1158-
src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
1202+
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
11591203
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
11601204
}
11611205
else if (src0t == GGML_TYPE_Q4_K) {

0 commit comments

Comments
 (0)
0