Skip to content

Commit

Permalink
Merge pull request #258 from murthylab/nn-interface
Browse files Browse the repository at this point in the history
Nn interface
  • Loading branch information
ntabris committed Dec 12, 2019
2 parents 03a0e0c + 4114244 commit 4b4f773
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 16 deletions.
19 changes: 13 additions & 6 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(self, labels_path: Optional[str] = None, *args, **kwargs):
self.state["filename"] = None
self.state["show labels"] = True
self.state["show edges"] = True
self.state["edge style"] = "Line"
self.state["fit"] = False
self.state["show trails"] = False
self.state["color predicted"] = False
Expand Down Expand Up @@ -374,6 +375,14 @@ def prev_vid():

add_menu_check_item(viewMenu, "show labels", "Show Node Names")
add_menu_check_item(viewMenu, "show edges", "Show Edges")

add_submenu_choices(
menu=viewMenu,
title="Edge Style",
options=("Line", "Wedge"),
key="edge style",
)

add_menu_check_item(viewMenu, "show trails", "Show Trails")

add_submenu_choices(
Expand Down Expand Up @@ -817,12 +826,10 @@ def overlay_state_connect(overlay, state_key, overlay_attribute=None):
self.state.connect("palette", lambda x: self.updateSeekbarMarks())

# update the skeleton tables since we may want to redraw colors
self.state.connect(
"palette", lambda x: self.on_data_update([UpdateTopic.skeleton])
)
self.state.connect(
"distinctly_color", lambda x: self.on_data_update([UpdateTopic.skeleton])
)
for state_var in ("palette", "distinctly_color", "edge style"):
self.state.connect(
state_var, lambda x: self.on_data_update([UpdateTopic.skeleton])
)

# Set defaults
self.state["trail_length"] = 10
Expand Down
22 changes: 22 additions & 0 deletions sleap/gui/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,28 @@ def view_datagen(self):
)
# negative_samples = form_data.get("negative_samples", 0)

# Augment dataset
aug_params = dict(
# rotate=conf_job.trainer.augment_rotate,
# rotation_min_angle=-conf_job.trainer.augment_rotation,
# rotation_max_angle=conf_job.trainer.augment_rotation,
scale=form_data.get("scale", conf_job.trainer.scale),
# scale_min=conf_job.trainer.augment_scale_min,
# scale_max=conf_job.trainer.augment_scale_max,
# uniform_noise=conf_job.trainer.augment_uniform_noise,
# min_noise_val=conf_job.trainer.augment_uniform_noise_min_val,
# max_noise_val=conf_job.trainer.augment_uniform_noise_max_val,
# gaussian_noise=conf_job.trainer.augment_gaussian_noise,
# gaussian_noise_mean=conf_job.trainer.augment_gaussian_noise_mean,
# gaussian_noise_stddev=conf_job.trainer.augment_gaussian_noise_stddev,
contrast=conf_job.trainer.augment_contrast,
contrast_min_gamma=conf_job.trainer.augment_contrast_min_gamma,
contrast_max_gamma=conf_job.trainer.augment_contrast_max_gamma,
brightness=conf_job.trainer.augment_brightness,
brightness_val=conf_job.trainer.augment_brightness_val,
)
ds = data.augment_dataset(ds, **aug_params)

if bounding_box_size is None or bounding_box_size <= 0:
bounding_box_size = data.estimate_instance_crop_size(
training_data.points,
Expand Down
62 changes: 52 additions & 10 deletions sleap/gui/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
QGraphicsScene,
)
from PySide2.QtGui import QImage, QPixmap, QPainter, QPainterPath, QTransform
from PySide2.QtGui import QPen, QBrush, QColor, QFont
from PySide2.QtGui import QPen, QBrush, QColor, QFont, QPolygonF
from PySide2.QtGui import QKeyEvent
from PySide2.QtCore import Qt, QRectF, QPointF, QMarginsF
from PySide2.QtCore import Qt, QRectF, QPointF, QMarginsF, QLineF

import math

Expand All @@ -37,6 +37,7 @@
QGraphicsLineItem,
QGraphicsTextItem,
QGraphicsRectItem,
QGraphicsPolygonItem,
)

from sleap.skeleton import Node
Expand Down Expand Up @@ -1105,6 +1106,13 @@ def calls(self):
if callable(callback):
callback(self)

@property
def visible_radius(self):
if self.point.visible:
return self.radius / self.player.view.zoomFactor
else:
return self.radius / (2.0 * self.player.view.zoomFactor)

def updatePoint(self, user_change: bool = False):
"""
Method to update data for node/edge when node position is manipulated.
Expand Down Expand Up @@ -1132,6 +1140,7 @@ def updatePoint(self, user_change: bool = False):
self.setBrush(self.brush_missing)
if not self.show_non_visible:
self.hide()

self.setRect(-radius, -radius, radius * 2, radius * 2)

for edge in self.edges:
Expand Down Expand Up @@ -1231,7 +1240,7 @@ def mouseDoubleClickEvent(self, event):
view.instanceDoubleClicked.emit(self.parentObject().instance)


class QtEdge(QGraphicsLineItem):
class QtEdge(QGraphicsPolygonItem):
"""
QGraphicsLineItem to handle display of edge between skeleton instance nodes.
Expand All @@ -1253,28 +1262,61 @@ def __init__(
**kwargs,
):
self.parent = parent
self.player = player
self.src = src
self.dst = dst
self.show_non_visible = show_non_visible

super(QtEdge, self).__init__(
self.src.point.x,
self.src.point.y,
self.dst.point.x,
self.dst.point.y,
parent=parent,
*args,
**kwargs,
polygon=QPolygonF(), parent=parent, *args, **kwargs,
)

self.setLine(
QLineF(
self.src.point.x, self.src.point.y, self.dst.point.x, self.dst.point.y,
)
)

edge_pair = (src.node, dst.node)
color = player.color_manager.get_item_color(edge_pair, parent.instance)
pen_width = player.color_manager.get_item_pen_width(edge_pair, parent.instance)
pen = QPen(QColor(*color), pen_width)
pen.setCosmetic(True)

brush = QBrush(QColor(*color, a=128))

self.setPen(pen)
self.setBrush(brush)
self.full_opacity = 1

def line(self):
return self._line

def setLine(self, line):
self._line = line
polygon = QPolygonF()

if self.player.state.get("edge style", default="").lower() == "wedge":

r = self.src.visible_radius / 2.0

norm_a = line.normalVector()
norm_a.setLength(r)

norm_b = line.normalVector()
norm_b.setLength(-r)

polygon.append(norm_a.p2())
polygon.append(line.p2())
polygon.append(norm_b.p2())
polygon.append(norm_a.p2())

else:
polygon.append(line.p1())
polygon.append(line.p2())

self.setPolygon(polygon)

def connected_to(self, node: QtNode):
"""
Return the other node along the edge.
Expand Down
13 changes: 13 additions & 0 deletions sleap/nn/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,11 @@ def augment_dataset(
gaussian_noise: bool = False,
gaussian_noise_mean: float = 0.05,
gaussian_noise_stddev: float = 0.02,
contrast: bool = False,
contrast_min_gamma: float = 0.5,
contrast_max_gamma: float = 2.0,
brightness: bool = False,
brightness_val: float = 0.0,
) -> tf.data.Dataset:
"""Augments a pair of image and points dataset.
Expand Down Expand Up @@ -355,6 +360,14 @@ def augment_dataset(
if scale:
aug_stack.append(iaa.Affine(scale=(scale_min, scale_max)))

if contrast:
aug_stack.append(
iaa.GammaContrast(gamma=(contrast_min_gamma, contrast_max_gamma))
)

if brightness:
aug_stack.append(iaa.Add(value=(-brightness_val, brightness_val)))

if uniform_noise:
aug_stack.append(
iaa.AddElementwise(value=(min_noise_val * 255, max_noise_val * 255))
Expand Down
5 changes: 5 additions & 0 deletions sleap/nn/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class TrainerConfig:
augment_gaussian_noise: bool = False
augment_gaussian_noise_mean: float = 0.05
augment_gaussian_noise_stddev: float = 0.1
augment_contrast: bool = False
augment_contrast_min_gamma: float = 0.5
augment_contrast_max_gamma: float = 2.0
augment_brightness: bool = False
augment_brightness_val: float = 0.0

# Optimization:
optimizer: str = "adam"
Expand Down
5 changes: 5 additions & 0 deletions sleap/nn/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ def setup_data(
gaussian_noise=self.training_job.trainer.augment_gaussian_noise,
gaussian_noise_mean=self.training_job.trainer.augment_gaussian_noise_mean,
gaussian_noise_stddev=self.training_job.trainer.augment_gaussian_noise_stddev,
contrast=self.training_job.trainer.augment_contrast,
contrast_min_gamma=self.training_job.trainer.augment_contrast_min_gamma,
contrast_max_gamma=self.training_job.trainer.augment_contrast_max_gamma,
brightness=self.training_job.trainer.augment_brightness,
brightness_val=self.training_job.trainer.augment_brightness_val,
)
ds_train = data.augment_dataset(ds_train, **aug_params)
ds_val = data.augment_dataset(ds_val, **aug_params)
Expand Down

0 comments on commit 4b4f773

Please sign in to comment.