Skip to content

Commit

Permalink
Merge branch 'theislab:main' into fix/print_versions
Browse files Browse the repository at this point in the history
  • Loading branch information
Oisin-M authored Jun 20, 2024
2 parents 10ea61a + d3e81d9 commit 24e019a
Show file tree
Hide file tree
Showing 78 changed files with 199 additions and 222 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip build
python -m pip install --upgrade build
- name: Build package
run: |
Expand Down
14 changes: 0 additions & 14 deletions debug.py

This file was deleted.

15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
[build-system]
build-backend = "hatchling.build"
requires = ["hatchling"]

requires = ["setuptools>=61", "setuptools-scm[toml]>=6.2"]
build-backend = "setuptools.build_meta"

[project]
name = "scvelo"
version = "0.3.1"
dynamic = ["version"]
description = "RNA velocity generalized through dynamical modeling"
readme = {file = "README.md", content-type="text/markdown"}
requires-python = ">=3.8"
Expand Down Expand Up @@ -46,7 +45,7 @@ dependencies = [
"loompy>=2.0.12",
"umap-learn>=0.3.10",
"numba>=0.41.0",
"numpy>=1.17",
"numpy>=1.17, <2.0.0",
"pandas>=1.1.1, !=1.4.0",
"scipy>=1.4.1",
"scikit-learn>=0.21.2",
Expand Down Expand Up @@ -98,8 +97,10 @@ docs = [
"nbsphinx>=0.7,<0.8.7"
]

[tool.hatch.build.targets.wheel]
packages = ["scvelo"]
[tool.setuptools]
include-package-data = true

[tool.setuptools_scm]

[tool.coverage.run]
source = ["scvelo"]
Expand Down
124 changes: 68 additions & 56 deletions scvelo/core/_anndata.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
import warnings
from typing import List, Literal, Optional, Union

import numpy as np
Expand Down Expand Up @@ -178,6 +179,13 @@ def get_df(
:class:`pd.DataFrame`
A dataframe.
"""
warnings.warn(
"`get_df` is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please `AnnData::get_df` or Scanpy's `scanpy.get.obs_df` or `scanpy.get.var_df`.",
DeprecationWarning,
stacklevel=2,
)

if precision is not None:
pd.set_option("display.precision", precision)

Expand All @@ -188,8 +196,6 @@ def get_df(
keys, key_add = (
keys.split("/") if isinstance(keys, str) and "/" in keys else (keys, None)
)
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

s_keys = ["obs", "var", "obsm", "varm", "uns", "layers"]
d_keys = [
Expand All @@ -207,62 +213,68 @@ def get_df(

if keys is None:
df = data.to_df()
elif key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(f"'{key}' not found in any of {', '.join(s_keys)}.")
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
keys = [keys] if isinstance(keys, str) else keys
key = keys[0]

if key in data.var_names:
df = obs_df(data, keys, layer=layer)
elif key in data.obs_names:
df = var_df(data, keys, layer=layer)
else:
if keys_split is not None:
keys = [
k
for k in list(data.obs.keys()) + list(data.var.keys())
if key in k and keys_split in k
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
key = keys[0]
s_key = [s for (s, d_key) in zip(s_keys, d_keys) if key in d_key]
if len(s_key) == 0:
raise ValueError(
f"'{key}' not found in any of {', '.join(s_keys)}."
)
if len(s_key) > 1:
logg.warn(f"'{key}' found multiple times in {', '.join(s_key)}.")

s_key = s_key[-1]
df = getattr(data, s_key)[keys if len(keys) > 1 else key]
if key_add is not None:
df = df[key_add]
if index is None:
index = (
data.var_names
if s_key == "varm"
else data.obs_names
if s_key in {"obsm", "layers"}
else None
)
if index is None and s_key == "uns" and hasattr(df, "shape"):
key_cats = np.array(
[
key
for key in data.obs.keys()
if is_categorical_dtype(data.obs[key])
]
)
num_cats = [
len(data.obs[key].cat.categories) == df.shape[0]
for key in key_cats
]
if np.sum(num_cats) == 1:
index = data.obs[key_cats[num_cats][0]].cat.categories
if (
columns is None
and len(df.shape) > 1
and df.shape[0] == df.shape[1]
):
columns = index
elif isinstance(index, str) and index in data.obs.keys():
index = pd.Categorical(data.obs[index]).categories
if columns is None and s_key == "layers":
columns = data.var_names
elif isinstance(columns, str) and columns in data.obs.keys():
columns = pd.Categorical(data.obs[columns]).categories
elif isinstance(data, pd.DataFrame):
if isinstance(keys, str) and "*" in keys:
keys, keys_split = keys.split("*")
Expand Down
16 changes: 12 additions & 4 deletions scvelo/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import numpy as np
import pandas as pd
from numpy.core._exceptions import UFuncTypeError
from pandas import Index
from scipy import stats
from scipy.sparse import issparse
Expand All @@ -24,6 +23,12 @@
from scvelo.tools.utils import strings_to_categoricals
from . import palettes

try:
from numpy.core._exceptions import UFuncTypeError
except ModuleNotFoundError:
from numpy._core._exceptions import UFuncTypeError


"""helper functions"""


Expand Down Expand Up @@ -1460,6 +1465,7 @@ def hist(
ax=None,
dpi=None,
show=True,
save=None,
**kwargs,
):
"""Plot a histogram.
Expand Down Expand Up @@ -1530,6 +1536,9 @@ def hist(
Figure dpi.
show: `bool`, optional (default: `None`)
Show the plot, do not return axis.
save: `bool` or `str`, optional (default: `None`)
If `True` or a `str`, save the figure. A string is appended to the default filename.
Infer the filetype if ending on {'.pdf', '.png', '.svg'}.
Returns
-------
Expand Down Expand Up @@ -1658,10 +1667,9 @@ def log_fmt(x, pos):
if rcParams["savefig.transparent"]:
ax.patch.set_alpha(0)

if not show:
savefig_or_show(dpi=dpi, save=save, show=show)
if show is False:
return ax
else:
pl.show()


# TODO: Add docstrings
Expand Down
8 changes: 8 additions & 0 deletions scvelo/preprocessing/moments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import numpy as np
from scipy.sparse import csr_matrix, issparse

Expand Down Expand Up @@ -60,6 +62,12 @@ def moments(
normalize_per_cell(adata)

if n_neighbors is not None and n_neighbors > get_n_neighs(adata):
warnings.warn(
"Automatic neighbor calculation is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please compute neighbors first with Scanpy.",
DeprecationWarning,
stacklevel=2,
)
neighbors(
adata,
n_neighbors=n_neighbors,
Expand Down
14 changes: 13 additions & 1 deletion scvelo/preprocessing/neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def _get_scanpy_neighbors(adata: AnnData, **kwargs):
with warnings.catch_warnings(): # ignore numba warning (umap/issues/252)
warnings.simplefilter("ignore")
neighbors = Neighbors(adata)
neighbors.compute_neighbors(write_knn_indices=True, **kwargs)
neighbors.compute_neighbors(**kwargs)
logg.switch_verbosity("on", module="scanpy")

return neighbors
Expand Down Expand Up @@ -124,6 +124,12 @@ def _set_pca(adata, n_pcs: Optional[int], use_highly_variable: bool):
or n_pcs is not None
and n_pcs > adata.obsm["X_pca"].shape[1]
):
warnings.warn(
"Automatic computation of PCA is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please compute PCA with Scanpy first.",
DeprecationWarning,
stacklevel=2,
)
if use_highly_variable and "highly_variable" in adata.var.keys():
n_vars = np.sum(adata.var["highly_variable"])
else:
Expand Down Expand Up @@ -213,6 +219,12 @@ def neighbors(
distances : `.obsp`
Sparse matrix of distances for each pair of neighbors.
"""
warnings.warn(
"`neighbors` is deprecated since scvelo==0.4.0 and will be removed in a future version "
"of scVelo. Please compute neighbors with Scanpy.",
DeprecationWarning,
stacklevel=2,
)
adata = adata.copy() if copy else adata

use_rep = _get_rep(adata=adata, use_rep=use_rep, n_pcs=n_pcs)
Expand Down
5 changes: 4 additions & 1 deletion scvelo/tools/_em_model_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def recover_dynamics(
copy=False,
n_jobs=None,
backend="loky",
show_progress_bar: bool = True,
**kwargs,
):
"""Recovers the full splicing kinetics of specified genes.
Expand Down Expand Up @@ -476,6 +477,8 @@ def recover_dynamics(
backend: `str` (default: "loky")
Backend used for multiprocessing. See :class:`joblib.Parallel` for valid
options.
show_progress_bar
Whether to show a progress bar.
Returns
-------
Expand Down Expand Up @@ -557,7 +560,7 @@ def recover_dynamics(
unit="gene",
as_array=False,
backend=backend,
show_progress_bar=len(var_names) > 9,
show_progress_bar=show_progress_bar,
)(
adata=adata,
use_raw=use_raw,
Expand Down
15 changes: 10 additions & 5 deletions scvelo/tools/_vi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def train(
max_epochs: Optional[int] = 500,
lr: float = 1e-2,
weight_decay: float = 1e-2,
use_gpu: Optional[Union[str, int, bool]] = None,
accelerator: str = "auto",
devices: Union[int, List[int], str] = "auto",
train_size: float = 0.9,
Expand All @@ -149,9 +148,14 @@ def train(
Learning rate for optimization
weight_decay
Weight decay for optimization
use_gpu
Use default GPU if available (if None or True), or index of GPU to use (if int),
or name of GPU (if str, e.g., `'cuda:0'`), or use CPU (if False).
accelerator
Supports passing different accelerator types `("cpu", "gpu", "tpu", "ipu", "hpu", "mps, "auto")` as well as
custom accelerator instances.
devices
The devices to use. Can be set to a non-negative index (`int` or `str`), a sequence of device indices
(`list` or comma-separated `str`), the value `-1` to indicate all available devices, or `"auto"` for
automatic selection based on the chosen `accelerator`. If set to `"auto"` and `accelerator` is not
determined to be `"cpu"`, then `devices` will be set to the first available device.
train_size
Size of training set in the range [0.0, 1.0].
validation_size
Expand Down Expand Up @@ -195,7 +199,8 @@ def train(
training_plan=training_plan,
data_splitter=data_splitter,
max_epochs=max_epochs,
use_gpu=use_gpu,
accelerator=accelerator,
devices=devices,
**trainer_kwargs,
)
return runner()
Expand Down
Loading

0 comments on commit 24e019a

Please sign in to comment.