Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashAttention Triton error on the MosaicBERT models other than base #441

Closed
Taytay opened this issue Jan 2, 2024 · 3 comments
Closed

Comments

@Taytay
Copy link

Taytay commented Jan 2, 2024

When I try to run MosaicBERT like this:

from transformers import AutoModelForMaskedLM, BertTokenizer, pipeline

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
# It's essential to set the attention_probs_dropout_prob to 0.1
#config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)

mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")

classifier("I [MASK] to the store yesterday.")

I get this error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124), in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   [1123](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1123) try:
-> [1124](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1124)     generator.visit(fn.parse())
   [1125](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1125) except CompilationError as e:

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017), in CodeGenerator.visit(self, node)
   [1016](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1016)     last_loc = self.builder.get_loc()
-> [1017](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1017) ret = super().visit(node)
   [1018](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:1018) # Reset the location to the last one before the visit

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407), in NodeVisitor.visit(self, node)
    [406](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:406) visitor = getattr(self, method, self.generic_visit)
--> [407](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:407) return visitor(node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293), in CodeGenerator.visit_Module(self, node)
    [292](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:292) def visit_Module(self, node):
--> [293](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/site-packages/triton/compiler/code_generator.py:293)     ast.NodeVisitor.generic_visit(self, node)

File [~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415), in NodeVisitor.generic_visit(self, node)
    [414](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:414)         if isinstance(item, AST):
--> [415](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:415)             self.visit(item)
    [416](https://vscode-remote+wsl-002bubuntu-002d18-002e04.vscode-resource.vscode-cdn.net/home/taytay/YNAB/ML/ai_categorize/src/BERT/~/YNAB/ML/ai_categorize/.conda-env/lib/python3.9/ast.py:416) elif isinstance(value, AST):
...
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^
TypeError("dot() got an unexpected keyword argument 'trans_b'")

This appears to have been fixed a few days ago by @jacobfulano in the mosaic-bert-base repo:
https://huggingface.co/mosaicml/mosaic-bert-base/blob/ed2a544063a892b78823cba2858d1e098c0e6012/config.json

It looks like that removes FlashAttention? Does that mean that the speed increase from FA is also removed?

Here's how I can fix it in the meantime if someone else Googles and stumbles across this

config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base')
# It's essential to set the attention_probs_dropout_prob to 0.1, which mosaic-bert-base does. So we just update the alibi_starting_size:
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default of 512
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.")
@jacobfulano
Copy link
Contributor

jacobfulano commented Jan 3, 2024

Hi @Taytay,

Thanks for your comment. This code was written in the early days of Triton Flash Attention and we are in the process of updating to Flash Attention 2 with support for ALiBi (see #440 which updates to PyTorch 2 and the flash-attn package instead of triton). Earlier versions of Flash Attention did not support ALiBi, and we had to come up with a custom ALiBi implementation using Triton Flash Attention that depend on PyTorch 1.13 and triton==2.0.0.dev20221103 (here's the requirements file).

The config value of attention_probs_dropout_prob: 0.1 turns off Triton Flash Attention, while attention_probs_dropout_prob: 0.0 turns it on (you can see this in the source code here).

This was understandably confusing for a lot of people, so I have set the default in Hugging Face to attention_probs_dropout_prob: 0.1 (e.g. here) across all mosaicbert models. It should be noted that most of the benefits of Flash Attention come during training, and not necessarily by serving a loaded model.

@jacobfulano
Copy link
Contributor

Your code should now work with the config from mosaic-bert-base-seqlen-2048. Not that this does not use our custom Triton Flash Attention implementation.

config = transformers.BertConfig.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048')
config.alibi_starting_size = 2048 # maximum sequence length updated to 2048 from config default
mlm = AutoModelForMaskedLM.from_pretrained('mosaicml/mosaic-bert-base-seqlen-2048', trust_remote_code=True, config=config)
mlm.to("cuda")
classifier = pipeline('fill-mask', model=mlm, tokenizer=tokenizer, device="cuda")
classifier("I [MASK] to the store yesterday.")

@Taytay
Copy link
Author

Taytay commented Jan 3, 2024

Thank you for the thorough explanation! I was trying training yesterday and ran into some more errors, so #440 is especially welcome! (I have a model I need to train on a large number of tokens, so the perf is going to be particularly helpful.)

@Taytay Taytay closed this as completed Jan 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants