Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

accurate-gelu #813

Merged
merged 16 commits into from
Jul 20, 2023
Merged

accurate-gelu #813

merged 16 commits into from
Jul 20, 2023

Conversation

jcrist1
Copy link
Contributor

@jcrist1 jcrist1 commented Jul 13, 2023

This PR adds an accurate gelu function (used at least in in GPT2 model from huggingface) (see Issue #804. Importantly in order to make the operators generic over Dtype, I introduce an Erf trait that allows us to call d_type.erf() to get the error function of the value d_type. I currently am getting a compile error with feature cuda, and I have a hunch that this trait might be the issue. I'm having trouble debugging the build further as I don't have the cuda headers anywhere (working on a mac), so I would appreciate some help, so I was wondering if someone more familiar with the code might be able to point out the error.

For Github: Resolves #804

@jcrist1 jcrist1 marked this pull request as ready for review July 13, 2023 14:10
@jcrist1
Copy link
Contributor Author

jcrist1 commented Jul 13, 2023

Also I'm not at sure that the cuda actually supports the erf as I can't build cuda locally. Here's a link to the cuda 32bit error function

@jcrist1
Copy link
Contributor Author

jcrist1 commented Jul 13, 2023

Finally the naming is pretty terrible. Would it make sens to make a Gelu Activation an enum

enum Gelu {
   Fast,
   Accurate
}

But then probably the activation_impls! macro couldn't be used

@nkoppel
Copy link
Contributor

nkoppel commented Jul 13, 2023

Thank you for contributing! This looks good so far, I just have a few comments.

  • To check the CUDA build without needing to call nvcc, use the command below. This won't be able to check your kernels, but should help you resolve your current CI errors.
cargo +nightly clippy -F cuda,ci-check
  • I think it's best that we name the op/module 'AccurateGeLU' rather than 'GeLUCorrect' because it's easier to read and doesn't imply the the approximate GeLU op is incorrect.
  • Be sure to document the difference between the accurate gelu op and the normal gelu.

@jcrist1
Copy link
Contributor Author

jcrist1 commented Jul 14, 2023

Okay @nkoppel I updated the name. Accurate GeLU is much better. I think distinguishing that the other GeLU is faster will help. I couldn't find an explicit citation for the fact that it's faster, but it seems to only require a single exponential, while the error function requires a much higher degree polynomial and still an exponential. I also beefed up the docs.

I feel like the changes to the docs for the activations don't fit in the code, but I'm guessing most people using the code will be using them as activations, rather than postfix operations so I added some info there even though it breaks the nice code block.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the changes in this file just indentation? Can you revert them if so? Just for easier reviewing 😀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be, and will do. LSP automatically updated them.

@nkoppel
Copy link
Contributor

nkoppel commented Jul 14, 2023

@jcrist1 , can you write "Resolves #804" somewhere in your first comment so that Github will mark this pr as resolving that issue?

Copy link
Contributor

@nkoppel nkoppel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a few little fixes to documentation and this should be good to go!

Note that I don't have write access to dfdx, I've just contributed a lot, so @coreylowman gets the final say on everything.

Comment on lines 40 to 57
/// See [gelu]
pub fn fast_gelu(self) -> Self {
self.try_fast_gelu().unwrap()
}
/// See [gelu]
pub fn try_fast_gelu(self) -> Result<Self, D::Err> {
try_unary_op(FastGeLUKernelOp, self)
}

#[deprecated(since = "0.12.0", note = "Use `fast_gelu` instead")]
pub fn gelu(self) -> Self {
self.fast_gelu()
}

#[deprecated(since = "0.12.0", note = "Use `try_fast_gelu` instead")]
pub fn try_gelu(self) -> Result<Self, D::Err> {
self.try_fast_gelu()
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top two methods should link to fast_gelu, and deprecated items should have a link to their non-deprecated counterparts.

#[derive(Debug, Default, Copy, Clone)]
pub struct FastGeLUKernelOp;

/// [Fast Gaussian Linear Unit (GeLU)](https://paperswithcode.com/method/gelu). `0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These docs should include a link to AccurateGeLU

/// GeLU(x) ~ 0.5 ∗ x ∗ (1.0 + tanh((sqrt(2.0/π) ∗ (x + 0.044715 ∗ x^3)))
/// ```
///
/// See [gelu](crate::tensor_ops::gelu::gelu) to use this approximation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This link needs to be fixed with the new naming

Comment on lines +46 to +48
#[deprecated(since = "0.12.0", note = "please use `FastGeLU` instead")]
#[derive(Default, Debug, Clone, Copy)]
pub struct GeLU;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to link to it's non-deprecated counterpart.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

link is 3 lines up

@jcrist1 jcrist1 requested a review from nkoppel July 17, 2023 20:02
@jcrist1
Copy link
Contributor Author

jcrist1 commented Jul 17, 2023

@nkoppel alright should be addressed

@coreylowman
Copy link
Owner

This looks good from me - any other updates planned here?

@jcrist1
Copy link
Contributor Author

jcrist1 commented Jul 19, 2023

I'm not planning on anything new. Just waiting on confirmation from @nkoppel that everything is addressed

@coreylowman
Copy link
Owner

We can open another PR if there's something else to add/change!

@coreylowman coreylowman merged commit 47deac7 into coreylowman:main Jul 20, 2023
4 checks passed
@jcrist1 jcrist1 mentioned this pull request Jul 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support accurate GeLU as well as fast approximate GeLu
3 participants