Skip to content

Commit

Permalink
Load video as grayscale if models have one channel
Browse files Browse the repository at this point in the history
  • Loading branch information
ntabris committed Jan 21, 2020
1 parent 87dec33 commit 7aab3af
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 6 deletions.
17 changes: 17 additions & 0 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ def predict(
if video_kwargs is None:
video_kwargs = dict()

# Detect whether to use grayscale video by looking at trained models.
if self.has_grayscale_models:
video_kwargs["grayscale"] = True

video_ds = utils.VideoLoader(
filename=video_filename, frame_inds=frames, **video_kwargs
)
Expand Down Expand Up @@ -223,6 +227,19 @@ def make_labeled_frames(

return frames

@property
def has_grayscale_models(self):
for policy_name in ("centroid", "topdown", "confmap"):
if policy_name in self.policies:
if (
self.policies[policy_name].inference_model.input_tensor.shape[-1]
== 1
):
return True
else:
return False
return False

@classmethod
def from_cli_args(cls):
parser = cls.make_cli_parser()
Expand Down
17 changes: 11 additions & 6 deletions sleap/nn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ class VideoLoader:
filename: str
dataset: str = None
input_format: str = None
grayscale: bool = False
chunk_size: int = 32
prefetch_chunks: int = 1
frame_inds: Optional[List[int]] = None
Expand Down Expand Up @@ -351,18 +352,22 @@ def tf_dtype(self):

def __attrs_post_init__(self):

self._video = Video.from_filename(
self.filename, dataset=self.dataset, input_format=self.input_format
)
self._video = self._load_video(self.filename)
self._shape = self.video.shape
self._np_dtype = self.video.dtype
self._tf_dtype = tf.dtypes.as_dtype(self.np_dtype)
self._ds = self.make_ds()

def load_frames(self, frame_inds):
local_vid = Video.from_filename(
self.video.filename, dataset=self.dataset, input_format=self.input_format
def _load_video(self, filename) -> "Video":
return Video.from_filename(
filename,
dataset=self.dataset,
input_format=self.input_format,
grayscale=self.grayscale,
)

def load_frames(self, frame_inds):
local_vid = self._load_video(self.video.filename)
imgs = local_vid[np.array(frame_inds).astype("int64")]
return imgs

Expand Down
9 changes: 9 additions & 0 deletions tests/nn/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sleap.nn.utils import VideoLoader


def test_grayscale_video():
vid = VideoLoader(filename="tests/data/videos/small_robot.mp4",)
assert vid.shape[-1] == 3

vid = VideoLoader(filename="tests/data/videos/small_robot.mp4", grayscale=True)
assert vid.shape[-1] == 1

0 comments on commit 7aab3af

Please sign in to comment.