@@ -84,6 +84,14 @@ static void gelu_quick(const T *x, T *dst, int k,
84
84
dst[i] = x[i] * (static_cast <T>(1 .0f ) / (static_cast <T>(1 .0f ) + sycl::native::exp (GELU_QUICK_COEF * x[i])));
85
85
}
86
86
87
+ template <typename T>
88
+ static void gelu_erf (const T * x, T * dst, const int k, const sycl::nd_item<3 > &item_ct1) {
89
+ const T SQRT_2_INV = static_cast <T>(0 .70710678118654752440084436210484f );
90
+ for (auto i = item_ct1.get_global_id (2 ); i < (const size_t )k; i += item_ct1.get_global_range (2 )) {
91
+ dst[i] = static_cast <T>(0 .5f )*x[i]*(static_cast <T>(1 .0f ) + sycl::erf (x[i]*SQRT_2_INV));
92
+ }
93
+ }
94
+
87
95
template <typename T>
88
96
static void tanh (const T *x, T *dst, int k,
89
97
const sycl::nd_item<3 > &item_ct1) {
@@ -400,6 +408,20 @@ static void gelu_quick_sycl(const T *x, T *dst, const int k,
400
408
});
401
409
}
402
410
411
+
412
+ template <typename T>
413
+ static void gelu_erf_sycl (const T *x, T *dst, const int k,
414
+ queue_ptr stream) {
415
+ const int num_blocks = ceil_div (k, SYCL_GELU_BLOCK_SIZE);
416
+ stream->parallel_for (
417
+ sycl::nd_range<3 >(sycl::range<3 >(1 , 1 , num_blocks) *
418
+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE),
419
+ sycl::range<3 >(1 , 1 , SYCL_GELU_BLOCK_SIZE)),
420
+ [=](sycl::nd_item<3 > item_ct1) {
421
+ gelu_erf (x, dst, k, item_ct1);
422
+ });
423
+ }
424
+
403
425
template <typename T>
404
426
static void tanh_sycl (const T *x, T *dst, const int k,
405
427
queue_ptr stream) {
@@ -816,6 +838,38 @@ inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor
816
838
}
817
839
}
818
840
841
+ inline void ggml_sycl_op_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor *dst) {
842
+ #if defined (GGML_SYCL_F16)
843
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
844
+ GGML_ASSERT (dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
845
+ #else
846
+ GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32);
847
+ GGML_ASSERT (dst->type == GGML_TYPE_F32);
848
+ #endif
849
+ GGML_ASSERT (dst->src [0 ]->type == dst->type );
850
+ dpct::queue_ptr main_stream = ctx.stream ();
851
+ SYCL_CHECK (ggml_sycl_set_device (ctx.device ));
852
+ switch (dst->type ) {
853
+ #if defined (GGML_SYCL_F16)
854
+ case GGML_TYPE_F16:
855
+ {
856
+ auto data_pts = cast_data<sycl::half>(dst);
857
+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
858
+ break ;
859
+ }
860
+ #endif
861
+ case GGML_TYPE_F32:
862
+ {
863
+ auto data_pts = cast_data<float >(dst);
864
+ gelu_erf_sycl (data_pts.src , data_pts.dst , ggml_nelements (dst->src [0 ]), main_stream);
865
+ break ;
866
+ }
867
+ default :
868
+ GGML_ABORT (" GGML tensor type not supported!\n " );
869
+ }
870
+ }
871
+
872
+
819
873
inline void ggml_sycl_op_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
820
874
#if defined (GGML_SYCL_F16)
821
875
GGML_ASSERT (dst->src [0 ]->type == GGML_TYPE_F32 || dst->src [0 ]->type == GGML_TYPE_F16);
@@ -1432,6 +1486,12 @@ void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1432
1486
GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1433
1487
}
1434
1488
1489
+ void ggml_sycl_gelu_erf (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1490
+ GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1491
+ ggml_sycl_op_gelu_erf (ctx, dst);
1492
+ GGML_SYCL_DEBUG (" call %s done\n " , __func__);
1493
+ }
1494
+
1435
1495
void ggml_sycl_tanh (ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
1436
1496
GGML_SYCL_DEBUG (" call %s: DST Tensor type: %s\n " , __func__, ggml_type_name (dst->type ));
1437
1497
ggml_sycl_op_tanh (ctx, dst);
0 commit comments