Skip to content

Commit

Permalink
Merge pull request #227 from ssciwr/fix_223_mouse_reset_before_trial_…
Browse files Browse the repository at this point in the history
…and_freeze

Don't reset mouse to centre before each trial and reduce processing b…
  • Loading branch information
lkeegan committed Jun 26, 2023
2 parents 82623de + 7d041ad commit 56e5122
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 35 deletions.
Binary file modified benchmarks/profile_task.prof
Binary file not shown.
6 changes: 5 additions & 1 deletion benchmarks/profile_task.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import cProfile
import ctypes

from psychopy import core
from vstt.experiment import Experiment
from vstt.task import MotorTask

xlib = ctypes.cdll.LoadLibrary("libX11.so")
xlib.XInitThreads()

experiment = Experiment()
experiment.trial_list[0]["weight"] = 1
experiment.trial_list[0]["weight"] = 3
experiment.trial_list[0]["inter_target_duration"] = 0
experiment.trial_list[0]["post_block_display_results"] = False
task = MotorTask(experiment)
Expand Down
Binary file modified benchmarks/snakeviz.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
92 changes: 58 additions & 34 deletions src/vstt/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Union

import numpy as np
import vstt.vtypes
from psychopy.clock import Clock
from psychopy.data import TrialHandlerExt
from psychopy.event import Mouse
Expand Down Expand Up @@ -60,8 +61,7 @@ def __init__(self, trial: Dict[str, Any], rng: np.random.Generator):
class TrialManager:
"""Stores the drawable elements and other objects needed during a trial"""

def __init__(self, win: Window, trial: Dict[str, Any], rng: np.random.Generator):
self.data = TrialData(trial, rng)
def __init__(self, win: Window, trial: vstt.vtypes.Trial):
self.targets = vis.make_targets(
win,
trial["num_targets"],
Expand All @@ -70,15 +70,16 @@ def __init__(self, win: Window, trial: Dict[str, Any], rng: np.random.Generator)
trial["add_central_target"],
trial["central_target_size"],
)
self.target_labels = vis.make_target_labels(
win,
trial["num_targets"],
trial["target_distance"],
trial["target_size"],
trial["target_labels"],
)
self.drawables: List[Union[BaseVisualStim, ElementArrayStim]] = [self.targets]
self.target_labels = None
if trial["show_target_labels"]:
self.target_labels = vis.make_target_labels(
win,
trial["num_targets"],
trial["target_distance"],
trial["target_size"],
trial["target_labels"],
)
self.drawables.extend(self.target_labels)
self.cursor = vis.make_cursor(win, trial["cursor_size"])
self.cursor.setPos(np.array([0.0, 0.0]))
Expand Down Expand Up @@ -130,6 +131,10 @@ def __init__(self, experiment: Experiment, win: Optional[Window] = None):
self.win = win
if not experiment.trial_list:
return
self.trial_managers = {
condition_index: TrialManager(self.win, trial)
for condition_index, trial in enumerate(experiment.trial_list)
}
self.trial_handler = experiment.create_trialhandler()
self.mouse = Mouse(visible=False, win=win)
self.kb = Keyboard()
Expand Down Expand Up @@ -164,13 +169,16 @@ def _do_trials(self) -> None:
current_condition_first_trial_index = 0
current_condition_clock = Clock()
current_condition_max_time = 0.0
self.mouse.setPos((0.0, 0.0))
current_cursor_pos = (0.0, 0.0)
for trial in self.trial_handler:
if self.trial_handler.thisIndex != current_condition_index:
# starting a new set of conditions
current_condition_clock.reset()
current_condition_max_time = trial["condition_timeout"]
current_condition_index = self.trial_handler.thisIndex
current_condition_first_trial_index = self.trial_handler.thisTrialN
trial_manager = self.trial_managers[current_condition_index]
condition_trial_indices[self.trial_handler.thisIndex].append(
self.trial_handler.thisTrialN
)
Expand All @@ -179,7 +187,9 @@ def _do_trials(self) -> None:
or current_condition_clock.getTime() < current_condition_max_time
):
# only do the trial if there is still time left for this condition
self._do_trial(trial)
current_cursor_pos = self._do_trial(
trial, trial_manager, current_cursor_pos
)
is_final_trial_of_block = (
len(condition_trial_indices[self.trial_handler.thisIndex])
== trial["weight"]
Expand Down Expand Up @@ -207,24 +217,31 @@ def _do_splash_screen(self) -> None:
win=self.win,
)

def _do_trial(self, trial: Dict[str, Any]) -> None:
def _do_trial(
self,
trial: Dict[str, Any],
trial_manager: TrialManager,
initial_cursor_pos: Tuple[float, float],
) -> Tuple[float, float]:
if trial["use_joystick"] and self.js is None:
raise RuntimeError("Use joystick option is enabled, but no joystick found.")
tm = TrialManager(self.win, trial, self.rng)
self.mouse.setPos(tm.cursor.pos)
trial_manager.cursor.setPos(initial_cursor_pos)
trial_data = TrialData(trial, self.rng)
self.win.recordFrameIntervals = True
tm.clock.reset()
vis.update_target_colors(tm.targets, trial["show_inactive_targets"], None)
if trial["show_target_labels"]:
trial_manager.clock.reset()
vis.update_target_colors(
trial_manager.targets, trial["show_inactive_targets"], None
)
if trial["show_target_labels"] and trial_manager.target_labels is not None:
vis.update_target_label_colors(
tm.target_labels, trial["show_inactive_targets"], None
trial_manager.target_labels, trial["show_inactive_targets"], None
)
for index in tm.data.target_indices:
self._do_target(trial, index, tm)
for index in trial_data.target_indices:
self._do_target(trial, index, trial_manager, trial_data)
self.win.recordFrameIntervals = False
if trial["automove_cursor_to_center"]:
tm.data.to_center_success = [True] * trial["num_targets"]
add_trial_data_to_trial_handler(tm.data, self.trial_handler)
trial_data.to_center_success = [True] * trial["num_targets"]
add_trial_data_to_trial_handler(trial_data, self.trial_handler)
if trial["post_trial_delay"] > 0:
vis.display_results(
trial["post_trial_delay"],
Expand All @@ -236,14 +253,17 @@ def _do_trial(self, trial: Dict[str, Any]) -> None:
False,
self.win,
)
return trial_manager.cursor.pos

def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> None:
def _do_target(
self, trial: Dict[str, Any], index: int, tm: TrialManager, trial_data: TrialData
) -> None:
minimum_window_for_flip = 1.0 / 60.0
mouse_pos = tm.cursor.pos
stop_waiting_time = 0.0
stop_target_time = 0.0
if trial["fixed_target_intervals"]:
num_completed_targets = len(tm.data.to_target_timestamps)
num_completed_targets = len(trial_data.to_target_timestamps)
stop_waiting_time = (num_completed_targets + 1) * trial["target_duration"]
stop_target_time = stop_waiting_time + trial["target_duration"]
for target_index in _get_target_indices(index, trial):
Expand All @@ -264,7 +284,7 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non
vis.update_target_colors(
tm.targets, trial["show_inactive_targets"], None
)
if trial["show_target_labels"]:
if trial["show_target_labels"] and tm.target_labels is not None:
vis.update_target_label_colors(
tm.target_labels, trial["show_inactive_targets"], None
)
Expand Down Expand Up @@ -296,17 +316,21 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non
vis.update_target_colors(
tm.targets, trial["show_inactive_targets"], target_index
)
if trial["show_target_labels"]:
if trial["show_target_labels"] and tm.target_labels is not None:
vis.update_target_label_colors(
tm.target_labels, trial["show_inactive_targets"], target_index
)
if trial["play_sound"]:
Sound("A", secs=0.3, blockSize=1024, stereo=True).play()
if is_central_target:
tm.data.to_center_num_timestamps_before_visible.append(len(mouse_times))
trial_data.to_center_num_timestamps_before_visible.append(
len(mouse_times)
)
else:
tm.data.target_pos.append(tm.targets.xys[target_index])
tm.data.to_target_num_timestamps_before_visible.append(len(mouse_times))
trial_data.target_pos.append(tm.targets.xys[target_index])
trial_data.to_target_num_timestamps_before_visible.append(
len(mouse_times)
)
target_size = trial["target_size"]
if is_central_target:
target_size = trial["central_target_size"]
Expand Down Expand Up @@ -351,15 +375,15 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non
and tm.clock.getTime() + minimum_window_for_flip < stop_target_time
)
if is_central_target:
tm.data.to_center_success.append(success)
trial_data.to_center_success.append(success)
else:
tm.data.to_target_success.append(success)
trial_data.to_target_success.append(success)
if is_central_target:
tm.data.to_center_timestamps.append(np.array(mouse_times))
tm.data.to_center_mouse_positions.append(np.array(mouse_positions))
trial_data.to_center_timestamps.append(np.array(mouse_times))
trial_data.to_center_mouse_positions.append(np.array(mouse_positions))
else:
tm.data.to_target_timestamps.append(np.array(mouse_times))
tm.data.to_target_mouse_positions.append(np.array(mouse_positions))
trial_data.to_target_timestamps.append(np.array(mouse_times))
trial_data.to_target_mouse_positions.append(np.array(mouse_positions))

def _clean_up_and_return(self, return_value: bool) -> bool:
if self.win is not None and self.close_window_when_done:
Expand Down

0 comments on commit 56e5122

Please sign in to comment.