Skip to content

Commit

Permalink
allowing timing measurement
Browse files Browse the repository at this point in the history
  • Loading branch information
leoguignard committed Jul 20, 2023
1 parent 08dfa01 commit 93f6424
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/sc3D/sc3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from seaborn import scatterplot
import json
from pathlib import Path
from time import time

import anndata
from sc3D.transformations import transformations as tr
Expand Down Expand Up @@ -537,7 +538,7 @@ def build_pairing(self, cs1, cs2, rebuild=False, refine=False, th_d=None):
return pos_ref, pos_flo

def register_cs(
self, cs1, cs2, refine=False, rigid=False, final=False, th_d=None
self, cs1, cs2, refine=False, rigid=False, final=False, th_d=None, timing=False
):
"""
Registers the puck `cs2` onto the puck `cs1`.
Expand All @@ -560,8 +561,16 @@ def register_cs(
Usually used as a float.
"""
if timing:
start = time()
if not hasattr(self, "timing"):
self.timing = {}
current_cs_timing = self.timing.setdefault((cs1, cs2), {})
if self.registered_pos is None:
self.register_with_tissues()
if timing:
current_cs_timing["register_with_tissues"] = time() - start
start = time()
if (self.final is None) and final:
self.final = {
c: self.centered_pos[c]
Expand All @@ -570,9 +579,15 @@ def register_cs(
pos_ref, pos_flo = self.build_pairing(
cs1, cs2, rebuild=False, refine=refine, th_d=th_d
)
if timing:
current_cs_timing["build_pairing"] = time(time() - start)
start = time()
M = self.register(
np.array(pos_ref), np.array(pos_flo), apply=False, rigid=rigid
)
if timing:
current_cs_timing["register"] = time(time() - start)
start = time()
cells_cs2 = self.cells_from_cover_slip[cs2]
if refine:
positions_cs2 = np.array([self.pos_reg_aff[c] for c in cells_cs2])
Expand All @@ -588,6 +603,9 @@ def register_cs(
self.pos_reg_aff.update(zip(cells_cs2, new_pos))
if final:
self.final.update(zip(cells_cs2, new_pos))
if timing:
current_cs_timing["apply"] = time(time() - start)
start = time()
return M

@staticmethod
Expand Down Expand Up @@ -1052,8 +1070,6 @@ def registration_3d(
if self.z_pos is None or set(self.z_pos) != set(self.all_cells):
self.set_zpos()
if timing:
from time import time

start = current_time = time()
times = []
self.trsfs = {}
Expand Down

0 comments on commit 93f6424

Please sign in to comment.