From 20d3725269f73f7aff81be24d133fd85ee8badb1 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Tue, 2 Jan 2024 12:51:14 -0500 Subject: [PATCH] Modernize MosaicBERT --- examples/benchmarks/bert/main.py | 4 +- examples/benchmarks/bert/requirements-cpu.txt | 10 +- examples/benchmarks/bert/requirements.txt | 14 +-- examples/benchmarks/bert/src/bert_layers.py | 96 ++++++++++++++----- examples/benchmarks/bert/src/text_data.py | 8 -- 5 files changed, 90 insertions(+), 42 deletions(-) diff --git a/examples/benchmarks/bert/main.py b/examples/benchmarks/bert/main.py index 277eee2d4..8a90ea80e 100644 --- a/examples/benchmarks/bert/main.py +++ b/examples/benchmarks/bert/main.py @@ -246,7 +246,9 @@ def main(cfg: DictConfig, load_path=cfg.get('load_path', None), load_weights_only=cfg.get('load_weights_only', False), python_log_level=cfg.get('python_log_level', None), - ) + autoresume=cfg.get('autoresume', None), + fsdp_config=cfg.get('fsdp_config', None), + compile_config=cfg.get('compile_config', None)) print('Logging config...') log_config(cfg) diff --git a/examples/benchmarks/bert/requirements-cpu.txt b/examples/benchmarks/bert/requirements-cpu.txt index 6c1038911..196140592 100644 --- a/examples/benchmarks/bert/requirements-cpu.txt +++ b/examples/benchmarks/bert/requirements-cpu.txt @@ -1,6 +1,6 @@ einops==0.5.0 -torch==1.13.1 -mosaicml[nlp,wandb]>=0.14.0,<0.15 -mosaicml-streaming==0.4.1 -omegaconf==2.2.3 -transformers==4.28.1 +torch==2.1.1 +composer[nlp,wandb]>=0.17.0,<0.18 +mosaicml-streaming<=0.7 +omegaconf==2.3.0 +transformers==4.36.2 diff --git a/examples/benchmarks/bert/requirements.txt b/examples/benchmarks/bert/requirements.txt index 9bf635b20..147cfa156 100644 --- a/examples/benchmarks/bert/requirements.txt +++ b/examples/benchmarks/bert/requirements.txt @@ -1,8 +1,10 @@ einops==0.5.0 -torch==1.13.1 -mosaicml[nlp,wandb]>=0.14.0,<0.15 -mosaicml-streaming==0.4.1 -omegaconf==2.2.3 -transformers==4.28.1 +torch==2.1.1 +composer[nlp,wandb]>=0.17.0,<0.18 +mosaicml-streaming<=0.7 +omegaconf==2.3.0 +transformers== 4.36.2 +# need a newer version of FA2 +flash_attn>=2.4.2 # need a newer version of triton -triton==2.0.0.dev20221103 +#triton==2.0.0.dev20221103 diff --git a/examples/benchmarks/bert/src/bert_layers.py b/examples/benchmarks/bert/src/bert_layers.py index af10e9d7d..f882a18e3 100644 --- a/examples/benchmarks/bert/src/bert_layers.py +++ b/examples/benchmarks/bert/src/bert_layers.py @@ -54,11 +54,26 @@ SequenceClassifierOutput) from transformers.models.bert.modeling_bert import BertPreTrainedModel +IMPL_USE_FLASH2 = False try: - import flash_attn_triton as flash_attn_triton - flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func + import importlib + + from flash_attn import flash_attn_qkvpacked_func + installed_version = importlib.metadata.version('flash_attn') + if installed_version < '2.4.2': + raise ImportError('newer version of flash_attn required (>= 2.4.2)') + IMPL_USE_FLASH2 = True except ImportError as e: - flash_attn_qkvpacked_func = None + warnings.warn( + f'Failed to import flash_attn. Will try to import triton implementation: {e}', + stacklevel=2) + try: + import flash_attn_triton as flash_attn_triton + flash_attn_qkvpacked_func = flash_attn_triton.flash_attn_qkvpacked_func + except ImportError as e: + flash_attn_qkvpacked_func = None + warnings.warn(f'Failed to import flash_attn_triton as a fallback: {e}', + stacklevel=2) logger = logging.getLogger(__name__) @@ -183,7 +198,8 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen_in_batch: int, indices: torch.Tensor, - attn_mask: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: + attn_mask: torch.Tensor, bias: torch.Tensor, + slopes: torch.Tensor) -> torch.Tensor: """Perform self-attention. If dropout is zero, then we can use the Triton kernel, so we do that. However, if not, we send through a standard PyTorch @@ -201,6 +217,7 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, indices: (total_nnz,) attn_mask: (batch, max_seqlen_in_batch) bias: (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: (heads) or (batch, heads) Returns: attention: (total_nnz, dim) @@ -213,7 +230,8 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, 'b s (t h d) -> b s t h d', t=3, h=self.num_attention_heads) - if self.p_dropout or flash_attn_qkvpacked_func is None: + if (not IMPL_USE_FLASH2 and + self.p_dropout) or flash_attn_qkvpacked_func is None: # if we have nonzero attention dropout (e.g. during fine-tuning) or no Triton, compute attention in PyTorch q = qkv[:, :, 0, :, :].permute(0, 2, 1, 3) # b h s d k = qkv[:, :, 1, :, :].permute(0, 2, 3, 1) # b h d s @@ -226,19 +244,41 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, attention = torch.matmul(attention_probs, v).permute(0, 2, 1, 3) # b s h d else: - # Triton implementation only supports 0 attention dropout - convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16] - if convert_dtype: - # Triton implementation only supports fp16 and bf16 - orig_dtype = qkv.dtype - qkv = qkv.to(torch.float16) - bias_dtype = bias.dtype - bias = bias.to(torch.float16) - attention = flash_attn_qkvpacked_func(qkv, bias) - attention = attention.to(orig_dtype) - bias = bias.to(bias_dtype) + if IMPL_USE_FLASH2: + assert 1 <= len(slopes.shape) <= 2, f'{slopes=}' + assert slopes.shape[ + -1] == self.num_attention_heads, f'{slopes=}' + + # Triton implementation only supports 0 attention dropout + convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16] + if convert_dtype: + # Triton implementation only supports fp16 and bf16 + orig_dtype = qkv.dtype + qkv = qkv.to(torch.float16) + bias_dtype = bias.dtype + bias = bias.to(torch.float16) + + attention = flash_attn_qkvpacked_func( + qkv, dropout_p=self.p_dropout, alibi_slopes=slopes) + attention = attention.to(orig_dtype) + bias = bias.to(bias_dtype) + else: + attention = flash_attn_qkvpacked_func( + qkv, dropout_p=self.p_dropout, alibi_slopes=slopes) else: - attention = flash_attn_qkvpacked_func(qkv, bias) + # Triton implementation only supports 0 attention dropout + convert_dtype = qkv.dtype not in [torch.float16, torch.bfloat16] + if convert_dtype: + # Triton implementation only supports fp16 and bf16 + orig_dtype = qkv.dtype + qkv = qkv.to(torch.float16) + bias_dtype = bias.dtype + bias = bias.to(torch.float16) + attention = flash_attn_qkvpacked_func(qkv, bias) + attention = attention.to(orig_dtype) + bias = bias.to(bias_dtype) + else: + attention = flash_attn_qkvpacked_func(qkv, bias) # attn_mask is 1 for attend and 0 for don't attention = bert_padding_module.unpad_input_only( @@ -291,6 +331,7 @@ def forward( indices: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, + slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass for scaled self-attention without padding. @@ -303,9 +344,11 @@ def forward( indices: None or (total_nnz,) attn_mask: None or (batch, max_seqlen_in_batch) bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: None or (batch, heads) or (heads,) """ + assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}' self_output = self.self(input_tensor, cu_seqlens, max_s, indices, - attn_mask, bias) + attn_mask, bias, slopes) if subset_idx is not None: return self.output( bert_padding_module.index_first_axis(self_output, subset_idx), @@ -379,6 +422,7 @@ def forward( indices: Optional[torch.Tensor] = None, attn_mask: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, + slopes: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Forward pass for a BERT layer, including both attention and MLP. @@ -391,9 +435,12 @@ def forward( indices: None or (total_nnz,) attn_mask: None or (batch, max_seqlen_in_batch) bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) + slopes: None or (batch, heads) or (heads,) """ + assert (bias is None) == (slopes is None), f'{bias=}, {slopes=}' attention_output = self.attention(hidden_states, cu_seqlens, seqlen, - subset_idx, indices, attn_mask, bias) + subset_idx, indices, attn_mask, bias, + slopes) layer_output = self.mlp(attention_output) return layer_output @@ -463,6 +510,7 @@ def get_slopes_power_of_2(n_heads: int) -> List[float]: relative_position = relative_position.unsqueeze(0).expand( n_heads, -1, -1) slopes = torch.Tensor(_get_alibi_head_slopes(n_heads)).to(device) + self.slopes = slopes alibi = slopes.unsqueeze(1).unsqueeze(1) * -relative_position # [1, n_heads, max_token_length, max_token_length] alibi = alibi.unsqueeze(0) @@ -504,6 +552,7 @@ def forward( elif self.alibi.device != hidden_states.device: # Device catch-up self.alibi = self.alibi.to(hidden_states.device) + self.slopes = self.slopes.to(hidden_states.device) alibi_bias = self.alibi[:, :, :seqlen, :seqlen] attn_bias = extended_attention_mask[:, :, :seqlen, :seqlen] alibi_attn_mask = attn_bias + alibi_bias @@ -517,7 +566,8 @@ def forward( None, indices, attn_mask=attention_mask, - bias=alibi_attn_mask) + bias=alibi_attn_mask, + slopes=self.slopes) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) # Pad inputs and mask. It will insert back zero-padded tokens. @@ -536,7 +586,8 @@ def forward( None, indices, attn_mask=attention_mask, - bias=alibi_attn_mask) + bias=alibi_attn_mask, + slopes=self.slopes) if output_all_encoded_layers: all_encoder_layers.append(hidden_states) subset_idx = torch.nonzero(subset_mask[attention_mask_bool], @@ -547,7 +598,8 @@ def forward( subset_idx=subset_idx, indices=indices, attn_mask=attention_mask, - bias=alibi_attn_mask) + bias=alibi_attn_mask, + slopes=self.slopes) if not output_all_encoded_layers: all_encoder_layers.append(hidden_states) diff --git a/examples/benchmarks/bert/src/text_data.py b/examples/benchmarks/bert/src/text_data.py index 70a57dbcb..166801005 100644 --- a/examples/benchmarks/bert/src/text_data.py +++ b/examples/benchmarks/bert/src/text_data.py @@ -69,9 +69,6 @@ class StreamingTextDataset(StreamingDataset): keep_zip (bool): Whether to keep or delete the compressed form when decompressing downloaded shards. If ``False``, keep iff remote is local or no remote. Defaults to `False``. - keep_raw (bool): Whether to keep or delete the decompressed form (or only form) - of shards after all their samples have been yielded this epoch. If ``False``, keep iff - remote is local or no remote and no compression. Defaults to ``True``. samples_per_epoch (int, optional): Provide this field iff you are weighting sub-datasets proportionally. Defaults to ``None``. predownload (int, optional): Target number of samples ahead to download the shards of while @@ -99,7 +96,6 @@ def __init__(self, download_timeout: float = 60, validate_hash: Optional[str] = None, keep_zip: bool = False, - keep_raw: bool = True, samples_per_epoch: Optional[int] = None, predownload: int = 100_000, partition_algo: str = 'orig', @@ -140,7 +136,6 @@ def __init__(self, download_timeout=download_timeout, validate_hash=validate_hash, keep_zip=keep_zip, - keep_raw=keep_raw, samples_per_epoch=samples_per_epoch, predownload=predownload, partition_algo=partition_algo, @@ -266,8 +261,6 @@ def build_text_dataloader( cfg.dataset.get('validate_hash', None), keep_zip=stream.get('keep_zip', None) or cfg.dataset.get('keep_zip', False), - keep_raw=stream.get('keep_raw', None) or - cfg.dataset.get('keep_raw', True), )) # build dataset potentially with streams @@ -282,7 +275,6 @@ def build_text_dataloader( download_timeout=cfg.dataset.get('download_timeout', 60), validate_hash=cfg.dataset.get('validate_hash', None), keep_zip=cfg.dataset.get('keep_zip', False), - keep_raw=cfg.dataset.get('keep_raw', True), samples_per_epoch=cfg.dataset.get('samples_per_epoch', None), predownload=cfg.dataset.get('predownload', 100_000), partition_algo=cfg.dataset.get('partition_algo', 'orig'),