diff --git a/sleap/gui/app.py b/sleap/gui/app.py index a85744b4d..f471470dd 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -1181,9 +1181,9 @@ def _frames_for_prediction(self): values are {video: list of frame indices} dictionaries. """ - def remove_user_labeled( - video, frame_idxs, user_labeled_frames=self.labels.user_labeled_frames - ): + user_labeled_frames = self.labels.user_labeled_frames + + def remove_user_labeled(video, frame_idxs): if len(frame_idxs) == 0: return frame_idxs video_user_labeled_frame_idxs = { @@ -1216,6 +1216,19 @@ def remove_user_labeled( for video in self.labels.videos } + if len(self.labels.videos) > 1: + selection["random_video"] = { + current_video: remove_user_labeled( + current_video, random.sample(range(current_video.frames), min(20, current_video.frames)) + ) + } + + if user_labeled_frames: + selection["user"] = { + video: [lf.frame_idx for lf in user_labeled_frames if lf.video == video] + for video in self.labels.videos + } + return selection def showLearningDialog(self, mode: str): diff --git a/sleap/gui/learning/training.py b/sleap/gui/learning/training.py index d3cba5371..c77c8f342 100644 --- a/sleap/gui/learning/training.py +++ b/sleap/gui/learning/training.py @@ -143,16 +143,24 @@ def count_total_frames(videos_frames): total_random = 0 total_suggestions = 0 + total_user = 0 + random_video = 0 clip_length = 0 video_length = 0 # Determine which options are available given _frame_selection if "random" in self._frame_selection: total_random = count_total_frames(self._frame_selection["random"]) + if "random_video" in self._frame_selection: + random_video = count_total_frames(self._frame_selection["random_video"]) if "suggestions" in self._frame_selection: total_suggestions = count_total_frames( self._frame_selection["suggestions"] ) + if "user" in self._frame_selection: + total_user = count_total_frames( + self._frame_selection["user"] + ) if "clip" in self._frame_selection: clip_length = count_total_frames(self._frame_selection["clip"]) if "video" in self._frame_selection: @@ -168,11 +176,19 @@ def count_total_frames(videos_frames): prediction_options.append(option) default_option = option + if random_video > 0: + option = f"random frames in current video ({random_video} frames)" + prediction_options.append(option) + if total_suggestions > 0: option = f"suggested frames ({total_suggestions} total frames)" prediction_options.append(option) default_option = option + if total_user > 0: + option = f"user labeled frames ({total_user} total frames)" + prediction_options.append(option) + if clip_length > 0: option = f"selected clip ({clip_length} frames)" prediction_options.append(option)