Source code for botmpy.nodes.linear_filter

# -*- coding: utf-8 -*-
#_____________________________________________________________________________
#
# Copyright (c) 2012-2013, Berlin Institute of Technology
# All rights reserved.
#
# Developed by:	Philipp Meier <pmeier82@gmail.com>
#
#               Neural Information Processing Group (NI)
#               School for Electrical Engineering and Computer Science
#               Berlin Institute of Technology
#               MAR 5-6, Marchstr. 23, 10587 Berlin, Germany
#               http://www.ni.tu-berlin.de/
#
# Repository:   https://github.com/pmeier82/BOTMpy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal with the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimers.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimers in the documentation
#   and/or other materials provided with the distribution.
# * Neither the names of Neural Information Processing Group (NI), Berlin
#   Institute of Technology, nor the names of its contributors may be used to
#   endorse or promote products derived from this Software without specific
#   prior written permission.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# WITH THE SOFTWARE.
#_____________________________________________________________________________
#
# Acknowledgements:
#   Philipp Meier <pmeier82@gmail.com>
#_____________________________________________________________________________
#
# Changelog:
#   * <iso-date> <identity> :: <description>
#_____________________________________________________________________________
#

"""filter classes for linear filters in the time domain"""
__docformat__ = "restructuredtext"
__all__ = ["FilterError", "FilterNode", "MatchedFilterNode", "NormalisedMatchedFilterNode"]

## IMPORTS

import scipy as sp
from collections import deque
from .base_nodes import Node
from ..common import mcvec_from_conc, mcvec_to_conc, TimeSeriesCovE, MxRingBuffer, snr_maha
from ..mcfilter import mcfilter_hist

## CLASSES

[docs]class FilterError(Exception): pass
[docs]class FilterNode(Node): """linear filter in the time domain This node applies a linear filter to the data and returns the filtered data. The derivation of the filter (f) from the pattern (xi) is specified in the implementing subclass via the 'filter_calculation' classmethod. The template will be averaged from a ringbuffer of observations. The covariance matrix is supplied from an external covariance estimator. """ ## special def __init__(self, tf, nc, ce, chan_set=None, rb_cap=None, dtype=None): """ :type tf: int :param tf: template length in samples :type nc: int :type nc: template channel count :type ce: TimeSeriesCovE :param ce: covariance estimator instance :type chan_set: tuple :param chan_set: tuple of int designating the subset of channels this filter operates on. Default=tuple(range(nc)) :type rb_cap: int :param rb_cap: capacity of the xi buffer Default=350 :type dtype: dtype resolvable :param dtype: determines the internal dtype Default=None """ # checks if tf <= 0: raise ValueError("tf <= 0") if nc <= 0: raise ValueError("nc <= 0") if chan_set is None: chan_set = tuple(range(nc)) # super super(FilterNode, self).__init__(output_dim=1, dtype=dtype) # members self._xi_buf = MxRingBuffer(capacity=rb_cap or 350, dimension=(tf, nc), dtype=self.dtype) self._ce = None self._f = None self._hist = sp.zeros((tf - 1, nc), dtype=self.dtype) self._chan_set = tuple(sorted(chan_set)) self.ce = ce self.active = True ## properties - not settable
[docs] def get_xi(self): return self._xi_buf.mean()
xi = property(get_xi, doc="template (multi-channeled)")
[docs] def get_xi_conc(self): return mcvec_to_conc(self._xi_buf.mean())
xi_conc = property(get_xi_conc, doc="template (concatenated)")
[docs] def get_tf(self): return self._xi_buf.dimension[0]
tf = property(get_tf, doc="temporal extend [sample]")
[docs] def get_nc(self): return self._xi_buf.dimension[1]
nc = property(get_nc, doc="number of channels")
[docs] def get_f(self): return self._f
f = property(get_f, doc="filter (multi-channeled)")
[docs] def get_f_conc(self): return mcvec_to_conc(self._f)
f_conc = property(get_f_conc, doc="filter (concatenated)") ## properties settable
[docs] def get_ce(self): return self._ce
[docs] def set_ce(self, value): if not isinstance(value, TimeSeriesCovE): raise TypeError("ce is not of type TimeSeriesCovE") if value.get_tf_max() < self.tf: raise ValueError("tf_max of ce < than filter tf") if value.get_nc() < self.nc: raise ValueError("nc of cov_est < than filter nc") if value.is_initialised is False: raise ValueError("ce not initialised!") self._ce = value if len(self._xi_buf) > 0: self.calc_filter()
ce = property(get_ce, set_ce, doc="covariance estimator")
[docs] def get_snr(self): return snr_maha( sp.array([mcvec_to_conc(self.xi)]), self._ce.get_icmx(tf=self.tf, chan_set=self._chan_set))[0]
snr = property(get_snr, doc="signal to noise ratio (mahalanobis distance)") ## mdp.Node interface def _execute(self, x): """apply the filter to data""" # DOC: sp.ascontiguousarray is here to assert continuous memory for # the array. This important for ctypes/cython implementations. x_in = sp.ascontiguousarray(x, dtype=self.dtype)[:, self._chan_set] rval, self._hist = mcfilter_hist(x_in, self._f, self._hist) return rval
[docs] def is_invertible(self): return False
[docs] def is_trainable(self): return False
def _get_supported_dtypes(self): return ["float32", "float64"] ## filter interface
[docs] def append_xi_buf(self, wf, recalc=False): """append one waveform to the xi_buffer :type wf: ndarray :param wf: wavefom data [self.tf, self.nc] :type recalc: bool :param recalc: if True, call self.calc_filter after appending """ self._xi_buf.append(wf) if recalc is True: self.calc_filter()
[docs] def extend_xi_buf(self, wfs, recalc=False): """append an iterable of waveforms to the xi_buffer :type wfs: iterable of ndarray :param wfs: wavefom data [n][self.tf, self.nc] :type recalc: bool :param recalc: if True, call self.calc_filter after extending """ self._xi_buf.extend(wfs) if recalc is True: self.calc_filter()
[docs] def fill_xi_buf(self, wf, recalc=False): """fill all of the xi_buffer with wf :Parameters: wf : ndarrsay ndarray of shape (self.tf, self.nc) recalc : bool if True, call self.calc_filter after appending """ self._xi_buf.fill(wf) if recalc is True: self.calc_filter()
[docs] def reset_history(self): """sets the history to all zeros""" self._hist[:] = 0.0 ## plotting methods # XXX: delete plotting functions
[docs] def plot_buffer_to_axis(self, axis=None, idx=None, limits=None): """plots the current buffer on the passed axis handle""" try: from spikeplot import plt, COLOURS except ImportError: return None # init ax = axis if ax is None: f = plt.figure() ax = f.add_subplot(111) col = 'k' if idx is not None: col = COLOURS[idx % len(COLOURS)] spks = self._xi_buf[:] n, s, c = spks.shape spks = spks.swapaxes(2, 1).reshape(n, s * c) # plot ax.plot(spks.T, color='gray') ax.plot(spks.mean(axis=0), color=col, lw=2) for i in xrange(1, c): ax.axvline((self.tf * i), ls="dashed", color='y') ax.set_xlim(0, s * c) if limits is not None: ax.set_ylim(*limits) ax.set_xlabel("time [samples]") ax.set_ylabel("amplitude [mV]") return spks.min(), spks.max() ## filter calculation
[docs] def calc_filter(self): """initiate a calculation of the filter""" self._f = self.filter_calculation(self.xi, self._ce, self._chan_set)
@classmethod
[docs] def filter_calculation(cls, xi, ce, cs, *args, **kwargs): """ABSTRACT METHOD FOR FILTER CALCULATION Implement this in a meaningful way in any subclass. The method should return the filter given the multi-channeled template `xi`, the covariance estimator `ce` and the channel set `cs` plus any number of optional arguments and keywords. The filter is usually the same shape as the pattern `xi`. """ raise NotImplementedError ## special methods
def __str__(self): return '%s(tf=%s,nc=%s,cs=%s)' % (self.__class__.__name__, self.tf, self.nc, str(self._chan_set))
[docs]class MatchedFilterNode(FilterNode): """matched filters in the time domain optimise the signal to noise ratio (SNR) of the matched pattern with respect to covariance matrix describing the noise background (deconvolution). """ @classmethod
[docs] def filter_calculation(cls, xi, ce, cs, *args, **kwargs): tf, nc = xi.shape ## don't do loading for now # params = {'tf':tf, 'chan_set':cs} # if ce.is_cond_ok(**params) is True: # icmx = ce.get_icmx(**params) # else: # icmx = ce.get_icmx_loaded(**params) ## icmx = ce.get_icmx(tf=tf, chan_set=cs) f = sp.dot(mcvec_to_conc(xi), icmx) return sp.ascontiguousarray(mcvec_from_conc(f, nc=nc), dtype=xi.dtype)
[docs]class NormalisedMatchedFilterNode(FilterNode): """matched filters in the time domain optimise the signal to noise ratio (SNR) of the matched pattern with respect to covariance matrix describing the noise background (deconvolution). Here the deconvolution output is normalised s.t. the response of the pattern is peak of unit amplitude. """ @classmethod
[docs] def filter_calculation(cls, xi, ce, cs, *args, **kwargs): tf, nc = xi.shape ## don't do loading for now # params = {'tf':tf, 'chan_set':cs} # if ce.is_cond_ok(**params) is True: # icmx = ce.get_icmx(**params) # else: # icmx = ce.get_icmx_loaded(**params) ## icmx = ce.get_icmx(tf=tf, chan_set=cs) f = sp.dot(mcvec_to_conc(xi), icmx) norm_factor = sp.dot(mcvec_to_conc(xi), f) return sp.ascontiguousarray(mcvec_from_conc(f / norm_factor, nc=nc), dtype=sp.float32)
class RateEstimator(object): def __init__(self, *args, **kwargs): self._spike_count = deque() self._sample_count = deque() self._n_sample_max = int(kwargs.get('n_sample_max', 2500000)) self._sample_rate = float(kwargs.get('sample_rate', 32000.0)) self._filled = False def estimate(self): try: return self._sample_rate * sum(self._spike_count) / \ float(self.sample_size) except ZeroDivisionError: return 0.0 def observation(self, nobs, tlen): self._spike_count.append(nobs) self._sample_count.append(tlen) while sum(self._sample_count) > self._n_sample_max: self._filled = True self._spike_count.popleft() self._sample_count.popleft() def reset(self): self._spike_count.clear() self._sample_count.clear() self._filled = False def is_filled(self): return self._filled filled = property(is_filled) def get_sample_size(self): return sum(self._sample_count) sample_size = property(get_sample_size) class REMF(MatchedFilterNode): def __init__(self, *args, **kwargs): srate = kwargs.pop('sample_rate', 32000.0) nsample = kwargs.pop('n_sample_max', sp.inf) super(REMF, self).__init__(*args, **kwargs) self.rate = RateEstimator(srate, nsample) class RENMF(NormalisedMatchedFilterNode): def __init__(self, *args, **kwargs): srate = kwargs.pop('sample_rate', 32000.0) nsample = kwargs.pop('n_sample_max', sp.inf) super(RENMF, self).__init__(*args, **kwargs) self.rate = RateEstimator(srate, nsample) ## MAIN if __name__ == '__main__': pass

Project Versions

This Page