# -*- coding: utf-8 -*-
#_____________________________________________________________________________
#
# Copyright (c) 2012-2013, Berlin Institute of Technology
# All rights reserved.
#
# Developed by: Philipp Meier <pmeier82@gmail.com>
#
# Neural Information Processing Group (NI)
# School for Electrical Engineering and Computer Science
# Berlin Institute of Technology
# MAR 5-6, Marchstr. 23, 10587 Berlin, Germany
# http://www.ni.tu-berlin.de/
#
# Repository: https://github.com/pmeier82/BOTMpy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal with the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimers.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimers in the documentation
# and/or other materials provided with the distribution.
# * Neither the names of Neural Information Processing Group (NI), Berlin
# Institute of Technology, nor the names of its contributors may be used to
# endorse or promote products derived from this Software without specific
# prior written permission.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# WITH THE SOFTWARE.
#_____________________________________________________________________________
#
# Acknowledgements:
# Philipp Meier <pmeier82@gmail.com>
#_____________________________________________________________________________
#
# Changelog:
# * <iso-date> <identity> :: <description>
#_____________________________________________________________________________
#
"""functions for spike sorting"""
__docformat__ = 'restructuredtext'
__all__ = [
'threshold_detection', 'merge_epochs', 'invert_epochs',
'epochs_from_binvec', 'epochs_from_spiketrain',
'epochs_from_spiketrain_set', 'chunk_data', 'extract_spikes',
'get_cut', 'snr_maha', 'snr_peak', 'snr_power', 'overlaps']
## IMPORTS
import scipy as sp
from .util import *
from .funcs_general import sortrows
## FUNCTIONS
# event detection
[docs]def threshold_detection(data, th, min_dist=1, mode='gt', find_max=True):
"""detect events by applying a threshold to the data
:type data: ndarray
:param data: the 2d-data to apply the threshold on. channels are in the
second dimension (columns).
Required
:type th: ndarray or list
:param th: list of threshold values, one value per channel in the `data`
Required
:type min_dist: int
:param min_dist: minimal distance two successive events have to be
separated in samples, else the event is ignored.
Default=1
:type mode: str
:param mode: one of 'gt' for greater than or 'lt' for less than. will
determine how the threshold is applied.
Default='gt'
:type find_max: bool
:param find_max: if True, will find the maximum for each event epoch, else
will find the start for each event epoch.
Default=True
:rtype: ndarray
:returns: event samples
"""
# checks
data = sp.asarray(data)
if data.ndim != 2:
if data.ndim == 1:
data = sp.atleast_2d(data).T
else:
raise ValueError('data.ndim != 2')
th = sp.asarray(th)
if th.ndim != 1:
raise ValueError('th.ndim != 1')
if th.size != data.shape[1]:
raise ValueError('thresholds have to match the data channel count')
if mode not in ['gt', 'lt']:
raise ValueError('unknown mode, use one of \'lt\' or \'gt\'')
if min_dist < 1:
min_dist = 1
# inits
rval = []
ep_func = {'gt': lambda d, t: epochs_from_binvec(d > t).tolist(),
'lt': lambda d, t: epochs_from_binvec(d < t).tolist(),
}[mode]
# per channel detection
for c in xrange(data.shape[1]):
epochs = ep_func(data[:, c], th[c])
if len(epochs) == 0:
continue
for e in xrange(len(epochs)):
rval.append(epochs[e][0])
if find_max is True:
rval[-1] += data[epochs[e][0]:epochs[e][1] + 1, c].argmax()
rval = sp.asarray(rval, dtype=INDEX_DTYPE)
# do we have events?
if rval.size == 0:
return rval
# drop event duplicates by sorting and checking for min_dist
rval.sort()
rval = rval[sp.diff(sp.concatenate(([0], rval))) >= min_dist]
# return
return rval
## epoch handling functions
[docs]def merge_epochs(*args, **kwargs):
"""for a set of epoch sets check if the combined set of epochs overlap
and merge to one set with no overlapping epochs and no epochs of negative
length.
:param args: arbitrary count of epoch sets [[start, end]]
:keyword min_dist: int - If present and greater than zero, this integer
will be taken as the minimum distance in between epochs that is
allowed. Should the gap in between two epochs smaller than min_dist,
they are merged including the gap. This might reduce the
segmentation of the data.
:returns: ndarray - merged epoch set [[start, end]]
"""
# checks
for item in args:
if not isinstance(item, (list, sp.ndarray)):
raise ValueError('wrong inputs! lists and ndarrays allowed')
# inits
epochs = sortrows(sp.vstack(args)).tolist()
if len(epochs) == 0:
return sp.zeros((0, 2), dtype=INDEX_DTYPE)
# rval_ovlp overlaps
rval_ovlp = [epochs.pop(0)]
k = 0
while len(epochs) > 0:
ep = epochs.pop(0)
if ep[0] <= rval_ovlp[k][1] - 1:
rval_ovlp[k] = [min(ep[0], rval_ovlp[k][0]),
max(ep[1], rval_ovlp[k][1])]
else:
k += 1
rval_ovlp.append(ep)
rval = rval_ovlp
# rval_ovlp epochs with gaps smaller than minimum distance
min_dist = int(kwargs.get('min_dist', 0))
if min_dist > 0:
rval_gaps = [rval_ovlp.pop(0)]
while len(rval_ovlp) > 0:
ep = rval_ovlp.pop(0)
if ep[0] - rval_gaps[-1][1] < min_dist:
rval_gaps[-1][1] = ep[1]
else:
rval_gaps.append(ep)
rval = rval_gaps
# return
rval = sp.asarray(rval, dtype=INDEX_DTYPE)
rval[rval[:, 0] < 0, :] = 0
rval = rval[rval[:, 1] - rval[:, 0] > 0, :]
return rval
[docs]def invert_epochs(epochs, end=None):
"""inverts epochs inverted
The first epoch will be mapped to [0, start] and the last will be mapped
to [end of last epoch, :end:]. Epochs that accidentally become negative
or zero-length will be omitted.
:type epochs: ndarray
:param epochs: epoch set to invert
:type end: int
:param end: If not None, it i taken for the end of the last epoch,
else max(index-dtype) is taken instead.
Default=None
:returns: ndarray - inverted epoch set
"""
# checks
if end is None:
end = sp.iinfo(INDEX_DTYPE).max
else:
end = INDEX_DTYPE.type(end)
# flip them
rval = sp.vstack((
sp.concatenate(([0], epochs[:, 1])),
sp.concatenate((epochs[:, 0], [end])))).T
return (rval[rval[:, 1] - rval[:, 0] > 0]).astype(INDEX_DTYPE)
[docs]def epochs_from_binvec(binvec):
"""returns the discrete epochs where the :binvec: is true
:type binvec: ndarray
:param binvec: one-domensinal boolean ndarray.
:returns: ndarray - epoch set where :binvec: is True [[start, end]]
"""
# early exit
if not binvec.any():
return sp.zeros((0, 2))
# calculate
output = sp.correlate(sp.concatenate(([0], binvec, [0])), [-1, 1], 'same')
return sp.vstack((
(output > 0).nonzero()[0] - 1,
(output < 0).nonzero()[0] - 2)).T
[docs]def epochs_from_spiketrain(st, cut, end=None, with_corrected_st=False):
"""yields epoch set, given a spiketrain and cut parameter
:type st: ndarray
:param st: spiketrains as 1d array
:type cut: tuple
:param cut: 2-tuple of cutting parameters, (cut_left,cut_right) spike
epochs will be generated by using cut_left and cut_right on the spike
time. If an int is given, a symmetric cut tuple is assumed.
:type end: int
:param end: to determine potential problems with epochs overlapping data
boundaries. If an event in the spiketrain is closer to 0 than :cut[0]:
or closer to :end: than :cut[1]: the corresponding epoch will be
omitted. If None, :end: will be set to max(INDEX_DTYPE)
Default=None
:type with_corrected_st: bool
:param with_corrected_st: if True, return the corrected spiketrain by
omitting spike events that cannot generate valid spike epochs given
the passed cut settings.
Default=False
:returns: ndarray - epoch set of valid spike epochs, and if
:with_corrected_st: is True additionally the corrected spike train
"""
# checks
st = sp.asarray(st)
cut = get_cut(cut)
if end is None:
end = sp.iinfo(INDEX_DTYPE).max
else:
end = INDEX_DTYPE.type(end)
# return the epochs for the spiketrain
st_ok = (st >= cut[0]) * (st < end - cut[1])
rval = sp.vstack((
st[st_ok] - cut[0],
st[st_ok] + cut[1])).T.astype(INDEX_DTYPE)
if with_corrected_st is True:
return rval, st[st_ok]
else:
return rval
[docs]def epochs_from_spiketrain_set(sts, cut, end=None):
"""yields epoch sets, given a spiketrain set and cut parameter
one set for each unit plus one for the noise epochs in a dict
:type sts: dict
:param sts: dict with the spiketrains for each unit in the set. none of
the units in the spiketrain set may have the key 'noise'!
:type cut: tuple
:param cut: 2-tuple of cutting parameters, (cut_left, cut_right) spike
epochs will be generated by using cu_left and cut_right on the spike
time. If an int is given, a symmetric cut tuple is assumed.
:param end: to determine potential problems with epochs overlapping data
boundaries. If an event in the spiketrain is closer to 0 than :cut[0]:
or closer to :end: than :cut[1]: the corresponding epoch will be
omitted. If None, :end: will be set to max(INDEX_DTYPE)
Default=None
:returns: dict - one epoch set per spike train plus the merged noise
epoch set.
"""
# inits and checks
if not isinstance(sts, dict):
raise ValueError('sts has to be a set of spiketrains in a dict')
rval = {}
# get the spiketrain epochs
for key in sts:
rval[key] = epochs_from_spiketrain(sts[key], cut, end=end)
# add the noise epochs
rval['noise'] = invert_epochs(merge_epochs(*rval.values()), end=end)
# return the epoch set
return rval
## spike and data extraction
[docs]def chunk_data(data, epochs=None, invert=False):
"""returns a generator of chunks from data given epochs
:type data: ndarray
:param data: signal data [[samples, channels]]
:type epochs: ndarray
:param epochs: epoch set, positive mask
:type invert: bool
:param invert: invert epochs, negative mask instead of positive mask
:returns: generator - data chunks as per :epochs:
"""
# checks
data = sp.asarray(data)
if data.ndim != 2:
data = sp.atleast_2d(data).T
if epochs is not None:
if epochs.ndim != 2:
raise ValueError('epochs has to be ndim=2 like [[start,end]]')
if invert is True and epochs is not None:
epochs = invert_epochs(epochs, end=data.shape[0])
if epochs is None or len(epochs) == 0:
epochs = [[0, data.shape[0]]]
# yield data chunks
for ep in epochs:
yield data[ep[0]:ep[1], :], list(ep)
[docs]def get_cut(tf, off=0):
"""cut 2-tuple (cut_left,cut_right) generating function
Used to generate epochs from events. Per default the epoch will be
placed symmetrically around the event sample. :off: can be used to
influence the placement. For odd tf values the extension of the
cut_right part will be favored.
:type tf: int
:param tf: length of the waveform in samples
:type off: int
:param off: offset for epoch start/end
Default=0
"""
if isinstance(tf, tuple):
if len(tf) == 2:
return tf[0] - int(off), tf[1] + int(off)
else:
raise ValueError('tuples have to be of length==2 for get_cut')
elif isinstance(tf, int):
return int(tf / 2.0) - int(off), int(tf / 2.0) + tf % 2 + int(off)
else:
raise TypeError('only int or tuple are allowed for get_cut')
## SNR functions - added by Felix 10. aug 2009
[docs]def snr_peak(waveforms, noise_var):
"""SNR from instantaneous variance
Definition of signal to noise ratio (SNR) as the ratio between the peak of
a waveforms and the noise standard deviation.
:type waveforms: ndarray
:param waveforms: waveform data (signal), one per row
:type noise_var: float
:param noise_var: instantaneous variance of the noise (noise)
:returns: ndarray - SNR per waveform
"""
return sp.absolute(waveforms).max(axis=1) / sp.sqrt(noise_var)
[docs]def snr_power(waveforms, noise_var):
"""SNR from signal energy
Definition of signal to noise ratio (SNR) using the waveforms energy as
defined by Rutishauser et al (2006)
:type waveforms: ndarray
:param waveforms: waveform data (signal), one per row
:type noise_var: float
:param noise_var: instantaneous variance of the noise (noise)
:returns: ndarray - SNR per waveform
"""
denom = waveforms.shape[1] * noise_var
return sp.sqrt((waveforms * waveforms).sum(axis=1) / denom)
[docs]def snr_maha(waveforms, invC, mu=None):
"""SNR from Mahalanobis distance (generalised euclidean distance)
Definition of signal to noise ratio (SNR) as derived from the Mahalanobis
distance. For C=eye this is equivalent to snr_power.
:type waveforms: ndarray
:param waveforms: waveform data (signal), one per row
:type invC: ndarray
:param invC: inverted noise covariance matrix (a block toeplitz matrix)
:type mu: ndarray
:param mu: mean correction. Usually we assume zero-mean waveforms,
so if this is None it will be ignored.
Default=None
:returns: ndarray - SNR per waveform
"""
# inits and checks
n, dim = waveforms.shape
if dim != invC.shape[0] or dim != invC.shape[1]:
raise ValueError('dimension mismatch for waveforms and covariance')
rval = sp.zeros(n)
# correct for mu
if mu is not None:
if mu.shape != (dim,):
raise ValueError('dimension mismatch for waveforms and mu')
waveforms -= mu
# compute
for i in xrange(n):
rval[i] = sp.dot(sp.dot(waveforms[i], invC), waveforms[i].T)
rval[i] /= float(dim)
return sp.sqrt(rval)
## data processing algorithms
[docs]def overlaps(sts, window):
"""produces dict of boolean sequences indicating for all spikes in all
spike trains in :sts: if it participates in an overlap event.
:type sts: dict
:param sts: spike train set
:type window: int
:param window: overlap window size
:returns: dict - boolean spike train set
"""
# inits
n = len(sts)
ovlp, ovlp_nums = {}, {}
for k in sts.keys():
ovlp[k] = sp.zeros(sts[k].shape, dtype=bool)
ovlp_nums[k] = 0
# run over all pairs of spike trains in sts
for i in xrange(n):
i_ = sts.keys()[i]
trainI = sts[i_]
for j in xrange(i + 1, n):
# for every pair run over all spikes in i and check if a spike
# in j overlaps
j_ = sts.keys()[j]
trainJ = sts[j_]
for spkI, spk in enumerate(trainI):
d = trainJ - spk
overlap_indices = sp.absolute(d) < window
if not sum(overlap_indices):
continue
ovlp[i_][spkI] = True
ovlp[j_][overlap_indices] = True
return ovlp, ovlp_nums
## MAIN
if __name__ == '__main__':
pass