Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proofing and augmentor exception handling. #17

Merged
merged 2 commits into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ Detail documentation regarding the code base can be found in the [GitPages](http
#### Sample Output of StainTools
![Screenshot](https://raw.githubusercontent.com/CielAl/torch-staintools/main/showcases/sample_out_staintools.png)

## Usecase
## Use case
* For details, follow the example in demo.py
* Normalizers are wrapped as `torch.nn.Module`, working similarly to a standalone neural network. This means that for a workflow involving dataloader with multiprocessing, the normalizer
(Note that CUDA has poor support in multiprocessing and therefore it may not be the best practice to perform GPU-accelerated on-the-fly stain transformation in pytorch's dataset/dataloader)
(Note that CUDA has poor support in multiprocessing, and therefore it may not be the best practice to perform GPU-accelerated on-the-fly stain transformation in pytorch's dataset/dataloader)


```python
Expand Down Expand Up @@ -143,7 +143,7 @@ timeit. Comparison between torch_stain_tools in CPU/GPU mode, as well as that of
|:---------|:-------|:-------|:-------------|
| Vahadane | 119 | 7.5 | 20.9 |
| Macenko | 5.57 | 0.479 | 20.7 |
| Reinhard | 0.840 |0.024 | 0.414 |
| Reinhard | 0.840 | 0.024 | 0.414 |

### Fitting
| Method | CPU[s] | GPU[s] | StainTool[s] |
Expand Down
36 changes: 21 additions & 15 deletions torch_staintools/augmentor/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from ..functional.stain_extraction.factory import build_from_name
from ..functional.optimization.dict_learning import get_concentrations
from ..functional.stain_extraction.extractor import BaseExtractor
from ..functional.utility.implementation import transpose_trailing, img_from_concentration, default_rng
from ..functional.tissue_mask import get_tissue_mask
from ..functional.utility.implementation import transpose_trailing, img_from_concentration
from ..functional.tissue_mask import get_tissue_mask, TissueMaskException
from ..cache.tensor_cache import TensorCache
from ..base_module.base import CachedRNGModule
from ..loggers import GlobalLoggers
Expand All @@ -13,6 +13,7 @@

TYPE_RNG = Optional[int | torch.Generator]


class Augmentor(CachedRNGModule):
"""Basic augmentation object as a nn.Module with stain matrices cache.

Expand Down Expand Up @@ -234,17 +235,22 @@ def forward(self, target: torch.Tensor, cache_keys: Optional[List[Hashable]] = N
# B x num_stains x num_pixel_in_mask
concentration = get_concentrations(target, target_stain_matrix, regularizer=self.regularizer,
algorithm=self.reconst_method, rng=self.rng)
tissue_mask = get_tissue_mask(target, luminosity_threshold=self.luminosity_threshold, throw_error=False,
true_when_empty=False)
concentration_aug = Augmentor.augment(target_concentration=concentration,
tissue_mask=tissue_mask,
target_stain_idx=self.target_stain_idx,
inplace=False, rng=self.rng, sigma_alpha=self.sigma_alpha,
sigma_beta=self.sigma_beta)
# transpose to B x num_pixel x num_stains

concentration_aug = transpose_trailing(concentration_aug)
return img_from_concentration(concentration_aug, target_stain_matrix, img_shape=target.shape, out_range=(0, 1))
try:
tissue_mask = get_tissue_mask(target, luminosity_threshold=self.luminosity_threshold, throw_error=True,
true_when_empty=False)
concentration_aug = Augmentor.augment(target_concentration=concentration,
tissue_mask=tissue_mask,
target_stain_idx=self.target_stain_idx,
inplace=False, rng=self.rng, sigma_alpha=self.sigma_alpha,
sigma_beta=self.sigma_beta)
# transpose to B x num_pixel x num_stains

concentration_aug = transpose_trailing(concentration_aug)
return img_from_concentration(concentration_aug, target_stain_matrix,
img_shape=target.shape, out_range=(0, 1))
except TissueMaskException:
logger.error(f"Empty mask encountered. Dismiss and return the clone of input. Cache Key: {cache_keys}")
return target.clone()

@classmethod
def build(cls,
Expand All @@ -265,14 +271,14 @@ def build(cls,
Args:
method: algorithm name to extract stain - support 'vahadane' or 'macenko'
reconst_method: algorithm to compute concentration. default ista
rng: a optional seed (either an int or a torch.Generator) to determine the random number generation.
rng: an optional seed (either an int or a torch.Generator) to determine the random number generation.
target_stain_idx: what stains to augment: e.g., for HE cases, it can be either or both from [0, 1]
sigma_alpha: alpha is uniformly randomly selected from (1-sigma_alpha, 1+sigma_alpha)
sigma_beta: beta is uniformly randomly selected from (-sigma_beta, sigma_beta)
luminosity_threshold: luminosity threshold to find tissue regions (smaller than but positive)
a pixel is considered as being tissue if the intensity falls in the open interval of (0, threshold).
regularizer: regularization term in ISTA algorithm
use_cache: whether use cache to save the stain matrix to avoid recomputation
use_cache: whether to use cache to save the stain matrix to avoid re-computation
cache_size_limit: size limit of the cache. negative means no limits.
device: what device to hold the cache.
load_path: If specified, then stain matrix cache will be loaded from the file path. See the `cache`
Expand Down
4 changes: 2 additions & 2 deletions torch_staintools/base_module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
class CachedRNGModule(torch.nn.Module):
"""Optionally cache the stain matrices and manage the rng

Note that using .to to move the module across GPU/cpu device will reset the states.
Note that using nn.Module.to(device) to move the module across GPU/cpu device will reset the states.


"""
Expand Down Expand Up @@ -121,7 +121,7 @@ def tensor_from_cache(self,
if not self.cache_initialized() or cache_keys is None:
logger.debug(f'{self.cache_initialized()} + {cache_keys is None} - no cache')
return func_partial(target)
# if use cache
# if using cache
assert self.cache_initialized(), f"Attempt to fetch data from cache but cache is not initialized"
assert cache_keys is not None, f"Attempt to fetch data from cache but key is not given"
# move fetched stain matrix to the same device of the target
Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def write_batch(self, keys: List[Hashable], batch: V):
"""Write a batch of data to the cache.

Args:
keys: list of keys corresponding to individual data points in the batch
keys: list of keys corresponding to individual data points in the batch.
batch: batch data to cache.

Returns:
Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/functional/conversion/od.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def rgb2od(image: torch.Tensor):
image: Image RGB. Input scale does not matter.

Returns:
Optical denisty RGB image.
Optical density RGB image.
"""
# to [0, 255]
image = convert_image_dtype(image, torch.uint8)
Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/functional/optimization/solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def rss_grad(z_k):
except RuntimeError as e:
print(e)
print('lr error ', lr, 'did not update z')
z_next = z_prev # if there a failure just reset state.
z_next = z_prev # if there is a failure just reset state.

# check convergence
if (z - z_next).abs().sum() <= tol:
Expand Down
13 changes: 9 additions & 4 deletions torch_staintools/functional/stain_extraction/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@


def percentile(t: torch.Tensor, q: float, dim: int) -> torch.Tensor:
""" Author: adapted from https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30
"""Author: adapted from https://gist.github.com/spezold/42a451682422beb42bc43ad0c0967a30

Return the ``q``-th percentile of the flattenepip d input tensor's data.

CAUTION:
* Needs PyTorch >= 1.1.0, as ``torch.kthvalue()`` is used.
* Values are not interpolated, which corresponds to
``numpy.percentile(..., interpolation="nearest")``.

:param t: Input tensor.
:param q: Percentile to compute, which must be between 0 and 100 inclusive.
:return: Resulting value (scalar).
Args:
t: Input tensor.
q: Percentile to compute, which must be between 0 and 100 inclusive.
dim: which dim to operate for function `tensor.kthvalue`.

Returns:
Resulting value (scalar).
"""
# Note that ``kthvalue()`` works one-based, i.e. the first sorted value
# indeed corresponds to k=1, not k=0! Use float(q) instead of q directly,
Expand Down
4 changes: 2 additions & 2 deletions torch_staintools/functional/stain_extraction/vahadane.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def get_stain_matrix_from_od(od: torch.Tensor, tissue_mask: torch.Tensor, *,
steps: max number of steps if still not converged
constrained: whether to force dictionary to be positive
persist: whether retain the previous z value for its update or initialize every time in the iteration.
init: init method of the codes a in X = D x a. Selected from `ridge`, `zero`, `unif` (uniformly random), or
`transpose`. Details see torch_staintools.functional.optimization.sparse_util.initialize_code
init: init method of the codes `a` in `X = D x a`. Selected from `ridge`, `zero`, `unif` (uniformly random),
or `transpose`. Details see torch_staintools.functional.optimization.sparse_util.initialize_code
verbose: whether to print progress messages.
rng: torch.Generator for any random initializations incurred (e.g., if `init` is set to be unif)

Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/functional/tissue_mask/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def get_tissue_mask_np(I: np.ndarray, luminosity_threshold: float = 0.8, throw_
A numpy version for preprocessing purposes. Note that both Macenko and Vahadane may fail due to mathematical
instability to process image that is mostly bright background and no tissue at all.

Typically we use to identify tissue in the image and exclude the bright white background.
Typically, we use to identify tissue in the image and exclude the bright white background.

Args:
I: numpy image. H x W x C. Input will be automatically converted to uint8 format and range [0, 255]
Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/functional/utility/implementation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,4 +90,4 @@ def nanstd(data: torch.Tensor, dim: Optional[int | tuple] = None,
# \Sigma (x - mean)^2 --> any x that is nan will be filtered by using nansum
sum_dev2 = ((data - mean) ** 2).nansum(dim=dim, keepdim=True)
# sqrt and normalize by corrected degrees of freedom
return torch.sqrt(sum_dev2 / (non_nan_count - correction))
return torch.sqrt(sum_dev2 / (non_nan_count - correction))
5 changes: 2 additions & 3 deletions torch_staintools/normalizer/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from .base import Normalizer
from .reinhard import ReinhardNormalizer
from .separation import StainSeparation
from functools import partial
import torch
TYPE_REINHARD = Literal['reinhard']
TYPE_VAHADANE = Literal['vahadane']
Expand Down Expand Up @@ -37,7 +36,7 @@ def build(method: TYPE_SUPPORTED,
num_stains: number of stains to separate. Currently, Macenko only supports 2. Only applies to `macenko` and
'vahadane' methods.
luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
as tissue. Scale of luminiosty threshold is within [0, 1]. Only applies to `macenko` and
as tissue. Scale of luminosity threshold is within [0, 1]. Only applies to `macenko` and
'vahadane' methods.
regularizer: regularizer term in ISTA for stain separation and concentration computation. Only applies
to `macenko` and 'vahadane' methods if 'ista' is used.
Expand All @@ -57,7 +56,7 @@ def build(method: TYPE_SUPPORTED,
norm_method: Callable
match method:
case 'reinhard':
return ReinhardNormalizer.build(luminosity_threshold=luminosity_threshold)
return ReinhardNormalizer.build(luminosity_threshold=luminosity_threshold)
case 'macenko' | 'vahadane':
return StainSeparation.build(method=method, reconst_method=reconst_method,
num_stains=num_stains,
Expand Down
4 changes: 2 additions & 2 deletions torch_staintools/normalizer/separation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class StainSeparation(Normalizer):
"""Stain Separation-based normalizer's interface: Macenko and Vahadane

The stain matrix of the reference image (i.e., target image) will be dumped to the state_dict should torch.save
The stain matrix of the reference image (i.e., target image) will be dumped to the state_dict should torch.save().
is used to export the normalizer's state dict.

"""
Expand Down Expand Up @@ -189,7 +189,7 @@ def build(cls, method: str,
luminosity_threshold: luminosity threshold to ignore the background. None means all regions are considered
as tissue.
regularizer: regularizer term in ista for stain separation and concentration computation.
rng: seed or torch.Generator for any random initialization might incurred.
rng: seed or torch.Generator for any random initialization might incur.
use_cache: whether to use cache to save the stain matrix of input image to normalize
cache_size_limit: size limit of the cache. negative means no limits.
device: what device to hold the cache and the normalizer. If none the device is set to cpu.
Expand Down
2 changes: 1 addition & 1 deletion torch_staintools/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.3'
__version__ = '1.0.4'