diff --git a/README.md b/README.md index a99a73d..f211deb 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,37 @@ Notes: * Batch normalization appears to work best with a "post" normalization strategy, whereas a "pre" normalization strategy appears to work best with layer normalization. +#### [NSynth](https://magenta.tensorflow.org/datasets/nsynth) + +```sh +python train.py \ + --dataset=nsynth_short \ + --batch_size=-1 \ + --val_prop=0.01 \ + --max_epochs=150 \ + --limit_train_batches=0.025 \ + --lr=1e-2 \ + --n_blocks=4 \ + --pooling=avg_2 \ + --d_model=128 \ + --weight_decay=0.0 \ + --norm_type=batch \ + --norm_strategy=post \ + --p_dropout=0.1 \ + --precision=16 \ + --accumulate_grad=4 \ + --patience=10 +``` + +**Validation Accuracy**: 39.6% after 5 epochs, 54.1% after 17 epochs (best)
+**Speed**: ~1.6 batches/second + +Notes: + + * The model is tasked with classifying waveforms based on the musical instrument which generated them (10 classes) + * The `nsynth_short` dataset contains waveforms which are truncated after 2 seconds, whereas the `nsyth` dataset contains + the full four-second waveforms. + ## Components ### Layer diff --git a/dev_requirements.txt b/dev_requirements.txt index cc991bb..c68a476 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -2,5 +2,7 @@ black==21.12b0 fire==0.4.0 isort==5.10.1 pytorch-lightning==1.5.9 +requests==2.27.1 torchaudio==0.10.1 torchvision==0.11.2 +tqdm==4.62.3 diff --git a/experiments/data/transforms.py b/experiments/data/_transforms.py similarity index 100% rename from experiments/data/transforms.py rename to experiments/data/_transforms.py diff --git a/experiments/data/_utils.py b/experiments/data/_utils.py new file mode 100644 index 0000000..418b397 --- /dev/null +++ b/experiments/data/_utils.py @@ -0,0 +1,73 @@ +""" + + Utils + +""" +import tarfile +from contextlib import closing +from pathlib import Path + +import requests +from tqdm import tqdm + + +def download(url: str, dst: Path, chunk_size: int = 1024, verbose: bool = True) -> Path: + """Dowload a file from a ``url`` to ``dst``. + + Args: + url (str): URL of the file to download + dst (Path): download destination. If a directory, the filename + will be determined using the URL. + chunk_size (int): size of "chunks" to use when streaming the data + verbose (bool): if ``True`` display a progress bar + + Returns: + dst (Path): the path to the downloaded file + + """ + if dst.is_dir(): + dst = dst.joinpath(Path(url).name) + + response = requests.get(url, stream=True) + total = int(response.headers.get("content-length", 0)) or None + with dst.open("wb") as file: + with tqdm( + desc=f"Downloading {Path(url).name}", + total=total, + unit="iB", + unit_scale=True, + unit_divisor=chunk_size, + disable=total is None or not verbose, + ) as pbar: + for data in response.iter_content(chunk_size=chunk_size): + size = file.write(data) + pbar.update(size) + return dst + + +def untar( + src: Path, dst: Path, delete_src: bool = False, verbose: bool = False +) -> Path: + """Untar ``src``. + + Args: + src (Path): source file to untar + dst (Path): destination directory + delete_src (bool): if ``True``, delete ``src`` when complete + + Returns: + None + + """ + if not src.is_file(): + raise OSError(f"No file {str(dst)}") + if not dst.is_dir(): + raise OSError(f"No directory {str(dst)}") + + if verbose: + print(f"Untaring {str(src)}...") + with closing(tarfile.open(src)) as f: + f.extractall(dst) + + if delete_src: + src.unlink() diff --git a/experiments/data/datasets.py b/experiments/data/datasets.py index 1bc7eea..9e5fa57 100644 --- a/experiments/data/datasets.py +++ b/experiments/data/datasets.py @@ -8,33 +8,32 @@ """ from __future__ import annotations +import json +from functools import cached_property from pathlib import Path from typing import Any, Optional import numpy as np import torch +import torchaudio from torch.nn import functional as F from torchaudio.datasets import SPEECHCOMMANDS as _SpeechCommands # noqa from torchvision.datasets import CIFAR10, MNIST from torchvision.transforms import Compose, Lambda, ToTensor -from experiments.data.transforms import build_permute_transform +from experiments.data._utils import download, untar +from experiments.data._transforms import build_permute_transform + +_DATASETS_DIRECTORY = Path("~/datasets") class SequenceDataset: NAME: Optional[str] = None SAVE_NAME: Optional[str] = None - classes: list[str | int] + class_names: Optional[list[str | int]] = None - def __init__( - self, - val_prop: float = 0.1, - seed: int = 42, - **kwargs: Any, - ) -> None: - super().__init__(self.root_dir, **kwargs) - self.val_prop = val_prop - self.seed = seed + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) @property def root_dir(self) -> Path: @@ -43,13 +42,21 @@ def root_dir(self) -> Path: if not isinstance(name, str): raise TypeError("`NAME` not set") - path = Path("~/datasets").expanduser().joinpath(name) + path = _DATASETS_DIRECTORY.expanduser().joinpath(name) path.mkdir(parents=True, exist_ok=True) return path + @property + def classes(self) -> list[str | int]: + """Names of all classes in the dataset.""" + if self.class_names: + return self.class_names + else: + raise AttributeError("Class names not set") + @property def n_classes(self) -> int: - """Number of classes in the dataset.""" + """Number of class_names in the dataset.""" return len(self.classes) @property @@ -62,12 +69,20 @@ def shape(self) -> tuple[int, ...]: """Shape of the data in the dataset.""" raise NotImplementedError() + def __len__(self) -> int: + raise NotImplementedError() + + def __getitem__(self, item: int) -> tuple[torch.Tensor, int]: + raise NotImplementedError() + class SMnistDataset(SequenceDataset, MNIST): NAME: str = "SMNIST" + class_names: list[int] = list(range(10)) def __init__(self, **kwargs: Any) -> None: super().__init__( + root=self.root_dir, download=True, transform=Compose([ToTensor(), Lambda(lambda t: t.flatten())]), **kwargs, @@ -87,6 +102,7 @@ class PMnistDataset(SequenceDataset, MNIST): def __init__(self, **kwargs: Any) -> None: super().__init__( + root=self.root_dir, download=True, transform=Compose([ToTensor(), build_permute_transform((28 * 28,))]), **kwargs, @@ -103,10 +119,11 @@ def shape(self) -> tuple[int, ...]: class SCIFAR10Dataset(SequenceDataset, CIFAR10): NAME: str = "SCIFAR10" - classes = list(range(10)) + class_names: list[int] = list(range(10)) def __init__(self, **kwargs: Any) -> None: super().__init__( + root=self.root_dir, download=True, transform=Compose( [ @@ -129,7 +146,7 @@ def shape(self) -> tuple[int, ...]: class SpeechCommands(SequenceDataset, _SpeechCommands): NAME: str = "SPEECH_COMMANDS" SEGMENT_SIZE: int = 16_000 - classes = [ + class_names: list[str] = [ "bed", "cat", "down", @@ -168,7 +185,7 @@ class SpeechCommands(SequenceDataset, _SpeechCommands): ] def __init__(self, **kwargs: Any) -> None: - super().__init__(download=True, **kwargs) + super().__init__(root=self.root_dir, download=True, **kwargs) self.label_ids = {l: e for e, l in enumerate(self.classes)} self._walker = [i for i in self._walker if Path(i).parent.name in self.classes] @@ -197,7 +214,7 @@ def shape(self) -> tuple[int, ...]: class SpeechCommands10(SpeechCommands): NAME: str = "SPEECH_COMMANDS_10" SAVE_NAME = "SPEECH_COMMANDS" - classes = [ + class_names: list[int] = [ "yes", "no", "up", @@ -226,7 +243,7 @@ def __getitem__(self, item: int) -> tuple[torch.Tensor, int]: if not hot_idx.any(): # ensure at least one hot_idx[np.random.choice(self.N_REPEATS - 1)] = True - label = 0 + label = -1 chunks = list() for use_y in hot_idx: chunks.append(y if use_y else torch.zeros_like(y)) @@ -238,6 +255,77 @@ def shape(self) -> tuple[int, ...]: return (self.N_REPEATS * self.SEGMENT_SIZE,) # noqa +class NSynthDataset(SequenceDataset): + NAME: str = "NSYNTH" + SEGMENT_SIZE: int = 64_000 + URLS: dict[str, str] = { + "train": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-train.jsonwav.tar.gz", # noqa + "valid": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-valid.jsonwav.tar.gz", # noqa + "test": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-test.jsonwav.tar.gz", # noqa + } + + def __init__(self, download: bool = True, verbose: bool = True) -> None: + super().__init__() + self.download = download + self.verbose = verbose + + if download: + self.fetch_data() + + def fetch_data(self, force: bool = False) -> None: + for url in self.URLS.values(): + dirname, *_ = Path(url).stem.split(".") + if force or not self.root_dir.joinpath(dirname).is_dir(): + untar( + download(url, dst=self.root_dir, verbose=self.verbose), + dst=self.root_dir, + delete_src=True, + verbose=self.verbose, + ) + + @cached_property + def metadata(self) -> dict[str, dict[str, Any]]: + metadata = dict() + for path in self.root_dir.rglob("*.json"): + with path.open("r") as f: + payload = json.load(f) + for v in payload.values(): + v["split"] = path.parent.name.split("-")[-1] + metadata |= payload + return metadata + + @cached_property + def classes(self) -> list[str | int]: + return sorted({v["instrument_family_str"] for v in self.metadata.values()}) + + @cached_property + def files(self) -> list[Path]: + return list(self.root_dir.rglob("*.wav")) + + @property + def channels(self) -> int: + return 1 + + @property + def shape(self) -> tuple[int, ...]: + return (self.SEGMENT_SIZE,) # noqa + + def __len__(self) -> int: + return len(self.files) + + def __getitem__(self, item: int) -> tuple[torch.Tensor, int]: + path = self.files[item] + y, _ = torchaudio.load(path, normalize=True, channels_first=False) # noqa + label = self.metadata[path.stem]["instrument_family"] + return y[: self.SEGMENT_SIZE, ...], label + + +class NSynthDatasetShort(NSynthDataset): + NAME: str = "NSYNTH_SHORT" + SAVE_NAME: str = "NSYNTH" + SEGMENT_SIZE: int = 64_000 // 2 + + if __name__ == "__main__": smnist_wrapper = SMnistDataset() diff --git a/experiments/utils.py b/experiments/utils.py index 8cee5dc..553e48c 100644 --- a/experiments/utils.py +++ b/experiments/utils.py @@ -9,10 +9,10 @@ import torch from torch import nn +from torch.utils.data import Dataset, random_split from s4torch.block import S4Block from s4torch.layer import S4Layer -from torch.utils.data import Dataset, random_split class OutputPaths: @@ -34,9 +34,9 @@ def checkpoints(self) -> Path: return self._make_dir(f"checkpoints/{self.run_name}") -def count_parameters(model: nn.Module, ajust_complex: bool = True) -> int: +def count_parameters(model: nn.Module, adjust_complex: bool = True) -> int: def get_count(param: nn.Parameter) -> int: - return param.numel() * (param.is_complex() + ajust_complex) + return param.numel() * (param.is_complex() + adjust_complex) return sum(get_count(p) for p in model.parameters()) diff --git a/s4torch/aux/residual.py b/s4torch/aux/residual.py index b4101b8..f7d34e6 100644 --- a/s4torch/aux/residual.py +++ b/s4torch/aux/residual.py @@ -10,7 +10,7 @@ class Residual(nn.Module): - def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # noqa return y + x diff --git a/s4torch/layer.py b/s4torch/layer.py index 5a8100a..e92fa35 100644 --- a/s4torch/layer.py +++ b/s4torch/layer.py @@ -8,6 +8,7 @@ import numpy as np import torch from torch import nn +from torch import view_as_real as as_real from torch.fft import ifft, irfft, rfft from torch.nn import functional as F from torch.nn import init @@ -22,7 +23,7 @@ def _log_step_initializer( return tensor * scale + np.log(dt_min) -def _make_omega_l(l_max: int, dtype: torch.dtype) -> torch.Tensor: +def _make_omega_l(l_max: int, dtype: torch.dtype = torch.complex64) -> torch.Tensor: return torch.arange(l_max).type(dtype).mul(2j * np.pi / l_max).exp() @@ -73,9 +74,9 @@ def _cauchy_dot(v: torch.Tensor, denominator: torch.Tensor) -> torch.Tensor: def _non_circular_convolution(u: torch.Tensor, K: torch.Tensor) -> torch.Tensor: l_max = u.shape[1] - ud = rfft(F.pad(u, pad=(0, 0, 0, l_max, 0, 0)), dim=1) - Kd = rfft(F.pad(K, pad=(0, l_max)), dim=-1) - return irfft(ud.transpose(-2, -1) * Kd)[..., :l_max].transpose(-2, -1) + ud = rfft(F.pad(u.float(), pad=(0, 0, 0, l_max, 0, 0)), dim=1) + Kd = rfft(F.pad(K.float(), pad=(0, l_max)), dim=-1) + return irfft(ud.transpose(-2, -1) * Kd)[..., :l_max].transpose(-2, -1).type_as(u) class S4Layer(nn.Module): @@ -104,9 +105,9 @@ def __init__(self, d_model: int, n: int, l_max: int) -> None: self.l_max = l_max p, q, lambda_ = map(lambda t: t.type(torch.complex64), _make_p_q_lambda(n)) - self.p = nn.Parameter(p) - self.q = nn.Parameter(q) - self.lambda_ = nn.Parameter(lambda_.unsqueeze(0).unsqueeze(1)) + self._p = nn.Parameter(as_real(p)) + self._q = nn.Parameter(as_real(q)) + self._lambda_ = nn.Parameter(as_real(lambda_).unsqueeze(0).unsqueeze(1)) self.register_buffer( "omega_l", @@ -120,11 +121,11 @@ def __init__(self, d_model: int, n: int, l_max: int) -> None: ), ) - self.B = nn.Parameter( - init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64)) + self._B = nn.Parameter( + as_real(init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64))) ) - self.Ct = nn.Parameter( - init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64)) + self._Ct = nn.Parameter( + as_real(init.xavier_normal_(torch.empty(d_model, n, dtype=torch.complex64))) ) self.D = nn.Parameter(torch.ones(1, 1, d_model)) self.log_step = nn.Parameter(_log_step_initializer(torch.rand(d_model))) @@ -132,6 +133,26 @@ def __init__(self, d_model: int, n: int, l_max: int) -> None: def extra_repr(self) -> str: return f"d_model={self.d_model}, n={self.n}, l_max={self.l_max}" + @property + def p(self) -> torch.Tensor: + return torch.view_as_complex(self._p) + + @property + def q(self) -> torch.Tensor: + return torch.view_as_complex(self._q) + + @property + def lambda_(self) -> torch.Tensor: + return torch.view_as_complex(self._lambda_) + + @property + def B(self) -> torch.Tensor: + return torch.view_as_complex(self._B) + + @property + def Ct(self) -> torch.Tensor: + return torch.view_as_complex(self._Ct) + def _compute_roots(self) -> torch.Tensor: a0, a1 = self.Ct.conj(), self.q.conj() b0, b1 = self.B, self.p @@ -153,7 +174,7 @@ def K(self) -> torch.Tensor: # noqa at_roots = self._compute_roots() out = ifft(at_roots, n=self.l_max, dim=-1) conv = torch.stack([i[self.ifft_order] for i in out]).real - return conv.float().unsqueeze(0) + return conv.unsqueeze(0) def forward(self, u: torch.Tensor) -> torch.Tensor: """Forward pass. diff --git a/train.py b/train.py index 19193fc..eae6f14 100644 --- a/train.py +++ b/train.py @@ -8,29 +8,28 @@ import math from argparse import Namespace from datetime import datetime +from multiprocessing import cpu_count from typing import Any, Optional, Tuple, Type import fire import pytorch_lightning as pl import torch -from multiprocessing import cpu_count from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import TensorBoardLogger from pytorch_lightning.utilities.seed import seed_everything from pytorch_lightning.utilities.types import EPOCH_OUTPUT from torch import nn from torch.optim.lr_scheduler import ReduceLROnPlateau -from torch.utils.data import DataLoader, Dataset -from torch.cuda import is_available as cuda_available +from torch.utils.data import DataLoader from experiments.data.datasets import SequenceDataset from experiments.metrics import compute_accuracy from experiments.utils import ( OutputPaths, enumerate_subclasses, - train_val_split, parse_params_in_s4blocks, to_sequence, + train_val_split, ) from s4torch import S4Model @@ -60,24 +59,6 @@ def _parse_pooling(pooling: Optional[str]) -> Optional[nn.AvgPool1d | nn.MaxPool raise ValueError(f"Unsupported pooling method '{method}'") -def _make_dataloader( - dataset: Dataset, - shuffle: bool, - batch_size: int, - num_workers: int = max(1, cpu_count() - 1), - pin_memory: Optional[bool] = None, - **kwargs: Any, -) -> DataLoader: - return DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda_available() if pin_memory is None else pin_memory, - **kwargs, - ) - - class LighteningS4Model(pl.LightningModule): def __init__( self, @@ -156,29 +137,25 @@ def configure_optimizers(self) -> dict[str, Any]: "lr_scheduler": {"scheduler": scheduler, "monitor": "val_acc"}, } - def train_dataloader(self) -> DataLoader: - ds_train, _ = train_val_split( + def _make_dataloader(self, train: bool) -> DataLoader: + ds_train, ds_val = train_val_split( self.seq_dataset, - self.seq_dataset.val_prop, - seed=self.seq_dataset.seed, + val_prop=self.hparams.val_prop, + seed=self.hparams.seed, ) - return _make_dataloader( - ds_train, - shuffle=True, + return DataLoader( + ds_train if train else ds_val, batch_size=self.hparams.batch_size, + shuffle=train, + num_workers=max(1, cpu_count() - 1), + pin_memory=torch.cuda.is_available(), ) + def train_dataloader(self) -> DataLoader: + return self._make_dataloader(train=True) + def val_dataloader(self) -> DataLoader: - _, ds_val = train_val_split( - self.seq_dataset, - self.seq_dataset.val_prop, - seed=self.seq_dataset.seed, - ) - return _make_dataloader( - ds_val, - shuffle=False, - batch_size=self.hparams.batch_size, - ) + return self._make_dataloader(train=False) def main( @@ -196,6 +173,7 @@ def main( pooling: Optional[str] = None, # Training max_epochs: Optional[int] = None, + limit_train_batches: int | float = 1.0, lr: float = 1e-2, lr_s4: float = 1e-3, min_lr: float = 1e-6, @@ -205,6 +183,7 @@ def main( patience: int = 5, gpus: int = -1, # Auxiliary + precision: int | str = 32, output_dir: str = "~/s4-output", save_top_k: int = 0, seed: int = 1234, @@ -233,6 +212,8 @@ def main( pooling (str, optional): pooling method to use. Options: ``None``, ``avg_KERNEL_SIZE``, ``max_KERNEL_SIZE``. Example: ``avg_2``. max_epochs (int, optional): maximum number of epochs to train for + limit_train_batches (int, float): number (``int``) or proportion (``float``) + of the total number of training batches to use on each epoch lr (float): learning rate for parameters which do not belong to S4 blocks lr_s4 (float): learning rate for parameters which belong to S4 blocks min_lr (float): minimum learning rate to permit ``ReduceLROnPlateau`` to use @@ -243,6 +224,7 @@ def main( patience (int): number of epochs with no improvement to wait before reducing the learning rate gpus (int): number of GPUs to use. If ``-1``, use all available GPUs. + precision (int, str): precision of floating point operations output_dir (str): directory where output (logs and checkpoints) will be saved save_top_k (int): save top k models, as determined by the ``val_acc`` metric. (Defaults to ``0``, which disables model saving.) @@ -257,7 +239,7 @@ def main( run_name = f"s4-model-{datetime.utcnow().isoformat()}" output_paths = OutputPaths(output_dir, run_name=run_name) auto_scale_batch_size = batch_size == -1 - seq_dataset = _get_seq_wrapper(dataset.strip())(val_prop=val_prop, seed=seed) + seq_dataset = _get_seq_wrapper(dataset.strip())() pl_model = LighteningS4Model( S4Model( @@ -280,9 +262,11 @@ def main( trainer = pl.Trainer( max_epochs=max_epochs, gpus=(torch.cuda.device_count() if gpus == -1 else gpus) or None, + precision=precision, stochastic_weight_avg=swa, accumulate_grad_batches=accumulate_grad, auto_scale_batch_size=auto_scale_batch_size, + limit_train_batches=limit_train_batches, logger=TensorBoardLogger(output_paths.logs, name=run_name), callbacks=ModelCheckpoint( dirpath=output_paths.checkpoints,