Skip to content

Commit

Permalink
reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist1 committed Jul 25, 2023
1 parent 6adffdb commit e7ebf8e
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/tensor_ops/accurate_gelu/accurate_gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,29 @@
struct AccurateGeLUKernelOp {};

template <typename T> __device__ T accurate_gelu_fwd(T x) {
T one = 1.0;
T half = 0.5;
T alpha = M_SQRT1_2;
return half * x * (one + erfg(x * alpha));
T one = 1.0;
T half = 0.5;
T alpha = M_SQRT1_2;
return half * x * (one + erfg(x * alpha));
}

template <typename T> __device__ T accurate_gelu_bwd(T x) {
T one = 1.0;
T half = 0.5;
T alpha = M_SQRT1_2;
T scale = M_2_SQRTPI;
T x_sq = x * x;
T arg = -half * x_sq;
T norm = scale * expg(arg);
T one = 1.0;
T half = 0.5;
T alpha = M_SQRT1_2;
T scale = M_2_SQRTPI;
T x_sq = x * x;
T arg = -half * x_sq;
T norm = scale * expg(arg);

T left = half * x;
T right = one + erfg(alpha * x);
T left = half * x;
T right = one + erfg(alpha * x);

T left_derivative = half * right;
T left_derivative = half * right;

T right_derivative = left * norm;
T right_derivative = left * norm;

return left_derivative + right_derivative;
return left_derivative + right_derivative;
}

UNARY_OP(__half, accurate_gelu_fwd_f16, accurate_gelu_bwd_f16,
Expand Down

0 comments on commit e7ebf8e

Please sign in to comment.