diff --git a/examples/benchmarks/bert/main.py b/examples/benchmarks/bert/main.py index 277eee2d4..91bbe4b5b 100644 --- a/examples/benchmarks/bert/main.py +++ b/examples/benchmarks/bert/main.py @@ -11,6 +11,7 @@ import src.hf_bert as hf_bert_module import src.mosaic_bert as mosaic_bert_module import src.text_data as text_data_module +import src.mlm_scheduling as mlm_scheduling_module from composer import Trainer, algorithms from composer.callbacks import (HealthChecker, LRMonitor, MemoryMonitor, OptimizerMonitor, RuntimeEstimator, @@ -52,6 +53,26 @@ def update_batch_size_info(cfg: DictConfig): return cfg +def update_mlm_schedule(cfg: DictConfig): + + def convert_constant_rate(dataset_cfg: DictConfig): + mlm_schedule = dataset_cfg.get('mlm_schedule', None) + if mlm_schedule is None: + mlm_probability = dataset_cfg.mlm_probability + mlm_schedule = om.create({ + 'name': 'constant', + 'initial_masking_rate': mlm_probability, + 'final_masking_rate': mlm_probability, + }) + return mlm_schedule + + cfg.train_loader.dataset.mlm_schedule = convert_constant_rate( + cfg.train_loader.dataset) + cfg.eval_loader.dataset.mlm_schedule = convert_constant_rate( + cfg.eval_loader.dataset) + return cfg + + def log_config(cfg: DictConfig): print(om.to_yaml(cfg)) if 'wandb' in cfg.get('loggers', {}): @@ -174,7 +195,7 @@ def main(cfg: DictConfig, # Dataloaders print('Building train loader...') - train_loader = build_dataloader( + train_loader, distributed_masking_rate = build_dataloader( cfg.train_loader, model.tokenizer, cfg.global_train_batch_size // dist.get_world_size(), @@ -182,7 +203,7 @@ def main(cfg: DictConfig, print('Building eval loader...') global_eval_batch_size = cfg.get('global_eval_batch_size', cfg.global_train_batch_size) - eval_loader = build_dataloader( + eval_loader, _ = build_dataloader( cfg.eval_loader, model.tokenizer, global_eval_batch_size // dist.get_world_size(), @@ -205,6 +226,9 @@ def main(cfg: DictConfig, build_callback(name, callback_cfg) for name, callback_cfg in cfg.get('callbacks', {}).items() ] + callbacks.append( + mlm_scheduling_module(cfg.train_loader.dataset.mlm_schedule, + distributed_masking_rate)) # Algorithms algorithms = [ @@ -265,5 +289,6 @@ def main(cfg: DictConfig, yaml_cfg = om.load(f) cli_cfg = om.from_cli(args_list) cfg = om.merge(yaml_cfg, cli_cfg) + cfg = update_mlm_schedule(cfg) cfg = cast(DictConfig, cfg) # for type checking main(cfg) diff --git a/examples/benchmarks/bert/src/mlm_scheduling.py b/examples/benchmarks/bert/src/mlm_scheduling.py new file mode 100644 index 000000000..472a936f5 --- /dev/null +++ b/examples/benchmarks/bert/src/mlm_scheduling.py @@ -0,0 +1,81 @@ +from typing import Union +import multiprocessing + +from composer import Callback, State, Logger, Event, Time +from composer.optim.scheduler import (ComposerScheduler, ConstantScheduler, + CosineAnnealingScheduler, _convert_time, + LinearScheduler) +from omegaconf import DictConfig +""" +Definition of schedulers and callbacks for setting the masking rate dynamically +""" + + +# Define special case of step-wise scheduling where decay is only performed +# once and as such define by start and terminal masking rates +class StepScheduler(ComposerScheduler): + r"""Decays the masking rate by discrete step to new rate. + Args: + alpha_i (float): Multiplier of initial masking rate. Default = ``0.3``. + alpha_f (float): Masking rate to end at. Default = ``0.15``. + t_step (str | Time): The time step to switch masking rate. Default = ``"0.5dur"``. + """ + + def __init__(self, + alpha_i: float = 1, + alpha_f: float = 0.5, + t_step: Union[str, Time] = '0.5dur'): + self.alpha_i = alpha_i + self.alpha_f = alpha_f + self.t_step = t_step + + def __call__(self, state: State, ssr: float = 1.0): + t_step = _convert_time(self.t_step, state, ssr=ssr) + current_time = state.timestamp.get(t_step.unit) + + if t_step.value > current_time.value: + return self.alpha_i + + return self.alpha_f + + +class MaskingRateSetter(Callback): + + def __init__(self, scheduler: ComposerScheduler, + initial_masking_rate: float, + dynamic_masking_rate: multiprocessing.Value): + super().__init__() + self.scheduler = scheduler + self.initial_masking_rate = initial_masking_rate + self.dynamic_masking_rate = dynamic_masking_rate + + def run_event(self, event: Event, state: State, logger: Logger): + if event == Event.BATCH_END: + masking_rate = self.scheduler(state) * self.initial_masking_rate + + self.dynamic_masking_rate.value = masking_rate + + logger.log_metrics({'mlm_schedule/masking_rate': masking_rate}) + + +def build_mlm_scheduler_callback( + cfg: DictConfig, distributed_masking_rate: multiprocessing.Value): + initial_masking_rate = cfg.initial_masking_rate + final_masking_rate = cfg.final_masking_rate + alpha_f = final_masking_rate / initial_masking_rate # Multiple to reach final mlm rate + + if cfg.name == 'constant': + mlm_schedule = ConstantScheduler() + elif cfg.name == 'cosine': + mlm_schedule = CosineAnnealingScheduler(alpha_f=alpha_f) + elif cfg.name == 'linear': + mlm_schedule = LinearScheduler(alpha_f=alpha_f) + elif cfg.name == 'step': + mlm_schedule = StepScheduler(alpha_f=alpha_f) + else: + raise ValueError( + f'Not sure how to build masking rate scheduler: {cfg.name}') + + return MaskingRateSetter(mlm_schedule, + initial_masking_rate=initial_masking_rate, + dynamic_masking_rate=distributed_masking_rate) diff --git a/examples/benchmarks/bert/src/text_data.py b/examples/benchmarks/bert/src/text_data.py index 70a57dbcb..27c123a35 100644 --- a/examples/benchmarks/bert/src/text_data.py +++ b/examples/benchmarks/bert/src/text_data.py @@ -4,8 +4,9 @@ """Build a StreamingTextDataset dataset and dataloader for training.""" import os +import multiprocessing from itertools import islice -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -184,6 +185,23 @@ def __getitem__(self, idx: int) -> Union[Dict[str, Any], torch.Tensor]: return token_sample +class ScheduledDataCollatorForLanguageModeling( + transformers.DataCollatorForLanguageModeling): + + def __init__(self, distributed_mlm_probability: multiprocessing.Value, + *args: Tuple[Any], **kwargs: Dict[str, Any]): + super().__init__(*args, **kwargs) + self.distributed_mlm_probability = distributed_mlm_probability + + @property + def mlm_probability(self): + return self.distributed_mlm_probability.value + + @mlm_probability.setter + def mlm_probability(self, _): + return + + class ConcatenatedSequenceCollatorWrapper: """Collator wrapper to add sequence_id to batch.""" @@ -293,11 +311,15 @@ def build_text_dataloader( shuffle_seed=cfg.dataset.get('shuffle_seed', 9176), ) - mlm_probability = cfg.dataset.get('mlm_probability', None) - collate_fn = transformers.DataCollatorForLanguageModeling( + mlm_schedule = cfg.dataset.get('mlm_schedule', None) + distributed_mlm_probability = None + if mlm_schedule: + distributed_mlm_probability = multiprocessing.Value( + "d", mlm_schedule.initial_masking_rate) + collate_fn = ScheduledDataCollatorForLanguageModeling( tokenizer=dataset.tokenizer, - mlm=mlm_probability is not None, - mlm_probability=mlm_probability) + mlm=mlm_schedule is not None, + distributed_mlm_probability=distributed_mlm_probability) eos_token_id = cfg.dataset.get('eos_token_id') bos_token_id = cfg.dataset.get('bos_token_id') @@ -318,7 +340,7 @@ def build_text_dataloader( prefetch_factor=cfg.get('prefetch_factor', 2), persistent_workers=cfg.get('persistent_workers', True), timeout=cfg.get('timeout', 0), - ) + ), distributed_mlm_probability # Helpful to test if your dataloader is working locally