Skip to content

v0.7.0

Compare
Choose a tag to compare
@ravi-mosaicml ravi-mosaicml released this 24 May 00:56
· 1794 commits to dev since this release

πŸš€ Composer v0.7.0

Composer v0.7.0 is released! Install via pip:

pip install --upgrade mosaicml==0.7.0

Alternatively, install Composer with Conda:

conda install -c mosaicml mosaicml=0.7.0

New Features

  1. 🏎️ FFCV Integration

    Composer supports FFCV, a fast dataloader for image datasets. We've found FFCV can speed up ResNet-56 training by 16%, in addition to existing speed-ups already supported by Composer! It's easy to use FFCV with any existing image dataset:

    import ffcv
    from ffcv.fields.decoders import IntDecoder, SimpleRGBImageDecoder
    from torchvision.datasets import ImageFolder
    
    from composer import Trainer
    from composer.datasets.ffcv_utils import write_ffcv_dataset, ffcv_monkey_patches
    
    # Convert the dataset to FFCV format
    # This step needs to be done only once per dataset
    dataset = ImageFolder(...)
    ffcv_dataset_path = "my_ffcv_dataset.ffcv"
    write_ffcv_dataset(dataset=dataset, write_path=ffcv_dataset_path)
    
    # In FFCV v0.0.3, len(dataloader) is expensive. Fix that via a monkeypatch
    ffcv_monkey_patches()
    
    # Construct the train dataloader
    train_dl = ffcv.Loader(
        ffcv_dataset_path,
        ...
    )
    
    # Construct the trainer
    trainer = Trainer(
        train_dataloader=train_dl,
    )
    
    # Train using FFCV!
    trainer.fit()

    See our notebook on training with FFCV for a full example.

  2. βœ… Autoresume from Checkpoints

    When setting autoresume=True, Composer can automatically resume from an existing checkpoint before starting a new training run. Specifically, the trainer will look in the save_folder (and any loggers that save artifacts) for the latest checkpoint; if none is found, then it'll start from the beginning.

    This feature does not require a different entrypoint to distinguish between starting a new training run or automatically resuming from an existing one, making it easy to use Composer on spot preemptable cloud instances. Simply set autoresume=True, point the instance to your training script, and Composer will handle the rest!

    from composer import Trainer
    
    # When using `autoresume`, it is required to specify the
    # `run_name`, so Composer will know which training run to
    # resume
    run_name = "my_autoresume_training_run"
    
    trainer = Trainer(
        ...,
        run_name=run_name,
        # specify where to save checkpoints
        save_folder="./my_autoresume_training_run",
        autoresume=True,
    )
    
    # Train! Composer will handle loading an existing
    # checkpoint or starting a new training run
    trainer.fit()

    See the Trainer API Reference for more information.

  3. ♻️ Reuse the Trainer

    Want to train on multiple dataloaders sequentially? Each trainer object now supports multiple calls to Trainer.fit(), so you can continue training an existing model on a new dataloader, with new schedulers, all while using the same model and trainer object.

    For example:

    from torch.utils.data import DataLoader
    
    from composer import Trainer
    
    train_dl_1 = DataLoader(...)
    trainer = Trainer(
        model=model,
        max_duration='5ep',
        train_dataloader=train_dl_1,
    )
    
    # Train once!
    trainer.fit()
    
    # Train again with a new dataloader for another 5 epochs
    train_dl_2 = DataLoader(...)
    trainer.fit(
        train_dataloader=train_dl_2,
        duration='5ep',
    )

    See the Trainer API Reference for more information.

  4. βš–οΈ Eval or Predict Only? No Problem

    You can evaluate or predict on an existing model, without having to supply a train dataloader or training duration argument -- they're now optional.

    import torchmetrics
    from torch.utils.data import DataLoader
    
    from composer import Trainer
    
    # Construct the trainer
    trainer = Trainer(model=model)
    
    # Evaluate!
    eval_dl = DataLoader(...)
    trainer.eval(
        dataloader=eval_dl,
        metrics=torchmetrics.Accuracy(),
    )
    
    # Examine evaluation metrics
    print("Eval metrics", trainer.state.metrics['eval'])
    
    # Or, predict!
    predict_dl = DataLoader(...)
    trainer.predict(dataloader=predict_dl)

    See the Trainer API Reference for more information.

  5. πŸ›‘ Early Stopper and Threshold Stopper Callbacks

    The Early Stopper and Threshold Stopper callbacks end training early when the target metrics are met:

    from composer.callbacks.early_stopper import EarlyStopper
    from torchmetrics.classification.accuracy import Accuracy
    
    # Construct the callback
    early_stopper = EarlyStopper(
        monitor="Accuracy",
        dataloader_label="eval",
        patience=2,
    )
    
    # Construct the trainer
    trainer = Trainer(
        ...,
        callbacks=early_stopper,
        max_duration="100ep",
    )
    
    # Train!
    # Training will end early if the accuracy does not improve
    # over two epochs
    trainer.fit()
  6. πŸͺ΅ Load Checkpoints from Loggers

    It's now possible to restore checkpoints from loggers that support file artifacts (such as the Weights & Baises Logger). No need to download your checkpoints manually anymore.

    from composer import Trainer
    from composer.loggers import WandBLogger
    
    # Configure the W&B Logger
    wandb_logger = WandBLogger(
        # set to True to capture artifacts, like checkpoints
        log_artifacts=True,
        init_params={
            'project': 'my-wandb-project-name',
        },
    )
    
    # Then, to train and save checkpoints to W&B:
    trainer = Trainer(
        ...,
        loggers=wandb_logger,
        save_folder="/tmp/checkpoints",
        save_interval="1ep",
        save_artifact_name="epoch{epoch}.pt",
    )
    
    # Finally, to load checkpoints from W&B
    trainer = Trainer(
        ...,
        load_object_store=wandb_logger,
        load_path="epoch1.pt:latest",
    )
  7. βŒ› Wall Clock, Evaluation, and Prediction Time Tracking

    The timestamp object measures wall clock time via three new fields: total_wct, epoch_wct, and batch_wct. These fields track the total elapsed training time, the elapsed training time of the current epoch, and the time to train the last batch. Read the wall clock time via a callback:

    from composer import Callback, Trainer
    
    class MyCallback(Callback):
        def batch_end(self, state, event):
            print(f"Total wct: {state.timetsamp.total_wct}")
            print(f"Epoch wct: {state.timetsamp.epoch_wct}")
            print(f"Batch wct: {state.timetsamp.batch_wct}")
    
    # Construct the trainer with this callback
    trainer = Trainer(
        ...,
        callbacks=MyCallback(),
    )
    
    # Train!
    trainer.fit()

    In addition, the training state object has two new fields for tracking time during evaluation and prediction: eval_timestamp and predict_timestamp. These fields, just like any others on the state object, are accessible to algorithms, callbacks, and loggers.

  8. Training DeepLabv3+ on the ADE20k Dataset

    DeepLabv3+ is a common baseline model for semantic segmentation tasks. We provide a ComposerModel implementation for DeepLabv3+ built using torchvision and mmsegmentation for the backbone and head, respectively.

    We found the DeepLabv3+ baseline can be significantly improved using the new PyTorch pre-trained weights. Additional gains are made through a hyperparameter sweep.

    We benchmark our DeepLabv3+ model on a single 8xA100 machine using ADE20k, a popular semantic segmentation dataset. The final results on ADE20k are:

    Model mIoU Time-to-Train
    Unoptimized DeepLabv3+ 44.17 +/- 0.14 6.39 hr
    Optimized DeepLabv3+ 45.78 +/- 0.26 4.67 hr

    Checkout our documentation for more info!

API Changes

  1. πŸͺ Additional Batch Type Support

    Composer v0.7.0 removed the BatchDict and BatchPair types, and now supports any batch type. We're updating our algorithms to support batches of custom formats.

  2. 🏎️ Simplified Profiling Arguments

    To simplify the Trainer constructor, the profiling arguments were replaced with a single profiler argument, which takes an instance of the Profiler.

    from composer.trainer import Trainer
    from composer.profiler import PRofiler, JSONTraceHandler, cyclic_schedule
    
    trainer = Trainer(
        ...,
        profiler=Profiler(
            trace_handlers=JSONTraceHandler(
                folder=composer_trace_dir,
                overwrite=True,
            ),
            schedule=cyclic_schedule(
                wait=0,
                warmup=1,
                active=4,
                repeat=1,
            ),
            torch_prof_folder=torch_trace_dir,
            torch_prof_overwrite=True,
            ...,
        )
    )

    See the profiling guide for additional information.

  3. πŸšͺ Event.FIT_END and Engine.close()

    With support for reusing the trainer for multiple calls to Trainer.fit, callbacks and loggers are no longer closed at the end of a training run.

    Instead, Event.FIT_END was added, which can be used by Callbacks for anything that should happen at the end of each invocation of Trainer.fit. See the Event Guide for aadditional inforrmation.

    Finally, whenever the trainer is garbage collected or Trainer.close is called, Callback.close and Callback.post_close are invoked, ensuring that they will be called only once per trainer.

  4. βŒ› State.timesamp replaces State.timer

    Removed State.timer and replaced it with State.timestamp, which is now a static Timestamp object. The training loop replaces State.timestamp with a new object on each batch. See the Time Guide for additional information.

  5. πŸ’Ώ Data Configuration

    Two new proerties, State.dataloader and State.dataloader_label, were added to the state. These properties track the currently active dataloader (e.g. the training dataloader when training; the evaluation dataloader when evaluating).

    In adddition, State.subset_num_batches was renamed to State.dataloader_len to reflect the actual dataloader length that will be used for training and evaluation.

    A helper method State.set_dataloader was added to ensure the dataloader properties are updated correctly.

  6. βš–οΈ Removed the Deprecated Scale Schedule Algorithm

    The scale schedule algorithm class, deprecated in v0.4.0, has been removed. Instead, use the scale_schedule_ratio argument when constructing the trainer.

    from composer import Trainer
    from composer.optim.scheduler import MultiStepScheduler
    
    trainer = Trainer(
        ...,
        max_duration="20ep",
        schedulers=MultiStepScheduler(milestones=["10ep", "16ep"]),
        scale_schedule_ratio=0.5,
    )

    See the Scale Schedule Method Card for additional info.

Bug Fixes

  • Fixed an bug where Event.FIT_END was not being called in the training loop (#1054)
  • Fixed a bug where evaluation would not run at the end of training unless if it aligned with the eval_interval (#1045)
  • Fixed a bug where models trained with SWA could not be used with checkpoints (#1015)
  • Fixed a bug where the Speed Monitor included validation time in the training throughput measurements, resulting in slower reported throughput measurements (#1053)
  • Fixed a bug to make the ComposerClassifier compatible with TorchScript (#1036)
  • Fixed a bug where fractional Time Objects were being truncated instead of raising an exception (#1038)
  • Changed the defaults for Selective Backprop to not scale inputs, so the algorithm can work with non-vision workloads (#896)

New Contributors

Changelog

v0.6.1...v0.7.0