diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index a99823f9e..32d7e3d48 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -73,6 +73,10 @@ def predict( if video_kwargs is None: video_kwargs = dict() + # Detect whether to use grayscale video by looking at trained models. + if self.has_grayscale_models: + video_kwargs["grayscale"] = True + video_ds = utils.VideoLoader( filename=video_filename, frame_inds=frames, **video_kwargs ) @@ -223,6 +227,19 @@ def make_labeled_frames( return frames + @property + def has_grayscale_models(self): + for policy_name in ("centroid", "topdown", "confmap"): + if policy_name in self.policies: + if ( + self.policies[policy_name].inference_model.input_tensor.shape[-1] + == 1 + ): + return True + else: + return False + return False + @classmethod def from_cli_args(cls): parser = cls.make_cli_parser() diff --git a/sleap/nn/utils.py b/sleap/nn/utils.py index 6733a5587..39831f991 100644 --- a/sleap/nn/utils.py +++ b/sleap/nn/utils.py @@ -308,6 +308,7 @@ class VideoLoader: filename: str dataset: str = None input_format: str = None + grayscale: bool = False chunk_size: int = 32 prefetch_chunks: int = 1 frame_inds: Optional[List[int]] = None @@ -351,18 +352,22 @@ def tf_dtype(self): def __attrs_post_init__(self): - self._video = Video.from_filename( - self.filename, dataset=self.dataset, input_format=self.input_format - ) + self._video = self._load_video(self.filename) self._shape = self.video.shape self._np_dtype = self.video.dtype self._tf_dtype = tf.dtypes.as_dtype(self.np_dtype) self._ds = self.make_ds() - def load_frames(self, frame_inds): - local_vid = Video.from_filename( - self.video.filename, dataset=self.dataset, input_format=self.input_format + def _load_video(self, filename) -> "Video": + return Video.from_filename( + filename, + dataset=self.dataset, + input_format=self.input_format, + grayscale=self.grayscale, ) + + def load_frames(self, frame_inds): + local_vid = self._load_video(self.video.filename) imgs = local_vid[np.array(frame_inds).astype("int64")] return imgs diff --git a/tests/nn/test_utils.py b/tests/nn/test_utils.py new file mode 100644 index 000000000..8701a77a5 --- /dev/null +++ b/tests/nn/test_utils.py @@ -0,0 +1,9 @@ +from sleap.nn.utils import VideoLoader + + +def test_grayscale_video(): + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4",) + assert vid.shape[-1] == 3 + + vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True) + assert vid.shape[-1] == 1