diff --git a/src/tensor_ops/accurate_gelu/accurate_gelu.cu b/src/tensor_ops/accurate_gelu/accurate_gelu.cu index 0d3f46e5a..b636ccaa4 100644 --- a/src/tensor_ops/accurate_gelu/accurate_gelu.cu +++ b/src/tensor_ops/accurate_gelu/accurate_gelu.cu @@ -15,34 +15,26 @@ template __device__ T accurate_gelu_bwd(T x) { T one = 1.0; T half = 0.5; T alpha = M_SQRT1_2; + T scale = M_2_SQRTPI; T x_sq = x * x; - T beta = M_2_SQRTPI; - T normal_dist = beta * expg(half * -x_sq); + T arg = -half * x_sq; + T norm = scale * expg(arg); T left = half * x; T right = one + erfg(alpha * x); T left_derivative = half * right; - T right_derivative = left * normal_dist; + T right_derivative = left * norm; return left_derivative + right_derivative; } UNARY_OP(__half, accurate_gelu_fwd_f16, accurate_gelu_bwd_f16, - AccurateGeLUKernelOp, - accurate_gelu_fwd(x), - accurate_gelu_bwd(x) -) + AccurateGeLUKernelOp, accurate_gelu_fwd(x), accurate_gelu_bwd(x)) UNARY_OP(float, accurate_gelu_fwd_f32, accurate_gelu_bwd_f32, - AccurateGeLUKernelOp, - accurate_gelu_fwd(x), - accurate_gelu_bwd(x) -) + AccurateGeLUKernelOp, accurate_gelu_fwd(x), accurate_gelu_bwd(x)) UNARY_OP(double, accurate_gelu_fwd_f64, accurate_gelu_bwd_f64, - AccurateGeLUKernelOp, - accurate_gelu_fwd(x), - accurate_gelu_bwd(x) -) + AccurateGeLUKernelOp, accurate_gelu_fwd(x), accurate_gelu_bwd(x)) diff --git a/src/tensor_ops/utilities/cuda_utils.cuh b/src/tensor_ops/utilities/cuda_utils.cuh index e3a78465c..e1278824f 100644 --- a/src/tensor_ops/utilities/cuda_utils.cuh +++ b/src/tensor_ops/utilities/cuda_utils.cuh @@ -137,6 +137,9 @@ __device__ __forceinline__ __half logg(__half a) { return hlog(a); } __device__ __forceinline__ float expg(float a) { return expf(a); } __device__ __forceinline__ double expg(double a) { return exp(a); } __device__ __forceinline__ __half expg(__half a) { return hexp(a); } +__device__ __forceinline__ float erfg(float a) { return erff(a); } +__device__ __forceinline__ double erfg(double a) { return erf(a); } +__device__ __forceinline__ __half erfg(__half a) { return erff(float(a)); } __device__ __forceinline__ float absg(float a) { return fabsf(a); } __device__ __forceinline__ double absg(double a) { return fabs(a); } __device__ __forceinline__ __half absg(__half a) { return __habs(a); }