Skip to content

Commit

Permalink
Merge branch 'develop' into feature/44-make-flash-attention-configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
theissenhelen committed Sep 20, 2024
2 parents f0a1137 + 0219266 commit 8caa922
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/anemoi/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
window_size: Optional[int] = None,
dropout_p: float = 0.0,
softcap: float = 0.0,

):
super().__init__()

Expand Down
1 change: 1 addition & 0 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap

)
assert isinstance(block, TransformerProcessorBlock)

Expand Down
2 changes: 2 additions & 0 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, softcap=softcap)


x = torch.randn(batch_size * 2, embed_dim)
shapes = [list(x.shape)]
output = mhsa.forward(x, shapes, batch_size)
Expand All @@ -68,6 +69,7 @@ def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_mul
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, softcap=softcap)


x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
shapes = [list(x.shape)]
output = mhsa.forward(x, shapes, batch_size)
Expand Down

0 comments on commit 8caa922

Please sign in to comment.