Skip to content

Commit

Permalink
Merge pull request #4 from TariqAHassan/audio-experiments
Browse files Browse the repository at this point in the history
Audio experiments
  • Loading branch information
TariqAHassan committed Jan 28, 2022
2 parents 9381cb0 + 38c6604 commit 8a9eee9
Show file tree
Hide file tree
Showing 9 changed files with 273 additions and 74 deletions.
31 changes: 31 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,37 @@ Notes:
* Batch normalization appears to work best with a "post" normalization strategy, whereas
a "pre" normalization strategy appears to work best with layer normalization.

#### [NSynth](https://magenta.tensorflow.org/datasets/nsynth)

```sh
python train.py \
--dataset=nsynth_short \
--batch_size=-1 \
--val_prop=0.01 \
--max_epochs=150 \
--limit_train_batches=0.025 \
--lr=1e-2 \
--n_blocks=4 \
--pooling=avg_2 \
--d_model=128 \
--weight_decay=0.0 \
--norm_type=batch \
--norm_strategy=post \
--p_dropout=0.1 \
--precision=16 \
--accumulate_grad=4 \
--patience=10
```

**Validation Accuracy**: 39.6% after 5 epochs, 54.1% after 17 epochs (best) <br>
**Speed**: ~1.6 batches/second

Notes:

* The model is tasked with classifying waveforms based on the musical instrument which generated them (10 classes)
* The `nsynth_short` dataset contains waveforms which are truncated after 2 seconds, whereas the `nsyth` dataset contains
the full four-second waveforms.

## Components

### Layer
Expand Down
2 changes: 2 additions & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ black==21.12b0
fire==0.4.0
isort==5.10.1
pytorch-lightning==1.5.9
requests==2.27.1
torchaudio==0.10.1
torchvision==0.11.2
tqdm==4.62.3
File renamed without changes.
73 changes: 73 additions & 0 deletions experiments/data/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""
Utils
"""
import tarfile
from contextlib import closing
from pathlib import Path

import requests
from tqdm import tqdm


def download(url: str, dst: Path, chunk_size: int = 1024, verbose: bool = True) -> Path:
"""Dowload a file from a ``url`` to ``dst``.
Args:
url (str): URL of the file to download
dst (Path): download destination. If a directory, the filename
will be determined using the URL.
chunk_size (int): size of "chunks" to use when streaming the data
verbose (bool): if ``True`` display a progress bar
Returns:
dst (Path): the path to the downloaded file
"""
if dst.is_dir():
dst = dst.joinpath(Path(url).name)

response = requests.get(url, stream=True)
total = int(response.headers.get("content-length", 0)) or None
with dst.open("wb") as file:
with tqdm(
desc=f"Downloading {Path(url).name}",
total=total,
unit="iB",
unit_scale=True,
unit_divisor=chunk_size,
disable=total is None or not verbose,
) as pbar:
for data in response.iter_content(chunk_size=chunk_size):
size = file.write(data)
pbar.update(size)
return dst


def untar(
src: Path, dst: Path, delete_src: bool = False, verbose: bool = False
) -> Path:
"""Untar ``src``.
Args:
src (Path): source file to untar
dst (Path): destination directory
delete_src (bool): if ``True``, delete ``src`` when complete
Returns:
None
"""
if not src.is_file():
raise OSError(f"No file {str(dst)}")
if not dst.is_dir():
raise OSError(f"No directory {str(dst)}")

if verbose:
print(f"Untaring {str(src)}...")
with closing(tarfile.open(src)) as f:
f.extractall(dst)

if delete_src:
src.unlink()
124 changes: 106 additions & 18 deletions experiments/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,33 +8,32 @@
"""
from __future__ import annotations

import json
from functools import cached_property
from pathlib import Path
from typing import Any, Optional

import numpy as np
import torch
import torchaudio
from torch.nn import functional as F
from torchaudio.datasets import SPEECHCOMMANDS as _SpeechCommands # noqa
from torchvision.datasets import CIFAR10, MNIST
from torchvision.transforms import Compose, Lambda, ToTensor

from experiments.data.transforms import build_permute_transform
from experiments.data._utils import download, untar
from experiments.data._transforms import build_permute_transform

_DATASETS_DIRECTORY = Path("~/datasets")


class SequenceDataset:
NAME: Optional[str] = None
SAVE_NAME: Optional[str] = None
classes: list[str | int]
class_names: Optional[list[str | int]] = None

def __init__(
self,
val_prop: float = 0.1,
seed: int = 42,
**kwargs: Any,
) -> None:
super().__init__(self.root_dir, **kwargs)
self.val_prop = val_prop
self.seed = seed
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)

@property
def root_dir(self) -> Path:
Expand All @@ -43,13 +42,21 @@ def root_dir(self) -> Path:
if not isinstance(name, str):
raise TypeError("`NAME` not set")

path = Path("~/datasets").expanduser().joinpath(name)
path = _DATASETS_DIRECTORY.expanduser().joinpath(name)
path.mkdir(parents=True, exist_ok=True)
return path

@property
def classes(self) -> list[str | int]:
"""Names of all classes in the dataset."""
if self.class_names:
return self.class_names
else:
raise AttributeError("Class names not set")

@property
def n_classes(self) -> int:
"""Number of classes in the dataset."""
"""Number of class_names in the dataset."""
return len(self.classes)

@property
Expand All @@ -62,12 +69,20 @@ def shape(self) -> tuple[int, ...]:
"""Shape of the data in the dataset."""
raise NotImplementedError()

def __len__(self) -> int:
raise NotImplementedError()

def __getitem__(self, item: int) -> tuple[torch.Tensor, int]:
raise NotImplementedError()


class SMnistDataset(SequenceDataset, MNIST):
NAME: str = "SMNIST"
class_names: list[int] = list(range(10))

def __init__(self, **kwargs: Any) -> None:
super().__init__(
root=self.root_dir,
download=True,
transform=Compose([ToTensor(), Lambda(lambda t: t.flatten())]),
**kwargs,
Expand All @@ -87,6 +102,7 @@ class PMnistDataset(SequenceDataset, MNIST):

def __init__(self, **kwargs: Any) -> None:
super().__init__(
root=self.root_dir,
download=True,
transform=Compose([ToTensor(), build_permute_transform((28 * 28,))]),
**kwargs,
Expand All @@ -103,10 +119,11 @@ def shape(self) -> tuple[int, ...]:

class SCIFAR10Dataset(SequenceDataset, CIFAR10):
NAME: str = "SCIFAR10"
classes = list(range(10))
class_names: list[int] = list(range(10))

def __init__(self, **kwargs: Any) -> None:
super().__init__(
root=self.root_dir,
download=True,
transform=Compose(
[
Expand All @@ -129,7 +146,7 @@ def shape(self) -> tuple[int, ...]:
class SpeechCommands(SequenceDataset, _SpeechCommands):
NAME: str = "SPEECH_COMMANDS"
SEGMENT_SIZE: int = 16_000
classes = [
class_names: list[str] = [
"bed",
"cat",
"down",
Expand Down Expand Up @@ -168,7 +185,7 @@ class SpeechCommands(SequenceDataset, _SpeechCommands):
]

def __init__(self, **kwargs: Any) -> None:
super().__init__(download=True, **kwargs)
super().__init__(root=self.root_dir, download=True, **kwargs)

self.label_ids = {l: e for e, l in enumerate(self.classes)}
self._walker = [i for i in self._walker if Path(i).parent.name in self.classes]
Expand Down Expand Up @@ -197,7 +214,7 @@ def shape(self) -> tuple[int, ...]:
class SpeechCommands10(SpeechCommands):
NAME: str = "SPEECH_COMMANDS_10"
SAVE_NAME = "SPEECH_COMMANDS"
classes = [
class_names: list[int] = [
"yes",
"no",
"up",
Expand Down Expand Up @@ -226,7 +243,7 @@ def __getitem__(self, item: int) -> tuple[torch.Tensor, int]:
if not hot_idx.any(): # ensure at least one
hot_idx[np.random.choice(self.N_REPEATS - 1)] = True

label = 0
label = -1
chunks = list()
for use_y in hot_idx:
chunks.append(y if use_y else torch.zeros_like(y))
Expand All @@ -238,6 +255,77 @@ def shape(self) -> tuple[int, ...]:
return (self.N_REPEATS * self.SEGMENT_SIZE,) # noqa


class NSynthDataset(SequenceDataset):
NAME: str = "NSYNTH"
SEGMENT_SIZE: int = 64_000
URLS: dict[str, str] = {
"train": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-train.jsonwav.tar.gz", # noqa
"valid": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-valid.jsonwav.tar.gz", # noqa
"test": "http://download.magenta.tensorflow.org/datasets/nsynth/nsynth-test.jsonwav.tar.gz", # noqa
}

def __init__(self, download: bool = True, verbose: bool = True) -> None:
super().__init__()
self.download = download
self.verbose = verbose

if download:
self.fetch_data()

def fetch_data(self, force: bool = False) -> None:
for url in self.URLS.values():
dirname, *_ = Path(url).stem.split(".")
if force or not self.root_dir.joinpath(dirname).is_dir():
untar(
download(url, dst=self.root_dir, verbose=self.verbose),
dst=self.root_dir,
delete_src=True,
verbose=self.verbose,
)

@cached_property
def metadata(self) -> dict[str, dict[str, Any]]:
metadata = dict()
for path in self.root_dir.rglob("*.json"):
with path.open("r") as f:
payload = json.load(f)
for v in payload.values():
v["split"] = path.parent.name.split("-")[-1]
metadata |= payload
return metadata

@cached_property
def classes(self) -> list[str | int]:
return sorted({v["instrument_family_str"] for v in self.metadata.values()})

@cached_property
def files(self) -> list[Path]:
return list(self.root_dir.rglob("*.wav"))

@property
def channels(self) -> int:
return 1

@property
def shape(self) -> tuple[int, ...]:
return (self.SEGMENT_SIZE,) # noqa

def __len__(self) -> int:
return len(self.files)

def __getitem__(self, item: int) -> tuple[torch.Tensor, int]:
path = self.files[item]
y, _ = torchaudio.load(path, normalize=True, channels_first=False) # noqa
label = self.metadata[path.stem]["instrument_family"]
return y[: self.SEGMENT_SIZE, ...], label


class NSynthDatasetShort(NSynthDataset):
NAME: str = "NSYNTH_SHORT"
SAVE_NAME: str = "NSYNTH"
SEGMENT_SIZE: int = 64_000 // 2


if __name__ == "__main__":
smnist_wrapper = SMnistDataset()

Expand Down
6 changes: 3 additions & 3 deletions experiments/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import torch
from torch import nn
from torch.utils.data import Dataset, random_split

from s4torch.block import S4Block
from s4torch.layer import S4Layer
from torch.utils.data import Dataset, random_split


class OutputPaths:
Expand All @@ -34,9 +34,9 @@ def checkpoints(self) -> Path:
return self._make_dir(f"checkpoints/{self.run_name}")


def count_parameters(model: nn.Module, ajust_complex: bool = True) -> int:
def count_parameters(model: nn.Module, adjust_complex: bool = True) -> int:
def get_count(param: nn.Parameter) -> int:
return param.numel() * (param.is_complex() + ajust_complex)
return param.numel() * (param.is_complex() + adjust_complex)

return sum(get_count(p) for p in model.parameters())

Expand Down
2 changes: 1 addition & 1 deletion s4torch/aux/residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class Residual(nn.Module):
def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
def forward(self, y: torch.Tensor, x: torch.Tensor) -> torch.Tensor: # noqa
return y + x


Expand Down
Loading

0 comments on commit 8a9eee9

Please sign in to comment.