diff --git a/src/tiledbsoma_ml/pytorch.py b/src/tiledbsoma_ml/pytorch.py index 8e6c077..970a550 100644 --- a/src/tiledbsoma_ml/pytorch.py +++ b/src/tiledbsoma_ml/pytorch.py @@ -114,8 +114,11 @@ def __init__( X_name: str, obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, + shuffle: bool = True, io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, return_sparse_X: bool = False, + seed: int | None = None, use_eager_fetch: bool = True, ): """ @@ -140,12 +143,25 @@ def __init__( this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but higher performance can be achieved by performing batching in this class, and setting the ``DataLoader``'s ``batch_size`` parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. io_batch_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts - maximum memory utilization, larger values provide better read performance, but require more memory. + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts: + 1. Maximum memory utilization, larger values provide better read performance, but require more memory. + 2. The number of rows read prior to shuffling (see the ``shuffle`` parameter for details). + The default value of 65,536 provides high performance but may need to be reduced in memory-limited hosts + or when using a large number of :class:`DataLoader` workers. + shuffle_chunk_size: + The number of contiguous rows sampled prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. + If ``shuffle == False``, this parameter is ignored. return_sparse_X: If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will return ``X`` data as a :class:`numpy.ndarray`. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *MUST* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. use_eager_fetch: Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made available for processing via the iterator. This allows network (or filesystem) requests to be made in @@ -162,6 +178,14 @@ def __init__( Lifecycle: experimental + .. warning:: + When using this class in any distributed mode, calling the :meth:`set_epoch` method at + the beginning of each epoch **before** creating the :class:`DataLoader` iterator + is necessary to make shuffling work properly across multiple epochs. Otherwise, + the same ordering will be always used. + + In addition, when using shuffling in a distributed configuration (e.g., ``DDP``), you + must provide a seed, ensuring that the same shuffle is used across all replicas. """ super().__init__() @@ -175,11 +199,24 @@ def __init__( self.obs_column_names = list(obs_column_names) self.batch_size = batch_size self.io_batch_size = io_batch_size + self.shuffle = shuffle self.return_sparse_X = return_sparse_X self.use_eager_fetch = use_eager_fetch self._obs_joinids: npt.NDArray[np.int64] | None = None self._var_joinids: npt.NDArray[np.int64] | None = None + self.seed = ( + seed if seed is not None else np.random.default_rng().integers(0, 2**32 - 1) + ) + self._user_specified_seed = seed is not None + self.shuffle_chunk_size = shuffle_chunk_size self._initialized = False + self.epoch = 0 + + if self.shuffle: + # round io_batch_size up to a unit of shuffle_chunk_size to simplify code. + self.io_batch_size = ( + ceil(io_batch_size / shuffle_chunk_size) * shuffle_chunk_size + ) if not self.obs_column_names: raise ValueError("Must specify at least one value in `obs_column_names`") @@ -187,7 +224,7 @@ def __init__( def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: """Create iterator over obs id chunks with split size of (roughly) io_batch_size. - As appropriate, will partition per worker. + As appropriate, will chunk, shuffle and apply partitioning per worker. IMPORTANT: in any scenario using torch.distributed, where WORLD_SIZE > 1, this will always partition such that each process has the same number of samples. Where @@ -197,7 +234,8 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: Abstractly, the steps taken: 1. Split the joinids into WORLD_SIZE sections (aka number of GPUS in DDP) 2. Trim the splits to be of equal length - 3. Partition by number of data loader workers (to not generate redundant batches + 3. Chunk and optionally shuffle the chunks + 4. Partition by number of data loader workers (to not generate redundant batches in cases where the DataLoader is running with `n_workers>1`). Private method. @@ -216,11 +254,29 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: assert 0 <= (np.diff(_gpu_splits).min() - min_len) <= 1 _gpu_split = _gpu_split[:min_len] - obs_joinids_chunked = np.array_split( - _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) - ) + # 3. Chunk and optionally shuffle chunks + if self.shuffle: + assert self.io_batch_size % self.shuffle_chunk_size == 0 + shuffle_split = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.shuffle_chunk_size)) + ) - # 3. Partition by DataLoader worker + # Deterministically create RNG - state must be same across all processes, ensuring + # that the joinid partitions are identical across all processes. + rng = np.random.default_rng(self.seed + self.epoch + 99) + rng.shuffle(shuffle_split) + obs_joinids_chunked = list( + np.concatenate(b) + for b in _batched( + shuffle_split, self.io_batch_size // self.shuffle_chunk_size + ) + ) + else: + obs_joinids_chunked = np.array_split( + _gpu_split, max(1, ceil(len(_gpu_split) / self.io_batch_size)) + ) + + # 4. Partition by DataLoader worker n_workers, worker_id = _get_worker_world_rank() obs_splits = _splits(len(obs_joinids_chunked), n_workers) obs_partition_joinids = obs_joinids_chunked[ @@ -230,7 +286,7 @@ def _create_obs_joinids_partition(self) -> Iterator[npt.NDArray[np.int64]]: if logger.isEnabledFor(logging.DEBUG): logger.debug( f"Process {os.getpid()} rank={rank}, world_size={world_size}, worker_id={worker_id}, " - f"n_workers={n_workers}, " + f"n_workers={n_workers}, epoch={self.epoch}, " f"partition_size={sum([len(chunk) for chunk in obs_partition_joinids])}" ) @@ -246,7 +302,9 @@ def _init_once(self, exp: soma.Experiment | None = None) -> None: if self._initialized: return - logger.debug("Initializing ExperimentAxisQueryIterable") + logger.debug( + f"Initializing ExperimentAxisQueryIterable (shuffle={self.shuffle})" + ) if exp is None: # If no user-provided Experiment, open/close it ourselves @@ -291,8 +349,12 @@ def __iter__(self) -> Iterator[XObsDatum]: world_size, rank = _get_distributed_world_rank() n_workers, worker_id = _get_worker_world_rank() logger.debug( - f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}" + f"Iterator created rank={rank}, world_size={world_size}, worker_id={worker_id}, n_workers={n_workers}, seed={self.seed}, epoch={self.epoch}" ) + if world_size > 1 and self.shuffle and self._user_specified_seed is None: + raise ValueError( + "ExperimentAxisQueryIterable requires an explicit `seed` when shuffle is used in a multi-process configuration." + ) with self.experiment_locator.open_experiment() as exp: self._init_once(exp) @@ -311,6 +373,8 @@ def __iter__(self) -> Iterator[XObsDatum]: yield from _mini_batch_iter + self.epoch += 1 + def __len__(self) -> int: """Return the approximate number of batches this iterable will produce. If run in the context of :class:`torch.distributed` or as a multi-process loader (i.e., :class:`torch.utils.data.DataLoader` instantiated with num_workers > 0), the obs (cell) @@ -350,6 +414,18 @@ def shape(self) -> Tuple[int, int]: div, rem = divmod(partition_len, self.batch_size) return div + bool(rem), len(self._var_joinids) + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + """ + self.epoch = epoch + def __getitem__(self, index: int) -> XObsDatum: raise NotImplementedError( "``ExperimentAxisQueryIterable can only be iterated - does not support mapping" @@ -365,12 +441,16 @@ def _io_batch_iter( (X: csr_array, obs: DataFrame). obs joinids read are controlled by the obs_joinid_iter. Iterator results will - be reindexed. + be reindexed and shuffled (if shuffling enabled). Private method. """ assert self._var_joinids is not None + # Create RNG - does not need to be identical across processes, but use the seed anyway + # for reproducibility. + shuffle_rng = np.random.default_rng(self.seed + self.epoch) + obs_column_names = ( list(self.obs_column_names) if "soma_joinid" in self.obs_column_names @@ -380,7 +460,10 @@ def _io_batch_iter( for obs_coords in obs_joinid_iter: st_time = time.perf_counter() - obs_indexer = soma.IntIndexer(obs_coords, context=X.context) + obs_shuffled_coords = ( + obs_coords if not self.shuffle else shuffle_rng.permuted(obs_coords) + ) + obs_indexer = soma.IntIndexer(obs_shuffled_coords, context=X.context) logger.debug( f"Retrieving next SOMA IO batch of length {len(obs_coords)}..." ) @@ -405,12 +488,12 @@ def _io_batch_iter( .concat() .to_pandas() .set_index("soma_joinid") - .reindex(obs_coords, copy=False) + .reindex(obs_shuffled_coords, copy=False) .reset_index(), ) obs_io_batch = obs_io_batch[self.obs_column_names] - del obs_indexer, obs_coords, X_tbl + del obs_indexer, obs_coords, obs_shuffled_coords, X_tbl gc.collect() tm = time.perf_counter() - st_time @@ -425,7 +508,7 @@ def _mini_batch_iter( X: soma.SparseNDArray, obs_joinid_iter: Iterator[npt.NDArray[np.int64]], ) -> Iterator[XObsDatum]: - """Break IO batches into mini-batch-sized chunks. + """Break IO batches into shuffled mini-batch-sized chunks. Private method. """ @@ -506,7 +589,10 @@ def __init__( X_name: str = "raw", obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, return_sparse_X: bool = False, use_eager_fetch: bool = True, ): @@ -522,9 +608,12 @@ def __init__( X_name=X_name, obs_column_names=obs_column_names, batch_size=batch_size, + shuffle=shuffle, + seed=seed, io_batch_size=io_batch_size, return_sparse_X=return_sparse_X, use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, ) def __iter__(self) -> Iterator[XObsDatum]: @@ -559,6 +648,25 @@ def shape(self) -> Tuple[int, int]: """ return self._exp_iter.shape + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + class ExperimentAxisQueryIterableDataset( torch.utils.data.IterableDataset[XObsDatum] # type:ignore[misc] @@ -601,6 +709,19 @@ class ExperimentAxisQueryIterableDataset( The ``io_batch_size`` parameter determines the number of rows read, from which mini-batches are yielded. A larger value will increase total memory usage and may reduce average read time per row. + Shuffling support is enabled with the ``shuffle`` parameter, and will normally be more performant than using + :class:`DataLoader` shuffling. The shuffling algorithm works as follows: + + 1. Rows selected by the query are subdivided into groups of size ``shuffle_chunk_size``, aka a "shuffle chunk". + 2. A random selection of shuffle chunks is drawn and read as a single I/O buffer (of size ``io_buffer_size``). + 3. The entire I/O buffer is shuffled. + + Put another way, we read randomly selected groups of observations from across all query results, concatenate + those into an I/O buffer, and shuffle the buffer before returning mini-batches. The randomness of the shuffle + is therefore determined by the ``io_buffer_size`` (number of rows read), and the ``shuffle_chunk_size`` + (number of rows in each draw). Decreasing ``shuffle_chunk_size`` will increase shuffling randomness, and decrease I/O + performance. + This class will detect when run in a multiprocessing mode, including multi-worker :class:`torch.utils.data.DataLoader` and multi-process training such as :class:`torch.nn.parallel.DistributedDataParallel`, and will automatically partition data appropriately. In the case of distributed training, sample partitions across all processes must be equal. Any @@ -616,7 +737,10 @@ def __init__( X_name: str = "raw", obs_column_names: Sequence[str] = ("soma_joinid",), batch_size: int = 1, + shuffle: bool = True, + seed: int | None = None, io_batch_size: int = 2**16, + shuffle_chunk_size: int = 64, return_sparse_X: bool = False, use_eager_fetch: bool = True, ): @@ -642,11 +766,26 @@ def __init__( Note that a ``batch_size`` of 1 allows this ``IterableDataset`` to be used with :class:`torch.utils.data.DataLoader` batching, but you will achieve higher performance by performing batching in this class, and setting the ``DataLoader`` batch_size parameter to ``None``. + shuffle: + Whether to shuffle the ``obs`` and ``X`` data being returned. Defaults to ``True``. io_batch_size: - The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. + The number of ``obs``/``X`` rows to retrieve when reading data from SOMA. This impacts two aspects of + this class's behavior: 1) The maximum memory utilization, with larger values providing + better read performance, but also requiring more memory; 2) The number of rows read prior to shuffling + (see ``shuffle`` parameter for details). The default value of 131,072 provides high performance, but + may need to be reduced in memory limited hosts (or where a large number of :class:`DataLoader` workers + are employed). + shuffle_chunk_size: + The number of contiguous rows sampled, prior to concatenation and shuffling. + Larger numbers correspond to less randomness, but greater read performance. + If ``shuffle == False``, this parameter is ignored. return_sparse_X: If ``True``, will return the ``X`` data as a :class:`scipy.sparse.csr_matrix`. If ``False`` (the default), will return ``X`` data as a :class:`numpy.ndarray`. + seed: + The random seed used for shuffling. Defaults to ``None`` (no seed). This argument *must* be specified when using + :class:`torch.nn.parallel.DistributedDataParallel` to ensure data partitions are disjoint across worker + processes. use_eager_fetch: Fetch the next SOMA chunk of ``obs`` and ``X`` data immediately after a previously fetched SOMA chunk is made available for processing via the iterator. This allows network (or filesystem) requests to be made in @@ -670,9 +809,12 @@ def __init__( X_name=X_name, obs_column_names=obs_column_names, batch_size=batch_size, + shuffle=shuffle, + seed=seed, io_batch_size=io_batch_size, return_sparse_X=return_sparse_X, use_eager_fetch=use_eager_fetch, + shuffle_chunk_size=shuffle_chunk_size, ) def __iter__(self) -> Iterator[XObsDatum]: @@ -721,6 +863,25 @@ def shape(self) -> Tuple[int, int]: """ return self._exp_iter.shape + def set_epoch(self, epoch: int) -> None: + """ + Set the epoch for this Data iterator. + + When :attr:`shuffle=True`, this will ensure that all replicas use a different + random ordering for each epoch. Failure to call this method before each epoch + will result in the same data ordering. + + This call must be made before the per-epoch iterator is created. + + Lifecycle: + experimental + """ + self._exp_iter.set_epoch(epoch) + + @property + def epoch(self) -> int: + return self._exp_iter.epoch + def experiment_dataloader( ds: torchdata.datapipes.iter.IterDataPipe | torch.utils.data.IterableDataset, diff --git a/tests/test_pytorch.py b/tests/test_pytorch.py index 155bbe1..35e1433 100644 --- a/tests/test_pytorch.py +++ b/tests/test_pytorch.py @@ -180,6 +180,7 @@ def test_non_batched( query, X_name="raw", obs_column_names=["label"], + shuffle=False, use_eager_fetch=use_eager_fetch, return_sparse_X=return_sparse_X, ) @@ -199,7 +200,6 @@ def test_non_batched( else: assert isinstance(row[0], np.ndarray) - print(row) assert np.squeeze(row[0]).shape == (3,) assert np.squeeze(row[0]).tolist() == [0, 1, 0] @@ -227,6 +227,7 @@ def test_uneven_soma_and_result_batches( query, X_name="raw", obs_column_names=["label"], + shuffle=False, batch_size=3, io_batch_size=2, use_eager_fetch=use_eager_fetch, @@ -267,6 +268,7 @@ def test_batching__all_batches_full_size( X_name="raw", obs_column_names=["label"], batch_size=3, + shuffle=False, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -329,6 +331,7 @@ def test_batching__partial_final_batch_size( X_name="raw", obs_column_names=["label"], batch_size=3, + shuffle=False, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -357,6 +360,7 @@ def test_batching__exactly_one_batch( X_name="raw", obs_column_names=["label"], batch_size=3, + shuffle=False, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -409,6 +413,7 @@ def test_sparse_output__non_batched( X_name="raw", obs_column_names=["label"], return_sparse_X=True, + shuffle=False, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -433,6 +438,7 @@ def test_sparse_output__batched( obs_column_names=["label"], batch_size=3, return_sparse_X=True, + shuffle=False, use_eager_fetch=use_eager_fetch, ) batch_iter = iter(exp_data_pipe) @@ -534,6 +540,7 @@ def test_distributed__returns_data_partition_for_rank( X_name="raw", obs_column_names=["soma_joinid"], io_batch_size=2, + shuffle=False, ) full_result = list(iter(dp)) soma_joinids = np.concatenate( @@ -590,6 +597,7 @@ def test_distributed_and_multiprocessing__returns_data_partition_for_rank( X_name="raw", obs_column_names=["soma_joinid"], io_batch_size=2, + shuffle=False, ) full_result = list(iter(dp)) @@ -623,6 +631,7 @@ def test_experiment_dataloader__non_batched( query, X_name="raw", obs_column_names=["label"], + shuffle=False, use_eager_fetch=use_eager_fetch, ) dl = experiment_dataloader(dp) @@ -653,6 +662,7 @@ def test_experiment_dataloader__batched( query, X_name="raw", batch_size=3, + shuffle=False, use_eager_fetch=use_eager_fetch, ) dl = experiment_dataloader(dp) @@ -685,6 +695,7 @@ def test_experiment_dataloader__batched_length( X_name="raw", obs_column_names=["label"], batch_size=3, + shuffle=False, use_eager_fetch=use_eager_fetch, ) dl = experiment_dataloader(dp) @@ -724,6 +735,7 @@ def collate_fn( X_name="raw", obs_column_names=["label"], batch_size=batch_size, + shuffle=False, ) dl = experiment_dataloader(dp, collate_fn=partial(collate_fn, batch_size)) assert len(list(dl)) > 0 @@ -772,6 +784,59 @@ def test__pytorch_splitting( assert len(all_rows) == 7 +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(16, 1, pytorch_seq_x_value_gen)] +) +@pytest.mark.parametrize("PipeClass", PipeClassImplementation) +def test__shuffle(PipeClass: PipeClassType, soma_experiment: Experiment) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = PipeClass( + query, + X_name="raw", + shuffle=True, + ) + + all_rows = list(iter(dp)) + if PipeClass is ExperimentAxisQueryIterable: + assert all(np.squeeze(r[0], axis=0).shape == (1,) for r in all_rows) + else: + assert all(r[0].shape == (1,) for r in all_rows) + soma_joinids = [row[1]["soma_joinid"].iloc[0] for row in all_rows] + X_values = [row[0][0].item() for row in all_rows] + + # same elements + assert set(soma_joinids) == set(range(16)) + # not ordered! (...with a `1/16!` probability of being ordered) + assert soma_joinids != list(range(16)) + # randomizes X in same order as obs + # note: X values were explicitly set to match obs_joinids to allow for this simple assertion + assert X_values == soma_joinids + + +@pytest.mark.parametrize( + "obs_range,var_range,X_value_gen", [(6, 3, pytorch_x_value_gen)] +) +def test_experiment_axis_query_iterable_error_checks( + soma_experiment: Experiment, +) -> None: + with soma_experiment.axis_query(measurement_name="RNA") as query: + dp = ExperimentAxisQueryIterable( + query, + X_name="raw", + shuffle=True, + ) + with pytest.raises(NotImplementedError): + dp[0] + + with pytest.raises(ValueError): + dp = ExperimentAxisQueryIterable( + query, + obs_column_names=(), + X_name="raw", + shuffle=True, + ) + + def test_experiment_dataloader__unsupported_params__fails() -> None: with patch( "tiledbsoma_ml.pytorch.ExperimentAxisQueryIterDataPipe"