From 1e60350b325d7cefe367e8a6ce6045fc453d5599 Mon Sep 17 00:00:00 2001 From: Talmo Date: Thu, 12 Dec 2019 20:05:31 -0500 Subject: [PATCH 01/25] Better chaining of batchwise ops for inference classes - Slightly slower but much more memory efficient --- sleap/nn/peak_finding.py | 9 ++++----- sleap/nn/region_proposal.py | 7 +++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 0a2a019a0..913de5d1a 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -404,12 +404,11 @@ def postproc(self, confmaps): def predict_rps(self, rps: "RegionProposalSet") -> RegionPeakSet: - imgs = self.preproc(rps.patches) - - confmaps = utils.batched_call(self.inference, imgs, batch_size=self.batch_size) - peak_subs_and_vals, batch_inds = utils.batched_call( - self.postproc, confmaps, batch_size=self.batch_size, return_batch_inds=True + lambda imgs: self.postproc(self.inference(self.preproc(imgs))), + rps.patches, + batch_size=self.batch_size, + return_batch_inds=True, ) # Split. diff --git a/sleap/nn/region_proposal.py b/sleap/nn/region_proposal.py index b57cce433..71905ee7a 100644 --- a/sleap/nn/region_proposal.py +++ b/sleap/nn/region_proposal.py @@ -330,8 +330,11 @@ def postproc(self, centroid_confmaps): return centroids, centroid_vals def predict(self, imgs): - imgs = self.preproc(imgs) - confmaps = utils.batched_call(self.inference, imgs, batch_size=self.batch_size) + confmaps = utils.batched_call( + lambda imgs: self.inference(self.preproc(imgs)), + imgs, + batch_size=self.batch_size, + ) return self.postproc(confmaps) From 5cd416d0fa195b68fce74b8cd89008cabdda3275 Mon Sep 17 00:00:00 2001 From: Talmo Date: Fri, 13 Dec 2019 19:03:14 -0500 Subject: [PATCH 02/25] TrainingJob.is_trained returns False if model is not found - Previously it would throw an exception from TrainingJob.model_path, which would break the inference GUI if there were empty folders from failed runs with existing training json files - Change TrainingJob.model_path final exception type to more appropriate FileNotFoundError --- sleap/nn/job.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/sleap/nn/job.py b/sleap/nn/job.py index a1407615c..548035699 100644 --- a/sleap/nn/job.py +++ b/sleap/nn/job.py @@ -222,15 +222,17 @@ def model_path(self): return model_path # Raise error if all fail. - raise ValueError(f"Could not find a saved model in run path: {self.run_path}") + raise FileNotFoundError(f"Could not find a saved model in run path: {self.run_path}") @property def is_trained(self): if self.run_path is None: return False - if os.path.exists(self.model_path): - return True - return False + + try: + return os.path.exists(self.model_path) + except FileNotFoundError: + return False @staticmethod def _to_dicts(training_job: "TrainingJob"): From 59ac6f039b83fb951c7ddbab4c3b08996a1e9cd6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 7 Jan 2020 10:58:03 -0500 Subject: [PATCH 03/25] Use thousands separators in frame numbers. --- sleap/gui/app.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 759f388d7..87b019500 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1011,10 +1011,10 @@ def updateStatusMessage(self, message: Optional[str] = None): frame_idx = self.state["frame_idx"] or 0 if message is None: - message = f"Frame: {frame_idx+1}/{len(current_video)}" + message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): start, end = self.state["frame_range"] - message += f" (selection: {start}-{end})" + message += f" (selection: {start:,}-{end:,})" if len(self.labels.videos) > 1: message += f" of video {self.labels.videos.index(current_video)}" From 491c13971ff831f4ed0ddc31aecb085513c3ab90 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 8 Jan 2020 16:19:47 -0500 Subject: [PATCH 04/25] Add command to select from current to frame N. --- sleap/config/shortcuts.yaml | 1 + sleap/gui/app.py | 5 ++++- sleap/gui/commands.py | 27 +++++++++++++++++++++++++++ sleap/gui/video.py | 4 ++++ 4 files changed, 36 insertions(+), 1 deletion(-) diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index 4e64205ed..4d08f96a8 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -26,6 +26,7 @@ prev video: QKeySequence.Back save as: QKeySequence.SaveAs save: Ctrl+S select next: '`' +select to frame: Ctrl+Shift+J show edges: Ctrl+Shift+Tab show labels: Ctrl+Tab show trails: diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 87b019500..9360a30e0 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -343,6 +343,9 @@ def prev_vid(): goMenu.addSeparator() add_menu_item(goMenu, "goto frame", "Go to Frame...", self.commands.gotoFrame) + add_menu_item( + goMenu, "select to frame", "Select to Frame...", self.commands.selectToFrame + ) ### View Menu ### @@ -1014,7 +1017,7 @@ def updateStatusMessage(self, message: Optional[str] = None): message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): start, end = self.state["frame_range"] - message += f" (selection: {start:,}-{end:,})" + message += f" (selection: {start+1:,}-{end+1:,})" if len(self.labels.videos) > 1: message += f" of video {self.labels.videos.index(current_video)}" diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 71011aeb6..9e18341d6 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -249,6 +249,10 @@ def gotoFrame(self): """Shows gui to go to frame by number.""" self.execute(GoFrameGui) + def selectToFrame(self): + """Shows gui to go to frame by number.""" + self.execute(SelectToFrameGui) + def gotoVideoAndFrame(self, video: Video, frame_idx: int): """Activates video and goes to frame.""" NavCommand.go_to(self, frame_idx, video) @@ -846,6 +850,29 @@ def ask(cls, context: "CommandContext", params: dict) -> bool: return okay +class SelectToFrameGui(NavCommand): + @classmethod + def do_action(cls, context: "CommandContext", params: dict): + context.app.player.setSeekbarSelection( + params["from_frame_idx"], params["to_frame_idx"] + ) + + @classmethod + def ask(cls, context: "CommandContext", params: dict) -> bool: + frame_number, okay = QtWidgets.QInputDialog.getInt( + context.app, + "Select To Frame...", + "Frame Number:", + context.state["frame_idx"] + 1, + 1, + context.state["video"].frames, + ) + params["from_frame_idx"] = context.state["frame_idx"] + params["to_frame_idx"] = frame_number - 1 + + return okay + + # Editing Commands diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 2ba437ae7..a34c5bf7f 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -130,6 +130,10 @@ def update_selection_state(a, b): if video is not None: self.load_video(video) + def setSeekbarSelection(self, a: int, b: int): + self.seekbar.startSelection(a) + self.seekbar.endSelection(b, update=True) + def show_contextual_menu(self, where: QtCore.QPoint): if not self.is_menu_enabled: return From fa0fb160f55ef2d065bbaca3d89566a10bad3f26 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Wed, 8 Jan 2020 16:21:04 -0500 Subject: [PATCH 05/25] Improvements to seekbar slider UI. - Zoom by alt/option drag over range of frames - Tick marks w/ frame number labels - Limit to how many tracks to stack vertically before wrapping to top --- sleap/gui/slider.py | 297 +++++++++++++++++++++++++++++++-------- tests/gui/test_slider.py | 4 +- 2 files changed, 242 insertions(+), 59 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 11e672ac8..ed576e14c 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -2,7 +2,7 @@ Drop-in replacement for QSlider with additional features. """ -from PySide2 import QtCore, QtWidgets +from PySide2 import QtCore, QtWidgets, QtGui from PySide2.QtGui import QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath from sleap.gui.color import ColorManager @@ -25,6 +25,7 @@ class SliderMark: * "open" (single value) * "predicted" (single value) * "track" (range of values) + * "tick" (single value) val: Beginning of mark range end_val: End of mark range (for "track" marks) row: The row that the mark goes in; used for tracks. @@ -42,7 +43,13 @@ class SliderMark: @property def color(self): """Returns color of mark.""" - colors = dict(simple="black", filled="blue", open="blue", predicted="yellow") + colors = dict( + simple="black", + filled="blue", + open="blue", + predicted="yellow", + tick="lightGray", + ) if self.type in colors: return colors[self.type] @@ -71,6 +78,27 @@ def filled(self): else: return True + @property + def padded(self): + """Returns whether mark has top and bottom padding.""" + if self.type == "tick": + return False + else: + return True + + @property + def visual_width(self): + return 2 if self.type in ("open", "filled", "tick") else 0 + + def get_height(self, box_height): + if self.type == "track": + return 1.5 + height = box_height + if self.padded: + height -= 4 # 2 (top) + 2 (bottom) + + return height + class VideoSlider(QtWidgets.QGraphicsView): """Drop-in replacement for QSlider with additional features. @@ -116,7 +144,7 @@ def __init__( marks=None, color_manager: Optional[ColorManager] = None, *args, - **kwargs + **kwargs, ): super(VideoSlider, self).__init__(*args, **kwargs) @@ -125,7 +153,7 @@ def __init__( self.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop) self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) - self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) + self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) self.setVerticalScrollBarPolicy( QtCore.Qt.ScrollBarAlwaysOff ) # ScrollBarAsNeeded @@ -135,14 +163,24 @@ def __init__( self.zoom_factor = 1 self._track_rows = 0 - self._track_height = 3 - self._header_height = 0 + self._track_height = 5 + self._max_tracks_stacked = 120 + self._track_stack_skip_count = 10 + self._header_label_height = 20 + self._header_graph_height = 30 + self._header_height = self._header_label_height # room for frame labels self._min_height = 19 + self._header_height + self._base_font = QtGui.QFont() + self._base_font.setPixelSize(10) + + self._tick_marks = [] + # Add border rect - outline_rect = QtCore.QRect(0, 0, 200, self._min_height - 3) - self.outlineBox = self.scene.addRect(outline_rect) - self.outlineBox.setPen(QPen(QColor("black"))) + outline_rect = QtCore.QRectF(0, 0, 200, self._min_height - 3) + self.setBoxRect(outline_rect) + # self.outlineBox = self.scene.addRect(outline_rect) + # self.outlineBox.setPen(QPen(QColor("black", alpha=0))) # Add drag handle rect handle_width = 6 @@ -163,6 +201,13 @@ def __init__( self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() + self.zoom_box = self.scene.addRect( + QtCore.QRect(0, 1, 0, outline_rect.height() - 2) + ) + self.zoom_box.setPen(QPen(QColor(80, 80, 80, 64))) + self.zoom_box.setBrush(QColor(80, 80, 80, 64)) + self.zoom_box.hide() + self.scene.setBackgroundBrush(QBrush(QColor(200, 200, 200))) self.clearSelection() @@ -252,14 +297,14 @@ def setHeaderSeries(self, series: Optional[Dict[int, float]] = None): None. """ self.headerSeries = [] if series is None else series - self._header_height = 30 + self._header_height = self._header_label_height + self._header_graph_height self.drawHeader() self.updateHeight() def clearHeader(self): """Remove header graph from slider.""" self.headerSeries = [] - self._header_height = 0 + self._header_height = self._header_label_height self.updateHeight() def setTracks(self, track_rows): @@ -271,9 +316,7 @@ def setTracks(self, track_rows): self._track_rows = track_rows self.updateHeight() - def updateHeight(self): - """Update the height of the slider.""" - + def getMinMaxHeights(self): tracks = self._track_rows if tracks == 0: min_height = self._min_height @@ -286,12 +329,19 @@ def updateHeight(self): # Add height for tracks min_height += self._track_height * min(tracks, 20) - max_height += self._track_height * tracks + max_height += self._track_height * min(tracks, self._max_tracks_stacked) # Make sure min/max height is at least 19, even if few tracks min_height = max(self._min_height, min_height) max_height = max(self._min_height, max_height) + return min_height, max_height + + def updateHeight(self): + """Update the height of the slider.""" + + min_height, max_height = self.getMinMaxHeights() + self.setMaximumHeight(max_height) self.setMinimumHeight(min_height) @@ -333,7 +383,7 @@ def _toVal(self, x: float, center=False) -> float: def _sliderWidth(self) -> float: """Returns visual width of slider.""" - return self.outlineBox.rect().width() - self.handle.rect().width() + return self.getBoxRect().width() - self.handle.rect().width() def value(self) -> float: """Returns value of slider.""" @@ -353,6 +403,10 @@ def setMaximum(self, max: float) -> float: """Sets maximum value for slider.""" self._val_max = max + @property + def value_range(self) -> float: + return self._val_max - self._val_min + def setEnabled(self, val: float) -> float: """Set whether the slider is enabled.""" self._enabled = val @@ -413,9 +467,16 @@ def getSelection(self): return start, end def drawSelection(self, a: float, b: float): - """Draws selection box on slider. + self.updateSelectionBoxPositions(self.select_box, a, b) + + def drawZoomBox(self, a: float, b: float): + self.updateSelectionBoxPositions(self.zoom_box, a, b) + + def updateSelectionBoxPositions(self, box_object, a: float, b: float): + """Update box item on slider. Args: + box_object: The box to update a: one endpoint value b: other endpoint value @@ -426,12 +487,24 @@ def drawSelection(self, a: float, b: float): end = max(a, b) start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) - selection_rect = QtCore.QRect( - start_pos, 1, end_pos - start_pos, self.outlineBox.rect().height() - 2 + box_rect = QtCore.QRect( + start_pos, + self._header_height, + end_pos - start_pos, + self.getBoxRect().height(), ) - self.select_box.setRect(selection_rect) - self.select_box.show() + box_object.setRect(box_rect) + box_object.show() + + def updateSelectionBoxesOnResize(self): + for box_object in (self.select_box, self.zoom_box): + rect = box_object.rect() + rect.setHeight(self._handleHeight()) + box_object.setRect(rect) + + if self.select_box.isVisible(): + self.drawSelection(*self.getSelection()) def moveSelectionAnchor(self, x: float, y: float): """ @@ -445,7 +518,7 @@ def moveSelectionAnchor(self, x: float, y: float): None. """ x = max(x, 0) - x = min(x, self.outlineBox.rect().width()) + x = min(x, self.getBoxRect().width()) anchor_val = self._toVal(x, center=True) if len(self._selection) % 2 == 0: @@ -465,17 +538,65 @@ def releaseSelectionAnchor(self, x, y): None. """ x = max(x, 0) - x = min(x, self.outlineBox.rect().width()) + x = min(x, self.getBoxRect().width()) anchor_val = self._toVal(x) self.endSelection(anchor_val) + def moveZoomDrag(self, x: float, y: float): + if getattr(self, "_zoom_start_val", None) is None: + self._zoom_start_val = self._toVal(x, center=True) + + current_val = self._toVal(x, center=True) + + self.drawZoomBox(current_val, self._zoom_start_val) + + def releaseZoomDrag(self, x, y): + + self.zoom_box.hide() + + val_a = self._zoom_start_val + val_b = self._toVal(x, center=True) + + val_start = min(val_a, val_b) + val_end = max(val_a, val_b) + + # pad the zoom + val_range = val_end - val_start + val_start -= val_range * 0.05 + val_end += val_range * 0.05 + + self.setZoomRange(val_start, val_end) + + self._zoom_start_val = None + + def setZoomRange(self, start_val: float, end_val: float): + + zoom_val_range = end_val - start_val + if zoom_val_range > 0: + self.zoom_factor = self.value_range / zoom_val_range + else: + self.zoom_factor = 1 + + self.resizeEvent() + + center_val = start_val + zoom_val_range / 2 + center_pos = self._toPos(center_val) + + self.centerOn(center_pos, 0) + def clearMarks(self): """Clears all marked values for slider.""" if hasattr(self, "_mark_items"): for item in self._mark_items.values(): self.scene.removeItem(item) + + if hasattr(self, "_mark_labels"): + for item in self._mark_labels.values(): + self.scene.removeItem(item) + self._marks = set() # holds mark position self._mark_items = dict() # holds visual Qt object for plotting mark + self._mark_labels = dict() def setMarks(self, marks: Iterable[Union[SliderMark, int]]): """Sets all marked values for the slider. @@ -487,15 +608,37 @@ def setMarks(self, marks: Iterable[Union[SliderMark, int]]): None. """ self.clearMarks() + + # Add tick marks first so they're behind other marks + self._add_tick_marks() + if marks is not None: for mark in marks: if not isinstance(mark, SliderMark): mark = SliderMark("simple", mark) self.addMark(mark, update=False) + self.updatePos() - def getMarks(self): + def _add_tick_marks(self): + val_range = self._val_max - self._val_min + val_order = 10 + while val_range // val_order > 10: + val_order *= 10 + + self._tick_marks = [] + + for tick_pos in range(self._val_min + val_order - 1, self._val_max, val_order): + self._tick_marks.append(SliderMark("tick", tick_pos)) + + for tick_mark in self._tick_marks: + self.addMark(tick_mark, update=False) + + def getMarks(self, type: str = ""): """Returns list of marks.""" + if type: + return [mark for mark in self._marks if mark.type == type] + return self._marks def addMark(self, new_mark: SliderMark, update: bool = True): @@ -516,19 +659,25 @@ def addMark(self, new_mark: SliderMark, update: bool = True): self._marks.add(new_mark) - v_top_pad = 3 + self._header_height - v_bottom_pad = 3 + v_top_pad = self._header_height + 1 + v_bottom_pad = 1 + if new_mark.padded: + v_top_pad += 2 + v_bottom_pad += 2 - width = 0 + width = new_mark.visual_width + v_offset = v_top_pad if new_mark.type == "track": - v_offset = v_top_pad + (self._track_height * new_mark.row) - height = 1 - else: - v_offset = v_top_pad - height = self.outlineBox.rect().height() - (v_offset + v_bottom_pad) + if new_mark.row < self._max_tracks_stacked: + v_offset += self._track_height * new_mark.row + else: + rows_down = new_mark.row - self._max_tracks_stacked + rows_down %= self._max_tracks_stacked - self._track_stack_skip_count + v_offset += self._track_height * self._track_stack_skip_count + v_offset += self._track_height * rows_down - width = 2 if new_mark.type in ("open", "filled") else 0 + height = new_mark.get_height(box_height=self.getBoxRect().height()) color = new_mark.QColor pen = QPen(color, 0.5) @@ -537,6 +686,13 @@ def addMark(self, new_mark: SliderMark, update: bool = True): line = self.scene.addRect(-width // 2, v_offset, width, height, pen, brush) self._mark_items[new_mark] = line + + if new_mark.type == "tick": + mark_label_text = f"{new_mark.val + 1:g}" # sci notation if large + self._mark_labels[new_mark] = self.scene.addSimpleText( + mark_label_text, self._base_font + ) + if update: self.updatePos() @@ -547,19 +703,25 @@ def updatePos(self): for mark in self._mark_items.keys(): - width = 0 if mark.type == "track": width_in_frames = mark.end_val - mark.val width = max(2, self._toPos(width_in_frames)) - elif mark.type in ("open", "filled"): - width = 2 + else: + width = mark.visual_width x = self._toPos(mark.val, center=True) self._mark_items[mark].setPos(x, 0) + if mark in self._mark_labels: + label_x = max( + 0, x - self._mark_labels[mark].boundingRect().width() // 2 + ) + self._mark_labels[mark].setPos(label_x, 4) + rect = self._mark_items[mark].rect() rect.setWidth(width) + rect.setHeight(mark.get_height(box_height=self.getBoxRect().height())) self._mark_items[mark].setRect(rect) @@ -604,7 +766,7 @@ def drawHeader(self): series_min = np.min(sampled) - 1 series_max = np.max(sampled) - series_scale = (self._header_height - 5) / (series_max - series_min) + series_scale = (self._header_graph_height) / (series_max - series_min) def toYPos(val): return self._header_height - ((val - series_min) * series_scale) @@ -636,7 +798,7 @@ def moveHandle(self, x, y): """ x -= self.handle.rect().width() / 2.0 x = max(x, 0) - x = min(x, self.outlineBox.rect().width() - self.handle.rect().width()) + x = min(x, self.getBoxRect().width() - self.handle.rect().width()) val = self._toVal(x) @@ -657,33 +819,53 @@ def moveHandle(self, x, y): if old != val: self.valueChanged.emit(self._val_main) + def getBoxRect(self): + # return self.outlineBox.rect() + return self._box_rect + + def setBoxRect(self, rect): + # self.outlineBox.setRect(rect) + self._box_rect = rect + + # Update the scene rect so that it matches how much space we + # currently want for drawing everything. + rect.setWidth(rect.width() - 1) + self.setSceneRect(rect) + + def getMarkAreaHeight(self): + _, max_height = self.getMinMaxHeights() + return max_height - 3 - self._header_height + def resizeEvent(self, event=None): """Override method to update visual size when necessary. Args: event """ - height = self.size().height() - outline_rect = self.outlineBox.rect() + outline_rect = self.getBoxRect() handle_rect = self.handle.rect() - select_box_rect = self.select_box.rect() - outline_rect.setHeight(height - 3) + outline_rect.setHeight(self.getMarkAreaHeight()) if event is not None: visual_width = event.size().width() - 1 - outline_rect.setWidth(visual_width * self.zoom_factor) - self.outlineBox.setRect(outline_rect) + else: + visual_width = self.width() - 1 + + drawn_width = visual_width * self.zoom_factor + + outline_rect.setWidth(drawn_width) + self.setBoxRect(outline_rect) handle_rect.setTop(self._handleTop()) handle_rect.setHeight(self._handleHeight()) self.handle.setRect(handle_rect) - select_box_rect.setHeight(self._handleHeight()) - self.select_box.setRect(select_box_rect) + self.updateSelectionBoxesOnResize() self.updatePos() self.drawHeader() + super(VideoSlider, self).resizeEvent(event) def _handleTop(self) -> float: @@ -702,14 +884,7 @@ def _handleHeight(self, outline_rect=None) -> float: Returns: Height of handle in pixels. """ - if outline_rect is None: - outline_rect = self.outlineBox.rect() - - handle_bottom_offset = 1 - handle_height = outline_rect.height() - ( - self._handleTop() + handle_bottom_offset - ) - return handle_height + return self.getMarkAreaHeight() def mousePressEvent(self, event): """Override method to move handle for mouse press/drag. @@ -723,7 +898,7 @@ def mousePressEvent(self, event): if not self.enabled(): return # Do nothing if click outside slider area - if not self.outlineBox.rect().contains(scenePos): + if not self.getBoxRect().contains(scenePos): return move_function = None @@ -739,6 +914,13 @@ def mousePressEvent(self, event): move_function = self.moveHandle release_function = None + elif event.modifiers() == QtCore.Qt.AltModifier: + move_function = self.moveZoomDrag + release_function = self.releaseZoomDrag + + else: + event.accept() # mouse events shouldn't be passed to video widgets + # Connect to signals if move_function is not None: self.mouseMoved.connect(move_function) @@ -746,7 +928,8 @@ def mousePressEvent(self, event): def done(x, y): if release_function is not None: release_function(x, y) - self.mouseMoved.disconnect(move_function) + if move_function is not None: + self.mouseMoved.disconnect(move_function) self.mouseReleased.disconnect(done) self.mouseReleased.connect(done) @@ -777,7 +960,7 @@ def keyReleaseEvent(self, event): def boundingRect(self) -> QtCore.QRectF: """Method required by Qt.""" - return self.outlineBox.rect() + return self.getBoxRect() def paint(self, *args, **kwargs): """Method required by Qt.""" diff --git a/tests/gui/test_slider.py b/tests/gui/test_slider.py index 0d05b057b..788d21265 100644 --- a/tests/gui/test_slider.py +++ b/tests/gui/test_slider.py @@ -23,8 +23,8 @@ def test_slider(qtbot, centered_pair_predictions): assert slider.maximumHeight() != initial_height slider.setTracksFromLabels(labels, labels.videos[0]) - assert len(slider.getMarks()) == 40 + assert len(slider.getMarks("track")) == 40 slider.moveSelectionAnchor(5, 5) slider.releaseSelectionAnchor(100, 15) - assert slider.getSelection() == (31, 619) + assert slider.getSelection() == (slider._toVal(5), slider._toVal(100)) From dd9398aac1dde16669c7f9828b98ff236111b091 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 9 Jan 2020 10:12:53 -0500 Subject: [PATCH 06/25] Line in seekbar marking new track column. --- sleap/gui/slider.py | 76 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 15 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index ed576e14c..6ac36014c 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -10,7 +10,7 @@ import attr import itertools import numpy as np -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union @attr.s(auto_attribs=True, cmp=False) @@ -26,6 +26,7 @@ class SliderMark: * "predicted" (single value) * "track" (range of values) * "tick" (single value) + * "tick_column" (single value) val: Beginning of mark range end_val: End of mark range (for "track" marks) row: The row that the mark goes in; used for tracks. @@ -49,6 +50,7 @@ def color(self): open="blue", predicted="yellow", tick="lightGray", + tick_column="gray", ) if self.type in colors: @@ -81,21 +83,41 @@ def filled(self): @property def padded(self): """Returns whether mark has top and bottom padding.""" - if self.type == "tick": + if self.type in ("tick", "tick_column"): return False else: return True + @property + def top_pad(self): + if self.type == "tick_column": + return 40 + if self.type == "tick": + return 0 + return 2 + + @property + def bottom_pad(self): + if self.type == "tick_column": + return 200 + if self.type == "tick": + return 0 + return 2 + @property def visual_width(self): - return 2 if self.type in ("open", "filled", "tick") else 0 + if self.type in ("open", "filled", "tick"): + return 2 + if self.type in ("tick_column"): + return 1 + return 0 def get_height(self, box_height): if self.type == "track": return 1.5 height = box_height - if self.padded: - height -= 4 # 2 (top) + 2 (bottom) + # if self.padded: + height -= self.top_pad + self.bottom_pad return height @@ -254,6 +276,10 @@ def setTracksFromLabels(self, labels: "Labels", video: "Video"): track_occupancy = labels.get_track_occupany(video) for track in labels.tracks: if track in track_occupancy and not track_occupancy[track].is_empty: + if track_row > 0 and self.isNewColTrack(track_row): + slider_marks.append( + SliderMark("tick_column", val=track_occupancy[track].start) + ) for occupancy_range in track_occupancy[track].list: slider_marks.append( SliderMark( @@ -661,21 +687,14 @@ def addMark(self, new_mark: SliderMark, update: bool = True): v_top_pad = self._header_height + 1 v_bottom_pad = 1 - if new_mark.padded: - v_top_pad += 2 - v_bottom_pad += 2 + v_top_pad += new_mark.top_pad + v_bottom_pad += new_mark.bottom_pad width = new_mark.visual_width v_offset = v_top_pad if new_mark.type == "track": - if new_mark.row < self._max_tracks_stacked: - v_offset += self._track_height * new_mark.row - else: - rows_down = new_mark.row - self._max_tracks_stacked - rows_down %= self._max_tracks_stacked - self._track_stack_skip_count - v_offset += self._track_height * self._track_stack_skip_count - v_offset += self._track_height * rows_down + v_offset += self.getTrackVerticalPos(*self.getTrackColRow(new_mark.row)) height = new_mark.get_height(box_height=self.getBoxRect().height()) @@ -696,6 +715,33 @@ def addMark(self, new_mark: SliderMark, update: bool = True): if update: self.updatePos() + def getTrackColRow(self, raw_row: int) -> Tuple[int, int]: + if raw_row < self._max_tracks_stacked: + return 0, raw_row + + else: + rows_after_first_col = raw_row - self._max_tracks_stacked + rows_per_later_cols = ( + self._max_tracks_stacked - self._track_stack_skip_count + ) + + rows_down = rows_after_first_col % rows_per_later_cols + col = (rows_after_first_col // rows_per_later_cols) + 1 + + return col, rows_down + + def getTrackVerticalPos(self, col: int, row: int) -> int: + if col == 0: + return row * self._track_height + else: + return (self._track_height * self._track_stack_skip_count) + ( + self._track_height * row + ) + + def isNewColTrack(self, row: int) -> bool: + _, row_down = self.getTrackColRow(row) + return row_down == 0 + def updatePos(self): """Update the visual x position of handle and slider annotations.""" x = self._toPos(self.value()) From 0c2469822e517ea4f8c51664271e782927cc5dbc Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 9 Jan 2020 11:19:52 -0500 Subject: [PATCH 07/25] Update tick marks when zoom changes. --- sleap/gui/slider.py | 53 ++++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 6ac36014c..b4320017d 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -80,14 +80,6 @@ def filled(self): else: return True - @property - def padded(self): - """Returns whether mark has top and bottom padding.""" - if self.type in ("tick", "tick_column"): - return False - else: - return True - @property def top_pad(self): if self.type == "tick_column": @@ -176,9 +168,7 @@ def __init__( self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) - self.setVerticalScrollBarPolicy( - QtCore.Qt.ScrollBarAlwaysOff - ) # ScrollBarAsNeeded + self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) self._color_manager = color_manager @@ -411,6 +401,11 @@ def _sliderWidth(self) -> float: """Returns visual width of slider.""" return self.getBoxRect().width() - self.handle.rect().width() + @property + def slider_visible_value_range(self) -> float: + """Value range that's visible given current size and zoom.""" + return self._toVal(self.width() - 1) + def value(self) -> float: """Returns value of slider.""" return self._val_main @@ -646,10 +641,23 @@ def setMarks(self, marks: Iterable[Union[SliderMark, int]]): self.updatePos() + def setTickMarks(self): + """Resets which tick marks to show.""" + self._clear_tick_marks() + self._add_tick_marks() + + def _clear_tick_marks(self): + if not hasattr(self, "_tick_marks"): + return + + for mark in self._tick_marks: + self.removeMark(mark) + def _add_tick_marks(self): - val_range = self._val_max - self._val_min + val_range = self.slider_visible_value_range + val_order = 10 - while val_range // val_order > 10: + while val_range // val_order > 24: val_order *= 10 self._tick_marks = [] @@ -660,6 +668,17 @@ def _add_tick_marks(self): for tick_mark in self._tick_marks: self.addMark(tick_mark, update=False) + def removeMark(self, mark: SliderMark): + """Removes an individual mark.""" + if mark in self._mark_labels: + self.scene.removeItem(self._mark_labels[mark]) + del self._mark_labels[mark] + if mark in self._mark_items: + self.scene.removeItem(self._mark_items[mark]) + del self._mark_items[mark] + if mark in self._marks: + self._marks.remove(mark) + def getMarks(self, type: str = ""): """Returns list of marks.""" if type: @@ -707,10 +726,17 @@ def addMark(self, new_mark: SliderMark, update: bool = True): self._mark_items[new_mark] = line if new_mark.type == "tick": + # Show tick mark behind other slider marks + self._mark_items[new_mark].setZValue(0) + + # Add a text label to show in header area mark_label_text = f"{new_mark.val + 1:g}" # sci notation if large self._mark_labels[new_mark] = self.scene.addSimpleText( mark_label_text, self._base_font ) + else: + # Show in front of tick marks + self._mark_items[new_mark].setZValue(1) if update: self.updatePos() @@ -909,6 +935,7 @@ def resizeEvent(self, event=None): self.updateSelectionBoxesOnResize() + self.setTickMarks() self.updatePos() self.drawHeader() From c6ad4fe2e7b6cda04f812141d4b71c098adaa5e1 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 9 Jan 2020 13:23:55 -0500 Subject: [PATCH 08/25] Show number/percent of predicted frames in status bar. --- sleap/gui/app.py | 15 ++++++++++++++- sleap/instance.py | 5 +++++ sleap/io/dataset.py | 10 ++++++++++ 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 9360a30e0..ecc304c25 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1013,6 +1013,8 @@ def updateStatusMessage(self, message: Optional[str] = None): current_video = self.state["video"] frame_idx = self.state["frame_idx"] or 0 + spacer = " " + if message is None: message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): @@ -1022,16 +1024,27 @@ def updateStatusMessage(self, message: Optional[str] = None): if len(self.labels.videos) > 1: message += f" of video {self.labels.videos.index(current_video)}" - message += f" Labeled Frames: " + message += f"{spacer}Labeled Frames: " if current_video is not None: message += ( f"{len(self.labels.get_video_user_labeled_frames(current_video))}" ) + if len(self.labels.videos) > 1: message += " in video, " if len(self.labels.videos) > 1: message += f"{len(self.labels.user_labeled_frames)} in project" + if current_video is not None: + pred_frame_count = len( + self.labels.get_video_predicted_frames(current_video) + ) + if pred_frame_count: + message += f"{spacer}Predicted Frames: {pred_frame_count:,}" + message += ( + f" ({pred_frame_count/current_video.num_frames*100:.2f}%)" + ) + self.statusBar().showMessage(message) def loadProjectFile(self, filename: Optional[str] = None): diff --git a/sleap/instance.py b/sleap/instance.py index f7188af86..0db6fa068 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1123,6 +1123,11 @@ def has_user_instances(self) -> bool: """Whether the frame contains any user instances.""" return len(self.user_instances) > 0 + @property + def has_predicted_instances(self) -> bool: + """Whether the frame contains any predicted instances.""" + return len(self.predicted_instances) > 0 + @property def unused_predictions(self) -> List[Instance]: """ diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 2f13e14e9..f401f7844 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -491,6 +491,16 @@ def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: if lf.has_user_instances and lf.video == video ] + def get_video_predicted_frames(self, video: Video) -> List[LabeledFrame]: + """ + Returns labeled frames for given video with user instances. + """ + return [ + lf + for lf in self.labeled_frames + if lf.has_predicted_instances and lf.video == video + ] + # Methods for instances def instance_count(self, video: Video, frame_idx: int) -> int: From 0ac3ee6fa5d88219612f871e0f112da5261b35b4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 9 Jan 2020 13:26:43 -0500 Subject: [PATCH 09/25] Double-click to select contiguously occupied frames. --- sleap/gui/slider.py | 92 +++++++++++++++++++++++++++++++++++++++++++++ sleap/gui/video.py | 3 +- 2 files changed, 93 insertions(+), 2 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index b4320017d..b213eceb8 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -473,6 +473,11 @@ def endSelection(self, val, update: bool = False): # Emit signal (even if user selected same region as before) self.selectionChanged.emit(*self.getSelection()) + def setSelection(self, start_val, end_val): + """Selects clip from start_val to end_val.""" + self.startSelection(start_val) + self.endSelection(end_val, update=True) + def hasSelection(self) -> bool: """Returns True if a clip is selected, False otherwise.""" a, b = self.getSelection() @@ -891,6 +896,75 @@ def moveHandle(self, x, y): if old != val: self.valueChanged.emit(self._val_main) + def contiguousSelectionMarksAroundVal(self, val): + """Selects contiguously marked frames around value.""" + if not self.isMarkedVal(val): + return + + last_val = val + dec_val = self.decrementContiguousMarkedVal(last_val) + while dec_val < last_val and dec_val > self._val_min: + last_val = dec_val + dec_val = self.decrementContiguousMarkedVal(last_val) + + last_val = val + inc_val = self.incrementContiguousMarkedVal(last_val) + while inc_val > last_val and inc_val < self._val_max: + last_val = inc_val + inc_val = self.incrementContiguousMarkedVal(last_val) + + self.setSelection(dec_val, inc_val) + + def isMarkedVal(self, val): + """Returns whether value has mark.""" + if val in [mark.val for mark in self._marks]: + return True + if any( + mark.val <= val < mark.end_val + for mark in self._marks + if mark.type == "track" + ): + return True + return False + + def decrementContiguousMarkedVal(self, val): + """Decrements value within contiguously marked range if possible.""" + dec_val = min( + ( + mark.val + for mark in self._marks + if mark.type == "track" and mark.val < val <= mark.end_val + ), + default=val, + ) + if dec_val < val: + return dec_val + + if val - 1 in [mark.val for mark in self._marks]: + return val - 1 + + # Return original value if we can't decrement it w/in contiguous range + return val + + def incrementContiguousMarkedVal(self, val): + """Increments value within contiguously marked range if possible.""" + inc_val = max( + ( + mark.end_val - 1 + for mark in self._marks + if mark.type == "track" and mark.val <= val < mark.end_val + ), + default=val, + ) + if inc_val > val: + return inc_val + + if val + 1 in [mark.val for mark in self._marks]: + return val + 1 + + # Return original value if we can't decrement it w/in contiguous range + return val + def getBoxRect(self): # return self.outlineBox.rect() return self._box_rect @@ -1021,6 +1095,24 @@ def mouseReleaseEvent(self, event): scenePos = self.mapToScene(event.pos()) self.mouseReleased.emit(scenePos.x(), scenePos.y()) + def mouseDoubleClickEvent(self, event): + """Override method to move handle for mouse double-click. + + Args: + event + """ + scenePos = self.mapToScene(event.pos()) + + # Do nothing if not enabled + if not self.enabled(): + return + # Do nothing if click outside slider area + if not self.getBoxRect().contains(scenePos): + return + + if event.modifiers() == QtCore.Qt.ShiftModifier: + self.contiguousSelectionMarksAroundVal(self._toVal(scenePos.x())) + def keyPressEvent(self, event): """Catch event and emit signal so something else can handle event.""" self.keyPress.emit(event) diff --git a/sleap/gui/video.py b/sleap/gui/video.py index a34c5bf7f..28098bcf7 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -131,8 +131,7 @@ def update_selection_state(a, b): self.load_video(video) def setSeekbarSelection(self, a: int, b: int): - self.seekbar.startSelection(a) - self.seekbar.endSelection(b, update=True) + self.seekbar.setSelection(a, b) def show_contextual_menu(self, where: QtCore.QPoint): if not self.is_menu_enabled: From 6cac93e16dd0df50099c9eb3e961adfaaa3384fe Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Thu, 9 Jan 2020 13:51:17 -0500 Subject: [PATCH 10/25] Keys to skip to first/last contiguously occupied frame. --- sleap/gui/slider.py | 11 ++++++++++- sleap/gui/video.py | 8 ++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index b213eceb8..589af8a2c 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -901,19 +901,28 @@ def contiguousSelectionMarksAroundVal(self, val): if not self.isMarkedVal(val): return + dec_val = self.getStartContiguousMark(val) + inc_val = self.getEndContiguousMark(val) + + self.setSelection(dec_val, inc_val) + + def getStartContiguousMark(self, val): last_val = val dec_val = self.decrementContiguousMarkedVal(last_val) while dec_val < last_val and dec_val > self._val_min: last_val = dec_val dec_val = self.decrementContiguousMarkedVal(last_val) + return dec_val + + def getEndContiguousMark(self, val): last_val = val inc_val = self.incrementContiguousMarkedVal(last_val) while inc_val > last_val and inc_val < self._val_max: last_val = inc_val inc_val = self.incrementContiguousMarkedVal(last_val) - self.setSelection(dec_val, inc_val) + return inc_val def isMarkedVal(self, val): """Returns whether value has mark.""" diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 28098bcf7..ebde5010b 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -450,6 +450,14 @@ def keyPressEvent(self, event: QKeyEvent): elif event.key() == Qt.Key.Key_Escape: self.view.click_mode = "" self.state["instance"] = None + elif event.key() == Qt.Key.Key_K: + self.state["frame_idx"] = self.seekbar.getEndContiguousMark( + self.state["frame_idx"] + ) + elif event.key() == Qt.Key.Key_J: + self.state["frame_idx"] = self.seekbar.getStartContiguousMark( + self.state["frame_idx"] + ) elif event.key() == Qt.Key.Key_QuoteLeft: self.state.increment_in_list("instance", self.selectable_instances) elif event.key() < 128 and chr(event.key()).isnumeric(): From 52e0025069e06a309cfd1ec9948f314c4a707de2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 10 Jan 2020 08:59:53 -0500 Subject: [PATCH 11/25] Add UI to replace videos in project. Resolves issue #263. --- sleap/gui/app.py | 3 +++ sleap/gui/commands.py | 31 +++++++++++++++++++++++++++++++ sleap/gui/missingfiles.py | 18 +++++++++++++----- sleap/io/video.py | 24 ++++++++++++++++++++++++ 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/sleap/gui/app.py b/sleap/gui/app.py index ecc304c25..af6a82b53 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -280,6 +280,9 @@ def add_submenu_choices(menu, title, options, key): fileMenu.addSeparator() add_menu_item(fileMenu, "add videos", "Add Videos...", self.commands.addVideo) + add_menu_item( + fileMenu, "replace videos", "Replace Videos...", self.commands.replaceVideo + ) fileMenu.addSeparator() add_menu_item(fileMenu, "save", "Save", self.commands.saveProject) diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 9e18341d6..d754513cd 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -263,6 +263,10 @@ def addVideo(self): """Shows gui for adding videos to project.""" self.execute(AddVideo) + def replaceVideo(self): + """Shows gui for replacing videos to project.""" + self.execute(ReplaceVideo) + def removeVideo(self): """Removes selected video from project.""" self.execute(RemoveVideo) @@ -909,6 +913,33 @@ def ask(context: CommandContext, params: dict) -> bool: return len(params["import_list"]) > 0 +class ReplaceVideo(EditCommand): + topics = [UpdateTopic.video] + + @staticmethod + def do_action(context: CommandContext, params: dict): + new_paths = params["new_video_paths"] + + for video, new_path in zip(context.labels.videos, new_paths): + if new_path != video.backend.filename: + video.backend.filename = new_path + video.backend.reset() + + @staticmethod + def ask(context: CommandContext, params: dict) -> bool: + """Shows gui for replacing videos in project.""" + paths = [video.backend.filename for video in context.labels.videos] + + okay = MissingFilesDialog(filenames=paths, replace=True).exec_() + + if not okay: + return False + + params["new_video_paths"] = paths + + return True + + class RemoveVideo(EditCommand): topics = [UpdateTopic.video] diff --git a/sleap/gui/missingfiles.py b/sleap/gui/missingfiles.py index 683e33147..073d17f65 100644 --- a/sleap/gui/missingfiles.py +++ b/sleap/gui/missingfiles.py @@ -14,7 +14,12 @@ class MissingFilesDialog(QtWidgets.QDialog): def __init__( - self, filenames: List[str], missing: List[bool] = None, *args, **kwargs + self, + filenames: List[str], + missing: List[bool] = None, + replace: bool = False, + *args, + **kwargs, ): """ Creates dialog window for finding missing files. @@ -42,10 +47,13 @@ def __init__( layout = QtWidgets.QVBoxLayout() - info_text = ( - f"{missing_count} file(s) which could not be found. " - "Please double-click on a file to locate it..." - ) + if replace: + info_text = "Double-click on a file to replace it..." + else: + info_text = ( + f"{missing_count} file(s) which could not be found. " + "Please double-click on a file to locate it..." + ) info_label = QtWidgets.QLabel(info_text) layout.addWidget(info_label) diff --git a/sleap/io/video.py b/sleap/io/video.py index f0a872b18..228c58a8c 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -188,6 +188,11 @@ def last_frame_idx(self) -> int: return last_key return self.frames - 1 + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx) -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -343,6 +348,10 @@ def dtype(self): """See :class:`Video`.""" return self.__test_frame.dtype + def reset(self): + """Reloads the video.""" + self._reader_ = None + def get_frame(self, idx: int, grayscale: bool = None) -> np.ndarray: """See :class:`Video`.""" if self.__reader.get(cv2.CAP_PROP_POS_FRAMES) != idx: @@ -436,6 +445,11 @@ def dtype(self): """See :class:`Video`.""" return self.__data.dtype + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx): """See :class:`Video`.""" return self.__data[idx] @@ -557,6 +571,11 @@ def last_frame_idx(self) -> int: return self.__store.frame_max return self.frames - 1 + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, frame_number: int) -> np.ndarray: """ Get a frame from the underlying ImgStore video data. @@ -733,6 +752,11 @@ def dtype(self): """See :class:`Video`.""" return self.__data.dtype + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx): """See :class:`Video`.""" if idx not in self.__data: From 1697695312a6264f77848ceb5de760df0eb91df6 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 10 Jan 2020 10:53:30 -0500 Subject: [PATCH 12/25] Bug fix to heights for slider w/o tracks. --- sleap/gui/slider.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 589af8a2c..33ce33c71 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -104,10 +104,10 @@ def visual_width(self): return 1 return 0 - def get_height(self, box_height): + def get_height(self, container_height): if self.type == "track": return 1.5 - height = box_height + height = container_height # if self.padded: height -= self.top_pad + self.bottom_pad @@ -720,7 +720,9 @@ def addMark(self, new_mark: SliderMark, update: bool = True): if new_mark.type == "track": v_offset += self.getTrackVerticalPos(*self.getTrackColRow(new_mark.row)) - height = new_mark.get_height(box_height=self.getBoxRect().height()) + height = new_mark.get_height( + container_height=self.getBoxRect().height() - self._header_height + ) color = new_mark.QColor pen = QPen(color, 0.5) @@ -798,7 +800,11 @@ def updatePos(self): rect = self._mark_items[mark].rect() rect.setWidth(width) - rect.setHeight(mark.get_height(box_height=self.getBoxRect().height())) + rect.setHeight( + mark.get_height( + container_height=self.getBoxRect().height() - self._header_height + ) + ) self._mark_items[mark].setRect(rect) @@ -1001,7 +1007,8 @@ def resizeEvent(self, event=None): outline_rect = self.getBoxRect() handle_rect = self.handle.rect() - outline_rect.setHeight(self.getMarkAreaHeight()) + outline_rect.setHeight(self.getMarkAreaHeight() + self._header_height) + if event is not None: visual_width = event.size().width() - 1 else: From ffd4109563e5130ea8232b54948c6ca8d6891bbc Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 10 Jan 2020 16:27:10 -0500 Subject: [PATCH 13/25] Use worker thread to load images in background. Resolves issues #264. --- sleap/gui/video.py | 149 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 130 insertions(+), 19 deletions(-) diff --git a/sleap/gui/video.py b/sleap/gui/video.py index ebde5010b..e5196d7bb 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -12,6 +12,9 @@ """ +FORCE_REQUEST_AFTER_TIME_IN_SECONDS = 1 + + from PySide2 import QtWidgets, QtCore from PySide2.QtWidgets import ( @@ -26,7 +29,9 @@ from PySide2.QtGui import QKeyEvent from PySide2.QtCore import Qt, QRectF, QPointF, QMarginsF, QLineF +import atexit import math +import time from typing import Callable, Dict, List, Optional, Tuple, Union @@ -34,7 +39,6 @@ from PySide2.QtWidgets import ( QGraphicsEllipseItem, - QGraphicsLineItem, QGraphicsTextItem, QGraphicsRectItem, QGraphicsPolygonItem, @@ -50,6 +54,72 @@ import qimage2ndarray +class LoadImageWorker(QtCore.QObject): + """ + Object to load video frames in background thread. + """ + + result = QtCore.Signal(QImage) + process = QtCore.Signal() + + load_queue = [] + video = None + _last_process_time = 0 + + def __init__(self, *args, **kwargs): + super(LoadImageWorker, self).__init__(*args, **kwargs) + + # Connect signal to processing function so that we can add processing + # event to event queue from the request handler. + self.process.connect(self.doProcessing) + + # Start timer which will trigger processing events when we're free + self.timer = QtCore.QTimer() + self.timer.timeout.connect(self.doProcessing) + self.timer.start(0) + + def doProcessing(self): + self._last_process_time = time.time() + + if not self.load_queue: + return + + frame_idx = self.load_queue[-1] + self.load_queue = [] + + # print(f"\t{frame_idx} starting to load") # DEBUG + + try: + # Get image data + frame = self.video.get_frame(frame_idx) + except: + frame = None + + if frame is not None: + # Convert ndarray to QImage + qimage = qimage2ndarray.array2qimage(frame) + + # print(f"\t{frame_idx} result") # DEBUG + + # Emit result + self.result.emit(qimage) + + def request(self, frame_idx): + # Add request to the queue so that we can just process the most recent. + self.load_queue.append(frame_idx) + + # If we haven't processed a request for a certain amount of time, + # then trigger a processing event now. This helps when the user has been + # continuously changing frames for a while (i.e., dragging on seekbar + # or holding down arrow key). + + since_last = time.time() - self._last_process_time + + if since_last > FORCE_REQUEST_AFTER_TIME_IN_SECONDS: + self._last_process_time = time.time() + self.process.emit() + + class QtVideoPlayer(QWidget): """ Main QWidget for displaying video with skeleton instances. @@ -65,6 +135,7 @@ class QtVideoPlayer(QWidget): """ changedPlot = QtCore.Signal(QWidget, int, Instance) + requestImage = QtCore.Signal(int) def __init__( self, @@ -110,6 +181,20 @@ def __init__( lambda e: self.state.set("frame_idx", self.seekbar.value()) ) + # Make worker thread to load images in the background + self._loader_thread = QtCore.QThread() + self._video_image_loader = LoadImageWorker() + self._video_image_loader.moveToThread(self._loader_thread) + self._loader_thread.start() + + # Connect signal so that image will be shown after it's loaded + self._video_image_loader.result.connect( + lambda qimage: self.view.setImage(qimage) + ) + + # Connect request signals from self to worker + self.requestImage.connect(self._video_image_loader.request) + def update_selection_state(a, b): self.state.set("frame_range", (a, b)) self.state.set("has_frame_range", (a < b)) @@ -127,9 +212,31 @@ def update_selection_state(a, b): self.view.show() + # Call cleanup method when application exits to end worker thread + self.destroyed.connect(self.cleanup) + atexit.register(self.cleanup) + if video is not None: self.load_video(video) + def cleanup(self): + self._loader_thread.quit() + self._loader_thread.wait() + + def _load_and_show_requested_image(self, frame_idx): + # Get image data + try: + frame = self.video.get_frame(frame_idx) + except: + frame = None + + if frame is not None: + # Convert ndarray to QImage + qimage = qimage2ndarray.array2qimage(frame) + + # Display image + self.view.setImage(qimage) + def setSeekbarSelection(self, a: int, b: int): self.seekbar.setSelection(a, b) @@ -241,24 +348,15 @@ def plot(self, *args): idx = self.state["frame_idx"] or 0 - # Get image data - try: - frame = self.video.get_frame(idx) - except: - frame = None - - if frame is not None: - # Clear existing objects - self.view.clear() - - # Convert ndarray to QImage - image = qimage2ndarray.array2qimage(frame) + # Clear exiting objects before drawing instances + self.view.clear() - # Display image - self.view.setImage(image) + # Emit signal for the instances to be drawn for this frame + self.changedPlot.emit(self, idx, self.state["instance"]) - # Emit signal - self.changedPlot.emit(self, idx, self.state["instance"]) + # Emit signal for the image to loaded and shown for this frame + self._video_image_loader.video = self.video + self.requestImage.emit(idx) def showLabels(self, show): """ Show/hide node labels for all instances in viewer. @@ -554,9 +652,17 @@ def hasImage(self) -> bool: def clear(self): """ Clears the displayed frame from the scene. """ - self._pixmapHandle = None + + if self._pixmapHandle: + # get the pixmap currently shown + pixmap = self._pixmapHandle.pixmap() + self.scene.clear() + if self._pixmapHandle: + # add the pixmap back + self._pixmapHandle = self.scene.addPixmap(pixmap) + def setImage(self, image: Union[QImage, QPixmap]): """ Set the scene's current image pixmap to the input QImage or QPixmap. @@ -582,7 +688,12 @@ def setImage(self, image: Union[QImage, QPixmap]): self._pixmapHandle.setPixmap(pixmap) else: self._pixmapHandle = self.scene.addPixmap(pixmap) - self.setSceneRect(QRectF(pixmap.rect())) # Set scene size to image size. + + # Ensure that image is behind everything else + self._pixmapHandle.setZValue(-1) + + # Set scene size to image size. + self.setSceneRect(QRectF(pixmap.rect())) self.updateViewer() def updateViewer(self): From f020d0562ecf26d95c77872581da2b9125bc73ab Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Fri, 10 Jan 2020 16:46:24 -0500 Subject: [PATCH 14/25] Accept list of search paths instead of callback. If you pass a list of strings as the video_callback to load_file(), we'll create a non-gui callback with these strings as the search paths for finding missing videos. --- sleap/io/dataset.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index f401f7844..be94d795e 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -1417,6 +1417,8 @@ def load_json( :class:`Video` objects. Usually you'll want to pass a callback created by :meth:`make_video_callback` or :meth:`make_gui_video_callback`. + Alternately, if you pass a list of strings we'll construct a + non-gui callback with those strings as the search paths. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly @@ -1496,6 +1498,13 @@ def load_json( tmp_dir, vid["backend"]["filename"] ) + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = cls.make_video_callback(search_paths) + # Use the callback if given to handle missing videos if callable(video_callback): abort = video_callback(dicts["videos"]) @@ -1848,6 +1857,8 @@ def load_hdf5( :class:`Video` objects. Usually you'll want to pass a callback created by :meth:`make_video_callback` or :meth:`make_gui_video_callback`. + Alternately, if you pass a list of strings we'll construct a + non-gui callback with those strings as the search paths. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly @@ -1877,6 +1888,13 @@ def load_hdf5( if video_item["backend"]["filename"] == ".": video_item["backend"]["filename"] = filename + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = cls.make_video_callback(search_paths) + # Use the callback if given to handle missing videos if callable(video_callback): video_callback(dicts["videos"]) From 8fd65484f1e17216a52f4404e0905336590471cf Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 13 Jan 2020 08:47:57 -0500 Subject: [PATCH 15/25] Add test for passing search paths to Labels.load_file(). --- tests/io/test_dataset.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index dcda548a4..4e733937b 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -781,8 +781,7 @@ def test_many_tracks_hdf5(tmpdir): labels = Labels() filename = os.path.join(tmpdir, "test.h5") - labels.tracks = [Track(spawned_on=i, name=f"track {i}") - for i in range(4000)] + labels.tracks = [Track(spawned_on=i, name=f"track {i}") for i in range(4000)] Labels.save_hdf5(filename=filename, labels=labels) @@ -791,8 +790,7 @@ def test_many_videos_hdf5(tmpdir): labels = Labels() filename = os.path.join(tmpdir, "test.h5") - labels.videos = [Video.from_filename(f"video {i}.mp4") - for i in range(3000)] + labels.videos = [Video.from_filename(f"video {i}.mp4") for i in range(3000)] Labels.save_hdf5(filename=filename, labels=labels) @@ -806,3 +804,20 @@ def test_many_suggestions_hdf5(tmpdir): labels.suggestions = [SuggestionFrame(video, i) for i in range(3000)] Labels.save_hdf5(filename=filename, labels=labels) + + +def test_path_fix(tmpdir): + labels = Labels() + filename = os.path.join(tmpdir, "test.h5") + + # Add a video without a full path + labels.add_video(Video.from_filename("small_robot.mp4")) + + Labels.save_hdf5(filename=filename, labels=labels) + + # Pass the path to the directory with the video + labels = Labels.load_file(filename, video_callback=["tests/data/videos/"]) + + # Make sure we got the actual video path by searching that directory + assert len(labels.videos) == 1 + assert labels.videos[0].filename == "tests/data/videos/small_robot.mp4" From c16a966365a9e1d651ae3e4d30c6393641ba717e Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 13 Jan 2020 09:59:22 -0500 Subject: [PATCH 16/25] Remove cuda from explicit environments. We added these so that we could pip install tensorflow on Windows, but since anaconda now has tf 2 we're back to using conda for tf. Explicitly requiring cuda means we download them even when we don't want gpu support--i.e., during CI testing. --- environment.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/environment.yml b/environment.yml index a86365a11..c41f4fee0 100644 --- a/environment.yml +++ b/environment.yml @@ -5,8 +5,6 @@ dependencies: - attrs - jsonpickle=1.2 - networkx -- cudatoolkit=10.0.* -- cudnn - tensorflow-gpu=2.0 - scikit-learn - h5py From 174487ba237eb6d6824aa4c6cf90414c50d44361 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Mon, 13 Jan 2020 11:44:10 -0500 Subject: [PATCH 17/25] Speed-up to track overlay. We were looping over each track for each frame, now we loop over all instances for each frame for adding them to appropriate track (which is much faster if there are lots of tracks for the video). --- sleap/gui/overlays/tracks.py | 61 +++++++++++++++++++----------------- tests/gui/test_tracks.py | 7 +++-- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index c4521bc65..033efeb6a 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -9,10 +9,12 @@ import attr import itertools -from typing import List +from typing import Iterable, List from PySide2 import QtCore, QtGui +MAX_NODES_IN_TRAIL = 30 + @attr.s(auto_attribs=True) class TrackTrailOverlay: @@ -37,42 +39,48 @@ class TrackTrailOverlay: trail_length: int = 10 show: bool = False - def get_track_trails(self, frame_selection, track: Track): + def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]): """Get data needed to draw track trail. Args: - frame_selection: an interable with the :class:`LabeledFrame` + frame_selection: an iterable with the :class:`LabeledFrame` objects to include in trail. - track: the :class:`Track` for which to get trail Returns: - list of lists of (x, y) tuples + Dictionary keyed by track, value is list of lists of (x, y) tuples i.e., for every node in instance, we get a list of positions """ - all_trails = [[] for _ in range(len(self.labels.nodes))] + all_track_trails = dict() + + nodes = self.labels.nodes + if len(nodes) > MAX_NODES_IN_TRAIL: + nodes = nodes[:MAX_NODES_IN_TRAIL] for frame in frame_selection: - frame_idx = frame.frame_idx - inst_on_track = [instance for instance in frame if instance.track == track] - if inst_on_track: - # just use the first instance from this track in this frame - inst = inst_on_track[0] - # loop through all nodes - for node_i, node in enumerate(self.labels.nodes): + for inst in frame: + if inst.track is not None: + if inst.track not in all_track_trails: + all_track_trails[inst.track] = [[] for _ in range(len(nodes))] + + # loop through all nodes + for node_i, node in enumerate(nodes): + + if node in inst.nodes and inst[node].visible: + point = (inst[node].x, inst[node].y) - if node in inst.nodes and inst[node].visible: - point = (inst[node].x, inst[node].y) - elif len(all_trails[node_i]): - point = all_trails[node_i][-1] - else: - point = None + # Add last location of node so that we can easily + # calculate trail length (since we adjust opacity). + elif len(all_track_trails[inst.track][node_i]): + point = all_track_trails[inst.track][node_i][-1] + else: + point = None - if point is not None: - all_trails[node_i].append(point) + if point is not None: + all_track_trails[inst.track][node_i].append(point) - return all_trails + return all_track_trails def get_frame_selection(self, video: Video, frame_idx: int): """ @@ -116,17 +124,14 @@ def add_to_scene(self, video: Video, frame_idx: int): video: current video frame_idx: index of the frame to which the trail is attached """ - if not self.show: + if not self.show or self.trail_length == 0: return frame_selection = self.get_frame_selection(video, frame_idx) - tracks_in_frame = self.get_tracks_in_frame( - video, frame_idx, include_trails=True - ) - for track in tracks_in_frame: + all_track_trails = self.get_track_trails(frame_selection) - trails = self.get_track_trails(frame_selection, track) + for track, trails in all_track_trails.items(): color = QtGui.QColor(*self.player.color_manager.get_track_color(track)) pen = QtGui.QPen() diff --git a/tests/gui/test_tracks.py b/tests/gui/test_tracks.py index 7e03117e1..7889eb1b3 100644 --- a/tests/gui/test_tracks.py +++ b/tests/gui/test_tracks.py @@ -15,10 +15,13 @@ def test_track_trails(centered_pair_predictions): assert tracks[0].name == "1" assert tracks[1].name == "2" - tracks_with_trails = trail_manager.get_tracks_in_frame(labels.videos[0], 27, include_trails=True) + tracks_with_trails = trail_manager.get_tracks_in_frame( + labels.videos[0], 27, include_trails=True + ) assert len(tracks_with_trails) == 13 - trails = trail_manager.get_track_trails(frames, tracks[0]) + all_trails = trail_manager.get_track_trails(frames) + trails = all_trails[tracks[0]] assert len(trails) == 24 From 96d7308cd9ff096d25a35c1658cf5cab5db4c062 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 21 Jan 2020 12:41:07 -0500 Subject: [PATCH 18/25] Hotfix for case where there are no centroids detected in inference. --- sleap/nn/inference.py | 15 ++++++++++++--- sleap/nn/region_proposal.py | 4 ++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index db7987177..b6b067f31 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -100,6 +100,7 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): """Runs the inference components of pipeline for a chunk.""" if "centroid" in self.policies: + # Detect centroids and pull out region proposals. centroid_predictor = self.policies["centroid"] region_proposal_extractor = self.policies["region"] @@ -114,21 +115,26 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): region_proposal.RegionProposalSet.from_uncropped(images=img_chunk) ] - # for region_ind in range(len(region_proposal_sets)): - # region_proposal_sets[region_ind].sample_inds += chunk_ind * chunk_size - if "paf" not in self.policies: + # If we don't have PAFs, we must be doing topdown or single-instance. if "topdown" in self.policies: topdown_peak_finder = self.policies["topdown"] else: topdown_peak_finder = self.policies["confmap"] + if len(region_proposal_sets) == 0: + # No region proposals were found, so just return empty result. + return defaultdict(list) + + # Only one scale region proposals in topdown inference. rps = region_proposal_sets[0] + # Find peaks in each sample of the region proposal set. sample_peak_pts, sample_peak_vals = topdown_peak_finder.predict_rps(rps) sample_peak_pts = sample_peak_pts.to_tensor().numpy() sample_peak_vals = sample_peak_vals.to_tensor().numpy() + # Gather instances across all region proposals for this chunk. predicted_instances_chunk = topdown.make_sample_grouped_predicted_instances( sample_peak_pts, sample_peak_vals, @@ -140,16 +146,19 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): confmap_peak_finder = self.policies["confmap"] paf_grouper = self.policies["paf"] + # Find multi-instance peaks in each region proposal set. region_peak_sets = [] for rps in region_proposal_sets: region_peaks = confmap_peak_finder.predict_rps(rps) region_peak_sets.append(region_peaks) + # Group peaks into instances using PAFs. region_instance_sets = [] for rps, region_peaks in zip(region_proposal_sets, region_peak_sets): region_instances = paf_grouper.predict_rps(rps, region_peaks) region_instance_sets.append(region_instances) + # Gather instances across all region proposals for this chunk. predicted_instances_chunk = defaultdict(list) for region_instance_set in region_instance_sets: for sample, region_instances in region_instance_set.items(): diff --git a/sleap/nn/region_proposal.py b/sleap/nn/region_proposal.py index 71905ee7a..3b6b0b847 100644 --- a/sleap/nn/region_proposal.py +++ b/sleap/nn/region_proposal.py @@ -467,6 +467,10 @@ def extract_region_proposal_sets( sample_inds = size_grouped_sample_inds[box_size] bboxes = size_grouped_bboxes[box_size] + if len(bboxes) == 0: + # Skip region proposal set if we found no bboxes. + continue + # Extract image patches for all regions in the set. patches = extract_patches( imgs, tf.cast(bboxes, tf.float32), tf.cast(sample_inds, tf.int32) From 3353ba267fac73628d1b5fc3c5990df3971eb7f4 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 21 Jan 2020 13:03:51 -0500 Subject: [PATCH 19/25] Bug fix, topdown now returns (nan, nan) points. The points should have been at (nan, nan) but were being returned as (0, 0). The problem happened when the RaggedTensor was converted to a regular tensor (and by default 0 is used for missing values). Resolves issue #272. --- sleap/nn/inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index b6b067f31..a99823f9e 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -131,8 +131,9 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): # Find peaks in each sample of the region proposal set. sample_peak_pts, sample_peak_vals = topdown_peak_finder.predict_rps(rps) - sample_peak_pts = sample_peak_pts.to_tensor().numpy() - sample_peak_vals = sample_peak_vals.to_tensor().numpy() + + sample_peak_pts = sample_peak_pts.to_tensor(default_value=np.nan).numpy() + sample_peak_vals = sample_peak_vals.to_tensor(default_value=np.nan).numpy() # Gather instances across all region proposals for this chunk. predicted_instances_chunk = topdown.make_sample_grouped_predicted_instances( From d49e986a09c40c06ad4b7f7c32acb7d6d8d0dbf2 Mon Sep 17 00:00:00 2001 From: Talmo Date: Tue, 21 Jan 2020 13:55:20 -0500 Subject: [PATCH 20/25] Tiny typo in data augmentation generator leading to no rotation aug. --- sleap/nn/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sleap/nn/data.py b/sleap/nn/data.py index 320c73453..3c5367a58 100644 --- a/sleap/nn/data.py +++ b/sleap/nn/data.py @@ -355,7 +355,7 @@ def augment_dataset( # Setup augmenter. aug_stack = [] if rotate: - aug_stack.append(iaa.Affine(rotate=(-rotation_min_angle, rotation_max_angle))) + aug_stack.append(iaa.Affine(rotate=(rotation_min_angle, rotation_max_angle))) if scale: aug_stack.append(iaa.Affine(scale=(scale_min, scale_max))) From 35b0aecc506fc842073caf9fa85fb15d916a6842 Mon Sep 17 00:00:00 2001 From: Talmo Date: Tue, 21 Jan 2020 14:02:48 -0500 Subject: [PATCH 21/25] Temporary monkey patch for imgaug bug with numpy >= 1.18. --- sleap/nn/data.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sleap/nn/data.py b/sleap/nn/data.py index 3c5367a58..21cb4b243 100644 --- a/sleap/nn/data.py +++ b/sleap/nn/data.py @@ -1,6 +1,12 @@ """This module contains utilities for data I/O and generating training data.""" import numpy as np + +# Monkey patch for: https://github.com/aleju/imgaug/issues/537 +# TODO: Fix when new version of imgaug is available on PyPI. +import numpy +numpy.random.bit_generator = numpy.random._bit_generator + import h5py import tensorflow as tf import imgaug as ia From 9d35e66d950b1572527adb55ec7215db97d39ab0 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 21 Jan 2020 14:28:18 -0500 Subject: [PATCH 22/25] Bug fix for topdown choice in editor. --- sleap/config/training_editor.yaml | 2 +- sleap/gui/training_editor.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sleap/config/training_editor.yaml b/sleap/config/training_editor.yaml index c126abd93..007af5a2d 100644 --- a/sleap/config/training_editor.yaml +++ b/sleap/config/training_editor.yaml @@ -3,7 +3,7 @@ model: - name: output_type label: Training For type: list - options: confmaps,topdown,pafs,centroids + options: confmaps,topdown_confidence_maps,pafs,centroids default: confmaps - name: arch # backbone_name diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index 72684945f..fed8fcab2 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -130,7 +130,7 @@ def _save_as(self): confmaps=ModelOutputType.CONFIDENCE_MAP, pafs=ModelOutputType.PART_AFFINITY_FIELD, centroids=ModelOutputType.CENTROIDS, - topdown=ModelOutputType.TOPDOWN_CONFIDENCE_MAP, + topdown_confidence_maps=ModelOutputType.TOPDOWN_CONFIDENCE_MAP, )[model_data["output_type"]] backbone_kwargs = { From 281b729743b8ee6bbd3592306a6872f655ecb6d0 Mon Sep 17 00:00:00 2001 From: Talmo Pereira Date: Tue, 21 Jan 2020 14:42:26 -0500 Subject: [PATCH 23/25] Fix for monkey patch backward compatibility (numpy < 1.17) --- sleap/nn/data.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sleap/nn/data.py b/sleap/nn/data.py index 21cb4b243..e0481fea3 100644 --- a/sleap/nn/data.py +++ b/sleap/nn/data.py @@ -5,7 +5,8 @@ # Monkey patch for: https://github.com/aleju/imgaug/issues/537 # TODO: Fix when new version of imgaug is available on PyPI. import numpy -numpy.random.bit_generator = numpy.random._bit_generator +if hasattr(numpy.random, "_bit_generator"): + numpy.random.bit_generator = numpy.random._bit_generator import h5py import tensorflow as tf From f7cc4623e6e6b51bbb602932e74b55c35593c4a1 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 21 Jan 2020 14:44:43 -0500 Subject: [PATCH 24/25] Bug fix, don't always retrain topdown. --- sleap/gui/inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/sleap/gui/inference.py b/sleap/gui/inference.py index 8ff8493a9..b47c4f786 100644 --- a/sleap/gui/inference.py +++ b/sleap/gui/inference.py @@ -363,6 +363,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]: # Use already trained model if desired if form_data.get(f"_use_trained_{str(model_type)}", default_use_trained): job.use_trained_model = True + elif model_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP: + if form_data.get(f"_use_trained_confmaps", default_use_trained): + job.use_trained_model = True # Clear parameters that shouldn't be copied job.val_set_filename = None From 7aab3afe70132e48a8ba2a230b3e2396c85239c2 Mon Sep 17 00:00:00 2001 From: Nat Tabris Date: Tue, 21 Jan 2020 15:24:06 -0500 Subject: [PATCH 25/25] Load video as grayscale if models have one channel --- sleap/nn/inference.py | 17 +++++++++++++++++ sleap/nn/utils.py | 17 +++++++++++------ tests/nn/test_utils.py | 9 +++++++++ 3 files changed, 37 insertions(+), 6 deletions(-) create mode 100644 tests/nn/test_utils.py diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index a99823f9e..32d7e3d48 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -73,6 +73,10 @@ def predict( if video_kwargs is None: video_kwargs = dict() + # Detect whether to use grayscale video by looking at trained models. + if self.has_grayscale_models: + video_kwargs["grayscale"] = True + video_ds = utils.VideoLoader( filename=video_filename, frame_inds=frames, **video_kwargs ) @@ -223,6 +227,19 @@ def make_labeled_frames( return frames + @property + def has_grayscale_models(self): + for policy_name in ("centroid", "topdown", "confmap"): + if policy_name in self.policies: + if ( + self.policies[policy_name].inference_model.input_tensor.shape[-1] + == 1 + ): + return True + else: + return False + return False + @classmethod def from_cli_args(cls): parser = cls.make_cli_parser() diff --git a/sleap/nn/utils.py b/sleap/nn/utils.py index 6733a5587..39831f991 100644 --- a/sleap/nn/utils.py +++ b/sleap/nn/utils.py @@ -308,6 +308,7 @@ class VideoLoader: filename: str dataset: str = None input_format: str = None + grayscale: bool = False chunk_size: int = 32 prefetch_chunks: int = 1 frame_inds: Optional[List[int]] = None @@ -351,18 +352,22 @@ def tf_dtype(self): def __attrs_post_init__(self): - self._video = Video.from_filename( - self.filename, dataset=self.dataset, input_format=self.input_format - ) + self._video = self._load_video(self.filename) self._shape = self.video.shape self._np_dtype = self.video.dtype self._tf_dtype = tf.dtypes.as_dtype(self.np_dtype) self._ds = self.make_ds() - def load_frames(self, frame_inds): - local_vid = Video.from_filename( - self.video.filename, dataset=self.dataset, input_format=self.input_format + def _load_video(self, filename) -> "Video": + return Video.from_filename( + filename, + dataset=self.dataset, + input_format=self.input_format, + grayscale=self.grayscale, ) + + def load_frames(self, frame_inds): + local_vid = self._load_video(self.video.filename) imgs = local_vid[np.array(frame_inds).astype("int64")] return imgs diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py new file mode 100644 index 000000000..8701a77a5 --- /dev/null +++ b/tests/nn/test_utils.py @@ -0,0 +1,9 @@ +from sleap.nn.utils import VideoLoader + + +def test_grayscale_video(): + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4",) + assert vid.shape[-1] == 3 + + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True) + assert vid.shape[-1] == 1