8000 Revert "ggml : add SSM Metal kernels (#8546)" · Nexesenex/croco.cpp@5078aa8 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5078aa8

Browse files
committed
Revert "ggml : add SSM Metal kernels (ggml-org#8546)"
This reverts commit fc18425.
1 parent e22739c commit 5078aa8

File tree

4 files changed

+2
-303
lines changed

4 files changed

+2
-303
lines changed

ggml/src/ggml-metal.m

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@
8282
GGML_METAL_KERNEL_TYPE_RMS_NORM,
8383
GGML_METAL_KERNEL_TYPE_GROUP_NORM,
8484
GGML_METAL_KERNEL_TYPE_NORM,
85-
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
86-
GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,
8785
GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
8886
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
8987
GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
@@ -544,8 +542,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
544542
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
545543
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
546544
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
547-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true);
548-
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true);
549545
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
550546
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
551547
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
@@ -807,9 +803,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
807803
return false;
808804
}
809805
return ctx->support_simdgroup_mm; // TO 8000 DO: over-restricted for vec-kernels
810-
case GGML_OP_SSM_CONV:
811-
case GGML_OP_SSM_SCAN:
812-
return true;
813806
case GGML_OP_MUL_MAT:
814807
case GGML_OP_MUL_MAT_ID:
815808
return ctx->support_simdgroup_reduction &&
@@ -1545,121 +1538,6 @@ static enum ggml_status ggml_metal_graph_compute(
15451538
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
15461539
}
15471540
} break;
1548-
case GGML_OP_SSM_CONV:
1549-
{
1550-
GGML_ASSERT(src0t == GGML_TYPE_F32);
1551-
GGML_ASSERT(src1t == GGML_TYPE_F32);
1552-
1553-
GGML_ASSERT(ggml_is_contiguous(src0));
1554-
GGML_ASSERT(ggml_is_contiguous(src1));
1555-
1556-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
1557-
1558-
[encoder setComputePipelineState:pipeline];
1559-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1560-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1561-
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
1562-
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
1563-
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
1564-
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
1565-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
1566-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
1567-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
1568-
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
1569-
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
1570-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
1571-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
1572-
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
1573-
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
1574-
[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15];
1575-
[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16];
1576-
[encoder setBytes:&nb1 length:sizeof(nb1) a 9E88 tIndex:17];
1577-
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18];
1578-
1579-
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1580-
} break;
1581-
case GGML_OP_SSM_SCAN:
1582-
{
1583-
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
1584-
struct ggml_tensor * src4 = gf->nodes[i]->src[4];
1585-
struct ggml_tensor * src5 = gf->nodes[i]->src[5];
1586-
1587-
GGML_ASSERT(src3);
1588-
GGML_ASSERT(src4);
1589-
GGML_ASSERT(src5);
1590-
1591-
size_t offs_src3 = 0;
1592-
size_t offs_src4 = 0;
1593-
size_t offs_src5 = 0;
1594-
1595-
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
1596-
id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil;
1597-
id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil;
1598-
1599-
const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30);
1600-
const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31);
1601-
1602-
const uint64_t nb30 = src3->nb[0];
1603-
const uint64_t nb31 = src3->nb[1];
1604-
1605-
const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40);
1606-
const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41);
1607-
const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42);
1608-
1609-
const uint64_t nb40 = src4->nb[0];
1610-
const uint64_t nb41 = src4->nb[1];
1611-
const uint64_t nb42 = src4->nb[2];
1612-
1613-
const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50);
1614-
const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51);
1615-
const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52);
1616-
1617-
const uint64_t nb50 = src5->nb[0];
1618-
const uint64_t nb51 = src5->nb[1];
1619-
const uint64_t nb52 = src5->nb[2];
1620-
1621-
const int64_t d_state = ne00;
1622-
const int64_t d_inner = ne01;
1623-
const int64_t n_seq_tokens = ne11;
1624-
const int64_t n_seqs = ne02;
1625-
1626-
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline;
1627-
1628-
[encoder setComputePipelineState:pipeline];
1629-
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
1630-
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
1631-
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
1632-
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
1633-
[encoder setBuffer:id_src4 offset:offs_src4 atIndex:4];
1634-
[encoder setBuffer:id_src5 offset:offs_src5 atIndex:5];
1635-
[encoder setBuffer:id_dst offset:offs_dst atIndex:6];
1636-
1637-
[encoder setBytes:&d_state length:sizeof(d_state) atIndex:7];
1638-
[encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8];
1639-
[encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9];
1640-
[encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10];
1641-
1642-
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11];
1643-
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12];
1644-
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13];
1645-
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
1646-
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
1647-
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
1648-
[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17];
1649-
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18];
1650-
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19];
1651-
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20];
1652-
[encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21];
1653-
[encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22];
1654-
[encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23];
1655-
[encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24];
1656-
[encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25];
1657-
[encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26];
1658-
[encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27];
1659-
[encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28];
1660-
1661-
[encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
1662-
} break;
16631541
case GGML_OP_MUL_MAT:
16641542
{
16651543
GGML_ASSERT(ne00 == ne10);

ggml/src/ggml-metal.metal

Lines changed: 0 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -667,127 +667,6 @@ kernel void kernel_diag_mask_inf_8(
667667
}
668668
}
669669

670-
// ref: ggml.c:ggml_compute_forward_ssm_conv_f32
671-
// TODO: optimize
672-
kernel void kernel_ssm_conv_f32(
673-
device const void * src0,
674-
device const void * src1,
675-
device float * dst,
676-
constant int64_t & ne00,
677-
constant int64_t & ne01,
678-
constant int64_t & ne02,
679-
constant uint64_t & nb00,
680-
constant uint64_t & nb01,
681-
constant uint64_t & nb02,
682-
constant int64_t & ne10,
683-
constant int64_t & ne11,
684-
constant uint64_t & nb10,
685-
constant uint64_t & nb11,
686-
constant int64_t & ne0,
687-
constant int64_t & ne1,
688-
constant int64_t & ne2,
689-
constant uint64_t & nb0,
690-
constant uint64_t & nb1,
691-
constant uint64_t & nb2,
692-
uint3 tgpig[[threadgroup_position_in_grid]],
693-
uint3 tpitg[[thread_position_in_threadgroup]],
694-
uint3 ntg[[threads_per_threadgroup]]) {
695-
const int64_t ir = tgpig.x;
696-
const int64_t i2 = tgpig.y;
697-
const int64_t i3 = tgpig.z;
698-
699-
const int64_t nc = ne10;
700-
const int64_t ncs = ne00;
701-
const int64_t nr = ne01;
702-
const int64_t n_t = ne1;
703-
const int64_t n_s = ne2;
704-
705-
device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
706-
device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
707-
device float * x = (device float *) ((device char *) dst + ir*nb0 + i2*nb1 + i3*nb2);
708-
709-
float sumf = 0.0f;
710-
711-
for (int64_t i0 = 0; i0 < nc; ++i0) {
712-
sumf += s[i0] * c[i0];
713-
}
714-
715-
x[0] = sumf;
716-
}
717-
718-
// ref: ggml.c:ggml_compute_forward_ssm_scan_f32
719-
// TODO: optimize
720-
kernel void kernel_ssm_scan_f32(
721-
device const void * src0,
722-
device const void * src1,
723-
device const void * src2,
724-
device const void * src3,
725-
device const void * src4,
726-
device const void * src5,
727-
device float * dst,
728-
constant int64_t & d_state,
729-
constant int64_t & d_inner,
730-
constant int64_t & n_seq_tokens,
731-
constant int64_t & n_seqs,
732-
constant uint64_t & nb00,
733-
constant uint64_t & nb01,
734-
constant uint64_t & nb02,
735-
constant uint64_t & nb10,
736-
constant uint64_t & nb11,
737-
constant uint64_t & nb12,
738-
constant uint64_t & nb13,
739-
constant uint64_t & nb20,
740-
constant uint64_t & nb21,
741-
constant uint64_t & nb22,
742-
constant uint64_t & nb30,
743-
constant uint64_t & nb31,
744-
constant uint64_t & nb40,
745-
constant uint64_t & nb41,
746-
constant uint64_t & nb42,
747-
constant uint64_t & nb50,
748-
constant uint64_t & nb51,
749-
constant uint64_t & nb52,
750-
uint3 tgpig[[threadgroup_position_in_grid]],
751-
uint3 tpitg[[thread_position_in_threadgroup]],
752-
uint3 ntg[[threads_per_threadgroup]]) {
753-
const int64_t ir = tgpig.x;
754-
const int64_t i3 = tgpig.y;
755-
756-
const int64_t nc = d_state;
757-
const int64_t nr = d_inner;
758-
const int64_t n_t = n_seq_tokens;
759-
const int64_t n_s = n_seqs;
760-
761-
for (int64_t i2 = 0; i2 < n_t; ++i2) {
762-
device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
763-
device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12);
764-
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22);
765-
device const float * A = (device const float *) ((device const char *) src3 + ir*nb31);
766-
device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42);
767-
device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52);
768-
device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides
769-
device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13);
770-
771-
if (i2 > 0) {
772-
s0 = s;
773-
}
774-
775-
// i1 == 0
776-
float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
777-
float x_dt = x[0] * dt_soft_plus;
778-
float sumf = 0.0f;
779-
780-
for (int64_t i0 = 0; i0 < nc; ++i0) {
781-
int64_t i = i0;
782-
float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt);
783-
sumf += state * C[i0];
784-
s[i] = state;
785-
}
786-
787-
y[0] = sumf;
788-
}
789-
}
790-
791670
kernel void kernel_norm(
792671
device const void * src0,
793672
device float * dst,

ggml/src/ggml.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15898,8 +15898,8 @@ static void ggml_compute_forward_ssm_scan_f32(
1589815898
const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner}
1589915899
const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s}
1590015900
const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s}
15901-
float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15902-
float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
15901+
float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s}
15902+
float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s}
1590315903

1590415904
// use the output as the source for the next token-wise iterations
1590515905
if (i2 > 0) { s0 = s; }

tests/test-backend-ops.cpp

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -949,58 +949,6 @@ struct test_rms_norm : public test_case {
949949
}
950950
};
951951

952-
// GGML_OP_SSM_CONV
953-
struct test_ssm_conv : public test_case {
954-
const ggml_type type;
955-
const std::array<int64_t, 4> ne_a;
956-
const std::array<int64_t, 4> ne_b;
957-
958-
std::string vars() override {
959-
return VARS_TO_STR3(type, ne_a, ne_b);
960-
}
961-
962-
test_ssm_conv(ggml_type type = GGML_TYPE_F32,
963-
std::array<int64_t, 4> ne_a = {10, 10, 10, 1},
964-
std::array<int64_t, 4> ne_b = {3, 3, 1, 1})
965-
: type(type), ne_a(ne_a), ne_b(ne_b) {}
966-
967-
ggml_tensor * build_graph(ggml_context * ctx) override {
968-
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
969-
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
970-
ggml_tensor * out = ggml_ssm_conv(ctx, a, b);
971-
return out;
972-
}
973-
};
974-
975-
// GGML_OP_SSM_SCAN
976-
struct test_ssm_scan : public test_case {
977-
const ggml_type type;
978-
979-
const int64_t d_state;
980-
const int64_t d_inner;
981-
const int64_t n_seq_tokens;
982-
const int64_t n_seqs;
983-
984-
std::string vars() override {
985-
return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs);
986-
}
987-
988-
test_ssm_scan(ggml_type type = GGML_TYPE_F32,
989-
int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32)
990-
: type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {}
991-
992-
ggml_tensor * build_graph(ggml_context * ctx) override {
993-
ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, n_seqs, 1 }.data());
994-
ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
995-
ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_inner, n_seq_tokens, n_seqs, 1 }.data());
996-
ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, d_inner, 1 , 1 }.data());
997-
ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
998-
ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector<int64_t>{ d_state, n_seq_tokens, n_seqs, 1 }.data());
999-
ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C);
1000-
return out;
1001-
}
1002-
};
1003-
1004952
// GGML_OP_MUL_MAT
1005953
struct test_mul_mat : public test_case {
1006954
const ggml_type type_a;
@@ -2292,12 +2240,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
22922240
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
22932241
}
22942242

2295-
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 1, 1}, {4, 1536, 1, 1}));
2296-
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1}));
2297-
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1}));
2298-
2299-
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4));
2300-
23012243
#if 1
23022244
for (ggml_type type_a : base_types) {
23032245
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {

0 commit comments

Comments
 (0)
0