Skip to content

Commit

Permalink
Merge pull request #276 from murthylab/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
ntabris committed Jan 21, 2020
2 parents 4b4f773 + 7aab3af commit 33dc409
Show file tree
Hide file tree
Showing 24 changed files with 860 additions and 151 deletions.
2 changes: 0 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ dependencies:
- attrs
- jsonpickle=1.2
- networkx
- cudatoolkit=10.0.*
- cudnn
- tensorflow-gpu=2.0
- scikit-learn
- h5py
Expand Down
1 change: 1 addition & 0 deletions sleap/config/shortcuts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ prev video: QKeySequence.Back
save as: QKeySequence.SaveAs
save: Ctrl+S
select next: '`'
select to frame: Ctrl+Shift+J
show edges: Ctrl+Shift+Tab
show labels: Ctrl+Tab
show trails:
Expand Down
2 changes: 1 addition & 1 deletion sleap/config/training_editor.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ model:
- name: output_type
label: Training For
type: list
options: confmaps,topdown,pafs,centroids
options: confmaps,topdown_confidence_maps,pafs,centroids
default: confmaps

- name: arch # backbone_name
Expand Down
25 changes: 22 additions & 3 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ def add_submenu_choices(menu, title, options, key):

fileMenu.addSeparator()
add_menu_item(fileMenu, "add videos", "Add Videos...", self.commands.addVideo)
add_menu_item(
fileMenu, "replace videos", "Replace Videos...", self.commands.replaceVideo
)

fileMenu.addSeparator()
add_menu_item(fileMenu, "save", "Save", self.commands.saveProject)
Expand Down Expand Up @@ -343,6 +346,9 @@ def prev_vid():
goMenu.addSeparator()

add_menu_item(goMenu, "goto frame", "Go to Frame...", self.commands.gotoFrame)
add_menu_item(
goMenu, "select to frame", "Select to Frame...", self.commands.selectToFrame
)

### View Menu ###

Expand Down Expand Up @@ -1010,25 +1016,38 @@ def updateStatusMessage(self, message: Optional[str] = None):
current_video = self.state["video"]
frame_idx = self.state["frame_idx"] or 0

spacer = " "

if message is None:
message = f"Frame: {frame_idx+1}/{len(current_video)}"
message = f"Frame: {frame_idx+1:,}/{len(current_video):,}"
if self.player.seekbar.hasSelection():
start, end = self.state["frame_range"]
message += f" (selection: {start}-{end})"
message += f" (selection: {start+1:,}-{end+1:,})"

if len(self.labels.videos) > 1:
message += f" of video {self.labels.videos.index(current_video)}"

message += f" Labeled Frames: "
message += f"{spacer}Labeled Frames: "
if current_video is not None:
message += (
f"{len(self.labels.get_video_user_labeled_frames(current_video))}"
)

if len(self.labels.videos) > 1:
message += " in video, "
if len(self.labels.videos) > 1:
message += f"{len(self.labels.user_labeled_frames)} in project"

if current_video is not None:
pred_frame_count = len(
self.labels.get_video_predicted_frames(current_video)
)
if pred_frame_count:
message += f"{spacer}Predicted Frames: {pred_frame_count:,}"
message += (
f" ({pred_frame_count/current_video.num_frames*100:.2f}%)"
)

self.statusBar().showMessage(message)

def loadProjectFile(self, filename: Optional[str] = None):
Expand Down
58 changes: 58 additions & 0 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,10 @@ def gotoFrame(self):
"""Shows gui to go to frame by number."""
self.execute(GoFrameGui)

def selectToFrame(self):
"""Shows gui to go to frame by number."""
self.execute(SelectToFrameGui)

def gotoVideoAndFrame(self, video: Video, frame_idx: int):
"""Activates video and goes to frame."""
NavCommand.go_to(self, frame_idx, video)
Expand All @@ -259,6 +263,10 @@ def addVideo(self):
"""Shows gui for adding videos to project."""
self.execute(AddVideo)

def replaceVideo(self):
"""Shows gui for replacing videos to project."""
self.execute(ReplaceVideo)

def removeVideo(self):
"""Removes selected video from project."""
self.execute(RemoveVideo)
Expand Down Expand Up @@ -846,6 +854,29 @@ def ask(cls, context: "CommandContext", params: dict) -> bool:
return okay


class SelectToFrameGui(NavCommand):
@classmethod
def do_action(cls, context: "CommandContext", params: dict):
context.app.player.setSeekbarSelection(
params["from_frame_idx"], params["to_frame_idx"]
)

@classmethod
def ask(cls, context: "CommandContext", params: dict) -> bool:
frame_number, okay = QtWidgets.QInputDialog.getInt(
context.app,
"Select To Frame...",
"Frame Number:",
context.state["frame_idx"] + 1,
1,
context.state["video"].frames,
)
params["from_frame_idx"] = context.state["frame_idx"]
params["to_frame_idx"] = frame_number - 1

return okay


# Editing Commands


Expand Down Expand Up @@ -882,6 +913,33 @@ def ask(context: CommandContext, params: dict) -> bool:
return len(params["import_list"]) > 0


class ReplaceVideo(EditCommand):
topics = [UpdateTopic.video]

@staticmethod
def do_action(context: CommandContext, params: dict):
new_paths = params["new_video_paths"]

for video, new_path in zip(context.labels.videos, new_paths):
if new_path != video.backend.filename:
video.backend.filename = new_path
video.backend.reset()

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
"""Shows gui for replacing videos in project."""
paths = [video.backend.filename for video in context.labels.videos]

okay = MissingFilesDialog(filenames=paths, replace=True).exec_()

if not okay:
return False

params["new_video_paths"] = paths

return True


class RemoveVideo(EditCommand):
topics = [UpdateTopic.video]

Expand Down
3 changes: 3 additions & 0 deletions sleap/gui/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,9 @@ def _get_current_training_jobs(self) -> Dict[ModelOutputType, TrainingJob]:
# Use already trained model if desired
if form_data.get(f"_use_trained_{str(model_type)}", default_use_trained):
job.use_trained_model = True
elif model_type == ModelOutputType.TOPDOWN_CONFIDENCE_MAP:
if form_data.get(f"_use_trained_confmaps", default_use_trained):
job.use_trained_model = True

# Clear parameters that shouldn't be copied
job.val_set_filename = None
Expand Down
18 changes: 13 additions & 5 deletions sleap/gui/missingfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

class MissingFilesDialog(QtWidgets.QDialog):
def __init__(
self, filenames: List[str], missing: List[bool] = None, *args, **kwargs
self,
filenames: List[str],
missing: List[bool] = None,
replace: bool = False,
*args,
**kwargs,
):
"""
Creates dialog window for finding missing files.
Expand Down Expand Up @@ -42,10 +47,13 @@ def __init__(

layout = QtWidgets.QVBoxLayout()

info_text = (
f"{missing_count} file(s) which could not be found. "
"Please double-click on a file to locate it..."
)
if replace:
info_text = "Double-click on a file to replace it..."
else:
info_text = (
f"{missing_count} file(s) which could not be found. "
"Please double-click on a file to locate it..."
)
info_label = QtWidgets.QLabel(info_text)
layout.addWidget(info_label)

Expand Down
61 changes: 33 additions & 28 deletions sleap/gui/overlays/tracks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import attr
import itertools

from typing import List
from typing import Iterable, List

from PySide2 import QtCore, QtGui

MAX_NODES_IN_TRAIL = 30


@attr.s(auto_attribs=True)
class TrackTrailOverlay:
Expand All @@ -37,42 +39,48 @@ class TrackTrailOverlay:
trail_length: int = 10
show: bool = False

def get_track_trails(self, frame_selection, track: Track):
def get_track_trails(self, frame_selection: Iterable["LabeledFrame"]):
"""Get data needed to draw track trail.
Args:
frame_selection: an interable with the :class:`LabeledFrame`
frame_selection: an iterable with the :class:`LabeledFrame`
objects to include in trail.
track: the :class:`Track` for which to get trail
Returns:
list of lists of (x, y) tuples
Dictionary keyed by track, value is list of lists of (x, y) tuples
i.e., for every node in instance, we get a list of positions
"""

all_trails = [[] for _ in range(len(self.labels.nodes))]
all_track_trails = dict()

nodes = self.labels.nodes
if len(nodes) > MAX_NODES_IN_TRAIL:
nodes = nodes[:MAX_NODES_IN_TRAIL]

for frame in frame_selection:
frame_idx = frame.frame_idx

inst_on_track = [instance for instance in frame if instance.track == track]
if inst_on_track:
# just use the first instance from this track in this frame
inst = inst_on_track[0]
# loop through all nodes
for node_i, node in enumerate(self.labels.nodes):
for inst in frame:
if inst.track is not None:
if inst.track not in all_track_trails:
all_track_trails[inst.track] = [[] for _ in range(len(nodes))]

# loop through all nodes
for node_i, node in enumerate(nodes):

if node in inst.nodes and inst[node].visible:
point = (inst[node].x, inst[node].y)

if node in inst.nodes and inst[node].visible:
point = (inst[node].x, inst[node].y)
elif len(all_trails[node_i]):
point = all_trails[node_i][-1]
else:
point = None
# Add last location of node so that we can easily
# calculate trail length (since we adjust opacity).
elif len(all_track_trails[inst.track][node_i]):
point = all_track_trails[inst.track][node_i][-1]
else:
point = None

if point is not None:
all_trails[node_i].append(point)
if point is not None:
all_track_trails[inst.track][node_i].append(point)

return all_trails
return all_track_trails

def get_frame_selection(self, video: Video, frame_idx: int):
"""
Expand Down Expand Up @@ -116,17 +124,14 @@ def add_to_scene(self, video: Video, frame_idx: int):
video: current video
frame_idx: index of the frame to which the trail is attached
"""
if not self.show:
if not self.show or self.trail_length == 0:
return

frame_selection = self.get_frame_selection(video, frame_idx)
tracks_in_frame = self.get_tracks_in_frame(
video, frame_idx, include_trails=True
)

for track in tracks_in_frame:
all_track_trails = self.get_track_trails(frame_selection)

trails = self.get_track_trails(frame_selection, track)
for track, trails in all_track_trails.items():

color = QtGui.QColor(*self.player.color_manager.get_track_color(track))
pen = QtGui.QPen()
Expand Down
Loading

0 comments on commit 33dc409

Please sign in to comment.