Skip to content

Commit

Permalink
update_transportation_linking (#23)
Browse files Browse the repository at this point in the history
* update_transportation_linking

* fix lint
  • Loading branch information
bgraedel committed Sep 10, 2024
1 parent ab8dd79 commit 54d641b
Show file tree
Hide file tree
Showing 5 changed files with 1,426 additions and 1,097 deletions.
200 changes: 100 additions & 100 deletions arcos4py/tools/_detect_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

import matplotlib.pyplot as plt
import numpy as np
import ot
import pandas as pd
import pulp
from kneed import KneeLocator
from scipy.spatial.distance import cdist
from numba import njit, prange
from skimage.transform import rescale
from sklearn.cluster import DBSCAN, HDBSCAN
from sklearn.neighbors import KDTree
Expand Down Expand Up @@ -147,82 +147,6 @@ def _hdbscan(
return np.empty((0, 0))


def solve_transportation_problem(
sources: np.ndarray,
destinations: np.ndarray,
supply_range: List[int],
demand_range: List[int],
max_distance: float,
) -> Tuple[Dict[Tuple[int, int], float], List[int], List[int]]:
"""Solves the transportation problem using the PuLP linear programming solver.
Arguments:
sources (np.ndarray): An array representing the source locations.
destinations (np.ndarray): An array representing the destination locations.
supply_range (List[int]): The range of the supply (minimum, maximum).
demand_range (List[int]): The range of the demand (minimum, maximum).
max_distance (float): The maximum allowable distance between a source and a destination.
Returns:
Tuple containing the assignments, supply, and demand as calculated by the linear programming solver.
"""
# Calculate the costs (distances)
costs = cdist(sources, destinations, metric='euclidean')

# Set costs that exceed the max distance to a high value
costs[costs > max_distance] = 1e6 # you can adjust this value as needed

# Create the problem
problem = pulp.LpProblem("TransportationProblem", pulp.LpMinimize)

# Create decision variables for the transportation amounts
x = pulp.LpVariable.dicts(
"x", ((i, j) for i in range(len(sources)) for j in range(len(destinations))), lowBound=0, cat='Integer'
)

# Create decision variables for the supply and demand
supply = [
pulp.LpVariable(f"supply_{i}", lowBound=supply_range[0], upBound=supply_range[1], cat='Integer')
for i in range(len(sources))
]
demand = [
pulp.LpVariable(f"demand_{j}", lowBound=demand_range[0], upBound=demand_range[1], cat='Integer')
for j in range(len(destinations))
]

# Objective function
problem += pulp.lpSum(costs[i, j] * x[i, j] for i in range(len(sources)) for j in range(len(destinations)))

# Supply constraints
for i in range(len(sources)):
problem += pulp.lpSum(x[i, j] for j in range(len(destinations))) <= supply[i]

# Demand constraints
for j in range(len(destinations)):
problem += pulp.lpSum(x[i, j] for i in range(len(sources))) == demand[j]

# Solve the problem
problem.solve()

# Get the assignments
assignments = {
(i, j): pulp.value(x[i, j])
for i in range(len(sources))
for j in range(len(destinations))
if pulp.value(x[i, j]) > 0
}

# # Filter out assignments that exceed the maximum distance
assignments = {(i, j): np.linalg.norm(sources[i] - destinations[j]) for (i, j) in assignments}

# Return the results
return (
assignments,
[pulp.value(supply[i]) for i in range(len(sources))],
[pulp.value(demand[j]) for j in range(len(destinations))],
)


def brute_force_linking(
cluster_labels: np.ndarray,
cluster_coordinates: np.ndarray,
Expand Down Expand Up @@ -269,46 +193,100 @@ def brute_force_linking(
return cluster_labels, max_cluster_label


@njit(parallel=True)
def _compute_filtered_distances(current_coords, memory_coords):
n, m = len(current_coords), len(memory_coords)
distances = np.empty((n, m))
for i in prange(n):
for j in prange(m):
distances[i, j] = np.sum((current_coords[i] - memory_coords[j]) ** 2)
return np.sqrt(distances)


@njit
def _assign_labels(matches, current_indices, memory_indices, memory_cluster_labels, cluster_labels_size):
new_cluster_labels = np.full(cluster_labels_size, -1)
for i, m in enumerate(matches):
if m != -1:
new_cluster_labels[current_indices[i]] = memory_cluster_labels[memory_indices[m]]
return new_cluster_labels


def transportation_linking(
cluster_labels: np.ndarray,
cluster_coordinates: np.ndarray,
memory_cluster_labels: np.ndarray,
memory_coordinates: np.ndarray,
memory_kdtree: KDTree,
epsPrev: float,
max_cluster_label: int,
supply_range: List[int],
demand_range: List[int],
reg: float = 1,
reg_m: float = 10,
cost_threshold: float = 0,
**kwargs: Dict[str, Any],
) -> Tuple[np.ndarray, int]:
"""Transportation linking of clusters across frames.
"""Optimized transportation linking of clusters across frames, using a pre-constructed sklearn KDTree.
Arguments:
Args:
cluster_labels (np.ndarray): The cluster labels for the current frame.
cluster_coordinates (np.ndarray): The cluster coordinates for the current frame.
memory_cluster_labels (np.ndarray): The cluster labels for previous frames.
memory_coordinates (np.ndarray): The coordinates for previous frames.
memory_kdtree (KDTree): Pre-constructed sklearn KDTree for memory coordinates.
epsPrev (float): Frame-to-frame distance, used to connect clusters across frames.
max_cluster_label (int): The maximum label for clusters.
supply_range (List[int]): The range of the supply (minimum, maximum).
demand_range (List[int]): The range of the demand (minimum, maximum).
reg (float): Entropy regularization parameter for Sinkhorn algorithm.
reg_m (float): Marginal relaxation parameter for unbalanced OT.
cost_threshold (float): Threshold for filtering low-probability matches.
**kwargs: Additional keyword arguments.
Returns:
Tuple containing the updated cluster labels and the maximum cluster label.
Tuple[np.ndarray, int]: Updated cluster labels and the maximum cluster label.
"""
sources = memory_coordinates
destinations = cluster_coordinates
assignments, supply, demand = solve_transportation_problem(
sources, destinations, supply_range, demand_range, epsPrev
# Find neighbors within the maximum allowed distance (epsPrev)
indices = memory_kdtree.query_radius(cluster_coordinates, r=epsPrev)

if all(len(ind) == 0 for ind in indices):
max_cluster_label += 1
return np.full_like(cluster_labels, max_cluster_label), max_cluster_label

# Prepare indices of valid points
valid_mask = np.array([len(ind) > 0 for ind in indices])
current_indices = np.arange(len(indices))[valid_mask]
memory_indices = np.array([ind[0] for ind in indices if len(ind) > 0])

if len(current_indices) == 0:
max_cluster_label += 1
return np.full_like(cluster_labels, max_cluster_label), max_cluster_label

# Compute distance matrix for valid pairs
filtered_distances = _compute_filtered_distances(
cluster_coordinates[current_indices], memory_coordinates[memory_indices]
)
cluster_labels = np.zeros(cluster_coordinates.shape[0])
for (i, j), dist in assignments.items():
if dist <= epsPrev:
cluster_labels[j] = memory_cluster_labels[i]
else:
cluster_labels[j] = -1
if any(cluster_labels == -1):

# Uniform distribution on the valid points
a = np.ones(len(current_indices)) / len(current_indices)
b = np.ones(len(memory_indices)) / len(memory_indices)

# Solve the unbalanced OT problem
ot_plan = ot.unbalanced.sinkhorn_unbalanced(a, b, filtered_distances, reg, reg_m)

# Propagate cluster id from previous frame
matches = np.argmax(ot_plan, axis=1)

# Set matches to -1 if the cost is too high
matches[ot_plan[np.arange(len(matches)), matches] < cost_threshold] = -1

new_cluster_labels = _assign_labels(
matches, current_indices, memory_indices, memory_cluster_labels, cluster_labels.size
)

# Assign new labels to unmatched clusters
if np.any(new_cluster_labels == -1):
max_cluster_label += 1
cluster_labels[cluster_labels == -1] = max_cluster_label
return cluster_labels, max_cluster_label
new_cluster_labels[new_cluster_labels == -1] = max_cluster_label

return new_cluster_labels, max_cluster_label


@dataclass
Expand Down Expand Up @@ -538,6 +516,9 @@ def __init__(
linking_method: str = "nearest",
predictor: bool | Callable = True,
n_prev: int = 1,
cost_threshold: float = 0,
reg: float = 1,
reg_m: float = 10,
n_jobs: int = 1,
**kwargs,
):
Expand All @@ -560,6 +541,9 @@ def __init__(
n_prev (int): Number of previous frames the tracking
algorithm looks back to connect collective events.
n_jobs (int): Number of jobs to run in parallel (only for clustering algorithm).
cost_threshold (int): Threshold for filtering low-probability matches (only for transportation linking).
reg (float): Entropy regularization parameter for unbalanced OT algorithm (only for transportation linking).
reg_m (float): Marginal relaxation parameter for unbalanced OT (only for transportation linking).
kwargs (Any): Additional keyword arguments. Includes deprecated parameters for backwards compatibility.
- epsPrev: Deprecated parameter for eps_prev. Use eps_prev instead.
- minClSz: Deprecated parameter for min_clustersize. Use min_clustersize instead.
Expand Down Expand Up @@ -609,6 +593,11 @@ def __init__(
self._eps_prev = eps
else:
self._eps_prev = eps_prev

self._reg = reg
self._reg_m = reg_m
self._cost_threshold = cost_threshold

self._n_jobs = n_jobs
self._validate_input(eps, eps_prev, min_clustersize, min_samples, clustering_method, n_prev, n_jobs)

Expand Down Expand Up @@ -679,10 +668,12 @@ def _link_next_cluster(self, cluster: np.ndarray, cluster_coordinates: np.ndarra
cluster_coordinates=cluster_coordinates,
memory_cluster_labels=self._memory.all_cluster_ids,
memory_coordinates=self._memory.all_coordinates,
memory_kdtree=self._nn_tree,
epsPrev=self._eps_prev,
max_cluster_label=self._memory.max_prev_cluster_id,
supply_range=[1, 10],
demand_range=[1, 10],
reg=self._reg,
reg_m=self._reg_m,
cost_threshold=self._cost_threshold,
)
else:
raise ValueError(f'Linking method must be (for now) in {AVAILABLE_LINKING_METHODS}')
Expand Down Expand Up @@ -1218,6 +1209,9 @@ def track_events_image(
n_prev: int = 1,
predictor: bool | Callable = False,
linking_method: str = 'nearest',
reg: float = 1,
reg_m: float = 10,
cost_threshold: float = 0,
dims: str = "TXY",
downsample: int = 1,
n_jobs: int = 1,
Expand All @@ -1240,6 +1234,9 @@ def track_events_image(
True uses the default predictor. A callable can be passed to use a custom predictor.
See default predictor method for details.
linking_method (str): The method used for linking. Default is 'nearest'.
reg (float): Entropy regularization parameter for unbalanced OT algorithm (only for transportation linking).
reg_m (float): Marginal relaxation parameter for unbalanced OT (only for transportation linking).
cost_threshold (float): Threshold for filtering low-probability matches (only for transportation linking).
dims (str): String of dimensions in order, such as. Default is "TXY". Possible values are "T", "X", "Y", "Z".
downsample (int): Factor by which to downsample the image. Default is 1.
n_jobs (int): Number of jobs to run in parallel. Default is 1.
Expand Down Expand Up @@ -1303,6 +1300,9 @@ def track_events_image(
linking_method=linking_method,
n_prev=n_prev,
predictor=predictor,
reg=reg,
reg_m=reg_m,
cost_threshold=cost_threshold,
n_jobs=n_jobs,
)
tracker = ImageTracker(linker, downsample=downsample)
Expand Down
Loading

0 comments on commit 54d641b

Please sign in to comment.