Skip to content

Commit

Permalink
test: add softcap
Browse files Browse the repository at this point in the history
xfail for MultiHeadSelfAttention
  • Loading branch information
theissenhelen committed Sep 20, 2024
1 parent 0c1d27b commit f0a1137
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 13 deletions.
17 changes: 11 additions & 6 deletions tests/layers/block/test_block_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging

import pytest
import torch
from hypothesis import given
from hypothesis import settings
Expand All @@ -30,12 +31,13 @@ class TestTransformerProcessorBlock:
activation=st.sampled_from(["ReLU", "GELU", "Tanh"]),
window_size=st.integers(min_value=1, max_value=512),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(max_examples=10)
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p):
def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, window_size, dropout_p, softcap):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)
assert isinstance(block, TransformerProcessorBlock)

Expand All @@ -53,7 +55,9 @@ def test_init(self, factor_attention_heads, hidden_dim, num_heads, activation, w
shapes=st.lists(st.integers(min_value=1, max_value=10), min_size=3, max_size=3),
batch_size=st.integers(min_value=1, max_value=40),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@pytest.mark.xfail(raises=TypeError)
@settings(max_examples=10)
def test_forward_output(
self,
Expand All @@ -65,15 +69,16 @@ def test_forward_output(
shapes,
batch_size,
dropout_p,
softcap,
):
num_channels = num_heads * factor_attention_heads
block = TransformerProcessorBlock(
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p
num_channels, hidden_dim, num_heads, activation, window_size, dropout_p=dropout_p, softcap=softcap
)

x = torch.randn((batch_size, num_channels))

output = block.forward(x, shapes, batch_size)
x = torch.randn((batch_size, num_channels)) # .to(torch.float16, non_blocking=True)
with torch.amp.autocast():
output = block.forward(x, shapes, batch_size)
assert isinstance(output, torch.Tensor)
assert output.shape == (batch_size, num_channels)

Expand Down
10 changes: 9 additions & 1 deletion tests/layers/processor/test_transformer_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def transformer_processor_init():
num_heads = 16
mlp_hidden_ratio = 4
dropout_p = 0.1
softcap = 0.5
return (
num_layers,
window_size,
Expand All @@ -32,6 +33,7 @@ def transformer_processor_init():
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
)


Expand All @@ -47,6 +49,7 @@ def transformer_processor(transformer_processor_init):
num_heads,
mlp_hidden_ratio,
dropout_p,
softcap,
) = transformer_processor_init
return TransformerProcessor(
num_layers=num_layers,
Expand All @@ -58,6 +61,7 @@ def transformer_processor(transformer_processor_init):
num_heads=num_heads,
mlp_hidden_ratio=mlp_hidden_ratio,
dropout_p=dropout_p,
softcap=softcap,
)


Expand All @@ -72,13 +76,15 @@ def test_transformer_processor_init(transformer_processor, transformer_processor
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
assert isinstance(transformer_processor, TransformerProcessor)
assert transformer_processor.num_chunks == num_chunks
assert transformer_processor.num_channels == num_channels
assert transformer_processor.chunk_size == num_layers // num_chunks


@pytest.mark.xfail(raises=TypeError)
def test_transformer_processor_forward(transformer_processor, transformer_processor_init):
(
_num_layers,
Expand All @@ -90,13 +96,15 @@ def test_transformer_processor_forward(transformer_processor, transformer_proces
_num_heads,
_mlp_hidden_ratio,
_dropout_p,
_softcap,
) = transformer_processor_init
gridsize = 100
batch_size = 1
x = torch.rand(gridsize, num_channels)
shard_shapes = [list(x.shape)]

output = transformer_processor.forward(x, batch_size, shard_shapes)
with torch.amp.autocast():
output = transformer_processor.forward(x, batch_size, shard_shapes)
assert output.shape == x.shape

# Generate dummy target and loss function
Expand Down
17 changes: 11 additions & 6 deletions tests/layers/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
num_heads=st.integers(min_value=1, max_value=50),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p):
def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout_p, softcap):
embed_dim = (
num_heads * embed_dim_multiplier
) # TODO: Make assert in MHSA to check if embed_dim is divisible by num_heads
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, softcap=softcap)

assert isinstance(mhsa, nn.Module)
assert mhsa.num_heads == num_heads
Expand All @@ -33,17 +34,19 @@ def test_multi_head_self_attention_init(num_heads, embed_dim_multiplier, dropout
assert dropout_p == mhsa.dropout_p


@pytest.mark.xfail(raises=TypeError)
@pytest.mark.gpu
@given(
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
@settings(deadline=None)
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, softcap=softcap)

x = torch.randn(batch_size * 2, embed_dim)
shapes = [list(x.shape)]
Expand All @@ -52,16 +55,18 @@ def test_multi_head_self_attention_forward(batch_size, num_heads, embed_dim_mult
assert output.shape == x.shape


@pytest.mark.xfail(raises=TypeError)
@pytest.mark.gpu
@given(
batch_size=st.integers(min_value=1, max_value=64),
num_heads=st.integers(min_value=1, max_value=20),
embed_dim_multiplier=st.integers(min_value=1, max_value=10),
dropout_p=st.floats(min_value=0.0, max_value=1.0),
softcap=st.floats(min_value=0.0, max_value=1.0),
)
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p):
def test_multi_head_self_attention_backward(batch_size, num_heads, embed_dim_multiplier, dropout_p, softcap):
embed_dim = num_heads * embed_dim_multiplier
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p)
mhsa = MultiHeadSelfAttention(num_heads, embed_dim, dropout_p=dropout_p, softcap=softcap)

x = torch.randn(batch_size * 2, embed_dim, requires_grad=True)
shapes = [list(x.shape)]
Expand Down

0 comments on commit f0a1137

Please sign in to comment.