diff --git a/benchmarks/profile_task.prof b/benchmarks/profile_task.prof index 0c21647c..2a003cc2 100644 Binary files a/benchmarks/profile_task.prof and b/benchmarks/profile_task.prof differ diff --git a/benchmarks/profile_task.py b/benchmarks/profile_task.py index 444084ff..afa73c02 100644 --- a/benchmarks/profile_task.py +++ b/benchmarks/profile_task.py @@ -1,11 +1,15 @@ import cProfile +import ctypes from psychopy import core from vstt.experiment import Experiment from vstt.task import MotorTask +xlib = ctypes.cdll.LoadLibrary("libX11.so") +xlib.XInitThreads() + experiment = Experiment() -experiment.trial_list[0]["weight"] = 1 +experiment.trial_list[0]["weight"] = 3 experiment.trial_list[0]["inter_target_duration"] = 0 experiment.trial_list[0]["post_block_display_results"] = False task = MotorTask(experiment) diff --git a/benchmarks/snakeviz.png b/benchmarks/snakeviz.png index 0d0d3e93..e2f16d83 100644 Binary files a/benchmarks/snakeviz.png and b/benchmarks/snakeviz.png differ diff --git a/src/vstt/task.py b/src/vstt/task.py index da29a823..ac811566 100644 --- a/src/vstt/task.py +++ b/src/vstt/task.py @@ -10,6 +10,7 @@ from typing import Union import numpy as np +import vstt.vtypes from psychopy.clock import Clock from psychopy.data import TrialHandlerExt from psychopy.event import Mouse @@ -60,8 +61,7 @@ def __init__(self, trial: Dict[str, Any], rng: np.random.Generator): class TrialManager: """Stores the drawable elements and other objects needed during a trial""" - def __init__(self, win: Window, trial: Dict[str, Any], rng: np.random.Generator): - self.data = TrialData(trial, rng) + def __init__(self, win: Window, trial: vstt.vtypes.Trial): self.targets = vis.make_targets( win, trial["num_targets"], @@ -70,15 +70,16 @@ def __init__(self, win: Window, trial: Dict[str, Any], rng: np.random.Generator) trial["add_central_target"], trial["central_target_size"], ) - self.target_labels = vis.make_target_labels( - win, - trial["num_targets"], - trial["target_distance"], - trial["target_size"], - trial["target_labels"], - ) self.drawables: List[Union[BaseVisualStim, ElementArrayStim]] = [self.targets] + self.target_labels = None if trial["show_target_labels"]: + self.target_labels = vis.make_target_labels( + win, + trial["num_targets"], + trial["target_distance"], + trial["target_size"], + trial["target_labels"], + ) self.drawables.extend(self.target_labels) self.cursor = vis.make_cursor(win, trial["cursor_size"]) self.cursor.setPos(np.array([0.0, 0.0])) @@ -130,6 +131,10 @@ def __init__(self, experiment: Experiment, win: Optional[Window] = None): self.win = win if not experiment.trial_list: return + self.trial_managers = { + condition_index: TrialManager(self.win, trial) + for condition_index, trial in enumerate(experiment.trial_list) + } self.trial_handler = experiment.create_trialhandler() self.mouse = Mouse(visible=False, win=win) self.kb = Keyboard() @@ -164,6 +169,8 @@ def _do_trials(self) -> None: current_condition_first_trial_index = 0 current_condition_clock = Clock() current_condition_max_time = 0.0 + self.mouse.setPos((0.0, 0.0)) + current_cursor_pos = (0.0, 0.0) for trial in self.trial_handler: if self.trial_handler.thisIndex != current_condition_index: # starting a new set of conditions @@ -171,6 +178,7 @@ def _do_trials(self) -> None: current_condition_max_time = trial["condition_timeout"] current_condition_index = self.trial_handler.thisIndex current_condition_first_trial_index = self.trial_handler.thisTrialN + trial_manager = self.trial_managers[current_condition_index] condition_trial_indices[self.trial_handler.thisIndex].append( self.trial_handler.thisTrialN ) @@ -179,7 +187,9 @@ def _do_trials(self) -> None: or current_condition_clock.getTime() < current_condition_max_time ): # only do the trial if there is still time left for this condition - self._do_trial(trial) + current_cursor_pos = self._do_trial( + trial, trial_manager, current_cursor_pos + ) is_final_trial_of_block = ( len(condition_trial_indices[self.trial_handler.thisIndex]) == trial["weight"] @@ -207,24 +217,31 @@ def _do_splash_screen(self) -> None: win=self.win, ) - def _do_trial(self, trial: Dict[str, Any]) -> None: + def _do_trial( + self, + trial: Dict[str, Any], + trial_manager: TrialManager, + initial_cursor_pos: Tuple[float, float], + ) -> Tuple[float, float]: if trial["use_joystick"] and self.js is None: raise RuntimeError("Use joystick option is enabled, but no joystick found.") - tm = TrialManager(self.win, trial, self.rng) - self.mouse.setPos(tm.cursor.pos) + trial_manager.cursor.setPos(initial_cursor_pos) + trial_data = TrialData(trial, self.rng) self.win.recordFrameIntervals = True - tm.clock.reset() - vis.update_target_colors(tm.targets, trial["show_inactive_targets"], None) - if trial["show_target_labels"]: + trial_manager.clock.reset() + vis.update_target_colors( + trial_manager.targets, trial["show_inactive_targets"], None + ) + if trial["show_target_labels"] and trial_manager.target_labels is not None: vis.update_target_label_colors( - tm.target_labels, trial["show_inactive_targets"], None + trial_manager.target_labels, trial["show_inactive_targets"], None ) - for index in tm.data.target_indices: - self._do_target(trial, index, tm) + for index in trial_data.target_indices: + self._do_target(trial, index, trial_manager, trial_data) self.win.recordFrameIntervals = False if trial["automove_cursor_to_center"]: - tm.data.to_center_success = [True] * trial["num_targets"] - add_trial_data_to_trial_handler(tm.data, self.trial_handler) + trial_data.to_center_success = [True] * trial["num_targets"] + add_trial_data_to_trial_handler(trial_data, self.trial_handler) if trial["post_trial_delay"] > 0: vis.display_results( trial["post_trial_delay"], @@ -236,14 +253,17 @@ def _do_trial(self, trial: Dict[str, Any]) -> None: False, self.win, ) + return trial_manager.cursor.pos - def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> None: + def _do_target( + self, trial: Dict[str, Any], index: int, tm: TrialManager, trial_data: TrialData + ) -> None: minimum_window_for_flip = 1.0 / 60.0 mouse_pos = tm.cursor.pos stop_waiting_time = 0.0 stop_target_time = 0.0 if trial["fixed_target_intervals"]: - num_completed_targets = len(tm.data.to_target_timestamps) + num_completed_targets = len(trial_data.to_target_timestamps) stop_waiting_time = (num_completed_targets + 1) * trial["target_duration"] stop_target_time = stop_waiting_time + trial["target_duration"] for target_index in _get_target_indices(index, trial): @@ -264,7 +284,7 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non vis.update_target_colors( tm.targets, trial["show_inactive_targets"], None ) - if trial["show_target_labels"]: + if trial["show_target_labels"] and tm.target_labels is not None: vis.update_target_label_colors( tm.target_labels, trial["show_inactive_targets"], None ) @@ -296,17 +316,21 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non vis.update_target_colors( tm.targets, trial["show_inactive_targets"], target_index ) - if trial["show_target_labels"]: + if trial["show_target_labels"] and tm.target_labels is not None: vis.update_target_label_colors( tm.target_labels, trial["show_inactive_targets"], target_index ) if trial["play_sound"]: Sound("A", secs=0.3, blockSize=1024, stereo=True).play() if is_central_target: - tm.data.to_center_num_timestamps_before_visible.append(len(mouse_times)) + trial_data.to_center_num_timestamps_before_visible.append( + len(mouse_times) + ) else: - tm.data.target_pos.append(tm.targets.xys[target_index]) - tm.data.to_target_num_timestamps_before_visible.append(len(mouse_times)) + trial_data.target_pos.append(tm.targets.xys[target_index]) + trial_data.to_target_num_timestamps_before_visible.append( + len(mouse_times) + ) target_size = trial["target_size"] if is_central_target: target_size = trial["central_target_size"] @@ -351,15 +375,15 @@ def _do_target(self, trial: Dict[str, Any], index: int, tm: TrialManager) -> Non and tm.clock.getTime() + minimum_window_for_flip < stop_target_time ) if is_central_target: - tm.data.to_center_success.append(success) + trial_data.to_center_success.append(success) else: - tm.data.to_target_success.append(success) + trial_data.to_target_success.append(success) if is_central_target: - tm.data.to_center_timestamps.append(np.array(mouse_times)) - tm.data.to_center_mouse_positions.append(np.array(mouse_positions)) + trial_data.to_center_timestamps.append(np.array(mouse_times)) + trial_data.to_center_mouse_positions.append(np.array(mouse_positions)) else: - tm.data.to_target_timestamps.append(np.array(mouse_times)) - tm.data.to_target_mouse_positions.append(np.array(mouse_positions)) + trial_data.to_target_timestamps.append(np.array(mouse_times)) + trial_data.to_target_mouse_positions.append(np.array(mouse_positions)) def _clean_up_and_return(self, return_value: bool) -> bool: if self.win is not None and self.close_window_when_done: