8000 vulkan: support CPY from any type to itself (#13695) · ggml-org/llama.cpp@1dcd019 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1dcd019

Browse files
authored
vulkan: support CPY from any type to itself (#13695)
Reuse the f16/f32 copy shaders, and just scale the number of elements according to the type size.
1 parent c10ed6c commit 1dcd019

File tree

1 file changed

+44
-2
lines changed

1 file changed

+44
-2
lines changed

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4676,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
46764676
}
46774677
}
46784678

4679+
if (src->type == to) {
4680+
// Copy two or four bytes at a time, depending on block size.
4681+
// For quantized types, we scale by block size/type size. But
4682+
// this path is also used for bf16->bf16 for example, where the
4683+
// type size must be exactly 2 or 4.
4684+
GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
4685+
if ((ggml_type_size(src->type) % 4) == 0) {
4686+
return ctx->device->pipeline_contig_cpy_f32_f32;
4687+
} else {
4688+
return ctx->device->pipeline_contig_cpy_f16_f16;
4689+
}
4690+
}
4691+
46794692
std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
46804693
GGML_ABORT("fatal error");
46814694
}
@@ -6737,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
67376750
case GGML_OP_UNARY:
67386751
case GGML_OP_CONV_2D_DW:
67396752
{
6740-
const uint32_t ne = ggml_nelements(dst);
6753+
uint32_t ne = ggml_nelements(dst);
6754+
if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
6755+
// Convert from number of logical elements to 2- or 4-byte units.
6756+
ne /= ggml_blck_size(src0->type);
6757+
if ((ggml_type_size(src0->type) % 4) == 0) {
6758+
ne *= ggml_type_size(src0->type) / 4;
6759+
} else {
6760+
ne *= ggml_type_size(src0->type) / 2;
6761+
}
6762+
}
67416763
if (ne > 262144) {
67426764
elements = { 512, 512, CEIL_DIV(ne, 262144) };
67436765
} else if (ne > 512) {
@@ -7287,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
72877309
const uint32_t src0_type_size = ggml_type_size(src0->type);
72887310
const uint32_t dst_type_size = ggml_type_size(dst->type);
72897311

7312+
uint32_t ne = (uint32_t)ggml_nelements(src0);
7313+
if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
7314+
// Convert from number of logical elements to 2- or 4-byte units.
7315+
ne /= ggml_blck_size(src0->type);
7316+
if ((ggml_type_size(src0->type) % 4) == 0) {
7317+
ne *= ggml_type_size(src0->type) / 4;
7318+
} else {
7319+
ne *= ggml_type_size(src0->type) / 2;
7320+
}
7321+
}
7322+
72907323
ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
7291-
(uint32_t)ggml_nelements(src0),
7324+
ne,
72927325
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
72937326
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
72947327
0,
@@ -9872,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
98729905
if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
98739906
return true;
98749907
}
9908+
9909+
// We can handle copying from a type to the same type if it's
9910+
// contiguous (memcpy). We use f16 or f32 shaders to do the copy,
9911+
// so the type/block size must be a multiple of 4.
9912+
if (src0_type == src1_type &&
9913+
ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
9914+
(ggml_type_size(src0_type) % 2) == 0) {
9915+
return true;
9916+
}
98759917
return false;
98769918
} break;
98779919
case GGML_OP_REPEAT:

0 commit comments

Comments
 (0)
0