From 01da28a915d7ed30dcc0fff8e23d0744bce529e0 Mon Sep 17 00:00:00 2001 From: Talmo Date: Sun, 5 Apr 2020 22:21:49 -0400 Subject: [PATCH] Fix topdown not accounting for different scale in centroids vs confmaps - Some optimization could probably be done to increase performance when both models have the same preprocessing settings, but currently is done independently. --- sleap/nn/data/inference.py | 1 + sleap/nn/data/instance_cropping.py | 9 +++++++-- sleap/nn/data/normalization.py | 5 +++-- sleap/nn/data/resizing.py | 13 ++++++++++--- sleap/nn/inference.py | 26 +++++++++++++++++--------- 5 files changed, 38 insertions(+), 16 deletions(-) diff --git a/sleap/nn/data/inference.py b/sleap/nn/data/inference.py index 2e309bacf..2b8f36a8d 100644 --- a/sleap/nn/data/inference.py +++ b/sleap/nn/data/inference.py @@ -556,6 +556,7 @@ def norm_instance(example): pts = example[self.peaks_key] pts += bboxes_x1y1 + pts /= example["scale"] example[self.new_centroid_key] = centroids example[self.new_centroid_confidence_key] = example[ diff --git a/sleap/nn/data/instance_cropping.py b/sleap/nn/data/instance_cropping.py index c93f15303..de68c45eb 100644 --- a/sleap/nn/data/instance_cropping.py +++ b/sleap/nn/data/instance_cropping.py @@ -424,6 +424,7 @@ class PredictedInstanceCropper: centroids_key: Text = "predicted_centroids" centroid_confidences_key: Text = "predicted_centroid_confidences" full_image_key: Text = "full_image" + full_image_scale_key: Text = "full_image_scale" keep_full_image: bool = False keep_instances_gt: bool = False @@ -432,9 +433,9 @@ def input_keys(self) -> List[Text]: """Return the keys that incoming elements are expected to have.""" input_keys = [ self.full_image_key, + self.full_image_scale_key, self.centroids_key, self.centroid_confidences_key, - "scale", "video_ind", "frame_ind", ] @@ -466,6 +467,7 @@ def output_keys(self) -> List[Text]: def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset: """Create a dataset that contains instance cropped data.""" keys_to_expand = ["scale", "video_ind", "frame_ind"] + # keys_to_expand = ["video_ind", "frame_ind"] if self.keep_full_image: keys_to_expand.append(self.full_image_key) if self.keep_instances_gt: @@ -475,10 +477,13 @@ def crop_instances(frame_data): """Local processing function for dataset mapping.""" # Make bounding boxes from centroids. full_centroids = frame_data[self.centroids_key] / frame_data["scale"] + full_centroids = full_centroids * frame_data[self.full_image_scale_key] bboxes = make_centered_bboxes( full_centroids, box_height=self.crop_height, box_width=self.crop_width ) + frame_data["scale"] = frame_data[self.full_image_scale_key] + # Crop images from bounding boxes. instance_images = crop_bboxes(frame_data[self.full_image_key], bboxes) n_instances = tf.shape(bboxes)[0] @@ -488,7 +493,7 @@ def crop_instances(frame_data): "instance_image": instance_images, "bbox": bboxes, "center_instance_ind": tf.range(n_instances, dtype=tf.int32), - "centroid": frame_data[self.centroids_key], + "centroid": full_centroids, "centroid_confidence": frame_data[self.centroid_confidences_key], "full_image_height": tf.repeat( tf.shape(frame_data[self.full_image_key])[0], n_instances diff --git a/sleap/nn/data/normalization.py b/sleap/nn/data/normalization.py index 4324fe97a..a104d3764 100644 --- a/sleap/nn/data/normalization.py +++ b/sleap/nn/data/normalization.py @@ -305,17 +305,18 @@ class Normalizer: ) @classmethod - def from_config(cls, config: PreprocessingConfig) -> "Normalizer": + def from_config(cls, config: PreprocessingConfig, image_key: Text = "image") -> "Normalizer": """Build an instance of this class from its configuration options. Args: config: An `PreprocessingConfig` instance with the desired parameters. + image_key: String name of the key containing the images to normalize. Returns: An instance of this class. """ return cls( - image_key="image", + image_key=image_key, ensure_float=True, ensure_rgb=config.ensure_rgb, ensure_grayscale=config.ensure_grayscale, diff --git a/sleap/nn/data/resizing.py b/sleap/nn/data/resizing.py index 61ae8aca5..74a72b66e 100644 --- a/sleap/nn/data/resizing.py +++ b/sleap/nn/data/resizing.py @@ -105,6 +105,7 @@ class Resizer: Attributes: image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. points_key: String name of the key containing points to adjust for the resizing operation. scale: Scalar float specifying scaling factor to resize images by. @@ -118,6 +119,7 @@ class Resizer: """ image_key: Text = "image" + scale_key: Text = "scale" points_key: Optional[Text] = "instances" scale: float = 1.0 pad_to_stride: int = 1 @@ -128,6 +130,8 @@ class Resizer: def from_config( cls, config: PreprocessingConfig, + image_key: Text = "image", + scale_key: Text = "scale", pad_to_stride: Optional[int] = None, keep_full_image: bool = False, full_image_key: Text = "full_image", @@ -139,6 +143,8 @@ def from_config( config: An `PreprocessingConfig` instance with the desired parameters. If `config.pad_to_stride` is not an explicit integer, the `pad_to_stride` parameter must be provided. + image_key: String name of the key containing the images to resize. + scale_key: String name of the key containing the scale of the images. pad_to_stride: An integer specifying the `pad_to_stride` if `config.pad_to_stride` is not an explicit integer (e.g., set to None). keep_full_image: If True, keeps the (original size) full image in the @@ -163,8 +169,9 @@ def from_config( ) return cls( - image_key="image", + image_key=image_key, points_key=points_key, + scale_key=scale_key, scale=config.input_scaling, pad_to_stride=pad_to_stride, keep_full_image=keep_full_image, @@ -174,7 +181,7 @@ def from_config( @property def input_keys(self) -> List[Text]: """Return the keys that incoming elements are expected to have.""" - input_keys = [self.image_key, "scale"] + input_keys = [self.image_key, self.scale_key] if self.points_key is not None: input_keys.append(self.points_key) return input_keys @@ -222,7 +229,7 @@ def resize(example): ) if self.points_key: example[self.points_key] = example[self.points_key] * self.scale - example["scale"] = example["scale"] * self.scale + example[self.scale_key] = example[self.scale_key] * self.scale if self.pad_to_stride > 1: example[self.image_key] = pad_to_stride( diff --git a/sleap/nn/inference.py b/sleap/nn/inference.py index 68eea54ab..d609f3852 100644 --- a/sleap/nn/inference.py +++ b/sleap/nn/inference.py @@ -32,6 +32,7 @@ GlobalPeakFinder, MockGlobalPeakFinder, KeyFilter, + KeyRenamer, PredictedCenterInstanceNormalizer, PartAffinityFieldInstanceGrouper, PointsRescaler, @@ -240,18 +241,23 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: if data_provider is not None: pipeline.providers = [data_provider] - if self.centroid_config is not None: - preprocessing_config = self.centroid_config.data.preprocessing - else: - preprocessing_config = self.confmap_config.data.preprocessing - pipeline += Normalizer.from_config(preprocessing_config) - pipeline += Resizer.from_config( - preprocessing_config, keep_full_image=True, points_key=None, - ) - pipeline += Prefetcher() + pipeline += KeyRenamer(old_key_names=["image", "scale"], new_key_names=["full_image", "full_image_scale"], drop_old=False) + if self.confmap_config is not None: + pipeline += Normalizer.from_config(self.confmap_config.data.preprocessing, image_key="full_image") + + points_key = "instances" if self.centroid_model is None else None + pipeline += Resizer.from_config( + self.confmap_config.data.preprocessing, keep_full_image=False, points_key=points_key, image_key="full_image", scale_key="full_image_scale" + ) + if self.centroid_model is not None: + pipeline += Normalizer.from_config(self.centroid_config.data.preprocessing, image_key="image") + pipeline += Resizer.from_config( + self.centroid_config.data.preprocessing, keep_full_image=False, points_key=None, + ) + # Predict centroids using model. pipeline += KerasModelPredictor( keras_model=self.centroid_model.keras_model, @@ -287,6 +293,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: centroids_key="predicted_centroids", centroid_confidences_key="predicted_centroid_confidences", full_image_key="full_image", + full_image_scale_key="full_image_scale", keep_instances_gt=self.confmap_model is None, ) @@ -298,6 +305,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline: anchor_part_names=anchor_part, skeletons=data_provider.labels.skeletons, ) + pipeline += KeyRenamer(old_key_names=["full_image", "full_image_scale"], new_key_names=["image", "scale"], drop_old=True) pipeline += InstanceCropper( crop_width=self.confmap_config.data.instance_cropping.crop_size, crop_height=self.confmap_config.data.instance_cropping.crop_size,