Skip to content

Commit

Permalink
Remove hard scvi-tools dependency (#1185)
Browse files Browse the repository at this point in the history
* Remove hard `scvi-tools` dependency

* Skip VELOVI tests if `scvi-tools` not installed

* Remove `torch` dependency

* Remove unused `skip_file`
  • Loading branch information
michalk8 authored Feb 14, 2024
1 parent 809632e commit 3f0a612
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-latest]
python-version: [3.9, "3.10"]
python-version: ["3.9", "3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
Expand Down
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,13 @@ dependencies = [
"pandas>=1.1.1, !=1.4.0",
"scipy>=1.4.1",
"scikit-learn>=0.21.2, <1.2.0",
"scvi-tools>=0.20.1",
"matplotlib>=3.3.0"
]

[project.optional-dependencies]
vi = [
"scvi-tools>=0.20.1",
]
louvain = [
"igraph",
"louvain"
Expand All @@ -78,7 +80,7 @@ dev = [
"pybind11",
"pytest-cov",
"igraph",
"setuptools_scm"
"setuptools_scm",
]
docs = [
# Just until rtd.org understands pyproject.toml
Expand Down
12 changes: 10 additions & 2 deletions scvelo/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import contextlib

from scanpy.tools import diffmap, dpt, louvain, tsne, umap

from ._em_model import ExpectationMaximizationModel
Expand All @@ -11,7 +13,6 @@
recover_latent_time,
)
from ._steady_state_model import SecondOrderSteadyStateModel, SteadyStateModel
from ._vi_model import VELOVI
from .paga import paga
from .rank_velocity_genes import rank_velocity_genes, velocity_clusters
from .score_genes_cell_cycle import score_genes_cell_cycle
Expand All @@ -23,6 +24,10 @@
from .velocity_graph import velocity_graph
from .velocity_pseudotime import velocity_map, velocity_pseudotime

with contextlib.suppress(ImportError):
from ._vi_model import VELOVI


__all__ = [
"align_dynamics",
"differential_kinetic_test",
Expand Down Expand Up @@ -54,5 +59,8 @@
"SteadyStateModel",
"SecondOrderSteadyStateModel",
"ExpectationMaximizationModel",
"VELOVI",
]
if "VELOVI" in locals():
__all__ += ["VELOVI"]

del contextlib
4 changes: 0 additions & 4 deletions scvelo/tools/_core.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from abc import abstractmethod
from typing import NamedTuple

import torch

from anndata import AnnData


Expand All @@ -13,8 +11,6 @@ class _REGISTRY_KEYS_NT(NamedTuple):

REGISTRY_KEYS = _REGISTRY_KEYS_NT()

DEFAULT_ACTIVATION_FUNCTION = torch.nn.Softplus()


class BaseInference:
"""Base Inference class for all velocity methods."""
Expand Down
4 changes: 3 additions & 1 deletion scvelo/tools/_vi_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@
from scvi.module.base import auto_move_data, BaseModuleClass, LossOutput
from scvi.nn import Encoder, FCLayers

from ._core import DEFAULT_ACTIVATION_FUNCTION, REGISTRY_KEYS
from ._core import REGISTRY_KEYS

DEFAULT_ACTIVATION_FUNCTION = torch.nn.Softplus()

torch.backends.cudnn.benchmark = True

Expand Down
13 changes: 11 additions & 2 deletions tests/tools/test_vi_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,16 @@
from scvi.data import synthetic_iid
import contextlib

import pytest

import scvelo as scv
from scvelo.tools import VELOVI

with contextlib.suppress(ImportError):
from scvi.data import synthetic_iid

from scvelo.tools import VELOVI


_ = pytest.importorskip("scvi")


def test_preprocess_data():
Expand Down

0 comments on commit 3f0a612

Please sign in to comment.