Source code for botmpy.nodes.filter_bank

# -*- coding: utf-8 -*-
#_____________________________________________________________________________
#
# Copyright (c) 2012-2013, Berlin Institute of Technology
# All rights reserved.
#
# Developed by:	Philipp Meier <pmeier82@gmail.com>
#
#               Neural Information Processing Group (NI)
#               School for Electrical Engineering and Computer Science
#               Berlin Institute of Technology
#               MAR 5-6, Marchstr. 23, 10587 Berlin, Germany
#               http://www.ni.tu-berlin.de/
#
# Repository:   https://github.com/pmeier82/BOTMpy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal with the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# * Redistributions of source code must retain the above copyright notice,
#   this list of conditions and the following disclaimers.
# * Redistributions in binary form must reproduce the above copyright notice,
#   this list of conditions and the following disclaimers in the documentation
#   and/or other materials provided with the distribution.
# * Neither the names of Neural Information Processing Group (NI), Berlin
#   Institute of Technology, nor the names of its contributors may be used to
#   endorse or promote products derived from this Software without specific
#   prior written permission.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# WITH THE SOFTWARE.
#_____________________________________________________________________________
#
# Acknowledgements:
#   Philipp Meier <pmeier82@gmail.com>
#_____________________________________________________________________________
#
# Changelog:
#   * <iso-date> <identity> :: <description>
#_____________________________________________________________________________
#

"""implementation of a filter bank consisting of a set of filters"""

from Queue import Queue
import ctypes as ctypes
import platform
from threading import Thread

__docformat__ = 'restructuredtext'
__all__ = ['FilterBankError', 'FilterBankNode']

## IMPORTS

import logging
import scipy as sp
from .base_nodes import Node
from .linear_filter import FilterNode, REMF
from ..common import (TimeSeriesCovE, xi_vs_f, VERBOSE)

## CLASSES

[docs]class FilterBankError(Exception): pass
[docs]class FilterBankNode(Node): """abstract class that handles filter instances and their outputs All filters constituting the filter bank have to be of the same temporal extend (Tf) and process the same channel set. There are two different index sets. One is abbreviated "idx" and one "key". The "idx" the index of filter in `self.bank` and thus a unique, hashable identifier. Where as the "key" an index in a subset of idx. Ex.: the index for list(self._idx_active_set) would be a "key". """ ## constructor def __init__(self, **kwargs): """see `mdp.Node` :type ce: TimeSeriesCovE :keyword ce: covariance estimator instance, if None a new instance will be created and initialised with the identity matrix corresponding to the template size. required :type chan_set: tuple :keyword chan_set: tuple of int designating the subset of channels this filter bank operates on. Defaults to all the channels of the input data, as determined by the max chan_set of the covariance estimator. Default=tuple(range(nc)) :type filter_cls: FilterNode :keyword filter_cls: the class of filter node to use for the filter bank, this must be a subclass of 'FilterNode'. required :type rb_cap: int :keyword rb_cap: capacity of the ringbuffer that stored observations for the filters to calculate the mean template. Default=350 :type tf: int :keyword tf: temporal extend of the filters in the filter bank in samples. Default=47 :type verbose: int :keyword verbose: verbosity level, 0:none, >1: print .. ref `VERBOSE` Default=0 """ # kwargs ce = kwargs.pop('ce', None) chan_set = kwargs.pop('chan_set', None) filter_cls = kwargs.pop('filter_cls', REMF) rb_cap = kwargs.pop('rb_cap', 350) tf = kwargs.pop('tf', 47) verbose = kwargs.pop('verbose', 0) # everything not popped goes to mdp.Node.__init__ via super # checks if not issubclass(ce.__class__, TimeSeriesCovE): raise TypeError('\'ce\' of type TimeSeriesCovE is required!') if not issubclass(filter_cls, FilterNode): raise TypeError('\'filter_cls\' of type FilterNode is required!') if chan_set is None: chan_set = tuple(range(ce.get_nc())) # super super(FilterBankNode, self).__init__(**kwargs) # members self._tf = int(tf) self._nc = None self._chan_set = None self._xcorrs = None self._ce = None self._filter_cls = filter_cls self._rb_cap = int(rb_cap) self._idx_active_set = set() self.bank = {} self.verbose = VERBOSE(verbose) # set members self.cs = chan_set self.ce = ce ## properties static or protected
[docs] def get_tf(self): return self._tf
tf = property(get_tf, doc='temporal filter extend [samples]')
[docs] def get_nc(self): return self._nc
nc = property(get_nc, doc='number of channels')
[docs] def get_nf(self, active=True): if active: return len(self._idx_active_set) else: return len(self.bank)
nf = property(get_nf, doc='number of filters')
[docs] def get_template_set(self, active=True, mc=True): key_set = self._idx_active_set if active else set(self.bank.keys()) if not key_set: shape = (0, self._tf, self._nc) if mc else (0, self._tf * self._nc) return sp.zeros(shape, dtype=self.dtype) f_list = self._get_idx_set(key_set) return sp.asarray([f.xi if mc else f.xi_conc for f in f_list])
template_set = property(get_template_set, doc='template set of active filters')
[docs] def get_filter_set(self, active=True, mc=True): key_set = self._idx_active_set if active else set(self.bank.keys()) if not key_set: shape = (0, self._tf, self._nc) if mc else (0, self._tf * self._nc) return sp.zeros(shape, dtype=self.dtype) f_list = self._get_idx_set(key_set) return sp.asarray([f.f if mc else f.f_conc for f in f_list])
filter_set = property(get_filter_set, doc='filter set of active filters')
[docs] def get_xcorrs(self): return self._xcorrs
xcorrs = property(get_xcorrs, doc='cross correlation tensor for active filters')
[docs] def get_xcorrs_at(self, idx0, idx1=None, shift=0): if self._xcorrs is None: return None return self._xcorrs[idx0, idx1 or idx0, self._tf - 1 + shift]
[docs] def get_idx_for(self, key): return list(self._idx_active_set)[key]
def _get_idx_set(self, key_set): return [self.bank[k] for k in key_set] ## properties public
[docs] def get_chan_set(self): return self._chan_set
[docs] def set_chan_set(self, value): self._chan_set = tuple(sorted(value)) self._nc = len(self._chan_set)
cs = property(get_chan_set, set_chan_set)
[docs] def get_ce(self): return self._ce
[docs] def set_ce(self, value): if not issubclass(value.__class__, TimeSeriesCovE): raise TypeError('Has to be of type %s' % TimeSeriesCovE) if value.tf_max < self._tf: raise ValueError('tf_max of cov_est is < than filter bank tf') if self._chan_set not in value.get_chan_set(): raise FilterBankError('\'chan_set\' not present at \'ce\'!') # TODO: not sure how to solve this #if value.get_nc() < self._nc: # raise ValueError('nc of cov_est is < than the filter bank nc') self._ce = value self._check_internals()
ce = property(get_ce, set_ce) ## filter bank interface
[docs] def reset_history(self): """sets the history to all zeros for all filters""" for filt in self.bank.values(): filt.reset_history()
[docs] def reset_rates(self): """resets the rate estimators for all filters (if applicable)""" for filt in self.bank.values(): if hasattr(filt, 'rate'): filt.rate.reset()
[docs] def create_filter(self, xi, check=True): """adds a new filter to the filter bank :type xi: ndarray :param xi: template to build the filter for """ # check input xi = sp.asarray(xi, dtype=self.dtype) if xi.ndim != 2 or xi.shape != (self._tf, self._nc): raise FilterBankError( 'template does not match the filter banks filter shape of %s' % str((self._tf, self._nc))) # build filter and add to filter bank new_f = self._filter_cls(self._tf, self._nc, self._ce, rb_cap=self._rb_cap, chan_set=self._chan_set, dtype=self.dtype) #new_f.fill_xi_buf(xi) new_f.append_xi_buf(xi) idx = 0 if len(self.bank): idx = max(self.bank.keys()) + 1 self.bank[idx] = new_f self._idx_active_set.add(idx) # return and check internals rval = True if check is True: rval = self._check_internals() return rval
[docs] def deactivate(self, idx, check=False): """deactivates a filter in the filter bank Filters are never deleted, but can be de-/reactivated and will be used respecting there activation state for the filter output of the filter bank. No effect if idx not in self.bank. """ if idx in self.bank: self.bank[idx].active = False self._idx_active_set.discard(idx) if check is True: self._check_internals() else: logging.warn('no idx=%s in filter bank!' % idx)
[docs] def activate(self, idx, check=False): """activates a filter in the filter bank Filters are never deleted, but can be de-/reactivated and will be used respecting there activation state for the filter output of the filter bank. No effect if idx not in self.bank. """ if idx in self.bank: self.bank[idx].active = True self._idx_active_set.add(idx) if check is True: self._check_internals() else: logging.warn('no idx=%s in filter bank!' % idx)
def _check_internals(self): """triggers filter recalculation and rebuild xcorr tensor""" # check if self.verbose.has_print: print '_check_internals' if not self.bank: return # build filters for i in self._idx_active_set: self.bank[i].calc_filter() # build cross-correlation tensor self._xcorrs = xi_vs_f( self.get_template_set(mc=False), self.get_filter_set(mc=False), nc=self._nc) ## mpd.Node interface
[docs] def is_invertible(self): return False
[docs] def is_trainable(self): return False
def _execute(self, x): if not self._idx_active_set: return sp.zeros((x.shape[0], 0), dtype=self.dtype) rval = sp.empty((x.shape[0], self.nf)) for k, i in enumerate(self._idx_active_set): rval[:, k] = self.bank[i](x) return rval ## plotting methods
[docs] def plot_xvft(self, ph=None, show=False): """plot the Xi vs F Tensor of the filter bank""" # get plotting tools try: from spikeplot import xvf_tensor, plt except ImportError: return None # check if self.nf == 0: logging.warn('skipping plot, no active units!') return None # init inlist = [self.get_template_set(mc=False), self.get_filter_set(mc=False), self._xcorrs] return xvf_tensor(inlist, nc=self._nc, plot_handle=ph, show=show)
[docs] def plot_template_set(self, ph=None, show=False): """plot the template set in a waveform plot""" # get plotting tools try: from spikeplot import waveforms, plt except ImportError: return None # checks if self.nf == 0: logging.warn('skipping plot, no active units!') return None # init units = {} for k in self._idx_active_set: units[k] = self.bank[k]._xi_buf[:] return waveforms( units, tf=self._tf, plot_separate=True, plot_mean=True, plot_single_waveforms=True, plot_handle=ph, show=show)
[docs] def plot_template_set2(self, show=False): """plot the template set in a waveform plot""" # get plotting tools try: from spikeplot import plt except ImportError: return None # checks if self.nf == 0: logging.warn('skipping plot, no active units!') return None # init f = plt.figure() y_min, y_max = 0, 0 share = None for k, i in enumerate(self._idx_active_set): ax = f.add_subplot(self.nf, 1, k + 1, sharex=share, sharey=share) a, b = self.bank[i].plot_buffer_to_axis(axis=ax, idx=i) y_min = min(y_min, a) y_max = max(y_max, b) share = ax f.axes[0].set_ylim(y_min, y_max) if show is True: plt.show() return f
[docs] def plot_filter_set(self, ph=None, show=False): """plot the filter set in a waveform plot""" # get plotting tools try: from spikeplot import waveforms except ImportError: return None # checks if self.nf == 0: logging.warn('skipping plot, no active units!') return None # init units = {} for k in self._idx_active_set: units[k] = sp.atleast_2d(self.bank[k].f_conc) return waveforms( units, tf=self._tf, plot_separate=True, plot_mean=False, plot_single_waveforms=False, plot_handle=ph, show=show) ## special methods
__len__ = get_nf ## MAIN
if __name__ == '__main__': pass

Project Versions

This Page