8000 fix naive implementation · Pints-AI/llama.cpp@0afe47f · GitHub
[go: up one dir, main page]

Skip to content

Commit 0afe47f

Browse files
committed
fix naive implementation
1 parent b1479df commit 0afe47f

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

tests/test-flash-attention.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
207207
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
208208
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
209209
kq = ggml_soft_max(ctx0, kq);
210-
kq = ggml_mul_mat(ctx0, model.v, kq);
210+
kq = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, model.v)), kq);
211+
kq = ggml_permute (ctx0, kq, 0, 2, 1, 3);
212+
//kq = ggml_cont_2d (ctx0, kq, model.q->ne[0] * model.q->ne[2], model.q->ne[1]);
211213
ggml_build_forward_expand(gf, kq);
212214
}
213215

0 commit comments

Comments
 (0)
0