diff --git a/environment.yml b/environment.yml index a86365a11..c41f4fee0 100644 --- a/environment.yml +++ b/environment.yml @@ -5,8 +5,6 @@ dependencies: - attrs - jsonpickle=1.2 - networkx -- cudatoolkit=10.0.* -- cudnn - tensorflow-gpu=2.0 - scikit-learn - h5py diff --git a/sleap/config/shortcuts.yaml b/sleap/config/shortcuts.yaml index 4e64205ed..4d08f96a8 100644 --- a/sleap/config/shortcuts.yaml +++ b/sleap/config/shortcuts.yaml @@ -26,6 +26,7 @@ prev video: QKeySequence.Back save as: QKeySequence.SaveAs save: Ctrl+S select next: '`' +select to frame: Ctrl+Shift+J show edges: Ctrl+Shift+Tab show labels: Ctrl+Tab show trails: diff --git a/sleap/config/training_editor.yaml b/sleap/config/training_editor.yaml index c126abd93..007af5a2d 100644 --- a/sleap/config/training_editor.yaml +++ b/sleap/config/training_editor.yaml @@ -3,7 +3,7 @@ model: - name: output_type label: Training For type: list - options: confmaps,topdown,pafs,centroids + options: confmaps,topdown_confidence_maps,pafs,centroids default: confmaps - name: arch # backbone_name diff --git a/sleap/gui/app.py b/sleap/gui/app.py index 759f388d7..af6a82b53 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -280,6 +280,9 @@ def add_submenu_choices(menu, title, options, key): fileMenu.addSeparator() add_menu_item(fileMenu, "add videos", "Add Videos...", self.commands.addVideo) + add_menu_item( + fileMenu, "replace videos", "Replace Videos...", self.commands.replaceVideo + ) fileMenu.addSeparator() add_menu_item(fileMenu, "save", "Save", self.commands.saveProject) @@ -343,6 +346,9 @@ def prev_vid(): goMenu.addSeparator() add_menu_item(goMenu, "goto frame", "Go to Frame...", self.commands.gotoFrame) + add_menu_item( + goMenu, "select to frame", "Select to Frame...", self.commands.selectToFrame + ) ### View Menu ### @@ -1010,25 +1016,38 @@ def updateStatusMessage(self, message: Optional[str] = None): current_video = self.state["video"] frame_idx = self.state["frame_idx"] or 0 + spacer = " " + if message is None: - message = f"Frame: {frame_idx+1}/{len(current_video)}" + message = f"Frame: {frame_idx+1:,}/{len(current_video):,}" if self.player.seekbar.hasSelection(): start, end = self.state["frame_range"] - message += f" (selection: {start}-{end})" + message += f" (selection: {start+1:,}-{end+1:,})" if len(self.labels.videos) > 1: message += f" of video {self.labels.videos.index(current_video)}" - message += f" Labeled Frames: " + message += f"{spacer}Labeled Frames: " if current_video is not None: message += ( f"{len(self.labels.get_video_user_labeled_frames(current_video))}" ) + if len(self.labels.videos) > 1: message += " in video, " if len(self.labels.videos) > 1: message += f"{len(self.labels.user_labeled_frames)} in project" + if current_video is not None: + pred_frame_count = len( + self.labels.get_video_predicted_frames(current_video) + ) + if pred_frame_count: + message += f"{spacer}Predicted Frames: {pred_frame_count:,}" + message += ( + f" ({pred_frame_count/current_video.num_frames*100:.2f}%)" + ) + self.statusBar().showMessage(message) def loadProjectFile(self, filename: Optional[str] = None): diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 71011aeb6..d754513cd 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -249,6 +249,10 @@ def gotoFrame(self): """Shows gui to go to frame by number.""" self.execute(GoFrameGui) + def selectToFrame(self): + """Shows gui to go to frame by number.""" + self.execute(SelectToFrameGui) + def gotoVideoAndFrame(self, video: Video, frame_idx: int): """Activates video and goes to frame.""" NavCommand.go_to(self, frame_idx, video) @@ -259,6 +263,10 @@ def addVideo(self): """Shows gui for adding videos to project.""" self.execute(AddVideo) + def replaceVideo(self): + """Shows gui for replacing videos to project.""" + self.execute(ReplaceVideo) + def removeVideo(self): """Removes selected video from project.""" self.execute(RemoveVideo) @@ -846,6 +854,29 @@ def ask(cls, context: "CommandContext", params: dict) -> bool: return okay +class SelectToFrameGui(NavCommand): + @classmethod + def do_action(cls, context: "CommandContext", params: dict): + context.app.player.setSeekbarSelection( + params["from_frame_idx"], params["to_frame_idx"] + ) + + @classmethod + def ask(cls, context: "CommandContext", params: dict) -> bool: + frame_number, okay = QtWidgets.QInputDialog.getInt( + context.app, + "Select To Frame...", + "Frame Number:", + context.state["frame_idx"] + 1, + 1, + context.state["video"].frames, + ) + params["from_frame_idx"] = context.state["frame_idx"] + params["to_frame_idx"] = frame_number - 1 + + return okay + + # Editing Commands @@ -882,6 +913,33 @@ def ask(context: CommandContext, params: dict) -> bool: return len(params["import_list"]) > 0 +class ReplaceVideo(EditCommand): + topics = [UpdateTopic.video] + + @staticmethod + def do_action(context: CommandContext, params: dict): + new_paths = params["new_video_paths"] + + for video, new_path in zip(context.labels.videos, new_paths): + if new_path != video.backend.filename: + video.backend.filename = new_path + video.backend.reset() + + @staticmethod + def ask(context: CommandContext, params: dict) -> bool: + """Shows gui for replacing videos in project.""" + paths = [video.backend.filename for video in context.labels.videos] + + okay = MissingFilesDialog(filenames=paths, replace=True).exec_() + + if not okay: + return False + + params["new_video_paths"] = paths + + return True + + class RemoveVideo(EditCommand): topics = [UpdateTopic.video] diff --git a/sleap/gui/inference.py b/sleap/gui/inference.py index 8ff8493a9..b47c4f786 100644 --- a/sleap/gui/inference.py +++ b/sleap/gui/inference.py @@ -363,6 +363,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]: # Use already trained model if desired if form_data.get(f"_use_trained_{str(model_type)}", default_use_trained): job.use_trained_model = True + elif model_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP: + if form_data.get(f"_use_trained_confmaps", default_use_trained): + job.use_trained_model = True # Clear parameters that shouldn't be copied job.val_set_filename = None diff --git a/sleap/gui/missingfiles.py b/sleap/gui/missingfiles.py index 683e33147..073d17f65 100644 --- a/sleap/gui/missingfiles.py +++ b/sleap/gui/missingfiles.py @@ -14,7 +14,12 @@ class MissingFilesDialog(QtWidgets.QDialog): def __init__( - self, filenames: List[str], missing: List[bool] = None, *args, **kwargs + self, + filenames: List[str], + missing: List[bool] = None, + replace: bool = False, + *args, + **kwargs, ): """ Creates dialog window for finding missing files. @@ -42,10 +47,13 @@ def __init__( layout = QtWidgets.QVBoxLayout() - info_text = ( - f"{missing_count} file(s) which could not be found. " - "Please double-click on a file to locate it..." - ) + if replace: + info_text = "Double-click on a file to replace it..." + else: + info_text = ( + f"{missing_count} file(s) which could not be found. " + "Please double-click on a file to locate it..." + ) info_label = QtWidgets.QLabel(info_text) layout.addWidget(info_label) diff --git a/sleap/gui/overlays/tracks.py b/sleap/gui/overlays/tracks.py index c4521bc65..033efeb6a 100644 --- a/sleap/gui/overlays/tracks.py +++ b/sleap/gui/overlays/tracks.py @@ -9,10 +9,12 @@ import attr import itertools -from typing import List +from typing import Iterable, List from PySide2 import QtCore, QtGui +MAX_NODES_IN_TRAIL = 30 + @attr.s(auto_attribs=True) class TrackTrailOverlay: @@ -37,42 +39,48 @@ class TrackTrailOverlay: trail_length: int = 10 show: bool = False - def get_track_trails(self, frame_selection, track: Track): + def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]): """Get data needed to draw track trail. Args: - frame_selection: an interable with the :class:`LabeledFrame` + frame_selection: an iterable with the :class:`LabeledFrame` objects to include in trail. - track: the :class:`Track` for which to get trail Returns: - list of lists of (x, y) tuples + Dictionary keyed by track, value is list of lists of (x, y) tuples i.e., for every node in instance, we get a list of positions """ - all_trails = [[] for _ in range(len(self.labels.nodes))] + all_track_trails = dict() + + nodes = self.labels.nodes + if len(nodes) > MAX_NODES_IN_TRAIL: + nodes = nodes[:MAX_NODES_IN_TRAIL] for frame in frame_selection: - frame_idx = frame.frame_idx - inst_on_track = [instance for instance in frame if instance.track == track] - if inst_on_track: - # just use the first instance from this track in this frame - inst = inst_on_track[0] - # loop through all nodes - for node_i, node in enumerate(self.labels.nodes): + for inst in frame: + if inst.track is not None: + if inst.track not in all_track_trails: + all_track_trails[inst.track] = [[] for _ in range(len(nodes))] + + # loop through all nodes + for node_i, node in enumerate(nodes): + + if node in inst.nodes and inst[node].visible: + point = (inst[node].x, inst[node].y) - if node in inst.nodes and inst[node].visible: - point = (inst[node].x, inst[node].y) - elif len(all_trails[node_i]): - point = all_trails[node_i][-1] - else: - point = None + # Add last location of node so that we can easily + # calculate trail length (since we adjust opacity). + elif len(all_track_trails[inst.track][node_i]): + point = all_track_trails[inst.track][node_i][-1] + else: + point = None - if point is not None: - all_trails[node_i].append(point) + if point is not None: + all_track_trails[inst.track][node_i].append(point) - return all_trails + return all_track_trails def get_frame_selection(self, video: Video, frame_idx: int): """ @@ -116,17 +124,14 @@ def add_to_scene(self, video: Video, frame_idx: int): video: current video frame_idx: index of the frame to which the trail is attached """ - if not self.show: + if not self.show or self.trail_length == 0: return frame_selection = self.get_frame_selection(video, frame_idx) - tracks_in_frame = self.get_tracks_in_frame( - video, frame_idx, include_trails=True - ) - for track in tracks_in_frame: + all_track_trails = self.get_track_trails(frame_selection) - trails = self.get_track_trails(frame_selection, track) + for track, trails in all_track_trails.items(): color = QtGui.QColor(*self.player.color_manager.get_track_color(track)) pen = QtGui.QPen() diff --git a/sleap/gui/slider.py b/sleap/gui/slider.py index 11e672ac8..33ce33c71 100644 --- a/sleap/gui/slider.py +++ b/sleap/gui/slider.py @@ -2,7 +2,7 @@ Drop-in replacement for QSlider with additional features. """ -from PySide2 import QtCore, QtWidgets +from PySide2 import QtCore, QtWidgets, QtGui from PySide2.QtGui import QPen, QBrush, QColor, QKeyEvent, QPolygonF, QPainterPath from sleap.gui.color import ColorManager @@ -10,7 +10,7 @@ import attr import itertools import numpy as np -from typing import Dict, Iterable, List, Optional, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union @attr.s(auto_attribs=True, cmp=False) @@ -25,6 +25,8 @@ class SliderMark: * "open" (single value) * "predicted" (single value) * "track" (range of values) + * "tick" (single value) + * "tick_column" (single value) val: Beginning of mark range end_val: End of mark range (for "track" marks) row: The row that the mark goes in; used for tracks. @@ -42,7 +44,14 @@ class SliderMark: @property def color(self): """Returns color of mark.""" - colors = dict(simple="black", filled="blue", open="blue", predicted="yellow") + colors = dict( + simple="black", + filled="blue", + open="blue", + predicted="yellow", + tick="lightGray", + tick_column="gray", + ) if self.type in colors: return colors[self.type] @@ -71,6 +80,39 @@ def filled(self): else: return True + @property + def top_pad(self): + if self.type == "tick_column": + return 40 + if self.type == "tick": + return 0 + return 2 + + @property + def bottom_pad(self): + if self.type == "tick_column": + return 200 + if self.type == "tick": + return 0 + return 2 + + @property + def visual_width(self): + if self.type in ("open", "filled", "tick"): + return 2 + if self.type in ("tick_column"): + return 1 + return 0 + + def get_height(self, container_height): + if self.type == "track": + return 1.5 + height = container_height + # if self.padded: + height -= self.top_pad + self.bottom_pad + + return height + class VideoSlider(QtWidgets.QGraphicsView): """Drop-in replacement for QSlider with additional features. @@ -116,7 +158,7 @@ def __init__( marks=None, color_manager: Optional[ColorManager] = None, *args, - **kwargs + **kwargs, ): super(VideoSlider, self).__init__(*args, **kwargs) @@ -125,24 +167,32 @@ def __init__( self.setAlignment(QtCore.Qt.AlignLeft | QtCore.Qt.AlignTop) self.setSizePolicy(QtWidgets.QSizePolicy.Expanding, QtWidgets.QSizePolicy.Fixed) - self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOff) - self.setVerticalScrollBarPolicy( - QtCore.Qt.ScrollBarAlwaysOff - ) # ScrollBarAsNeeded + self.setHorizontalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) + self.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAsNeeded) self._color_manager = color_manager self.zoom_factor = 1 self._track_rows = 0 - self._track_height = 3 - self._header_height = 0 + self._track_height = 5 + self._max_tracks_stacked = 120 + self._track_stack_skip_count = 10 + self._header_label_height = 20 + self._header_graph_height = 30 + self._header_height = self._header_label_height # room for frame labels self._min_height = 19 + self._header_height + self._base_font = QtGui.QFont() + self._base_font.setPixelSize(10) + + self._tick_marks = [] + # Add border rect - outline_rect = QtCore.QRect(0, 0, 200, self._min_height - 3) - self.outlineBox = self.scene.addRect(outline_rect) - self.outlineBox.setPen(QPen(QColor("black"))) + outline_rect = QtCore.QRectF(0, 0, 200, self._min_height - 3) + self.setBoxRect(outline_rect) + # self.outlineBox = self.scene.addRect(outline_rect) + # self.outlineBox.setPen(QPen(QColor("black", alpha=0))) # Add drag handle rect handle_width = 6 @@ -163,6 +213,13 @@ def __init__( self.select_box.setBrush(QColor(80, 80, 255, 128)) self.select_box.hide() + self.zoom_box = self.scene.addRect( + QtCore.QRect(0, 1, 0, outline_rect.height() - 2) + ) + self.zoom_box.setPen(QPen(QColor(80, 80, 80, 64))) + self.zoom_box.setBrush(QColor(80, 80, 80, 64)) + self.zoom_box.hide() + self.scene.setBackgroundBrush(QBrush(QColor(200, 200, 200))) self.clearSelection() @@ -209,6 +266,10 @@ def setTracksFromLabels(self, labels: "Labels", video: "Video"): track_occupancy = labels.get_track_occupany(video) for track in labels.tracks: if track in track_occupancy and not track_occupancy[track].is_empty: + if track_row > 0 and self.isNewColTrack(track_row): + slider_marks.append( + SliderMark("tick_column", val=track_occupancy[track].start) + ) for occupancy_range in track_occupancy[track].list: slider_marks.append( SliderMark( @@ -252,14 +313,14 @@ def setHeaderSeries(self, series: Optional[Dict[int, float]] = None): None. """ self.headerSeries = [] if series is None else series - self._header_height = 30 + self._header_height = self._header_label_height + self._header_graph_height self.drawHeader() self.updateHeight() def clearHeader(self): """Remove header graph from slider.""" self.headerSeries = [] - self._header_height = 0 + self._header_height = self._header_label_height self.updateHeight() def setTracks(self, track_rows): @@ -271,9 +332,7 @@ def setTracks(self, track_rows): self._track_rows = track_rows self.updateHeight() - def updateHeight(self): - """Update the height of the slider.""" - + def getMinMaxHeights(self): tracks = self._track_rows if tracks == 0: min_height = self._min_height @@ -286,12 +345,19 @@ def updateHeight(self): # Add height for tracks min_height += self._track_height * min(tracks, 20) - max_height += self._track_height * tracks + max_height += self._track_height * min(tracks, self._max_tracks_stacked) # Make sure min/max height is at least 19, even if few tracks min_height = max(self._min_height, min_height) max_height = max(self._min_height, max_height) + return min_height, max_height + + def updateHeight(self): + """Update the height of the slider.""" + + min_height, max_height = self.getMinMaxHeights() + self.setMaximumHeight(max_height) self.setMinimumHeight(min_height) @@ -333,7 +399,12 @@ def _toVal(self, x: float, center=False) -> float: def _sliderWidth(self) -> float: """Returns visual width of slider.""" - return self.outlineBox.rect().width() - self.handle.rect().width() + return self.getBoxRect().width() - self.handle.rect().width() + + @property + def slider_visible_value_range(self) -> float: + """Value range that's visible given current size and zoom.""" + return self._toVal(self.width() - 1) def value(self) -> float: """Returns value of slider.""" @@ -353,6 +424,10 @@ def setMaximum(self, max: float) -> float: """Sets maximum value for slider.""" self._val_max = max + @property + def value_range(self) -> float: + return self._val_max - self._val_min + def setEnabled(self, val: float) -> float: """Set whether the slider is enabled.""" self._enabled = val @@ -398,6 +473,11 @@ def endSelection(self, val, update: bool = False): # Emit signal (even if user selected same region as before) self.selectionChanged.emit(*self.getSelection()) + def setSelection(self, start_val, end_val): + """Selects clip from start_val to end_val.""" + self.startSelection(start_val) + self.endSelection(end_val, update=True) + def hasSelection(self) -> bool: """Returns True if a clip is selected, False otherwise.""" a, b = self.getSelection() @@ -413,9 +493,16 @@ def getSelection(self): return start, end def drawSelection(self, a: float, b: float): - """Draws selection box on slider. + self.updateSelectionBoxPositions(self.select_box, a, b) + + def drawZoomBox(self, a: float, b: float): + self.updateSelectionBoxPositions(self.zoom_box, a, b) + + def updateSelectionBoxPositions(self, box_object, a: float, b: float): + """Update box item on slider. Args: + box_object: The box to update a: one endpoint value b: other endpoint value @@ -426,12 +513,24 @@ def drawSelection(self, a: float, b: float): end = max(a, b) start_pos = self._toPos(start, center=True) end_pos = self._toPos(end, center=True) - selection_rect = QtCore.QRect( - start_pos, 1, end_pos - start_pos, self.outlineBox.rect().height() - 2 + box_rect = QtCore.QRect( + start_pos, + self._header_height, + end_pos - start_pos, + self.getBoxRect().height(), ) - self.select_box.setRect(selection_rect) - self.select_box.show() + box_object.setRect(box_rect) + box_object.show() + + def updateSelectionBoxesOnResize(self): + for box_object in (self.select_box, self.zoom_box): + rect = box_object.rect() + rect.setHeight(self._handleHeight()) + box_object.setRect(rect) + + if self.select_box.isVisible(): + self.drawSelection(*self.getSelection()) def moveSelectionAnchor(self, x: float, y: float): """ @@ -445,7 +544,7 @@ def moveSelectionAnchor(self, x: float, y: float): None. """ x = max(x, 0) - x = min(x, self.outlineBox.rect().width()) + x = min(x, self.getBoxRect().width()) anchor_val = self._toVal(x, center=True) if len(self._selection) % 2 == 0: @@ -465,17 +564,65 @@ def releaseSelectionAnchor(self, x, y): None. """ x = max(x, 0) - x = min(x, self.outlineBox.rect().width()) + x = min(x, self.getBoxRect().width()) anchor_val = self._toVal(x) self.endSelection(anchor_val) + def moveZoomDrag(self, x: float, y: float): + if getattr(self, "_zoom_start_val", None) is None: + self._zoom_start_val = self._toVal(x, center=True) + + current_val = self._toVal(x, center=True) + + self.drawZoomBox(current_val, self._zoom_start_val) + + def releaseZoomDrag(self, x, y): + + self.zoom_box.hide() + + val_a = self._zoom_start_val + val_b = self._toVal(x, center=True) + + val_start = min(val_a, val_b) + val_end = max(val_a, val_b) + + # pad the zoom + val_range = val_end - val_start + val_start -= val_range * 0.05 + val_end += val_range * 0.05 + + self.setZoomRange(val_start, val_end) + + self._zoom_start_val = None + + def setZoomRange(self, start_val: float, end_val: float): + + zoom_val_range = end_val - start_val + if zoom_val_range > 0: + self.zoom_factor = self.value_range / zoom_val_range + else: + self.zoom_factor = 1 + + self.resizeEvent() + + center_val = start_val + zoom_val_range / 2 + center_pos = self._toPos(center_val) + + self.centerOn(center_pos, 0) + def clearMarks(self): """Clears all marked values for slider.""" if hasattr(self, "_mark_items"): for item in self._mark_items.values(): self.scene.removeItem(item) + + if hasattr(self, "_mark_labels"): + for item in self._mark_labels.values(): + self.scene.removeItem(item) + self._marks = set() # holds mark position self._mark_items = dict() # holds visual Qt object for plotting mark + self._mark_labels = dict() def setMarks(self, marks: Iterable[Union[SliderMark, int]]): """Sets all marked values for the slider. @@ -487,15 +634,61 @@ def setMarks(self, marks: Iterable[Union[SliderMark, int]]): None. """ self.clearMarks() + + # Add tick marks first so they're behind other marks + self._add_tick_marks() + if marks is not None: for mark in marks: if not isinstance(mark, SliderMark): mark = SliderMark("simple", mark) self.addMark(mark, update=False) + self.updatePos() - def getMarks(self): + def setTickMarks(self): + """Resets which tick marks to show.""" + self._clear_tick_marks() + self._add_tick_marks() + + def _clear_tick_marks(self): + if not hasattr(self, "_tick_marks"): + return + + for mark in self._tick_marks: + self.removeMark(mark) + + def _add_tick_marks(self): + val_range = self.slider_visible_value_range + + val_order = 10 + while val_range // val_order > 24: + val_order *= 10 + + self._tick_marks = [] + + for tick_pos in range(self._val_min + val_order - 1, self._val_max, val_order): + self._tick_marks.append(SliderMark("tick", tick_pos)) + + for tick_mark in self._tick_marks: + self.addMark(tick_mark, update=False) + + def removeMark(self, mark: SliderMark): + """Removes an individual mark.""" + if mark in self._mark_labels: + self.scene.removeItem(self._mark_labels[mark]) + del self._mark_labels[mark] + if mark in self._mark_items: + self.scene.removeItem(self._mark_items[mark]) + del self._mark_items[mark] + if mark in self._marks: + self._marks.remove(mark) + + def getMarks(self, type: str = ""): """Returns list of marks.""" + if type: + return [mark for mark in self._marks if mark.type == type] + return self._marks def addMark(self, new_mark: SliderMark, update: bool = True): @@ -516,19 +709,20 @@ def addMark(self, new_mark: SliderMark, update: bool = True): self._marks.add(new_mark) - v_top_pad = 3 + self._header_height - v_bottom_pad = 3 + v_top_pad = self._header_height + 1 + v_bottom_pad = 1 + v_top_pad += new_mark.top_pad + v_bottom_pad += new_mark.bottom_pad - width = 0 + width = new_mark.visual_width + v_offset = v_top_pad if new_mark.type == "track": - v_offset = v_top_pad + (self._track_height * new_mark.row) - height = 1 - else: - v_offset = v_top_pad - height = self.outlineBox.rect().height() - (v_offset + v_bottom_pad) + v_offset += self.getTrackVerticalPos(*self.getTrackColRow(new_mark.row)) - width = 2 if new_mark.type in ("open", "filled") else 0 + height = new_mark.get_height( + container_height=self.getBoxRect().height() - self._header_height + ) color = new_mark.QColor pen = QPen(color, 0.5) @@ -537,9 +731,50 @@ def addMark(self, new_mark: SliderMark, update: bool = True): line = self.scene.addRect(-width // 2, v_offset, width, height, pen, brush) self._mark_items[new_mark] = line + + if new_mark.type == "tick": + # Show tick mark behind other slider marks + self._mark_items[new_mark].setZValue(0) + + # Add a text label to show in header area + mark_label_text = f"{new_mark.val + 1:g}" # sci notation if large + self._mark_labels[new_mark] = self.scene.addSimpleText( + mark_label_text, self._base_font + ) + else: + # Show in front of tick marks + self._mark_items[new_mark].setZValue(1) + if update: self.updatePos() + def getTrackColRow(self, raw_row: int) -> Tuple[int, int]: + if raw_row < self._max_tracks_stacked: + return 0, raw_row + + else: + rows_after_first_col = raw_row - self._max_tracks_stacked + rows_per_later_cols = ( + self._max_tracks_stacked - self._track_stack_skip_count + ) + + rows_down = rows_after_first_col % rows_per_later_cols + col = (rows_after_first_col // rows_per_later_cols) + 1 + + return col, rows_down + + def getTrackVerticalPos(self, col: int, row: int) -> int: + if col == 0: + return row * self._track_height + else: + return (self._track_height * self._track_stack_skip_count) + ( + self._track_height * row + ) + + def isNewColTrack(self, row: int) -> bool: + _, row_down = self.getTrackColRow(row) + return row_down == 0 + def updatePos(self): """Update the visual x position of handle and slider annotations.""" x = self._toPos(self.value()) @@ -547,19 +782,29 @@ def updatePos(self): for mark in self._mark_items.keys(): - width = 0 if mark.type == "track": width_in_frames = mark.end_val - mark.val width = max(2, self._toPos(width_in_frames)) - elif mark.type in ("open", "filled"): - width = 2 + else: + width = mark.visual_width x = self._toPos(mark.val, center=True) self._mark_items[mark].setPos(x, 0) + if mark in self._mark_labels: + label_x = max( + 0, x - self._mark_labels[mark].boundingRect().width() // 2 + ) + self._mark_labels[mark].setPos(label_x, 4) + rect = self._mark_items[mark].rect() rect.setWidth(width) + rect.setHeight( + mark.get_height( + container_height=self.getBoxRect().height() - self._header_height + ) + ) self._mark_items[mark].setRect(rect) @@ -604,7 +849,7 @@ def drawHeader(self): series_min = np.min(sampled) - 1 series_max = np.max(sampled) - series_scale = (self._header_height - 5) / (series_max - series_min) + series_scale = (self._header_graph_height) / (series_max - series_min) def toYPos(val): return self._header_height - ((val - series_min) * series_scale) @@ -636,7 +881,7 @@ def moveHandle(self, x, y): """ x -= self.handle.rect().width() / 2.0 x = max(x, 0) - x = min(x, self.outlineBox.rect().width() - self.handle.rect().width()) + x = min(x, self.getBoxRect().width() - self.handle.rect().width()) val = self._toVal(x) @@ -657,33 +902,133 @@ def moveHandle(self, x, y): if old != val: self.valueChanged.emit(self._val_main) + def contiguousSelectionMarksAroundVal(self, val): + """Selects contiguously marked frames around value.""" + if not self.isMarkedVal(val): + return + + dec_val = self.getStartContiguousMark(val) + inc_val = self.getEndContiguousMark(val) + + self.setSelection(dec_val, inc_val) + + def getStartContiguousMark(self, val): + last_val = val + dec_val = self.decrementContiguousMarkedVal(last_val) + while dec_val < last_val and dec_val > self._val_min: + last_val = dec_val + dec_val = self.decrementContiguousMarkedVal(last_val) + + return dec_val + + def getEndContiguousMark(self, val): + last_val = val + inc_val = self.incrementContiguousMarkedVal(last_val) + while inc_val > last_val and inc_val < self._val_max: + last_val = inc_val + inc_val = self.incrementContiguousMarkedVal(last_val) + + return inc_val + + def isMarkedVal(self, val): + """Returns whether value has mark.""" + if val in [mark.val for mark in self._marks]: + return True + if any( + mark.val <= val < mark.end_val + for mark in self._marks + if mark.type == "track" + ): + return True + return False + + def decrementContiguousMarkedVal(self, val): + """Decrements value within contiguously marked range if possible.""" + dec_val = min( + ( + mark.val + for mark in self._marks + if mark.type == "track" and mark.val < val <= mark.end_val + ), + default=val, + ) + if dec_val < val: + return dec_val + + if val - 1 in [mark.val for mark in self._marks]: + return val - 1 + + # Return original value if we can't decrement it w/in contiguous range + return val + + def incrementContiguousMarkedVal(self, val): + """Increments value within contiguously marked range if possible.""" + inc_val = max( + ( + mark.end_val - 1 + for mark in self._marks + if mark.type == "track" and mark.val <= val < mark.end_val + ), + default=val, + ) + if inc_val > val: + return inc_val + + if val + 1 in [mark.val for mark in self._marks]: + return val + 1 + + # Return original value if we can't decrement it w/in contiguous range + return val + + def getBoxRect(self): + # return self.outlineBox.rect() + return self._box_rect + + def setBoxRect(self, rect): + # self.outlineBox.setRect(rect) + self._box_rect = rect + + # Update the scene rect so that it matches how much space we + # currently want for drawing everything. + rect.setWidth(rect.width() - 1) + self.setSceneRect(rect) + + def getMarkAreaHeight(self): + _, max_height = self.getMinMaxHeights() + return max_height - 3 - self._header_height + def resizeEvent(self, event=None): """Override method to update visual size when necessary. Args: event """ - height = self.size().height() - outline_rect = self.outlineBox.rect() + outline_rect = self.getBoxRect() handle_rect = self.handle.rect() - select_box_rect = self.select_box.rect() - outline_rect.setHeight(height - 3) + outline_rect.setHeight(self.getMarkAreaHeight() + self._header_height) + if event is not None: visual_width = event.size().width() - 1 - outline_rect.setWidth(visual_width * self.zoom_factor) - self.outlineBox.setRect(outline_rect) + else: + visual_width = self.width() - 1 + + drawn_width = visual_width * self.zoom_factor + + outline_rect.setWidth(drawn_width) + self.setBoxRect(outline_rect) handle_rect.setTop(self._handleTop()) handle_rect.setHeight(self._handleHeight()) self.handle.setRect(handle_rect) - select_box_rect.setHeight(self._handleHeight()) - self.select_box.setRect(select_box_rect) + self.updateSelectionBoxesOnResize() + self.setTickMarks() self.updatePos() self.drawHeader() + super(VideoSlider, self).resizeEvent(event) def _handleTop(self) -> float: @@ -702,14 +1047,7 @@ def _handleHeight(self, outline_rect=None) -> float: Returns: Height of handle in pixels. """ - if outline_rect is None: - outline_rect = self.outlineBox.rect() - - handle_bottom_offset = 1 - handle_height = outline_rect.height() - ( - self._handleTop() + handle_bottom_offset - ) - return handle_height + return self.getMarkAreaHeight() def mousePressEvent(self, event): """Override method to move handle for mouse press/drag. @@ -723,7 +1061,7 @@ def mousePressEvent(self, event): if not self.enabled(): return # Do nothing if click outside slider area - if not self.outlineBox.rect().contains(scenePos): + if not self.getBoxRect().contains(scenePos): return move_function = None @@ -739,6 +1077,13 @@ def mousePressEvent(self, event): move_function = self.moveHandle release_function = None + elif event.modifiers() == QtCore.Qt.AltModifier: + move_function = self.moveZoomDrag + release_function = self.releaseZoomDrag + + else: + event.accept() # mouse events shouldn't be passed to video widgets + # Connect to signals if move_function is not None: self.mouseMoved.connect(move_function) @@ -746,7 +1091,8 @@ def mousePressEvent(self, event): def done(x, y): if release_function is not None: release_function(x, y) - self.mouseMoved.disconnect(move_function) + if move_function is not None: + self.mouseMoved.disconnect(move_function) self.mouseReleased.disconnect(done) self.mouseReleased.connect(done) @@ -765,6 +1111,24 @@ def mouseReleaseEvent(self, event): scenePos = self.mapToScene(event.pos()) self.mouseReleased.emit(scenePos.x(), scenePos.y()) + def mouseDoubleClickEvent(self, event): + """Override method to move handle for mouse double-click. + + Args: + event + """ + scenePos = self.mapToScene(event.pos()) + + # Do nothing if not enabled + if not self.enabled(): + return + # Do nothing if click outside slider area + if not self.getBoxRect().contains(scenePos): + return + + if event.modifiers() == QtCore.Qt.ShiftModifier: + self.contiguousSelectionMarksAroundVal(self._toVal(scenePos.x())) + def keyPressEvent(self, event): """Catch event and emit signal so something else can handle event.""" self.keyPress.emit(event) @@ -777,7 +1141,7 @@ def keyReleaseEvent(self, event): def boundingRect(self) -> QtCore.QRectF: """Method required by Qt.""" - return self.outlineBox.rect() + return self.getBoxRect() def paint(self, *args, **kwargs): """Method required by Qt.""" diff --git a/sleap/gui/training_editor.py b/sleap/gui/training_editor.py index 72684945f..fed8fcab2 100644 --- a/sleap/gui/training_editor.py +++ b/sleap/gui/training_editor.py @@ -130,7 +130,7 @@ def _save_as(self): confmaps=ModelOutputType.CONFIDENCE_MAP, pafs=ModelOutputType.PART_AFFINITY_FIELD, centroids=ModelOutputType.CENTROIDS, - topdown=ModelOutputType.TOPDOWN_CONFIDENCE_MAP, + topdown_confidence_maps=ModelOutputType.TOPDOWN_CONFIDENCE_MAP, )[model_data["output_type"]] backbone_kwargs = { diff --git a/sleap/gui/video.py b/sleap/gui/video.py index 2ba437ae7..e5196d7bb 100644 --- a/sleap/gui/video.py +++ b/sleap/gui/video.py @@ -12,6 +12,9 @@ """ +FORCE_REQUEST_AFTER_TIME_IN_SECONDS = 1 + + from PySide2 import QtWidgets, QtCore from PySide2.QtWidgets import ( @@ -26,7 +29,9 @@ from PySide2.QtGui import QKeyEvent from PySide2.QtCore import Qt, QRectF, QPointF, QMarginsF, QLineF +import atexit import math +import time from typing import Callable, Dict, List, Optional, Tuple, Union @@ -34,7 +39,6 @@ from PySide2.QtWidgets import ( QGraphicsEllipseItem, - QGraphicsLineItem, QGraphicsTextItem, QGraphicsRectItem, QGraphicsPolygonItem, @@ -50,6 +54,72 @@ import qimage2ndarray +class LoadImageWorker(QtCore.QObject): + """ + Object to load video frames in background thread. + """ + + result = QtCore.Signal(QImage) + process = QtCore.Signal() + + load_queue = [] + video = None + _last_process_time = 0 + + def __init__(self, *args, **kwargs): + super(LoadImageWorker, self).__init__(*args, **kwargs) + + # Connect signal to processing function so that we can add processing + # event to event queue from the request handler. + self.process.connect(self.doProcessing) + + # Start timer which will trigger processing events when we're free + self.timer = QtCore.QTimer() + self.timer.timeout.connect(self.doProcessing) + self.timer.start(0) + + def doProcessing(self): + self._last_process_time = time.time() + + if not self.load_queue: + return + + frame_idx = self.load_queue[-1] + self.load_queue = [] + + # print(f"\t{frame_idx} starting to load") # DEBUG + + try: + # Get image data + frame = self.video.get_frame(frame_idx) + except: + frame = None + + if frame is not None: + # Convert ndarray to QImage + qimage = qimage2ndarray.array2qimage(frame) + + # print(f"\t{frame_idx} result") # DEBUG + + # Emit result + self.result.emit(qimage) + + def request(self, frame_idx): + # Add request to the queue so that we can just process the most recent. + self.load_queue.append(frame_idx) + + # If we haven't processed a request for a certain amount of time, + # then trigger a processing event now. This helps when the user has been + # continuously changing frames for a while (i.e., dragging on seekbar + # or holding down arrow key). + + since_last = time.time() - self._last_process_time + + if since_last > FORCE_REQUEST_AFTER_TIME_IN_SECONDS: + self._last_process_time = time.time() + self.process.emit() + + class QtVideoPlayer(QWidget): """ Main QWidget for displaying video with skeleton instances. @@ -65,6 +135,7 @@ class QtVideoPlayer(QWidget): """ changedPlot = QtCore.Signal(QWidget, int, Instance) + requestImage = QtCore.Signal(int) def __init__( self, @@ -110,6 +181,20 @@ def __init__( lambda e: self.state.set("frame_idx", self.seekbar.value()) ) + # Make worker thread to load images in the background + self._loader_thread = QtCore.QThread() + self._video_image_loader = LoadImageWorker() + self._video_image_loader.moveToThread(self._loader_thread) + self._loader_thread.start() + + # Connect signal so that image will be shown after it's loaded + self._video_image_loader.result.connect( + lambda qimage: self.view.setImage(qimage) + ) + + # Connect request signals from self to worker + self.requestImage.connect(self._video_image_loader.request) + def update_selection_state(a, b): self.state.set("frame_range", (a, b)) self.state.set("has_frame_range", (a < b)) @@ -127,9 +212,34 @@ def update_selection_state(a, b): self.view.show() + # Call cleanup method when application exits to end worker thread + self.destroyed.connect(self.cleanup) + atexit.register(self.cleanup) + if video is not None: self.load_video(video) + def cleanup(self): + self._loader_thread.quit() + self._loader_thread.wait() + + def _load_and_show_requested_image(self, frame_idx): + # Get image data + try: + frame = self.video.get_frame(frame_idx) + except: + frame = None + + if frame is not None: + # Convert ndarray to QImage + qimage = qimage2ndarray.array2qimage(frame) + + # Display image + self.view.setImage(qimage) + + def setSeekbarSelection(self, a: int, b: int): + self.seekbar.setSelection(a, b) + def show_contextual_menu(self, where: QtCore.QPoint): if not self.is_menu_enabled: return @@ -238,24 +348,15 @@ def plot(self, *args): idx = self.state["frame_idx"] or 0 - # Get image data - try: - frame = self.video.get_frame(idx) - except: - frame = None - - if frame is not None: - # Clear existing objects - self.view.clear() - - # Convert ndarray to QImage - image = qimage2ndarray.array2qimage(frame) + # Clear exiting objects before drawing instances + self.view.clear() - # Display image - self.view.setImage(image) + # Emit signal for the instances to be drawn for this frame + self.changedPlot.emit(self, idx, self.state["instance"]) - # Emit signal - self.changedPlot.emit(self, idx, self.state["instance"]) + # Emit signal for the image to loaded and shown for this frame + self._video_image_loader.video = self.video + self.requestImage.emit(idx) def showLabels(self, show): """ Show/hide node labels for all instances in viewer. @@ -447,6 +548,14 @@ def keyPressEvent(self, event: QKeyEvent): elif event.key() == Qt.Key.Key_Escape: self.view.click_mode = "" self.state["instance"] = None + elif event.key() == Qt.Key.Key_K: + self.state["frame_idx"] = self.seekbar.getEndContiguousMark( + self.state["frame_idx"] + ) + elif event.key() == Qt.Key.Key_J: + self.state["frame_idx"] = self.seekbar.getStartContiguousMark( + self.state["frame_idx"] + ) elif event.key() == Qt.Key.Key_QuoteLeft: self.state.increment_in_list("instance", self.selectable_instances) elif event.key() < 128 and chr(event.key()).isnumeric(): @@ -543,9 +652,17 @@ def hasImage(self) -> bool: def clear(self): """ Clears the displayed frame from the scene. """ - self._pixmapHandle = None + + if self._pixmapHandle: + # get the pixmap currently shown + pixmap = self._pixmapHandle.pixmap() + self.scene.clear() + if self._pixmapHandle: + # add the pixmap back + self._pixmapHandle = self.scene.addPixmap(pixmap) + def setImage(self, image: Union[QImage, QPixmap]): """ Set the scene's current image pixmap to the input QImage or QPixmap. @@ -571,7 +688,12 @@ def setImage(self, image: Union[QImage, QPixmap]): self._pixmapHandle.setPixmap(pixmap) else: self._pixmapHandle = self.scene.addPixmap(pixmap) - self.setSceneRect(QRectF(pixmap.rect())) # Set scene size to image size. + + # Ensure that image is behind everything else + self._pixmapHandle.setZValue(-1) + + # Set scene size to image size. + self.setSceneRect(QRectF(pixmap.rect())) self.updateViewer() def updateViewer(self): diff --git a/sleap/instance.py b/sleap/instance.py index f7188af86..0db6fa068 100644 --- a/sleap/instance.py +++ b/sleap/instance.py @@ -1123,6 +1123,11 @@ def has_user_instances(self) -> bool: """Whether the frame contains any user instances.""" return len(self.user_instances) > 0 + @property + def has_predicted_instances(self) -> bool: + """Whether the frame contains any predicted instances.""" + return len(self.predicted_instances) > 0 + @property def unused_predictions(self) -> List[Instance]: """ diff --git a/sleap/io/dataset.py b/sleap/io/dataset.py index 2f13e14e9..be94d795e 100644 --- a/sleap/io/dataset.py +++ b/sleap/io/dataset.py @@ -491,6 +491,16 @@ def get_video_user_labeled_frames(self, video: Video) -> List[LabeledFrame]: if lf.has_user_instances and lf.video == video ] + def get_video_predicted_frames(self, video: Video) -> List[LabeledFrame]: + """ + Returns labeled frames for given video with user instances. + """ + return [ + lf + for lf in self.labeled_frames + if lf.has_predicted_instances and lf.video == video + ] + # Methods for instances def instance_count(self, video: Video, frame_idx: int) -> int: @@ -1407,6 +1417,8 @@ def load_json( :class:`Video` objects. Usually you'll want to pass a callback created by :meth:`make_video_callback` or :meth:`make_gui_video_callback`. + Alternately, if you pass a list of strings we'll construct a + non-gui callback with those strings as the search paths. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly @@ -1486,6 +1498,13 @@ def load_json( tmp_dir, vid["backend"]["filename"] ) + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = cls.make_video_callback(search_paths) + # Use the callback if given to handle missing videos if callable(video_callback): abort = video_callback(dicts["videos"]) @@ -1838,6 +1857,8 @@ def load_hdf5( :class:`Video` objects. Usually you'll want to pass a callback created by :meth:`make_video_callback` or :meth:`make_gui_video_callback`. + Alternately, if you pass a list of strings we'll construct a + non-gui callback with those strings as the search paths. match_to: If given, we'll replace particular objects in the data dictionary with *matching* objects in the match_to :class:`Labels` object. This ensures that the newly @@ -1867,6 +1888,13 @@ def load_hdf5( if video_item["backend"]["filename"] == ".": video_item["backend"]["filename"] = filename + if hasattr(video_callback, "__iter__"): + # If the callback is an iterable, then we'll expect it to be a + # list of strings and build a non-gui callback with those as + # the search paths. + search_paths = [path for path in video_callback] + video_callback = cls.make_video_callback(search_paths) + # Use the callback if given to handle missing videos if callable(video_callback): video_callback(dicts["videos"]) diff --git a/sleap/io/video.py b/sleap/io/video.py index f0a872b18..228c58a8c 100644 --- a/sleap/io/video.py +++ b/sleap/io/video.py @@ -188,6 +188,11 @@ def last_frame_idx(self) -> int: return last_key return self.frames - 1 + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx) -> np.ndarray: """ Get a frame from the underlying HDF5 video data. @@ -343,6 +348,10 @@ def dtype(self): """See :class:`Video`.""" return self.__test_frame.dtype + def reset(self): + """Reloads the video.""" + self._reader_ = None + def get_frame(self, idx: int, grayscale: bool = None) -> np.ndarray: """See :class:`Video`.""" if self.__reader.get(cv2.CAP_PROP_POS_FRAMES) != idx: @@ -436,6 +445,11 @@ def dtype(self): """See :class:`Video`.""" return self.__data.dtype + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx): """See :class:`Video`.""" return self.__data[idx] @@ -557,6 +571,11 @@ def last_frame_idx(self) -> int: return self.__store.frame_max return self.frames - 1 + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, frame_number: int) -> np.ndarray: """ Get a frame from the underlying ImgStore video data. @@ -733,6 +752,11 @@ def dtype(self): """See :class:`Video`.""" return self.__data.dtype + def reset(self): + """Reloads the video.""" + # TODO + pass + def get_frame(self, idx): """See :class:`Video`.""" if idx not in self.__data: diff --git a/sleap/nn/data.py b/sleap/nn/data.py index 320c73453..e0481fea3 100644 --- a/sleap/nn/data.py +++ b/sleap/nn/data.py @@ -1,6 +1,13 @@ """This module contains utilities for data I/O and generating training data.""" import numpy as np + +# Monkey patch for: https://github.com/aleju/imgaug/issues/537 +# TODO: Fix when new version of imgaug is available on PyPI. +import numpy +if hasattr(numpy.random, "_bit_generator"): + numpy.random.bit_generator = numpy.random._bit_generator + import h5py import tensorflow as tf import imgaug as ia @@ -355,7 +362,7 @@ def augment_dataset( # Setup augmenter. aug_stack = [] if rotate: - aug_stack.append(iaa.Affine(rotate=(-rotation_min_angle, rotation_max_angle))) + aug_stack.append(iaa.Affine(rotate=(rotation_min_angle, rotation_max_angle))) if scale: aug_stack.append(iaa.Affine(scale=(scale_min, scale_max))) diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index db7987177..32d7e3d48 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -73,6 +73,10 @@ def predict( if video_kwargs is None: video_kwargs = dict() + # Detect whether to use grayscale video by looking at trained models. + if self.has_grayscale_models: + video_kwargs["grayscale"] = True + video_ds = utils.VideoLoader( filename=video_filename, frame_inds=frames, **video_kwargs ) @@ -100,6 +104,7 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): """Runs the inference components of pipeline for a chunk.""" if "centroid" in self.policies: + # Detect centroids and pull out region proposals. centroid_predictor = self.policies["centroid"] region_proposal_extractor = self.policies["region"] @@ -114,21 +119,27 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): region_proposal.RegionProposalSet.from_uncropped(images=img_chunk) ] - # for region_ind in range(len(region_proposal_sets)): - # region_proposal_sets[region_ind].sample_inds += chunk_ind * chunk_size - if "paf" not in self.policies: + # If we don't have PAFs, we must be doing topdown or single-instance. if "topdown" in self.policies: topdown_peak_finder = self.policies["topdown"] else: topdown_peak_finder = self.policies["confmap"] + if len(region_proposal_sets) == 0: + # No region proposals were found, so just return empty result. + return defaultdict(list) + + # Only one scale region proposals in topdown inference. rps = region_proposal_sets[0] + # Find peaks in each sample of the region proposal set. sample_peak_pts, sample_peak_vals = topdown_peak_finder.predict_rps(rps) - sample_peak_pts = sample_peak_pts.to_tensor().numpy() - sample_peak_vals = sample_peak_vals.to_tensor().numpy() + sample_peak_pts = sample_peak_pts.to_tensor(default_value=np.nan).numpy() + sample_peak_vals = sample_peak_vals.to_tensor(default_value=np.nan).numpy() + + # Gather instances across all region proposals for this chunk. predicted_instances_chunk = topdown.make_sample_grouped_predicted_instances( sample_peak_pts, sample_peak_vals, @@ -140,16 +151,19 @@ def predict_chunk(self, img_chunk, chunk_ind, chunk_size): confmap_peak_finder = self.policies["confmap"] paf_grouper = self.policies["paf"] + # Find multi-instance peaks in each region proposal set. region_peak_sets = [] for rps in region_proposal_sets: region_peaks = confmap_peak_finder.predict_rps(rps) region_peak_sets.append(region_peaks) + # Group peaks into instances using PAFs. region_instance_sets = [] for rps, region_peaks in zip(region_proposal_sets, region_peak_sets): region_instances = paf_grouper.predict_rps(rps, region_peaks) region_instance_sets.append(region_instances) + # Gather instances across all region proposals for this chunk. predicted_instances_chunk = defaultdict(list) for region_instance_set in region_instance_sets: for sample, region_instances in region_instance_set.items(): @@ -213,6 +227,19 @@ def make_labeled_frames( return frames + @property + def has_grayscale_models(self): + for policy_name in ("centroid", "topdown", "confmap"): + if policy_name in self.policies: + if ( + self.policies[policy_name].inference_model.input_tensor.shape[-1] + == 1 + ): + return True + else: + return False + return False + @classmethod def from_cli_args(cls): parser = cls.make_cli_parser() diff --git a/sleap/nn/job.py b/sleap/nn/job.py index a1407615c..548035699 100644 --- a/sleap/nn/job.py +++ b/sleap/nn/job.py @@ -222,15 +222,17 @@ def model_path(self): return model_path # Raise error if all fail. - raise ValueError(f"Could not find a saved model in run path: {self.run_path}") + raise FileNotFoundError(f"Could not find a saved model in run path: {self.run_path}") @property def is_trained(self): if self.run_path is None: return False - if os.path.exists(self.model_path): - return True - return False + + try: + return os.path.exists(self.model_path) + except FileNotFoundError: + return False @staticmethod def _to_dicts(training_job: "TrainingJob"): diff --git a/sleap/nn/peak_finding.py b/sleap/nn/peak_finding.py index 0a2a019a0..913de5d1a 100644 --- a/sleap/nn/peak_finding.py +++ b/sleap/nn/peak_finding.py @@ -404,12 +404,11 @@ def postproc(self, confmaps): def predict_rps(self, rps: "RegionProposalSet") -> RegionPeakSet: - imgs = self.preproc(rps.patches) - - confmaps = utils.batched_call(self.inference, imgs, batch_size=self.batch_size) - peak_subs_and_vals, batch_inds = utils.batched_call( - self.postproc, confmaps, batch_size=self.batch_size, return_batch_inds=True + lambda imgs: self.postproc(self.inference(self.preproc(imgs))), + rps.patches, + batch_size=self.batch_size, + return_batch_inds=True, ) # Split. diff --git a/sleap/nn/region_proposal.py b/sleap/nn/region_proposal.py index b57cce433..3b6b0b847 100644 --- a/sleap/nn/region_proposal.py +++ b/sleap/nn/region_proposal.py @@ -330,8 +330,11 @@ def postproc(self, centroid_confmaps): return centroids, centroid_vals def predict(self, imgs): - imgs = self.preproc(imgs) - confmaps = utils.batched_call(self.inference, imgs, batch_size=self.batch_size) + confmaps = utils.batched_call( + lambda imgs: self.inference(self.preproc(imgs)), + imgs, + batch_size=self.batch_size, + ) return self.postproc(confmaps) @@ -464,6 +467,10 @@ def extract_region_proposal_sets( sample_inds = size_grouped_sample_inds[box_size] bboxes = size_grouped_bboxes[box_size] + if len(bboxes) == 0: + # Skip region proposal set if we found no bboxes. + continue + # Extract image patches for all regions in the set. patches = extract_patches( imgs, tf.cast(bboxes, tf.float32), tf.cast(sample_inds, tf.int32) diff --git a/sleap/nn/utils.py b/sleap/nn/utils.py index 6733a5587..39831f991 100644 --- a/sleap/nn/utils.py +++ b/sleap/nn/utils.py @@ -308,6 +308,7 @@ class VideoLoader: filename: str dataset: str = None input_format: str = None + grayscale: bool = False chunk_size: int = 32 prefetch_chunks: int = 1 frame_inds: Optional[List[int]] = None @@ -351,18 +352,22 @@ def tf_dtype(self): def __attrs_post_init__(self): - self._video = Video.from_filename( - self.filename, dataset=self.dataset, input_format=self.input_format - ) + self._video = self._load_video(self.filename) self._shape = self.video.shape self._np_dtype = self.video.dtype self._tf_dtype = tf.dtypes.as_dtype(self.np_dtype) self._ds = self.make_ds() - def load_frames(self, frame_inds): - local_vid = Video.from_filename( - self.video.filename, dataset=self.dataset, input_format=self.input_format + def _load_video(self, filename) -> "Video": + return Video.from_filename( + filename, + dataset=self.dataset, + input_format=self.input_format, + grayscale=self.grayscale, ) + + def load_frames(self, frame_inds): + local_vid = self._load_video(self.video.filename) imgs = local_vid[np.array(frame_inds).astype("int64")] return imgs diff --git a/tests/gui/test_slider.py b/tests/gui/test_slider.py index 0d05b057b..788d21265 100644 --- a/tests/gui/test_slider.py +++ b/tests/gui/test_slider.py @@ -23,8 +23,8 @@ def test_slider(qtbot, centered_pair_predictions): assert slider.maximumHeight() != initial_height slider.setTracksFromLabels(labels, labels.videos[0]) - assert len(slider.getMarks()) == 40 + assert len(slider.getMarks("track")) == 40 slider.moveSelectionAnchor(5, 5) slider.releaseSelectionAnchor(100, 15) - assert slider.getSelection() == (31, 619) + assert slider.getSelection() == (slider._toVal(5), slider._toVal(100)) diff --git a/tests/gui/test_tracks.py b/tests/gui/test_tracks.py index 7e03117e1..7889eb1b3 100644 --- a/tests/gui/test_tracks.py +++ b/tests/gui/test_tracks.py @@ -15,10 +15,13 @@ def test_track_trails(centered_pair_predictions): assert tracks[0].name == "1" assert tracks[1].name == "2" - tracks_with_trails = trail_manager.get_tracks_in_frame(labels.videos[0], 27, include_trails=True) + tracks_with_trails = trail_manager.get_tracks_in_frame( + labels.videos[0], 27, include_trails=True + ) assert len(tracks_with_trails) == 13 - trails = trail_manager.get_track_trails(frames, tracks[0]) + all_trails = trail_manager.get_track_trails(frames) + trails = all_trails[tracks[0]] assert len(trails) == 24 diff --git a/tests/io/test_dataset.py b/tests/io/test_dataset.py index dcda548a4..4e733937b 100644 --- a/tests/io/test_dataset.py +++ b/tests/io/test_dataset.py @@ -781,8 +781,7 @@ def test_many_tracks_hdf5(tmpdir): labels = Labels() filename = os.path.join(tmpdir, "test.h5") - labels.tracks = [Track(spawned_on=i, name=f"track {i}") - for i in range(4000)] + labels.tracks = [Track(spawned_on=i, name=f"track {i}") for i in range(4000)] Labels.save_hdf5(filename=filename, labels=labels) @@ -791,8 +790,7 @@ def test_many_videos_hdf5(tmpdir): labels = Labels() filename = os.path.join(tmpdir, "test.h5") - labels.videos = [Video.from_filename(f"video {i}.mp4") - for i in range(3000)] + labels.videos = [Video.from_filename(f"video {i}.mp4") for i in range(3000)] Labels.save_hdf5(filename=filename, labels=labels) @@ -806,3 +804,20 @@ def test_many_suggestions_hdf5(tmpdir): labels.suggestions = [SuggestionFrame(video, i) for i in range(3000)] Labels.save_hdf5(filename=filename, labels=labels) + + +def test_path_fix(tmpdir): + labels = Labels() + filename = os.path.join(tmpdir, "test.h5") + + # Add a video without a full path + labels.add_video(Video.from_filename("small_robot.mp4")) + + Labels.save_hdf5(filename=filename, labels=labels) + + # Pass the path to the directory with the video + labels = Labels.load_file(filename, video_callback=["tests/data/videos/"]) + + # Make sure we got the actual video path by searching that directory + assert len(labels.videos) == 1 + assert labels.videos[0].filename == "tests/data/videos/small_robot.mp4" diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py new file mode 100644 index 000000000..8701a77a5 --- /dev/null +++ b/tests/nn/test_utils.py @@ -0,0 +1,9 @@ +from sleap.nn.utils import VideoLoader + + +def test_grayscale_video(): + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4",) + assert vid.shape[-1] == 3 + + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True) + assert vid.shape[-1] == 1