diff --git a/brainbox/io/one.py b/brainbox/io/one.py index 175750384..f832fb04c 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -40,6 +40,7 @@ SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths'] CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids'] +WAVEFORMS_ATTRIBUTES = ['templates'] def load_lfp(eid, one=None, dataset_types=None, **kwargs): @@ -128,6 +129,10 @@ def _channels_alf2bunch(channels, brain_regions=None): 'axial_um': channels['localCoordinates'][:, 1], 'lateral_um': channels['localCoordinates'][:, 0], } + # here if we have some extra keys, they will carry over to the next dictionary + for k in channels: + if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']: + channels_[k] = channels[k] if brain_regions: channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym'] return channels_ @@ -851,14 +856,14 @@ def _load_object(self, *args, **kwargs): @staticmethod def _get_attributes(dataset_types): """returns attributes to load for spikes and clusters objects""" - if dataset_types is None: - return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES - else: - spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] - cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] - spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) - cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) - return spike_attributes, cluster_attributes + dataset_types = [] if dataset_types is None else dataset_types + spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp] + spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes)) + cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl] + cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes)) + waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl] + waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes)) + return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes} def _get_spike_sorting_collection(self, spike_sorter='pykilosort'): """ @@ -891,7 +896,7 @@ def get_version(self, spike_sorter='pykilosort'): return dset[0]['version'] if len(dset) else 'unknown' def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None, - missing='raise', **kwargs): + attribute=None, missing='raise', **kwargs): """ Downloads an ALF object :param obj: object name, str between 'spikes', 'clusters' or 'channels' @@ -899,6 +904,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_ :param dataset_types: list of extra dataset types, for example ['spikes.samples'] :param collection: string specifiying the collection, for example 'alf/probe01/pykilosort' :param kwargs: additional arguments to be passed to one.api.One.load_object + :param attribute: list of attributes to load for the object :param missing: 'raise' (default) or 'ignore' :return: """ @@ -907,8 +913,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_ self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter) collection = collection or self.collection _logger.debug(f"loading spike sorting object {obj} from {collection}") - spike_attributes, cluster_attributes = self._get_attributes(dataset_types) - attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes} + attributes = self._get_attributes(dataset_types) try: self.files[obj] = self.one.load_object( self.eid, obj=obj, attribute=attributes.get(obj, None), @@ -986,11 +991,10 @@ def load_channels(self, **kwargs): """ # we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore') - if 'electrodeSites' in self.files: - channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) - else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology - self.download_spike_sorting_object(obj='channels', **kwargs) - channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) + self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs) + channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards) + if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails + channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards) if 'brainLocationIds_ccf_2017' not in channels: _logger.debug(f"loading channels from alyx for {self.files['channels']}") _channels, self.histology = _load_channel_locations_traj( @@ -1000,7 +1004,7 @@ def load_channels(self, **kwargs): else: channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions) self.histology = 'alf' - return channels + return Bunch(channels) def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs): """ diff --git a/ibllib/pipes/ephys_tasks.py b/ibllib/pipes/ephys_tasks.py index 718cd5a3c..15530b49a 100644 --- a/ibllib/pipes/ephys_tasks.py +++ b/ibllib/pipes/ephys_tasks.py @@ -566,17 +566,17 @@ class SpikeSorting(base_tasks.EphysTask, CellQCMixin): force = True SHELL_SCRIPT = Path.home().joinpath( - "Documents/PYTHON/iblscripts/deploy/serverpc/kilosort2/run_pykilosort.sh" + "Documents/PYTHON/iblscripts/deploy/serverpc/iblsorter/run_iblsorter.sh" ) - SPIKE_SORTER_NAME = 'pykilosort' - PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/pykilosort') + SPIKE_SORTER_NAME = 'iblsorter' + PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter') @property def signature(self): signature = { 'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True), - ('*ap.cbin', f'{self.device_collection}/{self.pname}', True), - ('*ap.ch', f'{self.device_collection}/{self.pname}', True), + ('*ap.*bin', f'{self.device_collection}/{self.pname}', True), + ('*ap.ch', f'{self.device_collection}/{self.pname}', False), ('*sync.npy', f'{self.device_collection}/{self.pname}', True)], 'output_files': [('spike_sorting_pykilosort.log', f'spike_sorters/pykilosort/{self.pname}', True), ('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}', True), @@ -591,14 +591,13 @@ def _sample2v(ap_file): return s2v["ap"][0] @staticmethod - def _fetch_pykilosort_version(repo_path): + def _fetch_iblsorter_version(repo_path): try: - import pykilosort - return f"pykilosort_{pykilosort.__version__}" + import iblsorter + return f"iblsorter_{iblsorter.__version__}" except ImportError: - _logger.info('Pykilosort not in environment, trying to locate the repository') - init_file = Path(repo_path).joinpath('pykilosort', '__init__.py') - version = SpikeSorting._fetch_ks2_commit_hash(repo_path) # default + _logger.info('IBL-sorter not in environment, trying to locate the repository') + init_file = Path(repo_path).joinpath('ibl-sorter', '__init__.py') try: with open(init_file) as fid: lines = fid.readlines() @@ -607,10 +606,10 @@ def _fetch_pykilosort_version(repo_path): version = line.split('=')[-1].strip().replace('"', '').replace("'", '') except Exception: pass - return f"pykilosort_{version}" + return f"iblsorter_{version}" @staticmethod - def _fetch_pykilosort_run_version(log_file): + def _fetch_iblsorter_run_version(log_file): """ Parse the following line (2 formats depending on version) from the log files to get the version '\x1b[0m15:39:37.919 [I] ibl:90 Starting Pykilosort version ibl_1.2.1, output in gnagga^[[0m\n' @@ -623,36 +622,21 @@ def _fetch_pykilosort_run_version(log_file): version = re.sub('\\^[[0-9]+m', '', version.group(1)) # removes the coloring tags return version - @staticmethod - def _fetch_ks2_commit_hash(repo_path): - command2run = f"git --git-dir {repo_path}/.git rev-parse --verify HEAD" - process = subprocess.Popen( - command2run, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE - ) - info, error = process.communicate() - if process.returncode != 0: - _logger.error( - f"Can't fetch pykilsort commit hash, will still attempt to run \n" - f"Error: {error.decode('utf-8')}" - ) - return "" - return info.decode("utf-8").strip() - - def _run_pykilosort(self, ap_file): + def _run_iblsort(self, ap_file): """ Runs the ks2 matlab spike sorting for one probe dataset the raw spike sorting output is in session_path/spike_sorters/{self.SPIKE_SORTER_NAME}/probeXX folder (discontinued support for old spike sortings in the probe folder <1.5.5) :return: path of the folder containing ks2 spike sorting output """ - self.version = self._fetch_pykilosort_version(self.PYKILOSORT_REPO) + self.version = self._fetch_iblsorter_version(self.PYKILOSORT_REPO) label = ap_file.parts[-2] # this is usually the probe name sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, label) self.FORCE_RERUN = False if not self.FORCE_RERUN: log_file = sorter_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log") if log_file.exists(): - run_version = self._fetch_pykilosort_run_version(log_file) + run_version = self._fetch_iblsorter_run_version(log_file) if packaging.version.parse(run_version) >= packaging.version.parse('1.7.0'): _logger.info(f"Already ran: spike_sorting_{self.SPIKE_SORTER_NAME}.log" f" found in {sorter_dir}, skipping.") @@ -673,8 +657,8 @@ def _run_pykilosort(self, ap_file): check_nvidia_driver() try: # if pykilosort is in the environment, use the installed version within the task - import pykilosort.ibl # noqa - pykilosort.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=temp_dir) + import iblsorter.ibl # noqa + iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=temp_dir) except ImportError: command2run = f"{self.SHELL_SCRIPT} {ap_file} {temp_dir}" _logger.info(command2run) @@ -717,7 +701,7 @@ def _run(self): assert len(ap_files) == 1, f"Several bin files found for the same probe {ap_files}" ap_file, label = ap_files[0] out_files = [] - ks2_dir = self._run_pykilosort(ap_file) # runs ks2, skips if it already ran + ks2_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran probe_out_path = self.session_path.joinpath("alf", label, self.SPIKE_SORTER_NAME) shutil.rmtree(probe_out_path, ignore_errors=True) probe_out_path.mkdir(parents=True, exist_ok=True)