Skip to content

Commit

Permalink
Fix topdown not accounting for different scale in centroids vs confmaps
Browse files Browse the repository at this point in the history
- Some optimization could probably be done to increase performance when
  both models have the same preprocessing settings, but currently is
  done independently.
  • Loading branch information
talmo committed Apr 6, 2020
1 parent 29ca862 commit 01da28a
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 16 deletions.
1 change: 1 addition & 0 deletions sleap/nn/data/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def norm_instance(example):

pts = example[self.peaks_key]
pts += bboxes_x1y1
pts /= example["scale"]

example[self.new_centroid_key] = centroids
example[self.new_centroid_confidence_key] = example[
Expand Down
9 changes: 7 additions & 2 deletions sleap/nn/data/instance_cropping.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ class PredictedInstanceCropper:
centroids_key: Text = "predicted_centroids"
centroid_confidences_key: Text = "predicted_centroid_confidences"
full_image_key: Text = "full_image"
full_image_scale_key: Text = "full_image_scale"
keep_full_image: bool = False
keep_instances_gt: bool = False

Expand All @@ -432,9 +433,9 @@ def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
input_keys = [
self.full_image_key,
self.full_image_scale_key,
self.centroids_key,
self.centroid_confidences_key,
"scale",
"video_ind",
"frame_ind",
]
Expand Down Expand Up @@ -466,6 +467,7 @@ def output_keys(self) -> List[Text]:
def transform_dataset(self, input_ds: tf.data.Dataset) -> tf.data.Dataset:
"""Create a dataset that contains instance cropped data."""
keys_to_expand = ["scale", "video_ind", "frame_ind"]
# keys_to_expand = ["video_ind", "frame_ind"]
if self.keep_full_image:
keys_to_expand.append(self.full_image_key)
if self.keep_instances_gt:
Expand All @@ -475,10 +477,13 @@ def crop_instances(frame_data):
"""Local processing function for dataset mapping."""
# Make bounding boxes from centroids.
full_centroids = frame_data[self.centroids_key] / frame_data["scale"]
full_centroids = full_centroids * frame_data[self.full_image_scale_key]
bboxes = make_centered_bboxes(
full_centroids, box_height=self.crop_height, box_width=self.crop_width
)

frame_data["scale"] = frame_data[self.full_image_scale_key]

# Crop images from bounding boxes.
instance_images = crop_bboxes(frame_data[self.full_image_key], bboxes)
n_instances = tf.shape(bboxes)[0]
Expand All @@ -488,7 +493,7 @@ def crop_instances(frame_data):
"instance_image": instance_images,
"bbox": bboxes,
"center_instance_ind": tf.range(n_instances, dtype=tf.int32),
"centroid": frame_data[self.centroids_key],
"centroid": full_centroids,
"centroid_confidence": frame_data[self.centroid_confidences_key],
"full_image_height": tf.repeat(
tf.shape(frame_data[self.full_image_key])[0], n_instances
Expand Down
5 changes: 3 additions & 2 deletions sleap/nn/data/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,17 +305,18 @@ class Normalizer:
)

@classmethod
def from_config(cls, config: PreprocessingConfig) -> "Normalizer":
def from_config(cls, config: PreprocessingConfig, image_key: Text = "image") -> "Normalizer":
"""Build an instance of this class from its configuration options.
Args:
config: An `PreprocessingConfig` instance with the desired parameters.
image_key: String name of the key containing the images to normalize.
Returns:
An instance of this class.
"""
return cls(
image_key="image",
image_key=image_key,
ensure_float=True,
ensure_rgb=config.ensure_rgb,
ensure_grayscale=config.ensure_grayscale,
Expand Down
13 changes: 10 additions & 3 deletions sleap/nn/data/resizing.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class Resizer:
Attributes:
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
points_key: String name of the key containing points to adjust for the resizing
operation.
scale: Scalar float specifying scaling factor to resize images by.
Expand All @@ -118,6 +119,7 @@ class Resizer:
"""

image_key: Text = "image"
scale_key: Text = "scale"
points_key: Optional[Text] = "instances"
scale: float = 1.0
pad_to_stride: int = 1
Expand All @@ -128,6 +130,8 @@ class Resizer:
def from_config(
cls,
config: PreprocessingConfig,
image_key: Text = "image",
scale_key: Text = "scale",
pad_to_stride: Optional[int] = None,
keep_full_image: bool = False,
full_image_key: Text = "full_image",
Expand All @@ -139,6 +143,8 @@ def from_config(
config: An `PreprocessingConfig` instance with the desired parameters. If
`config.pad_to_stride` is not an explicit integer, the `pad_to_stride`
parameter must be provided.
image_key: String name of the key containing the images to resize.
scale_key: String name of the key containing the scale of the images.
pad_to_stride: An integer specifying the `pad_to_stride` if
`config.pad_to_stride` is not an explicit integer (e.g., set to None).
keep_full_image: If True, keeps the (original size) full image in the
Expand All @@ -163,8 +169,9 @@ def from_config(
)

return cls(
image_key="image",
image_key=image_key,
points_key=points_key,
scale_key=scale_key,
scale=config.input_scaling,
pad_to_stride=pad_to_stride,
keep_full_image=keep_full_image,
Expand All @@ -174,7 +181,7 @@ def from_config(
@property
def input_keys(self) -> List[Text]:
"""Return the keys that incoming elements are expected to have."""
input_keys = [self.image_key, "scale"]
input_keys = [self.image_key, self.scale_key]
if self.points_key is not None:
input_keys.append(self.points_key)
return input_keys
Expand Down Expand Up @@ -222,7 +229,7 @@ def resize(example):
)
if self.points_key:
example[self.points_key] = example[self.points_key] * self.scale
example["scale"] = example["scale"] * self.scale
example[self.scale_key] = example[self.scale_key] * self.scale

if self.pad_to_stride > 1:
example[self.image_key] = pad_to_stride(
Expand Down
26 changes: 17 additions & 9 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
GlobalPeakFinder,
MockGlobalPeakFinder,
KeyFilter,
KeyRenamer,
PredictedCenterInstanceNormalizer,
PartAffinityFieldInstanceGrouper,
PointsRescaler,
Expand Down Expand Up @@ -240,18 +241,23 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
if data_provider is not None:
pipeline.providers = [data_provider]

if self.centroid_config is not None:
preprocessing_config = self.centroid_config.data.preprocessing
else:
preprocessing_config = self.confmap_config.data.preprocessing
pipeline += Normalizer.from_config(preprocessing_config)
pipeline += Resizer.from_config(
preprocessing_config, keep_full_image=True, points_key=None,
)

pipeline += Prefetcher()

pipeline += KeyRenamer(old_key_names=["image", "scale"], new_key_names=["full_image", "full_image_scale"], drop_old=False)
if self.confmap_config is not None:
pipeline += Normalizer.from_config(self.confmap_config.data.preprocessing, image_key="full_image")

points_key = "instances" if self.centroid_model is None else None
pipeline += Resizer.from_config(
self.confmap_config.data.preprocessing, keep_full_image=False, points_key=points_key, image_key="full_image", scale_key="full_image_scale"
)

if self.centroid_model is not None:
pipeline += Normalizer.from_config(self.centroid_config.data.preprocessing, image_key="image")
pipeline += Resizer.from_config(
self.centroid_config.data.preprocessing, keep_full_image=False, points_key=None,
)

# Predict centroids using model.
pipeline += KerasModelPredictor(
keras_model=self.centroid_model.keras_model,
Expand Down Expand Up @@ -287,6 +293,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
centroids_key="predicted_centroids",
centroid_confidences_key="predicted_centroid_confidences",
full_image_key="full_image",
full_image_scale_key="full_image_scale",
keep_instances_gt=self.confmap_model is None,
)

Expand All @@ -298,6 +305,7 @@ def make_pipeline(self, data_provider: Optional[Provider] = None) -> Pipeline:
anchor_part_names=anchor_part,
skeletons=data_provider.labels.skeletons,
)
pipeline += KeyRenamer(old_key_names=["full_image", "full_image_scale"], new_key_names=["image", "scale"], drop_old=True)
pipeline += InstanceCropper(
crop_width=self.confmap_config.data.instance_cropping.crop_size,
crop_height=self.confmap_config.data.instance_cropping.crop_size,
Expand Down

0 comments on commit 01da28a

Please sign in to comment.