Skip to content

Commit

Permalink
iblsort release (#792)
Browse files Browse the repository at this point in the history
* spike sorting loader - load waveforms low level changes

* allows the loading of additional channel object attributes (qc labels)

* wip iblsort task

---------

Co-authored-by: olivier <olivier.winter@hotmail.fr>
Co-authored-by: chris-langfield <christopher.langfield@internationalbrainlab.org>
  • Loading branch information
3 people committed Jun 28, 2024
1 parent db14a3e commit 8e5d6f8
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 51 deletions.
38 changes: 21 additions & 17 deletions brainbox/io/one.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

SPIKES_ATTRIBUTES = ['clusters', 'times', 'amps', 'depths']
CLUSTERS_ATTRIBUTES = ['channels', 'depths', 'metrics', 'uuids']
WAVEFORMS_ATTRIBUTES = ['templates']


def load_lfp(eid, one=None, dataset_types=None, **kwargs):
Expand Down Expand Up @@ -128,6 +129,10 @@ def _channels_alf2bunch(channels, brain_regions=None):
'axial_um': channels['localCoordinates'][:, 1],
'lateral_um': channels['localCoordinates'][:, 0],
}
# here if we have some extra keys, they will carry over to the next dictionary
for k in channels:
if k not in list(channels_.keys()) + ['mlapdv', 'brainLocationIds_ccf_2017', 'localCoordinates']:
channels_[k] = channels[k]
if brain_regions:
channels_['acronym'] = brain_regions.get(channels_['atlas_id'])['acronym']
return channels_
Expand Down Expand Up @@ -851,14 +856,14 @@ def _load_object(self, *args, **kwargs):
@staticmethod
def _get_attributes(dataset_types):
"""returns attributes to load for spikes and clusters objects"""
if dataset_types is None:
return SPIKES_ATTRIBUTES, CLUSTERS_ATTRIBUTES
else:
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
return spike_attributes, cluster_attributes
dataset_types = [] if dataset_types is None else dataset_types
spike_attributes = [sp.split('.')[1] for sp in dataset_types if 'spikes.' in sp]
spike_attributes = list(set(SPIKES_ATTRIBUTES + spike_attributes))
cluster_attributes = [cl.split('.')[1] for cl in dataset_types if 'clusters.' in cl]
cluster_attributes = list(set(CLUSTERS_ATTRIBUTES + cluster_attributes))
waveform_attributes = [cl.split('.')[1] for cl in dataset_types if 'waveforms.' in cl]
waveform_attributes = list(set(WAVEFORMS_ATTRIBUTES + waveform_attributes))
return {'spikes': spike_attributes, 'clusters': cluster_attributes, 'waveforms': waveform_attributes}

def _get_spike_sorting_collection(self, spike_sorter='pykilosort'):
"""
Expand Down Expand Up @@ -891,14 +896,15 @@ def get_version(self, spike_sorter='pykilosort'):
return dset[0]['version'] if len(dset) else 'unknown'

def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_types=None, collection=None,
missing='raise', **kwargs):
attribute=None, missing='raise', **kwargs):
"""
Downloads an ALF object
:param obj: object name, str between 'spikes', 'clusters' or 'channels'
:param spike_sorter: (defaults to 'pykilosort')
:param dataset_types: list of extra dataset types, for example ['spikes.samples']
:param collection: string specifiying the collection, for example 'alf/probe01/pykilosort'
:param kwargs: additional arguments to be passed to one.api.One.load_object
:param attribute: list of attributes to load for the object
:param missing: 'raise' (default) or 'ignore'
:return:
"""
Expand All @@ -907,8 +913,7 @@ def download_spike_sorting_object(self, obj, spike_sorter='pykilosort', dataset_
self.collection = self._get_spike_sorting_collection(spike_sorter=spike_sorter)
collection = collection or self.collection
_logger.debug(f"loading spike sorting object {obj} from {collection}")
spike_attributes, cluster_attributes = self._get_attributes(dataset_types)
attributes = {'spikes': spike_attributes, 'clusters': cluster_attributes}
attributes = self._get_attributes(dataset_types)
try:
self.files[obj] = self.one.load_object(
self.eid, obj=obj, attribute=attributes.get(obj, None),
Expand Down Expand Up @@ -986,11 +991,10 @@ def load_channels(self, **kwargs):
"""
# we do not specify the spike sorter on purpose here: the electrode sites do not depend on the spike sorting
self.download_spike_sorting_object(obj='electrodeSites', collection=f'alf/{self.pname}', missing='ignore')
if 'electrodeSites' in self.files:
channels = self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
else: # otherwise, we try to load the channel object from the spike sorting folder - this may not contain histology
self.download_spike_sorting_object(obj='channels', **kwargs)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
self.download_spike_sorting_object(obj='channels', missing='ignore', **kwargs)
channels = self._load_object(self.files['channels'], wildcards=self.one.wildcards)
if 'electrodeSites' in self.files: # if common dict keys, electrodeSites prevails
channels = channels | self._load_object(self.files['electrodeSites'], wildcards=self.one.wildcards)
if 'brainLocationIds_ccf_2017' not in channels:
_logger.debug(f"loading channels from alyx for {self.files['channels']}")
_channels, self.histology = _load_channel_locations_traj(
Expand All @@ -1000,7 +1004,7 @@ def load_channels(self, **kwargs):
else:
channels = _channels_alf2bunch(channels, brain_regions=self.atlas.regions)
self.histology = 'alf'
return channels
return Bunch(channels)

def load_spike_sorting(self, spike_sorter='pykilosort', revision=None, enforce_version=True, good_units=False, **kwargs):
"""
Expand Down
52 changes: 18 additions & 34 deletions ibllib/pipes/ephys_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,17 +566,17 @@ class SpikeSorting(base_tasks.EphysTask, CellQCMixin):
force = True

SHELL_SCRIPT = Path.home().joinpath(
"Documents/PYTHON/iblscripts/deploy/serverpc/kilosort2/run_pykilosort.sh"
"Documents/PYTHON/iblscripts/deploy/serverpc/iblsorter/run_iblsorter.sh"
)
SPIKE_SORTER_NAME = 'pykilosort'
PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/pykilosort')
SPIKE_SORTER_NAME = 'iblsorter'
PYKILOSORT_REPO = Path.home().joinpath('Documents/PYTHON/SPIKE_SORTING/ibl-sorter')

@property
def signature(self):
signature = {
'input_files': [('*ap.meta', f'{self.device_collection}/{self.pname}', True),
('*ap.cbin', f'{self.device_collection}/{self.pname}', True),
('*ap.ch', f'{self.device_collection}/{self.pname}', True),
('*ap.*bin', f'{self.device_collection}/{self.pname}', True),
('*ap.ch', f'{self.device_collection}/{self.pname}', False),
('*sync.npy', f'{self.device_collection}/{self.pname}', True)],
'output_files': [('spike_sorting_pykilosort.log', f'spike_sorters/pykilosort/{self.pname}', True),
('_iblqc_ephysTimeRmsAP.rms.npy', f'{self.device_collection}/{self.pname}', True),
Expand All @@ -591,14 +591,13 @@ def _sample2v(ap_file):
return s2v["ap"][0]

@staticmethod
def _fetch_pykilosort_version(repo_path):
def _fetch_iblsorter_version(repo_path):
try:
import pykilosort
return f"pykilosort_{pykilosort.__version__}"
import iblsorter
return f"iblsorter_{iblsorter.__version__}"
except ImportError:
_logger.info('Pykilosort not in environment, trying to locate the repository')
init_file = Path(repo_path).joinpath('pykilosort', '__init__.py')
version = SpikeSorting._fetch_ks2_commit_hash(repo_path) # default
_logger.info('IBL-sorter not in environment, trying to locate the repository')
init_file = Path(repo_path).joinpath('ibl-sorter', '__init__.py')
try:
with open(init_file) as fid:
lines = fid.readlines()
Expand All @@ -607,10 +606,10 @@ def _fetch_pykilosort_version(repo_path):
version = line.split('=')[-1].strip().replace('"', '').replace("'", '')
except Exception:
pass
return f"pykilosort_{version}"
return f"iblsorter_{version}"

@staticmethod
def _fetch_pykilosort_run_version(log_file):
def _fetch_iblsorter_run_version(log_file):
"""
Parse the following line (2 formats depending on version) from the log files to get the version
'\x1b[0m15:39:37.919 [I] ibl:90 Starting Pykilosort version ibl_1.2.1, output in gnagga^[[0m\n'
Expand All @@ -623,36 +622,21 @@ def _fetch_pykilosort_run_version(log_file):
version = re.sub('\\^[[0-9]+m', '', version.group(1)) # removes the coloring tags
return version

@staticmethod
def _fetch_ks2_commit_hash(repo_path):
command2run = f"git --git-dir {repo_path}/.git rev-parse --verify HEAD"
process = subprocess.Popen(
command2run, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
info, error = process.communicate()
if process.returncode != 0:
_logger.error(
f"Can't fetch pykilsort commit hash, will still attempt to run \n"
f"Error: {error.decode('utf-8')}"
)
return ""
return info.decode("utf-8").strip()

def _run_pykilosort(self, ap_file):
def _run_iblsort(self, ap_file):
"""
Runs the ks2 matlab spike sorting for one probe dataset
the raw spike sorting output is in session_path/spike_sorters/{self.SPIKE_SORTER_NAME}/probeXX folder
(discontinued support for old spike sortings in the probe folder <1.5.5)
:return: path of the folder containing ks2 spike sorting output
"""
self.version = self._fetch_pykilosort_version(self.PYKILOSORT_REPO)
self.version = self._fetch_iblsorter_version(self.PYKILOSORT_REPO)
label = ap_file.parts[-2] # this is usually the probe name
sorter_dir = self.session_path.joinpath("spike_sorters", self.SPIKE_SORTER_NAME, label)
self.FORCE_RERUN = False
if not self.FORCE_RERUN:
log_file = sorter_dir.joinpath(f"spike_sorting_{self.SPIKE_SORTER_NAME}.log")
if log_file.exists():
run_version = self._fetch_pykilosort_run_version(log_file)
run_version = self._fetch_iblsorter_run_version(log_file)
if packaging.version.parse(run_version) >= packaging.version.parse('1.7.0'):
_logger.info(f"Already ran: spike_sorting_{self.SPIKE_SORTER_NAME}.log"
f" found in {sorter_dir}, skipping.")
Expand All @@ -673,8 +657,8 @@ def _run_pykilosort(self, ap_file):
check_nvidia_driver()
try:
# if pykilosort is in the environment, use the installed version within the task
import pykilosort.ibl # noqa
pykilosort.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=temp_dir)
import iblsorter.ibl # noqa
iblsorter.ibl.run_spike_sorting_ibl(bin_file=ap_file, scratch_dir=temp_dir)
except ImportError:
command2run = f"{self.SHELL_SCRIPT} {ap_file} {temp_dir}"
_logger.info(command2run)
Expand Down Expand Up @@ -717,7 +701,7 @@ def _run(self):
assert len(ap_files) == 1, f"Several bin files found for the same probe {ap_files}"
ap_file, label = ap_files[0]
out_files = []
ks2_dir = self._run_pykilosort(ap_file) # runs ks2, skips if it already ran
ks2_dir = self._run_iblsort(ap_file) # runs the sorter, skips if it already ran
probe_out_path = self.session_path.joinpath("alf", label, self.SPIKE_SORTER_NAME)
shutil.rmtree(probe_out_path, ignore_errors=True)
probe_out_path.mkdir(parents=True, exist_ok=True)
Expand Down

0 comments on commit 8e5d6f8

Please sign in to comment.