diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fd341688..f8058c62 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,7 +45,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest] - python-version: [3.9, "3.10"] + python-version: ["3.9", "3.10"] steps: - uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} diff --git a/pyproject.toml b/pyproject.toml index a0184458..71cd881b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,11 +50,13 @@ dependencies = [ "pandas>=1.1.1, !=1.4.0", "scipy>=1.4.1", "scikit-learn>=0.21.2, <1.2.0", - "scvi-tools>=0.20.1", "matplotlib>=3.3.0" ] [project.optional-dependencies] +vi = [ + "scvi-tools>=0.20.1", +] louvain = [ "igraph", "louvain" @@ -78,7 +80,7 @@ dev = [ "pybind11", "pytest-cov", "igraph", - "setuptools_scm" + "setuptools_scm", ] docs = [ # Just until rtd.org understands pyproject.toml diff --git a/scvelo/tools/__init__.py b/scvelo/tools/__init__.py index b2e3e126..3a40aff1 100644 --- a/scvelo/tools/__init__.py +++ b/scvelo/tools/__init__.py @@ -1,3 +1,5 @@ +import contextlib + from scanpy.tools import diffmap, dpt, louvain, tsne, umap from ._em_model import ExpectationMaximizationModel @@ -11,7 +13,6 @@ recover_latent_time, ) from ._steady_state_model import SecondOrderSteadyStateModel, SteadyStateModel -from ._vi_model import VELOVI from .paga import paga from .rank_velocity_genes import rank_velocity_genes, velocity_clusters from .score_genes_cell_cycle import score_genes_cell_cycle @@ -23,6 +24,10 @@ from .velocity_graph import velocity_graph from .velocity_pseudotime import velocity_map, velocity_pseudotime +with contextlib.suppress(ImportError): + from ._vi_model import VELOVI + + __all__ = [ "align_dynamics", "differential_kinetic_test", @@ -54,5 +59,8 @@ "SteadyStateModel", "SecondOrderSteadyStateModel", "ExpectationMaximizationModel", - "VELOVI", ] +if "VELOVI" in locals(): + __all__ += ["VELOVI"] + +del contextlib diff --git a/scvelo/tools/_core.py b/scvelo/tools/_core.py index df352bc0..92b18212 100644 --- a/scvelo/tools/_core.py +++ b/scvelo/tools/_core.py @@ -1,8 +1,6 @@ from abc import abstractmethod from typing import NamedTuple -import torch - from anndata import AnnData @@ -13,8 +11,6 @@ class _REGISTRY_KEYS_NT(NamedTuple): REGISTRY_KEYS = _REGISTRY_KEYS_NT() -DEFAULT_ACTIVATION_FUNCTION = torch.nn.Softplus() - class BaseInference: """Base Inference class for all velocity methods.""" diff --git a/scvelo/tools/_vi_module.py b/scvelo/tools/_vi_module.py index 4c5ecc81..c211da75 100644 --- a/scvelo/tools/_vi_module.py +++ b/scvelo/tools/_vi_module.py @@ -11,7 +11,9 @@ from scvi.module.base import auto_move_data, BaseModuleClass, LossOutput from scvi.nn import Encoder, FCLayers -from ._core import DEFAULT_ACTIVATION_FUNCTION, REGISTRY_KEYS +from ._core import REGISTRY_KEYS + +DEFAULT_ACTIVATION_FUNCTION = torch.nn.Softplus() torch.backends.cudnn.benchmark = True diff --git a/tests/tools/test_vi_model.py b/tests/tools/test_vi_model.py index 15592072..d8ecc0c8 100644 --- a/tests/tools/test_vi_model.py +++ b/tests/tools/test_vi_model.py @@ -1,7 +1,16 @@ -from scvi.data import synthetic_iid +import contextlib + +import pytest import scvelo as scv -from scvelo.tools import VELOVI + +with contextlib.suppress(ImportError): + from scvi.data import synthetic_iid + + from scvelo.tools import VELOVI + + +_ = pytest.importorskip("scvi") def test_preprocess_data():