8000 SYCL: add gelu_erf kernel · ggml-org/llama.cpp@bffd38f · GitHub
[go: up one dir, main page]

Skip to content

Commit bffd38f

Browse files
committed
SYCL: add gelu_erf kernel
1 parent f5cd27b commit bffd38f

File tree

3 files changed

+66
-0
lines changed

3 files changed

+66
-0
lines changed

ggml/src/ggml-sycl/element_wise.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,14 @@ static void gelu_quick(const T *x, T *dst, int k,
8484
dst[i] = x[i] * (static_cast<T>(1.0f) / (static_cast<T>(1.0f) + sycl::native::exp(GELU_QUICK_COEF * x[i])));
8585
}
8686

87+
template<typename T>
88+
static void gelu_erf(const T * x, T * dst, const int k, const sycl::nd_item<3> &item_ct1) {
89+
const T SQRT_2_INV = static_cast<T>(0.70710678118654752440084436210484f);
90+
for(auto i = item_ct1.get_global_id(2); i < (const size_t)k; i += item_ct1.get_global_range(2)) {
91+
dst[i] = static_cast<T>(0.5f)*x[i]*(static_cast<T>(1.0f) + sycl::erf(x[i]*SQRT_2_INV));
92+
}
93+
}
94+
8795
template<typename T>
8896
static void tanh(const T *x, T *dst, int k,
8997
const sycl::nd_item<3> &item_ct1) {
@@ -400,6 +408,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
400408
});
401409
}
402410

411+
412+
template<typename T>
413+
static void gelu_erf_sycl(const T *x, T *dst, const int k,
414+
queue_ptr stream) {
415+
const int num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE);
416+
stream->parallel_for(
417+
sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
418+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
419+
sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
420+
[=](sycl::nd_item<3> item_ct1) {
421+
gelu_erf(x, dst, k, item_ct1);
422+
});
423+
}
424+
403425
template<typename T>
404426
static void tanh_sycl(const T *x, T *dst, const int k,
405427
queue_ptr stream) {
@@ -816,6 +838,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
816838
}
817839
}
818840

841+
inline void ggml_sycl_op_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
842+
#if defined (GGML_SYCL_F16)
843+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
844+
GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
845+
#else
846+
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
847+
GGML_ASSERT(dst->type == GGML_TYPE_F32);
848+
#endif
849+
GGML_ASSERT(dst->src[0]->type == dst->type);
850+
dpct::queue_ptr main_stream = ctx.stream();
851+
SYCL_CHECK(ggml_sycl_set_device(ctx.device));
852+
switch (dst->type) {
853+
#if defined (GGML_SYCL_F16)
854+
case GGML_TYPE_F16:
855+
{
856+
auto data_pts = cast_data<sycl::half>(dst);
857+
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
858+
break;
859+
}
860+
#endif
861+
case GGML_TYPE_F32:
862+
{
863+
auto data_pts = cast_data<float>(dst);
864+
gelu_erf_sycl(data_pts.src, data_pts.dst, ggml_nelements(dst->src[0]), main_stream);
865+
break;
866+
}
867+
default:
868+
GGML_ABORT("GGML tensor type not supported!\n");
869+
}
870+
}
871+
872+
819873
inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820874
#if defined (GGML_SYCL_F16)
821875
GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16);
@@ -1432,6 +1486,12 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14321486
GGML_SYCL_DEBUG("call %s done\n", __func__);
14331487
}
14341488

1489+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1490+
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
1491+
ggml_sycl_op_gelu_erf(ctx, dst);
1492+
GGML_SYCL_DEBUG("call %s done\n", __func__);
1493+
}
1494+
14351495
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
14361496
GGML_SYCL_DEBUG("call %s: DST Tensor type: %s\n", __func__, ggml_type_name(dst->type));
14371497
ggml_sycl_op_tanh(ctx, dst);

ggml/src/ggml-sycl/element_wise.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ void ggml_sycl_silu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
3838

3939
void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4040

41+
void ggml_sycl_gelu_erf(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
42+
4143
void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
4244

4345
void ggml_sycl_relu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);

ggml/src/ggml-sycl/ggml-sycl.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3508,6 +3508,9 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
35083508
case GGML_UNARY_OP_GELU_QUICK:
35093509
ggml_sycl_gelu_quick(ctx, dst);
35103510
break;
3511+
case GGML_UNARY_OP_GELU_ERF:
3512+
ggml_sycl_gelu_erf(ctx, dst);
3513+
break;
35113514
case GGML_UNARY_OP_TANH:
35123515
ggml_sycl_tanh(ctx, dst);
35133516
break;
@@ -4048,6 +4051,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
40484051
case GGML_UNARY_OP_HARDSIGMOID:
40494052
case GGML_UNARY_OP_HARDSWISH:
40504053
case GGML_UNARY_OP_GELU_QUICK:
4054+
case GGML_UNARY_OP_GELU_ERF:
40514055
case GGML_UNARY_OP_TANH:
40524056
case GGML_UNARY_OP_EXP:
40534057
case GGML_UNARY_OP_SGN:

0 commit comments

Comments
 (0)
0