8000 SYCL: add gelu_erf kernel (#13749) · ggml-org/llama.cpp@f3101a8 · GitHub
[go: up one dir, main page]

Skip to content

Commit f3101a8

Browse files
qnixsynapseAD2605
andauthored
SYCL: add gelu_erf kernel (#13749)
* SYCL: add gelu_erf kernel * refactor code Co-authored-by: Atharva Dubey <atharva.dubey@codeplay.com> * Use scope_op_debug_print --------- Co-authored-by: Atharva Dubey <atharva.dubey@codeplay.com>
1 parent 1c49c70 commit f3101a8

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ static void gelu_quick(const T *x, T *dst, int k,
8484
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
8585
}
8686

87+
template<typename T>
88+
static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
89+
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
90+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
91+
auto x_i = x[i];
92+
dst[i] = static_cast<T>(0.5f) * x_i * (static_cast<T>(1.0f) + sycl::erf(x_i * SQRT_2_INV));
93+
}
94+
}
95+
8796
template<typename T>
8897
static void tanh(const T *x, T *dst, int k,
8998
const sycl::nd_item<3> &item_ct1) {
@@ -400,6 +409,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
400409
});
401410
}
402411

412+
413+
template<typename T>
414+
static void gelu_erf_sycl(const T *x, T *dst, const int k,
415+
queue_ptr stream) {
416+
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
417+
stream->parallel_for(
418+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
419+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
420+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
421+
[=](sycl::nd_item<3> item_ct1) {
422+
gelu_erf(x, dst, k, item_ct1);
423+
});
424+
}
425+
403426
template<typename T>
404427
static void tanh_sycl(const T *x, T *dst, const int k,
405428
queue_ptr stream) {
@@ -816,6 +839,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
816839
}
817840
}
818841

842+
inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
843+
#if defined (GGML_SYCL_F16)
844+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
845+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
846+
#else
847+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
848+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
849+
#endif
850+
GGML_ASSERT(dst->src[0]->type == dst->type);
851+
dpct::queue_ptr main_stream = ctx.stream();
852+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
853+
switch (dst->type) {
854+
#if defined (GGML_SYCL_F16)
855+
case GGML_TYPE_F16:
856+
{
857+
auto data_pts = cast_data<sycl::half>(dst);
858+
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
859+
break;
860+
}
861+
#endif
862+
case GGML_TYPE_F32:
863+
{
864+
auto data_pts = cast_data<float>(dst);
865+
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
866+
break;
867+
}
868+
default:
869+
GGML_ABORT("GGML tensor type not supported!\n");
870+
}
871+
}
872+
873+
819874
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820875
#if defined (GGML_SYCL_F16)
821876
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1425,6 +1480,11 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14251480
ggml_sycl_op_gelu_quick(ctx, dst);
14261481
}
14271482

1483+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1484+
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
1485 8000 +
ggml_sycl_op_gelu_erf(ctx, dst);
1486+
}
1487+
14281488
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14291489
scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/1);
14301490
ggml_sycl_op_tanh(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
3838

3939
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4040

41+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
42+
4143
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4244

4345
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3543,6 +3543,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35433543
case GGML_UNARY_OP_GELU_QUICK:
35443544
ggml_sycl_gelu_quick(ctx, dst);
35453545
break;
3546+
case GGML_UNARY_OP_GELU_ERF:
3547+
ggml_sycl_gelu_erf(ctx, dst);
3548+
break;
35463549
case GGML_UNARY_OP_TANH:
35473550
ggml_sycl_tanh(ctx, dst);
35483551
break;
@@ -4096,6 +4099,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
40964099
case GGML_UNARY_OP_HARDSIGMOID:
40974100
case GGML_UNARY_OP_HARDSWISH:
40984101
case GGML_UNARY_OP_GELU_QUICK:
4102+
case GGML_UNARY_OP_GELU_ERF:
40994103
case GGML_UNARY_OP_TANH:
41004104
case GGML_UNARY_OP_EXP:
41014105
case GGML_UNARY_OP_SGN:

0 commit comments

Comments
 (0)
0