# -*- coding: utf-8 -*-
#_____________________________________________________________________________
#
# Copyright (c) 2012-2013, Berlin Institute of Technology
# All rights reserved.
#
# Developed by: Philipp Meier <pmeier82@gmail.com>
#
# Neural Information Processing Group (NI)
# School for Electrical Engineering and Computer Science
# Berlin Institute of Technology
# MAR 5-6, Marchstr. 23, 10587 Berlin, Germany
# http://www.ni.tu-berlin.de/
#
# Repository: https://github.com/pmeier82/BOTMpy
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal with the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# * Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimers.
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimers in the documentation
# and/or other materials provided with the distribution.
# * Neither the names of Neural Information Processing Group (NI), Berlin
# Institute of Technology, nor the names of its contributors may be used to
# endorse or promote products derived from this Software without specific
# prior written permission.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# WITH THE SOFTWARE.
#_____________________________________________________________________________
#
# Acknowledgements:
# Philipp Meier <pmeier82@gmail.com>
#_____________________________________________________________________________
#
# Changelog:
# * <iso-date> <identity> :: <description>
#_____________________________________________________________________________
#
"""initialization - alignment of spike waveform sets"""
__docformat__ = 'restructuredtext'
__all__ = ['AlignmentNode']
## IMPORTS
import scipy as sp
from scipy.signal import resample
from .base_nodes import ResetNode
## CLASSES
[docs]class AlignmentNode(ResetNode):
"""aligns a set of spikes on the mean waveform of the set"""
## constructor
def __init__(self, nchan=4, max_rep=32, max_tau=10, resample_factor=None,
cut_down=True, dtype=sp.float32, debug=False):
"""
:Parameters:
nchan : int
channel count
Default=4
max_rep : int
maximum repetitions
Default=32
max_tau : int
upper bound for the shifting. will shift from -tau to +tau
Default=10
resample_factor : float or None
before shifting, resample with this factor. after return
resample with inverse of this factor, if None ignore
Default=None
cut_down: bool
If True, cut down to original size, stripping the padding
dimensions. If False, return with the padding dimensions.
Default=True
dtype : scipy.dtype
dtype for the internal calculations
Default=scipy.float32
debug : bool
If True, be verbose.
Defult=False
"""
# super
super(AlignmentNode, self).__init__(dtype=dtype)
# members
self.nchan = int(nchan)
self.tau = None
self.spikes = None
self.debug = bool(debug)
self.max_rep = int(max_rep)
self.max_tau = int(max_tau)
self.resample_factor = None
if resample_factor is not None:
self.resample_factor = float(resample_factor)
self.cut_down = bool(cut_down)
## node implementation
[docs] def is_invertable(self):
return False
[docs] def is_trainable(self):
return False
def _reset(self):
self.tau = None
self.spikes = None
def _execute(self, x):
# inits
n, dim = x.shape
if n < 2:
raise ValueError('too few spikes to align')
self.spikes = sp.zeros((n, dim + 2 * self.max_tau * self.nchan))
self.tau = sp.zeros(n)
# put spikes in, resample and extrapolate
idx_base = sp.arange(dim / self.nchan)
spike_idx = []
for c in xrange(self.nchan):
spike_idx += (
idx_base +
c * dim / self.nchan +
(2 * c + 1) * self.max_tau
).tolist()
self.spikes[:, spike_idx] = x
if self.resample_factor is not None:
if self.debug is True:
print 'upsampling by %f' % self.resample_factor
self.spikes = resample(
self.spikes,
self.spikes.shape[1] * self.resample_factor,
axis=1
)
self.max_tau *= self.resample_factor
self.max_tau = int(self.max_tau)
if self.debug is True:
print 'upsampled size: %d, maxtau: %d' % (self.spikes
.shape[1],
self.max_tau)
# get the mean spike and start iteration
mean_spike = self.spikes.mean(axis=0)
changes = sp.inf
cur_rep = 0
while cur_rep < self.max_rep and changes > n * 0.005:
changes = 0
q_avg = 0.0
for s in xrange(n):
# take the current spike out of the mean
mean_spike -= self.spikes[s, :] / n
# fit quality
q_max = 0.0
best_tau = 0
for tau in xrange(-self.max_tau, self.max_tau + 1):
# shift the current spike and compute distance to the mean
shifted_spike = shift_row(self.spikes[s, :], tau)
q_tau = sp.absolute(sp.dot(mean_spike, shifted_spike))
if q_tau > q_max or q_max == 0.0:
best_tau = tau
q_max = q_tau
if best_tau != 0:
# apply shift
self.spikes[s, :] = shift_row(self.spikes[s, :], best_tau)
self.tau[s] += best_tau
changes += 1
q_avg += q_max
# put the shifted spike back into the mean
mean_spike += self.spikes[s, :] / n
cur_rep += 1
if self.debug is True:
print '\t[%s] -> qual=%.4f (*%d)' % (
cur_rep, q_avg / n, changes)
# get rid of the padding and resampling
if self.resample_factor is not None:
if self.debug is True:
print 'downsampling again'
self.spikes = resample(
self.spikes,
self.spikes.shape[1] * 1.0 / self.resample_factor,
axis=1
)
# correct taus
self.tau = (self.tau / self.resample_factor).round().astype(int)
if self.cut_down is True:
self.spikes = self.spikes[:, spike_idx]
# return aligned spikes
return self.spikes
## HELPERS
def shift_row(row, shift):
if shift == 0:
return row
if shift > 0:
return sp.concatenate(([0] * shift, row[:-shift]))
else:
return sp.concatenate((row[-shift:], [0] * -shift))
## MAIN
if __name__ == '__main__':
pass