diff --git a/scvelo/core/_anndata.py b/scvelo/core/_anndata.py index eac66cb6..2a05666f 100644 --- a/scvelo/core/_anndata.py +++ b/scvelo/core/_anndata.py @@ -1,4 +1,5 @@ import re +import warnings from typing import List, Literal, Optional, Union import numpy as np @@ -178,6 +179,13 @@ def get_df( :class:`pd.DataFrame` A dataframe. """ + warnings.warn( + "`get_df` is deprecated since scvelo==0.4.0 and will be removed in a future version " + "of scVelo. Please `AnnData::get_df` or Scanpy's `scanpy.get.obs_df` or `scanpy.get.var_df`.", + DeprecationWarning, + stacklevel=2, + ) + if precision is not None: pd.set_option("display.precision", precision) @@ -188,8 +196,6 @@ def get_df( keys, key_add = ( keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None) ) - keys = [keys] if isinstance(keys, str) else keys - key = keys[0] s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"] d_keys = [ @@ -207,62 +213,68 @@ def get_df( if keys is None: df = data.to_df() - elif key in data.var_names: - df = obs_df(data, keys, layer=layer) - elif key in data.obs_names: - df = var_df(data, keys, layer=layer) else: - if keys_split is not None: - keys = [ - k - for k in list(data.obs.keys()) + list(data.var.keys()) - if key in k and keys_split in k - ] - key = keys[0] - s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key] - if len(s_key) == 0: - raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.") - if len(s_key) > 1: - logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.") - - s_key = s_key[-1] - df = getattr(data, s_key)[keys if len(keys) > 1 else key] - if key_add is not None: - df = df[key_add] - if index is None: - index = ( - data.var_names - if s_key == "varm" - else data.obs_names - if s_key in {"obsm", "layers"} - else None - ) - if index is None and s_key == "uns" and hasattr(df, "shape"): - key_cats = np.array( - [ - key - for key in data.obs.keys() - if is_categorical_dtype(data.obs[key]) - ] - ) - num_cats = [ - len(data.obs[key].cat.categories) == df.shape[0] - for key in key_cats + keys = [keys] if isinstance(keys, str) else keys + key = keys[0] + + if key in data.var_names: + df = obs_df(data, keys, layer=layer) + elif key in data.obs_names: + df = var_df(data, keys, layer=layer) + else: + if keys_split is not None: + keys = [ + k + for k in list(data.obs.keys()) + list(data.var.keys()) + if key in k and keys_split in k ] - if np.sum(num_cats) == 1: - index = data.obs[key_cats[num_cats][0]].cat.categories - if ( - columns is None - and len(df.shape) > 1 - and df.shape[0] == df.shape[1] - ): - columns = index - elif isinstance(index, str) and index in data.obs.keys(): - index = pd.Categorical(data.obs[index]).categories - if columns is None and s_key == "layers": - columns = data.var_names - elif isinstance(columns, str) and columns in data.obs.keys(): - columns = pd.Categorical(data.obs[columns]).categories + key = keys[0] + s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key] + if len(s_key) == 0: + raise ValueError( + f"'{key}' not found in any of {', '.join(s_keys)}." + ) + if len(s_key) > 1: + logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.") + + s_key = s_key[-1] + df = getattr(data, s_key)[keys if len(keys) > 1 else key] + if key_add is not None: + df = df[key_add] + if index is None: + index = ( + data.var_names + if s_key == "varm" + else data.obs_names + if s_key in {"obsm", "layers"} + else None + ) + if index is None and s_key == "uns" and hasattr(df, "shape"): + key_cats = np.array( + [ + key + for key in data.obs.keys() + if is_categorical_dtype(data.obs[key]) + ] + ) + num_cats = [ + len(data.obs[key].cat.categories) == df.shape[0] + for key in key_cats + ] + if np.sum(num_cats) == 1: + index = data.obs[key_cats[num_cats][0]].cat.categories + if ( + columns is None + and len(df.shape) > 1 + and df.shape[0] == df.shape[1] + ): + columns = index + elif isinstance(index, str) and index in data.obs.keys(): + index = pd.Categorical(data.obs[index]).categories + if columns is None and s_key == "layers": + columns = data.var_names + elif isinstance(columns, str) and columns in data.obs.keys(): + columns = pd.Categorical(data.obs[columns]).categories elif isinstance(data, pd.DataFrame): if isinstance(keys, str) and "*" in keys: keys, keys_split = keys.split("*") diff --git a/tests/core/test_anndata.py b/tests/core/test_anndata.py index 5a0abd16..c1c72e21 100644 --- a/tests/core/test_anndata.py +++ b/tests/core/test_anndata.py @@ -573,6 +573,20 @@ def test_data_as_array( else: assert (df.columns == ["col_1", "col_2"]).all() + @given( + adata=get_adata( + max_obs=5, + max_vars=5, + layer_keys=["layer_1", "layer_2"], + ), + modality=st.sampled_from([None, "X", "layer_1", "layer_2"]), + ) + def test_default(self, adata: AnnData, modality: Optional[None]): + df = get_df(adata, layer=modality) + + assert isinstance(df, pd.DataFrame) + np.testing.assert_equal(adata.to_df().values, df.values) + class TestGetInitialSize(TestBase): @given(