|
82 | 82 | GGML_METAL_KERNEL_TYPE_RMS_NORM,
|
83 | 83 | GGML_METAL_KERNEL_TYPE_GROUP_NORM,
|
84 | 84 | GGML_METAL_KERNEL_TYPE_NORM,
|
85 |
| - GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, |
86 |
| - GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, |
87 | 85 | GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,
|
88 | 86 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,
|
89 | 87 | GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,
|
@@ -544,8 +542,6 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){
|
544 | 542 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
545 | 543 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
546 | 544 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
547 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); |
548 |
| - GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); |
549 | 545 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
550 | 546 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
551 | 547 | GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
@@ -807,9 +803,6 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
807 | 803 | return false;
|
808 | 804 | }
|
809 | 805 | return ctx->support_simdgroup_mm; // TO
8000
DO: over-restricted for vec-kernels
|
810 |
| - case GGML_OP_SSM_CONV: |
811 |
| - case GGML_OP_SSM_SCAN: |
812 |
| - return true; |
813 | 806 | case GGML_OP_MUL_MAT:
|
814 | 807 | case GGML_OP_MUL_MAT_ID:
|
815 | 808 | return ctx->support_simdgroup_reduction &&
|
@@ -1545,121 +1538,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
1545 | 1538 | [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
1546 | 1539 | }
|
1547 | 1540 | } break;
|
1548 |
| - case GGML_OP_SSM_CONV: |
1549 |
| - { |
1550 |
| - GGML_ASSERT(src0t == GGML_TYPE_F32); |
1551 |
| - GGML_ASSERT(src1t == GGML_TYPE_F32); |
1552 |
| - |
1553 |
| - GGML_ASSERT(ggml_is_contiguous(src0)); |
1554 |
| - GGML_ASSERT(ggml_is_contiguous(src1)); |
1555 |
| - |
1556 |
| - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; |
1557 |
| - |
1558 |
| - [encoder setComputePipelineState:pipeline]; |
1559 |
| - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
1560 |
| - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
1561 |
| - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; |
1562 |
| - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; |
1563 |
| - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; |
1564 |
| - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; |
1565 |
| - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; |
1566 |
| - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; |
1567 |
| - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; |
1568 |
| - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9]; |
1569 |
| - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10]; |
1570 |
| - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11]; |
1571 |
| - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12]; |
1572 |
| - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; |
1573 |
| - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; |
1574 |
| - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:15]; |
1575 |
| - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:16]; |
1576 |
| - [encoder setBytes:&nb1 length:sizeof(nb1) a
9E88
tIndex:17]; |
1577 |
| - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:18]; |
1578 |
| - |
1579 |
| - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne1, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
1580 |
| - } break; |
1581 |
| - case GGML_OP_SSM_SCAN: |
1582 |
| - { |
1583 |
| - struct ggml_tensor * src3 = gf->nodes[i]->src[3]; |
1584 |
| - struct ggml_tensor * src4 = gf->nodes[i]->src[4]; |
1585 |
| - struct ggml_tensor * src5 = gf->nodes[i]->src[5]; |
1586 |
| - |
1587 |
| - GGML_ASSERT(src3); |
1588 |
| - GGML_ASSERT(src4); |
1589 |
| - GGML_ASSERT(src5); |
1590 |
| - |
1591 |
| - size_t offs_src3 = 0; |
1592 |
| - size_t offs_src4 = 0; |
1593 |
| - size_t offs_src5 = 0; |
1594 |
| - |
1595 |
| - id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; |
1596 |
| - id<MTLBuffer> id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; |
1597 |
| - id<MTLBuffer> id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; |
1598 |
| - |
1599 |
| - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); |
1600 |
| - const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); |
1601 |
| - |
1602 |
| - const uint64_t nb30 = src3->nb[0]; |
1603 |
| - const uint64_t nb31 = src3->nb[1]; |
1604 |
| - |
1605 |
| - const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); |
1606 |
| - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); |
1607 |
| - const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); |
1608 |
| - |
1609 |
| - const uint64_t nb40 = src4->nb[0]; |
1610 |
| - const uint64_t nb41 = src4->nb[1]; |
1611 |
| - const uint64_t nb42 = src4->nb[2]; |
1612 |
| - |
1613 |
| - const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); |
1614 |
| - const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); |
1615 |
| - const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); |
1616 |
| - |
1617 |
| - const uint64_t nb50 = src5->nb[0]; |
1618 |
| - const uint64_t nb51 = src5->nb[1]; |
1619 |
| - const uint64_t nb52 = src5->nb[2]; |
1620 |
| - |
1621 |
| - const int64_t d_state = ne00; |
1622 |
| - const int64_t d_inner = ne01; |
1623 |
| - const int64_t n_seq_tokens = ne11; |
1624 |
| - const int64_t n_seqs = ne02; |
1625 |
| - |
1626 |
| - id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; |
1627 |
| - |
1628 |
| - [encoder setComputePipelineState:pipeline]; |
1629 |
| - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; |
1630 |
| - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; |
1631 |
| - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; |
1632 |
| - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; |
1633 |
| - [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; |
1634 |
| - [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; |
1635 |
| - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; |
1636 |
| - |
1637 |
| - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; |
1638 |
| - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; |
1639 |
| - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; |
1640 |
| - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; |
1641 |
| - |
1642 |
| - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; |
1643 |
| - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; |
1644 |
| - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; |
1645 |
| - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; |
1646 |
| - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; |
1647 |
| - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; |
1648 |
| - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; |
1649 |
| - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; |
1650 |
| - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; |
1651 |
| - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; |
1652 |
| - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; |
1653 |
| - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; |
1654 |
| - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; |
1655 |
| - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; |
1656 |
| - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; |
1657 |
| - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; |
1658 |
| - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; |
1659 |
| - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; |
1660 |
| - |
1661 |
| - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; |
1662 |
| - } break; |
1663 | 1541 | case GGML_OP_MUL_MAT:
|
1664 | 1542 | {
|
1665 | 1543 | GGML_ASSERT(ne00 == ne10);
|
|
0 commit comments