import pandas as pd
import os
from contextlib import redirect_stderr
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
with redirect_stderr(open(os.devnull, "w")):
import tensorflow as tf
import h5py
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from importlib import resources
from automea import util
[docs]class Analysis:
"""
Class for performing analysis of MEA datasets.
This class provides methods for loading datasets, performing various analyses,
and visualizing the results.
Attributes
----------
util : module
Utility functions module.
dataset : list
The dataset. Default is an empty list.
dataset_filename : str
The filename of the dataset. Default is an empty string.
dataset_index : int or None
The index of the dataset. Default is None.
path_to_dataset : str
The path to the dataset. Default is an empty string.
path_to_csv : str
The path to the CSV file. Default is an empty string.
path_to_model : str
The path to the model file. Default is an empty string.
output_name : str
The name of the output. Default is an empty string.
output_folder : str
The output folder. Default is an empty string.
wellsIDs : list
The IDs of the wells. Default is an empty list.
wellsLabels : list
The labels of the wells. Default is an empty list.
well : str or None
The selected well. Default is None.
model : object or None
The ML model. Default is None.
model_name : str or None
The name of the model. Default is None.
signal : array_like or None
The recorded signal. Default is None.
time : array_like or None
The time array. Default is None.
spikes : array_like or None
The detected spikes. List with timestamp of every spike. Default is None.
spikes_binary : array_like or None
The binary representation of detected spikes. Default is None.
active_channels : array_like or None
Indices of channels that are active (mean firing rate > 0.1 spikes/s). Is set when detect_spikes() is called. Default is None.
reverbs : array_like or None
The detected reverberations. List of lists with start and end timestamp of reverb. Default is None.
reverbs_binary : array_like or None
The binary representation of detected reverberations. Default is None.
bursts : array_like or None
The detected bursts. List of lists with start and end timestamp of burst. Default is None.
bursts_binary : array_like or None
The binary representation of detected bursts. Default is None.
net_reverbs : array_like or None
The detected network reverberations. List of lists with start and end timestamp of network reverb. Default is None.
reverbs_binary : array_like or None
The binary representation of detected network reverberations. Default is None.
net_bursts : array_like or None
The detected network bursts. List of lists with start and end timestamp of network burst. Default is None.
net_bursts_binary : array_like or None
The binary representation of detected network bursts. Default is None.
adZero : int
Variable to convert integer signal (default in h5 dataset) to mV. Is subracted from integer signal. Default is 0
conversionFactor : int
Variable to convert integer signal (default in h5 dataset) to mV. Is multiplied to integer signal. Default is 59604
exponent : int
Variable to convert integer signal (default in h5 dataset) to mV. Is multiplied to integer signal. Default is -12
samplingFreq : int
Sampling frequency used to measure the signal. Default is 10_000 (10 kHz)
total_timesteps_signal : int
Number of timesteps present in measured signal. Default is 6 million.
time : array_like
Array containing time for measured signal in seconds. Uses samplingFreq and total_timesteps_signal.
threshold_params : dict
Dictionary of parameters used to detect threshold. Includes:
'rising' : (key : str, value : int)
How many standard deviations are used to set threshold. Default is 5.
'startTime' : (key : str, value : float)
Start time (s) to consider signal. Default is 0.
'baseTime' : (key : str, value : float)
Time lenght in ms to consider signal for std calculation. Default is 0.250.
'segments' : (key : str, value : int)
Number of segment to perform for calculation. Default is 10.
spike_params : dict
Dictionary of parameters used to detect spikes. Includes:
- 'deadtime': (key : str, value : float)
Amount of time after detecting a spike for which no spike is detect. Default is 3_000*1e-6.
reverbs_params : dict
Dictionary of parameters used to detect reverberations. Includes:
'max_interval_start' : (key : str, value : int)
Maximum interval between spikes to start reverberation detection in ms. Default is 15.
'max_interval_end' : (key : str, value : int)
Maximum interval between spikes to end reverberation detection in ms. Default is 20.
'min_interval_between' : (key : str, value : int)
Minimum interval between spikes to consider for reverberation detection in ms. Default is 25.
'min_duration' : (key : str, value : int)
Minimum reverberation duration in ms. Default is 20.
'min_spikes' : (key : str, value : int)
Minimum number of spikes to consider in a reverberation. Default is 5.
bursts_params : dict
Dictionary of parameters used to detect bursts (from reverberations). Includes:
'min_interval_between' : (key : str, value : int)
Minimum interval between spikes to consider for reverberation detection in ms. Default is 300.
_pretrained_models : list of str
List containing name of pretrained ML models used for reverberation detection.
model_params : dict
Dictionary of parameters used in model-based reverberation detection. Includes:
'name' : (key : str, value : str)
Name of the model. Default is None.
'input_type' : (key : str, value : str)
String defining type of input used by the model ('signal' or 'spikes'). Deault is None and changed when model is loaded.
'input_average' : (key : str, value : int)
How many points are used to calculate average of input. Default is 30.
'window_size' : (key : str, value : int)
Size of window (in timestamps) used as input for the model (before averaging). Default is 50_000.
'window_overlap' : (key : str, value : int)
Overlap between windows when sweeping a channel. Default is 25_000.
analysis_params : dict
Parameters for analysis - which quantities user wants to save - each one creates an output file. Includes:
'save_spikes' : (key : str, value : bool)
Whether or not to save spikes in a dedicated file. Default is False.
'save_reverbs' : (key : str, value : bool)
Whether or not to save reverberations in a dedicated file. Default is False.
'save_bursts' : (key : str, value : bool)
Whether or not to save bursts in a dedicated file. Default is False.
'save_net_reverbs : (key : str, value : bool)
Whether or not to save network reverberations in a dedicated file. Default is False.
'save_net_bursts' : (key : str, value : bool)
Whether or not to save network bursts in a dedicated file. Default is False.
'save_stats' : (key : str, value : bool)
Whether or not to save statistics in a dedicated file. Default is False.
wellIndexLabelDict : dict
Dictionary to convert well-index to well-label.
wellLabelIndexDict : dict
Dictionary to convert well-label to well-index.
Methods
-------
files_and_well_csv(file)
Load filenames and associated wells from a CSV file.
loadmodel(setting = 'default')
Load a machine learning model to perform reverberations/bursts detection.
loadsignal(signal)
Load a signal into the analysis object.
loadh5(filename = None)
Load data from an HDF5 file into the analysis object.
convert_signal(signal, adzero, conversionfactor, exponent)
Convert raw signal values to physical units using the provided parameters:
adzero, conversionfactor, and exponent
convert_threshold(threshold, adzero, conversionfactor, exponent)
Convert threshold value to physical units using the provided parameters:
adzero, conversionfactor, and exponent.
loadspikes(spikes)
Load spike data into the analysis object.
loadwell(file:str, well, method = 'default', spikes = False, reverbs = False,
bursts = False, net_reverbs = False, net_bursts = False)
Load data for a specific well.
detect_threshold()
Detect threshold value for spike detection.
convert_timestamps_to_binary(input_timestamp, input_type, size = None)
Convert timestamps to binary representation.
convert_binary_to_timestamps(input_binary, input_type)
Convert binary representation data to timestamps.
_convert_binary_reverb_to_timestamp(input_binary)
Convert binary representation of reverbs to timestamps.
_detect_spikes(signal, threshold)
Detect spikes in a signal based on a given threshold value.
detect_spikes()
Detect spikes in the signal.
reverbs_params_default()
Set default parameters for reverberations/bursts detection.
reverbs_params_manual(params)
Set manual parameters for reverberations detection.
detect_reverbs(method = 'default')
Detects reverberations based on the specified method.
_detect_reverbs(spikes)
Detect reverberations based on spikes timestamps.
detect_bursts()
Detect bursts in the signal based on the detected reverberations.
_detect_bursts(reverbs)
Detect bursts based on reverberations.
_predict_reverbs(model_input, spikes_binary)
Predict reverberations using a pre-trained machine learning model.
normalize_signal(signal)
Normalize the input signal.
normalize_threshold(signal, threshold)
Normalize the threshold relative to the input signal.
reduce_dimension(X, input_type = None, reduction_factor = None)
Reduce the dimensionality of the input data.
reduce_norm_abs_signal(signal)
Reduce the dimensionality of the normalized absolute signal.
detect_net(inp, minChannelsParticipating=None, minSimultaneousChannels=None)
Detect net bursts or net reverbs.
save_spikes()
Save spikes data to a CSV file.
analyze_dataset(file=None, mode='csv', save_default=False)
Analyze the datasets from a CSV file, and save the results to CSV files.
plot_window(signal, start_time=None, duration=None, threshold=None, spikes=None, reverberations=None,
bursts=None, net_bursts=None, save=False, output_name=None, figsize=(6, 6), yunits='a.u.',
xunits='s')
Plot a window of the signal with detected spikes, reverberations, bursts, and network bursts.
plot_raster(spikes, reverbs = None, bursts = None, net_reverbs = None, net_bursts = None)
Plot a raster plot of spikes with optional overlay of reverberations, bursts, and network bursts.
plot_raster_well(file:str, well, method = 'default', reverbs = False, bursts = False,
net_reverbs = False, net_bursts = False)
Plot a raster plot for a specific well with optional overlay of events.
"""
def __init__(self):
"""
Initializes class with default attributes.
"""
# import util functions
self.util = util
# initialize attributes to empty list/ empty string / None
self.dataset = []
self.dataset_filename = ''
self.dataset_index = None
self.path_to_dataset = ''
self.path_to_csv = ''
self.path_to_model = ''
self.output_name = ''
self.output_folder = ''
self.wellsIDs = []
self.wellsLabels = []
self.well = None
self.model = None
self.model_name = None
self.signal = None
self.time = None
self.spikes = None
self.spikes_binary = None
self.active_channels = None
self.reverbs = None
self.reverbs_binary = None
self.bursts = None
self.bursts_binary = None
self.net_reverbs = None
self.net_reverbs_binary = None
self.net_bursts = None
self.net_bursts_binary = None
self.adZero, self.conversionFactor, self.exponent = 0, 59604, -12
self.plate = None
self.plot_name = None
self.samplingFreq = 10_000 # default sampling frequency = 10kHZ
self.total_timesteps_signal = 6_000_000 #default number of timesteps in one recording - 10min = 60s (for sampl freq 10kHz) = 6million points
self.time = np.linspace(0, (self.total_timesteps_signal-1)/self.samplingFreq, self.total_timesteps_signal) # time array in seconds
# parameters used to detect threshold
self.threshold_params = {'rising' : 5, # how many standard deviations are used to set threshold
'startTime' : 0, # start time to consider signal
'baseTime' : 0.250, # in ms, for how long to consider signal for std calculation
'segments' : 10} # how many segment to perform calculation
self.spikes_params = {'deadtime': 3_000*1e-6} # time after detecting a spike for which spike detection is halted
# parameters used for reverberations / bursts detection - Max Interval parameters
self.reverbs_params = {'max_interval_start' : 15,
'max_interval_end' : 20,
'min_interval_between' : 25,
'min_duration' : 20,
'min_spikes' : 5}
# parameter to merge reverberations into bursts
self.bursts_params = {'min_interval_between' : 300}
# pretained ML models for burst detection distributed with the package
self._pretrained_models = ['signal30.h5', 'signal100.h5', 'spikes30.h5']
# parameters for model based burst detection
self.model_params = {'name' : None, # name of the model
'input_type' : None, # type of input: signal or spikes
'input_average' : 30, # how many points is used to calculate average of input
'window_size' : 50_000, # size of window (in timestamps) used as input for the model (before averaging)
'window_overlap' : 25_000} # overlap between windows when sweeping a channel
# parameters for analysis - which quantities user wants to save - each one creates an output file
self.analysis_params = {'save_spikes' : False,
'save_reverbs' : False,
'save_bursts' : False,
'save_net_reverbs' : False,
'save_net_bursts' : False,
'save_stats' : False}
# define dictionary for well labels to index and vice-versa,
# based on plate with 4x6 wells
keys = []
items = []
i = 0
j_labels = ['A', 'B', 'C', 'D']
k_labels = range(1,7)
for j in j_labels:
for k in k_labels:
keys.append(i)
items.append(f'{j}{k}')
i += 1
self.wellIndexLabelDict = dict(zip(keys, items)) # dictionary to convert well-index to well-label
self.wellLabelIndexDict = dict(zip(items, keys)) # dictionary to convert well-label to well-index
def _infer_plate(self):
"""
Infer plate type from loaded HDF5 file using wellsFromData.
Sets self.plate automatically and warns if some wells are missing.
"""
unique_wells, counts = np.unique(self.wellsFromData, return_counts=True)
n_wells = len(unique_wells)
# get the most common number of channels per well
from collections import Counter
count_freq = Counter(counts)
most_common_channels, freq = count_freq.most_common(1)[0]
# heuristics for known plates
if most_common_channels == 3:
self.plate = '96_3'
expected_wells = 96
elif most_common_channels == 12:
self.plate = '24_12'
expected_wells = 24
else:
self.plate = f'custom_{n_wells}w_{most_common_channels}c'
expected_wells = n_wells
# warn if some wells are missing
if n_wells < expected_wells:
print(f'Warning: only {n_wells}/{expected_wells} wells detected in HDF5 file.')
print(f'Inferred plate: {self.plate} ({n_wells} wells, {most_common_channels} channels/well)')
def _define_well_dicts(self):
"""
Define well index ↔ label dictionaries depending on detected plate type.
"""
if self.plate == '96_3':
rows = 'ABCDEFGH' # 8 rows
cols = range(1, 13) # 12 columns
elif self.plate == '24_12':
rows = ['A', 'B', 'C', 'D'] # 4 rows
cols = range(1, 7) # 6 columns
keys = []
items = []
i = 0
for r in rows:
for c in cols:
keys.append(i)
items.append(f'{r}{c}')
i += 1
self.wellIndexLabelDict = dict(zip(keys, items))
self.wellLabelIndexDict = dict(zip(items, keys))
[docs] def files_and_well_csv(self, file):
"""
Load filenames and associated wells from a CSV file.
This method reads a CSV file containing filenames and associated wells,
and populates the `dataset` and `wellsLabels` attributes accordingly.
Parameters
----------
file : str
The name of the CSV file to load, relative to the path_to_csv attribute.
Returns
-------
None
Examples
--------
>>> obj = Analysis()
>>> obj.files_and_well_csv('data.csv')
CSV file format example:
filename;wells
file1.h5;A1,B2,C3
file2.h5f;all
file3.h5;D4,E5,F6
In the above example, we want to analyze wells A1, B2, and C3 fromfile1.h5,
all wells from file2.h5, and D4, E5 and F6 from and file3.h5.
"""
if self.plate is not None:
self._define_well_dicts()
filenames_and_wells = pd.read_csv(self.path_to_csv+file, sep = ';')
self.dataset = []
for index in filenames_and_wells.index:
self.dataset.append(filenames_and_wells['filename'][index])
if filenames_and_wells['wells'][index] == 'all':
self.wellsLabels.append(list(self.wellIndexLabelDict.values()))
else:
self.wellsLabels.append([])
for well in filenames_and_wells['wells'][index].split(','):
self.wellsLabels[index].append(well)
[docs] def loadmodel(self, setting='default'):
"""
Load a machine learning model to perform reverberations/bursts detection.
Automatically loads bundled pretrained models if available.
"""
# Set default model name if none specified
if self.model_name is None:
self.model_name = 'signal30.h5'
# Set default or manual parameters
if setting == 'default':
self.model_params['input_average'] = 30
self.model_params['window_size'] = 50_000
self.model_params['window_overlap'] = 25_000
if 'signal' in self.model_name.lower():
self.model_params['input_type'] = 'signal'
elif 'spikes' in self.model_name.lower():
self.model_params['input_type'] = 'spikes'
if '100' in self.model_name.lower():
self.model_params['input_average'] = 100
elif setting == 'manual':
pass # keep user-defined parameters
# If the model is one of the bundled pretrained models
if self.model_name in self._pretrained_models:
try:
# Use importlib.resources to get path inside package
with resources.path("automea.models", self.model_name) as model_path:
self.model = tf.keras.models.load_model(model_path, compile=False)
except FileNotFoundError:
raise ValueError(f"Pretrained model '{self.model_name}' not found in package.")
else:
# Otherwise, treat model_name as a path
self.model = tf.keras.models.load_model(self.path_to_model + self.model_name, compile=False)
[docs] def loadsignal(self, signal):
"""
Load a signal into the analysis object.
Parameters
----------
signal : array_like
The signal data to be loaded for analysis. It can be a 1D or 2D array,
depending if the signal has one or multiple channels.
Returns
-------
None
Notes
-----
- If the signal is 1D, it is assumed to be a one-channel time series signal.
- If the signal is 2D, it is assumed to be a collection of channels, with signals over time.
- The `total_timesteps_signal` attribute is updated to reflect the length of the signal.
- The `time` attribute is generated based on the sampling frequency and the length of the signal.
Examples
--------
>>> obj = Analysis()
>>> obj.loadsignal(signal_data)
"""
self.signal = signal
if np.array(signal.ndim) == 1:
self.total_timesteps_signal = len(signal)
else:
self.total_timesteps_signal = np.array(signal).shape[1]
self.time = np.linspace(0, (self.total_timesteps_signal-1)/self.samplingFreq, self.total_timesteps_signal)
[docs] def loadh5(self, filename = None):
"""
Load data from an HDF5 file into the analysis object.
Parameters
----------
filename : str, optional
The name of the HDF5 file to load. If None, it loads the dataset specified in the dataset attribute.
Returns
-------
None
Notes
-----
- If None, sets the dataset attribute equals to the filename input
- The 'infoChannel', 'adZero', 'conversionFactor', 'exponent', and 'wellsFromData' attributes are updated
based on the data loaded from the HDF5 file.
Examples
--------
>>> obj = Analysis()
>>> obj.loadh5() # Load the dataset from dataset attribute
>>> obj.loadh5('example.h5') # Load data from a specific HDF5 file
"""
if filename is None:
self.h5 = h5py.File(self.path_to_dataset + self.dataset, 'r')
else:
self.h5 = h5py.File(self.path_to_dataset + filename, 'r')
self.dataset = filename
#self.file.append(filename)
self.infoChannel = np.asarray(self.h5['Data']['Recording_0']['AnalogStream']['Stream_0']['InfoChannel'])
try:
self.loadsignal(np.asarray(self.h5['Data']['Recording_0']['AnalogStream']['Stream_0']['ChannelData'], dtype = np.int32))
#self.signal = np.asarray(self.h5['Data']['Recording_0']['AnalogStream']['Stream_0']['ChannelData'], dtype = np.int32)[:,:self.total_timesteps_signal]
except:
print('Error when trying to allocate "signal".')
self.adZero = np.array([self.infoChannel[i][8] for i in range(len(self.infoChannel))]).reshape(len(self.infoChannel), 1)
if len(np.unique(self.adZero)): self.adZero = self.adZero[0][0]
self.conversionFactor = np.array([self.infoChannel[i][10] for i in range(len(self.infoChannel))]).reshape(len(self.infoChannel), 1)
if len(np.unique(self.conversionFactor)): self.conversionFactor = self.conversionFactor[0][0]
self.exponent = np.array([self.infoChannel[i][7] for i in range(len(self.infoChannel))]).reshape(len(self.infoChannel), 1)
if len(np.unique(self.exponent)): self.exponent = self.exponent[0][0]
self.wellsFromData = np.array([self.infoChannel[i][2] for i in range(len(self.infoChannel))]).reshape(len(self.infoChannel), 1)
self.wellsFromData = self.wellsFromData.flatten()
if self.plate is None:
self._infer_plate()
self._define_well_dicts()
[docs] def convert_signal(self, signal, adzero, conversionfactor, exponent):
"""
Convert raw signal values to physical units using the provided parameters:
adzero, conversionfactor, and exponent
Parameters
----------
signal : array_like
The raw signal values to be converted.
adzero : float
The adZero value to subract from signal.
conversionfactor : float
The conversion factor to convert raw values to physical units.
exponent : float
The exponent applied during conversion.
Returns
-------
array_like
The converted signal values in physical units.
Examples
--------
>>> obj = Analysis()
>>> converted_signal = obj.convert_signal(raw_signal, ad_zero_value, conversion_factor, exponent_value)
"""
return 1e6*(signal-adzero)*conversionfactor*10.**exponent
[docs] def convert_threshold(self, threshold, adzero, conversionfactor, exponent):
"""
Convert threshold value to physical units using the provided parameters:
adzero, conversionfactor, and exponent.
Parameters
----------
threshold : float
The threshold value to be converted.
adzero : float
The AD Zero value.
conversionfactor : float
The conversion factor to convert raw values to physical units.
exponent : float
The exponent applied during conversion.
Returns
-------
float
The converted threshold value in physical units.
Examples
--------
>>> obj = Analysis()
>>> converted_threshold = obj.convert_threshold(threshold_value, ad_zero_value, conversion_factor, exponent_value)
"""
return self.convert_signal(threshold, adzero, conversionfactor, exponent)
[docs] def loadspikes(self, spikes):
"""
Load spike data into the analysis object.
This method allows loading spike data into the analysis object for further processing.
Parameters
----------
spikes : array_like
The spike data to be loaded for analysis.
Returns
-------
None
"""
self.spikes = spikes
if self.signal.ndim == 1:
self.spikes_binary = self.convert_timestamps_to_binary(self.spikes, input_type = 'spikes')
else:
self.spikes_binary = np.zeros((len(self.spikes), self.total_timesteps_signal))
for channel, spikes_ in enumerate(self.spikes):
self.spikes_binary[channel] = self.convert_timestamps_to_binary(spikes_, input_type = 'spikes')
[docs] def loadwell(self, file:str, well, method = 'default', spikes = False, reverbs = False, bursts = False, net_reverbs = False, net_bursts = False):
"""
Load data for a specific well.
This method loads data for a specific well from an HDF5 file into the analysis object for further processing.
Optionally, it can also detect spikes and analyze them for reverberations and bursts, calling the deticated methods.
Parameters
----------
file : str
The name of the HDF5 file containing the data.
well : str or int
The label or index of the well to load.
method : str, optional
The method to use for detecting reverbs and bursts if spikes are detected. Default is 'default'.
spikes : bool, optional
Whether to detect spikes. Default is False.
reverbs : bool, optional
Whether to analyze reverberations. Default is False.
bursts : bool, optional
Whether to analyze bursts. Default is False.
net_reverbs : bool, optional
Whether to analyze net reverberations. Default is False.
net_bursts : bool, optional
Whether to analyze net bursts. Default is False.
Returns
-------
None
Examples
--------
>>> obj = Analysis()
>>> obj.loadwell('data.h5', 'well_label', method='model', spikes=True, reverbs=True)
Notes
-----
- The `file` parameter specifies the name of the HDF5 file containing the data.
- The `well` parameter can be either a well label (str) or index (int).
- The `method` parameter determines the method to use for detecting reverberations and bursts.
It is only used if spikes are detected and defaults to 'default'.
- The `spikes`, `reverbs`, `bursts`, `net_reverbs`, and `net_bursts` parameters control which
analyses to perform after loading the data. They default to False.
"""
if type(well) is str:
well_id = self.wellLabelIndexDict[well]
elif type(well) is int:
well_id = well
else:
print('Well type not valid.')
return
self.loadh5(file)
if well_id not in self.wellsFromData:
print('Well not found in the dataset.')
return
self.signal = self.signal[self.wellsFromData == well_id]
self.detect_threshold()
if spikes:
self.detect_spikes()
if method == 'model' and self.model is None:
pass
else:
self.detect_reverbs(method=method)
self.detect_bursts()
self.detect_net('reverbs')
self.detect_net('bursts')
else:
if reverbs or bursts or net_reverbs or net_bursts:
print('Reverbs and Bursts need Spikes to be detected.')
[docs] def detect_threshold(self):
"""
Detect threshold value for spike detection.
This method calculates the threshold(s) value(s) based on the signal data and threshold parameters
provided in the analysis object.
Returns
-------
None
Examples
--------
>>> obj = Analysis()
>>> obj.detect_threshold()
Notes
-----
- If the 'signal' attribute is not a numpy array, an error message is printed, and the method returns.
- The threshold value is calculated segment-wise based on the mean and standard deviation of the signal.
- The calculated threshold value is stored in the 'threshold' attribute of the analysis object.
"""
if isinstance(self.signal, np.ndarray) is False:
print('"signal" has to be a numpy array.')
return
if self.signal.ndim == 1:
signal = self.signal.reshape(1,-1)
else:
signal = self.signal
startTimestamp = int(self.threshold_params['startTime'] * self.samplingFreq)
totalTimestamps = int(self.threshold_params['baseTime'] * self.samplingFreq)
stds = np.zeros((len(signal), self.threshold_params['segments']))
mean = signal.mean(axis = 1)
for seg in range(self.threshold_params['segments']):
signal_ = signal[:,startTimestamp:startTimestamp+totalTimestamps]
stds[:, seg] = (signal_.T - mean).T.std(axis = 1)
startTimestamp+=totalTimestamps
thr = self.threshold_params['rising']*stds.min(axis = 1)
if len(thr) == 1: self.threshold = thr[0]
self.threshold = thr
[docs] def convert_timestamps_to_binary(self, input_timestamp, input_type, size = None):
"""
Convert timestamps to binary representation.
This method converts timestamps to a binary representation suitable for the specified input type,
such as 'spikes', 'reverbs', or 'bursts'.
Parameters
----------
input_timestamp : array_like
The timestamps data to be converted to binary representation.
input_type : str
The type of input timestamps. Options are 'spikes', 'reverbs', or 'bursts'.
size : int, optional
The size of the binary representation dta. Default is None.
Returns
-------
array_like
The binary representation of the timestamps data.
Examples
--------
>>> obj = Analysis()
>>> spikes_binary = obj.convert_timestamps_to_binary(spikes_timestamps, input_type='spikes')
Notes
-----
- If 'size' is not provided, it defaults to the total number of timesteps in the signal.
- For 'spikes', each spike timestamp is represented as 1 in the binary array.
- For 'reverbs' and 'bursts', each timestamp range is represented as 1 in the binary array.
"""
if size == None: size = self.total_timesteps_signal
if input_type == 'spikes':
if self.util._has_list(input_timestamp) is False:
#if np.array(input_timestamp).ndim == 1:
_binary = np.zeros(size)
_binary[input_timestamp] = 1
else:
_binary = np.zeros((len(input_timestamp), size))
for channel, timestamps in enumerate(input_timestamp):
_binary[channel][timestamps] = 1
elif input_type == 'reverbs' or input_type == 'bursts':
if np.array(input_timestamp, dtype = object).ndim in [1,2]:
_binary = np.zeros(size)
for item in input_timestamp:
_binary[item[0]:item[1]+1] = 1
else:
_binary = np.zeros((len(input_timestamp), size))
for channel, timestamps in enumerate(input_timestamp):
for item in timestamps:
_binary[channel][item[0]:item[1]+1] = 1
return _binary
[docs] def convert_binary_to_timestamps(self, input_binary, input_type):
"""
Convert binary representation data to timestamps.
This method converts a binary representation to timestamps for the specified input type,
such as 'spikes', 'reverbs', or 'bursts'.
Parameters
----------
input_binary : array_like
The binary representation data to be converted to timestamps.
input_type : str, optional
The type of input binary. Options are 'spikes', 'reverbs', or 'bursts'.
Returns
-------
list or list of lists
The timestamps corresponding to the binary representation.
Examples
--------
>>> obj = Analysis()
>>> spikes_timestamps = obj.convert_binary_to_timestamps(spikes_binary, input_type='spikes')
Notes
-----
- For 'spikes', each 1 in the binary array represents a spike timestamp.
- For 'reverbs' and 'bursts', consecutive 1s in the binary array represent timestamp ranges.
"""
if input_type == 'spikes':
if np.array(input_binary).ndim == 1:
return list(np.where(input_binary == 1)[0])
else:
timestamps_ = []
for spikes_bin in input_binary:
timestamps_.append(list(np.where(spikes_bin == 1)[0]))
return timestamps_
elif input_type == 'reverbs' or input_type == 'bursts':
if np.array(input_binary).ndim == 1:
return self._convert_binary_reverb_to_timestamp(input_binary)
else:
sequence = []
for input_channel in input_binary:
sequence.append(self._convert_binary_reverb_to_timestamp(input_channel))
return sequence
[docs] def _convert_binary_reverb_to_timestamp(self, input_binary):
"""
Convert binary representation of reverbs to timestamps.
This method converts a binary representation of reverbs to timestamps.
Parameters
----------
input_binary : array_like
The binary representation of reverbs.
Returns
-------
list
List of timestamp ranges corresponding to reverbs.
Examples
--------
>>> obj = Analysis()
>>> reverb_timestamps = obj._convert_binary_reverb_to_timestamp(reverb_binary)
Notes
-----
- This method converts a binary representation of reverbs to a list of timestamp ranges.
- Consecutive 1s in the binary array represent timestamp ranges for reverberations.
"""
start_reverb = False
reverbs_ = []
for i, item in enumerate(input_binary):
if item == 1 and not start_reverb:
reverbs_.append([i])
start_reverb = True
elif item == 0 and start_reverb:
reverbs_[-1].append(i-1)
start_reverb = False
return reverbs_
[docs] def _detect_spikes(self, signal, threshold):
"""
Detect spikes in a signal based on a given threshold value.
Parameters
----------
signal : array_like
The signal in which spikes are to be detected.
threshold : float
The threshold value for spike detection.
Returns
-------
list
List of spike timestamps.
Examples
--------
>>> obj = Analysis()
>>> spikes = obj._detect_spikes(signal_data, threshold_value)
Notes
-----
- It uses a deadtime parameter to avoid detecting multiple spikes within a short time period.
"""
dead = self.spikes_params['deadtime']*self.samplingFreq
spikes = []
for i in range(len(signal)):
if(not len(spikes)):
if(abs(signal[i])>threshold):
spikes.append(i)
else:
if(abs(signal[i])>threshold and (abs(i - spikes[-1]) > dead)):
spikes.append(i)
return spikes
[docs] def detect_spikes(self):
"""
Detect spikes in the signal.
This method detects spikes from the signal attribute of the analysis object based on the threshold attribute.
Returns
-------
None
Notes
-----
- It checks if the signal attribute is a numpy array, and if not, prints an error message and returns.
- If the signal attribute is a 1D numpy array, it detects spikes and updates the spikes and spikes_binary attributes.
- If the signal attribute is a 2D numpy array (multiple channels), it detects spikes for each channel and updates
the spikes and spikes_binary attributes accordingly.
- Calculates the number of active channels (mean-firing-rate > 0.1 spikes/s)
"""
if isinstance(self.signal, np.ndarray) is False:
print('"signal" attribute has to be a numpy array.')
return
if self.signal.ndim == 1:
self.spikes = self._detect_spikes(self.signal, self.threshold)
if len(self.spikes)/(self.total_timesteps_signal/self.samplingFreq) > 0.1:
self.active_channels = np.array([0])
else:
self.active_channels = []
self.spikes_binary = self.convert_timestamps_to_binary(self.spikes, input_type = 'spikes')
else:
self.spikes = list(map(self._detect_spikes, self.signal, self.threshold))
self.active_channels = np.array([index for index, spikes in enumerate(self.spikes) if len(spikes)/(self.total_timesteps_signal/self.samplingFreq) > 0.1])
self.spikes_binary = np.zeros((len(self.spikes), self.total_timesteps_signal))
for channel, spikes_ in enumerate(self.spikes):
self.spikes_binary[channel] = self.convert_timestamps_to_binary(spikes_, input_type = 'spikes')
[docs] def reverbs_params_default(self):
"""
Set default parameters for reverberations/bursts detection.
Returns
-------
None
Notes
-----
- This method sets default parameters for detecting reverberations, including:
- max_interval_start: Maximum interval start time for a reverberation.
- max_interval_end: Maximum interval end time for a reverberation.
- min_interval_between: Minimum interval between reverberations.
- min_duration: Minimum duration of a reverberation.
- min_spikes: Minimum number of spikes required to consider a reverberation.
"""
self.reverbs_params['max_interval_start'] = 15
self.reverbs_params['max_interval_end'] = 20
self.reverbs_params['min_interval_between'] = 25
self.reverbs_params['min_duration'] = 20
self.reverbs_params['min_spikes'] = 5
[docs] def reverbs_params_manual(self, params):
"""
Set manual parameters for reverberations detection.
Parameters
----------
params : list
List containing manual parameters for reverberations detection in the following order:
- max_interval_start: Maximum interval start time for a reverberation.
- max_interval_end: Maximum interval end time for a reverberation.
- min_interval_between: Minimum interval between reverberations.
- min_duration: Minimum duration of a reverberation.
- min_spikes: Minimum number of spikes required to consider a reverberation.
Returns
-------
None
"""
self.reverbs_params['max_interval_start'] = params[0]
self.reverbs_params['max_interval_end'] = params[1]
self.reverbs_params['min_interval_between'] = params[2]
self.reverbs_params['min_duration'] = params[3]
self.reverbs_params['min_spikes'] = params[4]
[docs] def detect_reverbs(self, method = 'default'):
"""
Detects reverberations based on the specified method.
Parameters
----------
method : str, optional
The method used for reverberations detection. Options are 'default', 'manual', or 'model'. Default is 'default'.
Returns
-------
None
Notes
-----
- If method is 'default', the default parameters for reverberations detection are used.
- If method is 'manual', the user can specify manual parameters for reverberations detection.
- If method is 'model', reverberations are detected using a pre-trained machine learning model.
- Detect reverberations are stored in the reverb attribute.
"""
if method == 'default':
self.reverbs_params_default()
if self.spikes is None: self.detect_spikes() # detect spikes if "detect_bursts" is called but no spikes are defined
if method == 'default' or method == 'manual':
if self.util._has_list(self.spikes) is False:
self.reverbs = self._detect_reverbs(self.spikes)
self.reverbs_binary = self.convert_timestamps_to_binary(self.reverbs, input_type = 'reverbs')
else:
self.reverbs = list(map(self._detect_reverbs, np.array(self.spikes, dtype = object)))
self.reverbs_binary = np.zeros((len(self.reverbs), self.total_timesteps_signal))
for channel, reverbs_ in enumerate(self.reverbs):
self.reverbs_binary[channel] = self.convert_timestamps_to_binary(reverbs_, input_type = 'reverbs')
elif method == 'model':
input_type = self.model_params['input_type']
if input_type == 'signal':
model_input = self.signal
run_once = self.signal.ndim == 1
elif input_type == 'spikes':
model_input = self.spikes_binary
run_once = not self.util._has_list(self.spikes)
else:
print('Input type or model not defined!')
return
if run_once:
self.reverbs = self._predict_reverbs(model_input, self.spikes_binary)
self.reverbs_binary = self.convert_timestamps_to_binary(self.reverbs, input_type = 'reverbs')
else:
self.reverbs = []
self.reverbs_binary = np.zeros((len(model_input), len(model_input[0])))
for channel, channel_input in enumerate(model_input):
self.reverbs.append(self._predict_reverbs(channel_input, self.spikes_binary[channel]))
self.reverbs_binary[channel] = self.convert_timestamps_to_binary(self.reverbs[-1], input_type = 'reverbs')
[docs] def _detect_reverbs(self, spikes):
"""
Detect reverberations based on spikes timestamps.
Parameters
----------
spikes : list
List of spike timestamps.
Returns
-------
list
List of timestamp ranges corresponding to reverberations.
Notes
-----
- Implements the Max Interval Method
"""
params_multiplier = 10
# detect bursts with a max interval to start and end
reverbs = []
i = 0
k = 0
while (i < len(spikes)-2):
if(abs(spikes[i] - spikes[i+1]) <= params_multiplier*self.reverbs_params['max_interval_start']):
reverbs.append([spikes[i]])
j = i+1
while(abs(spikes[j] - spikes[j+1]) < params_multiplier*self.reverbs_params['max_interval_end']):
reverbs[k].append(spikes[j])
j += 1
if(j >= len(spikes) - 1): break
i = j
k += 1
else:
i += 1
# check valid reverbarations with minimum duration and minimum number of spikes
valid_reverbs = []
for reverb in reverbs:
if len(reverb) > self.reverbs_params['min_spikes'] and \
abs(reverb[0] - reverb[-1]) > params_multiplier*self.reverbs_params['min_duration']:
valid_reverbs.append([reverb[0], reverb[-1]])
# merge all reverberations that are closer than the minimum interval
minIntervalBetweenBursts = self.bursts_params['min_interval_between']
self.bursts_params['min_interval_between'] = self.reverbs_params['min_interval_between']
merged_reverbs = self._detect_bursts(valid_reverbs)
self.bursts_params['min_interval_between'] = minIntervalBetweenBursts
return merged_reverbs
[docs] def detect_bursts(self):
"""
Detect bursts in the signal based on the detected reverberations.
Returns
-------
None
Notes
-----
- If no reverberations are detected or if the 'reverbs' attribute is empty, no bursts will be detected.
- Detected bursts are stored in the 'bursts' attribute.
"""
if self.util._has_list(self.spikes) is False:
self.bursts = self._detect_bursts(self.reverbs)
else:
self.bursts = list(map(self._detect_bursts, self.reverbs))
self.bursts_binary = np.zeros((len(self.bursts), self.total_timesteps_signal))
for channel, bursts_ in enumerate(self.bursts):
self.bursts_binary[channel] = self.convert_timestamps_to_binary(bursts_, input_type = 'bursts')
[docs] def _detect_bursts(self, reverbs):
"""
Detect bursts based on reverberations.
Parameters
----------
reverbs : list
List of timestamp ranges corresponding to detected reverberations.
Returns
-------
list
List of timestamp ranges corresponding to detected bursts.
Notes
-----
- This method iteratively merges adjacent reverberation intervals if they are closer than the minimum interval
between bursts.
- The 'min_interval_between' parameter from 'bursts_params' is used to determine the minimum interval between bursts.
"""
if len(reverbs) == 1 or not len(reverbs): return reverbs
params_multiplier = 10
r = np.array(reverbs)
diff = r[1:,0] - r[:-1,1]
merge = diff < self.bursts_params['min_interval_between']*params_multiplier
if not merge.any(): return reverbs
bursts = []
index = 0
while index < len(diff):
if merge[index]:
bursts.append([r[index,0],r[index+1,1]])
index += 1
else:
bursts.append(list(r[index]))
index += 1
if index == len(diff): bursts.append(list(r[-1]))
return self._detect_bursts(bursts)
[docs] def _predict_reverbs(self, model_input, spikes_binary):
"""
Predict reverberations using a pre-trained machine learning model.
This method predicts reverberations based on the given model input and binary spike data.
Parameters
----------
model_input : array_like
Input data for the model.
spikes_binary : array_like
Binary spike data.
Returns
-------
list
List of timestamp ranges corresponding to predicted reverberations.
Notes
-----
- The model input can be either signal or spike data, specified by the 'input_type' parameter in 'model_params' attribute.
- The method iterates through the input data in windows, predicts reverberations for each window, and combines the predictions.
- The 'max_interval_start', 'max_interval_end', and 'min_interval_between' parameters for reverberations detection are predicted by the model.
- Predicted reverberations are filtered based on the minimum duration criteria defined in 'reverbs_params' attribute.
"""
if self.model_params['input_type'] not in ['signal', 'spikes']:
print('Input type not correctly defined!')
return
if self.total_timesteps_signal%self.model_params['window_overlap'] != 0:
print("Number of windows and overlap do not match lenght of input channel (signal or spike)!")
return
reverbs_binary_pred_long = np.zeros(self.total_timesteps_signal)
number_of_windows = self.total_timesteps_signal//self.model_params['window_overlap']
for window_index in range(number_of_windows - 1):
window_start = window_index*self.model_params['window_overlap']
window_end = window_start + self.model_params['window_size']
input_in_window = model_input[window_start:window_end]
spikes_in_window = spikes_binary[window_start:window_end]
if self.model_params['input_type'] == 'signal':
input_in_window = self.reduce_norm_abs_signal(input_in_window)
elif self.model_params['input_type'] == 'spikes':
input_in_window = self.reduce_dimension(input_in_window)
mip_pred = self.model.predict(input_in_window.reshape(1,-1,1), verbose = 0, use_multiprocessing = True)[0]
self.reverbs_params['max_interval_start'], self.reverbs_params['max_interval_end'], self.reverbs_params['min_interval_between'] = mip_pred
reverbs_in_window = self._detect_reverbs(self.convert_binary_to_timestamps(spikes_in_window, input_type = 'spikes'))
reverbs_binary_pred_long[window_start:window_end] = self.convert_timestamps_to_binary(reverbs_in_window, input_type = 'reverbs', size = self.model_params['window_size'])
reverbs_binary_pred_long[reverbs_binary_pred_long >= 1] = 1
reverbs_pred = self.convert_binary_to_timestamps(reverbs_binary_pred_long, input_type = 'reverbs')
valid_reverbs = []
params_multiplier = 10
for reverb in reverbs_pred:
if abs(reverb[0] - reverb[-1]) > params_multiplier*self.reverbs_params['min_duration']:
valid_reverbs.append([reverb[0], reverb[-1]])
reverbs_pred = valid_reverbs
return reverbs_pred
[docs] def normalize_signal(self, signal):
"""
Normalize the input signal.
Parameters
----------
signal : array_like
Input signal to be normalized.
Returns
-------
array_like
Normalized signal.
Notes
-----
- The normalization is performed by dividing the signal by its maximum absolute value.
"""
return signal/abs(signal).max()
[docs] def normalize_threshold(self, signal, threshold):
"""
Normalize the threshold relative to the input signal.
Parameters
----------
signal : array_like
Input signal used for normalization.
threshold : float
Threshold value to be normalized.
Returns
-------
float
Normalized threshold.
Notes
-----
- The threshold is normalized by dividing it by the maximum absolute value of the input signal.
"""
return threshold/abs(signal).max()
[docs] def reduce_dimension(self, X, input_type = None, reduction_factor = None):
"""
Reduce the dimensionality of the input data.
Parameters
----------
X : array_like
Input data to be dimensionally reduced.
input_type : str, optional
Type of input data ('signal' or 'spikes').
reduction_factor : int, optional
Factor by which to reduce the dimensionality.
Returns
-------
array_like
Dimensionally reduced input data.
Notes
-----
- If 'input_type' is 'spikes', the dimensionality is reduced by selecting every 'reduction_factor' element.
- If 'input_type' is 'signal', the dimensionality is reduced by averaging every 'reduction_factor' elements.
"""
if reduction_factor is None:
reduction_factor = self.model_params['input_average']
if reduction_factor is None:
print('No reduction factor defined!')
return
if self.model_params['input_type'] == 'spikes' or input_type == 'spikes':
if X.ndim == 1:
X2 = np.zeros(len(X[::reduction_factor]))
X2 = np.array([1 if len(np.where(X[reduction_factor*j:reduction_factor*j+reduction_factor] == 1)[0]) else 0 for j in range(len(X[::reduction_factor]))])
else:
X2 = np.zeros((len(X), len(X[0][::reduction_factor])))
for i, x in enumerate(X):
X2[i] = np.array([1 if len(np.where(x[reduction_factor*j:reduction_factor*j+reduction_factor] == 1)[0]) else 0 for j in range(len(x[::reduction_factor]))])
return X2
elif self.model_params['input_type'] == 'signal' or input_type == 'signal':
numbers = X
i = 0
moving_averages = []
while (i+1)*self.model_params['input_average'] < len(numbers):
this_window = numbers[i*self.model_params['input_average'] : (i+1)*self.model_params['input_average']]
window_average = sum(this_window) / self.model_params['input_average']
moving_averages.append(window_average)
i += 1
if (i-1)*self.model_params['input_average'] < len(numbers):
this_window = numbers[(i-1)*self.model_params['input_average'] :]
window_average = sum(this_window) / len(this_window)
moving_averages.append(window_average)
return np.array(moving_averages)
else:
print('No input type (signal or spikes) defined!')
return
[docs] def reduce_norm_abs_signal(self, signal):
"""
Reduce the dimensionality of the normalized absolute signal.
Parameters
----------
signal : array_like
Input signal to be normalized and dimensionally reduced.
Returns
-------
array_like
Dimensionally reduced normalized absolute signal.
"""
if signal.ndim == 1:
return self.reduce_dimension(self.normalize_signal(abs(signal)), input_type = 'signal')
else:
return np.array(list(map(self.reduce_dimension, self.normalize_signal(abs(signal)), ['signal' for _ in range(len(signal))])))
[docs] def detect_net(self, inp, minChannelsParticipating=None, minSimultaneousChannels=None):
"""
Detect net bursts or net reverbs.
Parameters
----------
inp : str
Input type for which to detect the net ('reverbs' or 'bursts').
minChannelsParticipating : int, optional
Minimum number of channels participating to start a net burst.
minSimultaneousChannels : int, optional
Minimum number of simultaneous channels to continue a net burst.
Notes
-----
- A net burst or net reverb is defined by bursts or reverbs occurring simultaneously on multiple channels.
- The function detects net bursts or net reverbs based on the input type.
- It starts a net burst when the number of channels bursting together exceeds 'minChannelsParticipating'.
- It ends a net burst when the number of simultaneous channels drops below 'minSimultaneousChannels'.
- The resulting net bursts or net reverbs are stored in the corresponding attributes ('net_reverbs' or 'net_bursts')
and their binary representations are stored in 'net_reverbs_binary' or 'net_bursts_binary', respectively.
"""
if minChannelsParticipating is None:
minChannelsParticipating = np.ceil((2/3)*len(self.active_channels))
if minSimultaneousChannels is None:
minSimultaneousChannels = np.ceil(0.5*len(self.active_channels))
if inp == 'reverbs':
bursts_binary = self.reverbs_binary
elif inp == 'bursts':
bursts_binary = self.bursts_binary
howMuchIsBursting = np.sum(bursts_binary, axis = 0)
netBursts = []
alreadyInBurst = False
for i in range(len(howMuchIsBursting)):
## start net burst when at a timestamp there are more than minChannelParticipating bursting
if (not alreadyInBurst) and (howMuchIsBursting[i] >= minChannelsParticipating):
netBursts.append([i,i])
alreadyInBurst = True
## when the number of channels bursting together drops below minSimultaneousChannels, end the net burst
elif alreadyInBurst and howMuchIsBursting[i] < minSimultaneousChannels:
alreadyInBurst = False
netBursts[-1][-1] = i
for k, netburst in enumerate(netBursts):
## starting at the beginning of every netburst, go against time (decrease timestamp)
## until less than 2 channels are bursting. Reassing the beginning time of the net burst
## to this timestamp
startNetBurst, endNetBurst = netburst
for i in range(startNetBurst, 0, -1):
if(howMuchIsBursting[i] < 2):
netBursts[k][0] = i
break
## starting at the end of every netburst, go ahead in time (increase timestamp)
## until the last burst in the netburst group finishes. Assign this timestamp as
## the new end of netburst
for i in range(endNetBurst, len(howMuchIsBursting)):
if(howMuchIsBursting[i] < 1):
netBursts[k][-1] = i
break
netBursts_copy = []
for i in range(1,len(netBursts)):
if netBursts[i] != netBursts[i-1]:
netBursts_copy.append(netBursts[i])
netbin = np.zeros(len(bursts_binary[1]))
for net in netBursts_copy:
netbin[net[0]:net[1]+1] += 1
netbin[netbin > 1] = 1
if inp == 'reverbs':
self.net_reverbs_binary = netbin
self.net_reverbs = self.convert_binary_to_timestamps(self.net_reverbs_binary, input_type = 'reverbs')
elif inp == 'bursts':
self.net_bursts_binary = netbin
self.net_bursts = self.convert_binary_to_timestamps(self.net_bursts_binary, input_type = 'bursts')
[docs] def save_spikes(self):
"""
Save spikes data to a CSV file.
Notes
-----
- The method saves spikes data to a CSV file containing columns for channel ID, channel label, well ID, well label, compound ID, compound name,
experiment, dose in pM, dose label, and timestamp.
- The CSV file is saved in the output folder with a filename based on the output name attribute and the type of data being saved
(e.g., '_REVERBS_PREDICTED.csv').
"""
columns_spikes = ['Channel ID', 'Channel Label', 'Well ID', 'Well Label', 'Compound ID', 'Compound Name']
columns_spikes.extend(['Experiment', 'Dose [pM]', 'Dose Label'])
columns_spikes.extend(['timestamp'])
spikes_data_file = os.path.join(self.output_folder, f'{self.output_name}_REVERBS_PREDICTED.csv')
[docs] def analyze_dataset(self, file=None, mode='csv', save_default=False):
"""
Analyze the datasets from a CSV file, and save the results to CSV files.
This method conducts a thorough analysis of the dataset, encompassing the detection of reverberations (reverbs),
bursts, network bursts, and various statistics. It then saves the results to CSV files based on the specified mode
and whether to save default analysis results.
Parameters
----------
file : str or None, optional
The file name or path of the CSV with datasets to analyze. If None, the dataset associated with the
dataset attribute will be analyzed.
mode : str, optional
The mode of analysis. Default is 'csv', meaning it will receiva a CSV file with the dataset information.
save_default : bool, optional
Determines whether to also save the results of default analysis alongside model-based analysis. Default is False.
Notes
-----
- The method first creates dataframes to store information about reverbs, bursts, network bursts, and statistics.
- It loads the CSV file, iterates over each dataset, and analyzes each well speficified within the dataset.
- For each well, it performs spike detection, threshold detection, and then applies the specified methods
(default or model-based) for reverbs, bursts, and network bursts detection.
- It calculates various statistics based on the detected bursts and network bursts.
- Finally, it saves the results to separate CSV files according to the specified mode and whether to save default
analysis results.
"""
if self.output_folder != '':
os.makedirs(self.output_folder)
print('--- Running full analysis ---\n')
### create dataframes to save reverbs, bursts, network bursts, and statistics
columns_spikes = ['Channel ID', 'Channel Label', 'Well ID', 'Well Label', 'Compound ID', 'Compound Name']
columns_spikes.extend(['Experiment', 'Dose [pM]', 'Dose Label', 'Timestamp'])
columns_reverbs = ['Channel ID', 'Channel Label', 'Well ID', 'Well Label', 'Compound ID', 'Compound Name']
columns_reverbs.extend(['Experiment', 'Dose [pM]', 'Dose Label'])
columns_reverbs.extend(['Start timestamp [\u03BCs]', 'Duration [\u03BCs]', 'Spike Count', 'Spike Frequency [Hz]'])
columns_bursts = ['Channel ID', 'Channel Label', 'Well ID', 'Well Label', 'Compound ID', 'Compound Name']
columns_bursts.extend(['Experiment', 'Dose [pM]', 'Dose Label'])
columns_bursts.extend(['Start timestamp [\u03BCs]', 'Duration [\u03BCs]', 'Spike Count', 'Spike Frequency [Hz]'])
columns_net = ['Well ID', 'Well Label', 'Compound ID', 'Compound Name']
columns_net.extend(['Experiment', 'Dose [pM]', 'Dose Label'])
columns_net.extend(['Start timestamp [\u03BCs]', 'Duration [\u03BCs]', 'Spike Count', 'Spike Frequency [Hz]'])
columns_stats = ['Filename', 'Well Label', 'Number of channels', 'Total number of spikes', 'Mean Firing Rate [Hz]']
columns_stats.extend(['Stray spikes (%)'])
columns_stats.extend(['Total number of networks bursts', 'Mean Network Bursting Rate [bursts/minute]', 'Mean Network Burst Duration [ms]'])
columns_stats.extend(['NIBI', 'CV of NIBI'])
columns_stats.extend(['Mean reverb per burst', 'Median of reverb per burst'])
columns_stats.extend(['Mean net reverb per net burst', 'Median of net reverb per net burst', 'Total number of network reverb'])
columns_stats.extend(['Mean net reverb frequency [reverb/min]', 'Mean net reverb duration [ms]', 'Mean in-netreverb freq [Hz]'])
# bursts stats
columns_stats.extend(['Stray spikes bursts (%)'])
columns_stats.extend(['Total number of bursts', 'Mean Bursting Rate [bursts/minute]', 'Mean Burst Duration [ms]'])
columns_stats.extend(['Total number of reverb'])
columns_stats.extend(['Mean reverb frequency [reverb/min]', 'Mean reverb duration [ms]', 'Mean in-reverb freq [Hz]'])
columns_stats.extend(['Stray spikes reverbs (%)'])
columns_stats.extend(['NIBI bursts', 'CV of NIBI bursts'])
columns_stats.extend(['NIBI reverbs', 'CV of NIBI reverbs'])
spikes_data = []
spikes_data.append(columns_spikes)
reverbs_data_pred = []
bursts_data_pred = []
netBursts_data_pred = []
stats_data_pred = []
reverbs_data_pred.append(columns_reverbs)
bursts_data_pred.append(columns_reverbs)
netBursts_data_pred.append(columns_net)
stats_data_pred.append(columns_stats)
reverbs_data_def = []
bursts_data_def = []
netBursts_data_def = []
stats_data_def = []
reverbs_data_def.append(columns_reverbs)
bursts_data_def.append(columns_reverbs)
netBursts_data_def.append(columns_net)
stats_data_def.append(columns_stats)
self.loadmodel()
print('--- Using model ', self.model_name, ' ---\n')
if mode == 'csv':
self.files_and_well_csv(file)
for file_index, filename in enumerate(self.dataset):
self.dataset_filename = filename
self.dataset_index = file_index
print('\n Analyzing dataset: ', filename, '\n')
self.loadh5(filename)
fullSignal = self.signal
wellIDsToUse = [self.wellLabelIndexDict[label] for label in self.wellsLabels[file_index]]
self.wellsIDs.append(wellIDsToUse)
for well in wellIDsToUse:
self.well = well
print('Well: ', self.wellIndexLabelDict[well])
signal = fullSignal[np.where(self.wellsFromData == well)[0]]
self.loadsignal(signal)
self.detect_threshold()
self.detect_spikes()
if len(self.active_channels) < len(signal):
self.loadsignal(signal[self.active_channels])
self.detect_threshold()
self.detect_spikes()
if save_default:
self.detect_reverbs(method = 'default')
self.detect_bursts()
self.detect_net('reverbs')
self.detect_net('bursts')
default_reverbs = self.reverbs
default_reverbs_binary = self.reverbs_binary
default_net_reverbs = self.net_reverbs
default_net_reverbs_binary = self.net_reverbs_binary
default_bursts = self.bursts
default_bursts_binary = self.bursts_binary
default_net_bursts = self.net_bursts
default_net_bursts_binary = self.net_bursts_binary
self.detect_reverbs(method = 'model')
self.detect_bursts()
self.detect_net('reverbs')
self.detect_net('bursts')
pred_reverbs = self.reverbs
pred_reverbs_binary = self.reverbs_binary
pred_net_reverbs = self.net_reverbs
pred_net_reverbs_binary = self.net_reverbs_binary
pred_bursts = self.bursts
pred_bursts_binary = self.bursts_binary
pred_net_bursts = self.net_bursts
pred_net_bursts_binary = self.net_bursts_binary
########################
### predicted
########################
compoundID = 'No Compound'
compoundName = 'No Compound'
experiment = filename
dose = 0
doseLabel = 'Control'
i = 0
wellID = well
commonToRow = [wellID, self.wellIndexLabelDict[wellID], compoundID, compoundName, experiment, dose, doseLabel]
for burst in pred_net_bursts:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary
number = self.util.number_of_spikes_inside_burst(self.spikes_binary, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
netBursts_data_pred.append(newRow)
for channel in range(np.where(self.wellsFromData == well)[0].shape[0]):
if channel not in self.active_channels: continue
channelID = self.infoChannel[np.where(self.wellsFromData == well)[0][channel]]['ChannelID']
channelLabel = str(self.infoChannel[np.where(self.wellsFromData == well)[0][channel]]['Label'])[2:-1]
wellID = well
commonToRow = [channelID, channelLabel, wellID, self.wellIndexLabelDict[wellID],
compoundID, compoundName, experiment, dose, doseLabel]
for spike in self.spikes[channel]:
newRow = commonToRow.copy()
newRow.extend([spike])
spikes_data.append(newRow)
for burst in pred_reverbs[channel]:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary[channel].reshape(-1,1).T
number = self.util.number_of_spikes_inside_burst(spikes, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
reverbs_data_pred.append(newRow)
for burst in pred_bursts[channel]:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary[channel].reshape(-1,1).T
number = self.util.number_of_spikes_inside_burst(spikes, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
bursts_data_pred.append(newRow)
# newRow = [filename, self.wellIndexLabelDict[wellID], np.where(self.wellsFromData == well)[0].shape[0]]
newRow = [filename, self.wellIndexLabelDict[wellID], len(self.active_channels)]
newRow.extend([self.util.total_number_of_binary(self.spikes_binary)])
newRow.extend([self.util.mean_firing_rate(self.spikes_binary, total_seconds = self.total_timesteps_signal//self.samplingFreq)])
newRow.extend([self.util.random_spikes_percentage_net(self.spikes_binary, pred_net_bursts)]) # net
newRow.extend([self.util.total_number_of_netBursts(pred_net_bursts)]) # net
newRow.extend([self.util.mean_netbursting_rate(pred_net_bursts, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)]) # net
newRow.extend([self.util.mean_netburst_duration(pred_net_bursts)]) # net
newRow.extend([self.util.mean_interNetBurstTrain_interval(pred_net_bursts)])
newRow.extend([self.util.coeff_variance_interNetBurstTrain_interval(pred_net_bursts)])
newRow.extend([self.util.mean_bursts_per_burstTrain(pred_reverbs, pred_bursts)])
newRow.extend([self.util.median_bursts_per_burstTrain(pred_reverbs, pred_bursts)])
newRow.extend([self.util.mean_bursts_per_burstTrain(pred_net_reverbs, pred_net_bursts, net = True)])
newRow.extend([self.util.median_bursts_per_burstTrain(pred_net_reverbs, pred_net_bursts, net = True)])
newRow.extend([self.util.total_number_of_netBursts(pred_net_reverbs)]) # net
newRow.extend([self.util.mean_netbursting_rate(pred_net_reverbs, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)]) # net
newRow.extend([self.util.mean_netburst_duration(pred_net_reverbs)]) # net
newRow.extend([self.util.mean_innetburst_frequency(self.spikes_binary, pred_net_reverbs)]) # net
# bursts stats
newRow.extend([self.util.random_spikes_percentage(self.spikes_binary, pred_bursts)])
newRow.extend([self.util.total_number_of_bursts(pred_bursts)])
newRow.extend([self.util.mean_bursting_rate(pred_bursts, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_burst_duration(pred_bursts)])
newRow.extend([self.util.total_number_of_bursts(pred_bursts)])
newRow.extend([self.util.mean_bursting_rate(pred_reverbs, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_burst_duration(pred_reverbs)])
newRow.extend([self.util.mean_inburst_frequency(self.spikes_binary, pred_reverbs)])
newRow.extend([self.util.random_spikes_percentage(self.spikes_binary, pred_reverbs)])
newRow.extend([self.util.mean_interBurstTrain_interval(pred_bursts)])
newRow.extend([self.util.coeff_variance_interBurstTrain_interval(pred_bursts)])
newRow.extend([self.util.mean_interBurstTrain_interval(pred_reverbs)])
newRow.extend([self.util.coeff_variance_interBurstTrain_interval(pred_reverbs)])
stats_data_pred.append(newRow)
if save_default:
########################
### default
########################
compoundID = 'No Compound'
compoundName = 'No Compound'
experiment = filename
dose = 0
doseLabel = 'Control'
i = 0
wellID = well
commonToRow = [wellID, self.wellIndexLabelDict[wellID], compoundID, compoundName, experiment, dose, doseLabel]
for burst in default_net_bursts:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary
number = self.util.number_of_spikes_inside_burst(self.spikes_binary, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
netBursts_data_def.append(newRow)
for channel in range(np.where(self.wellsFromData == well)[0].shape[0]):
if channel not in self.active_channels: continue
channelID = self.infoChannel[np.where(self.wellsFromData == well)[0][channel]]['ChannelID']
channelLabel = str(self.infoChannel[np.where(self.wellsFromData == well)[0][channel]]['Label'])[2:-1]
wellID = well
commonToRow = [channelID, channelLabel, wellID, self.wellIndexLabelDict[wellID],
compoundID, compoundName, experiment, dose, doseLabel]
for burst in default_reverbs[channel]:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary[channel].reshape(-1,1).T
number = self.util.number_of_spikes_inside_burst(spikes, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
reverbs_data_def.append(newRow)
for burst in default_bursts[channel]:
startTime = int(burst[0]*100)
duration = int((burst[-1]-burst[0])*100)
spikes = self.spikes_binary[channel].reshape(-1,1).T
number = self.util.number_of_spikes_inside_burst(spikes, burst)
freq = self.util.mean_innetburst_frequency(spikes, [burst])
newRow = commonToRow.copy()
newRow.extend([startTime, duration, number, freq])
bursts_data_def.append(newRow)
# newRow = [filename, self.wellIndexLabelDict[wellID], np.where(self.wellsFromData == well)[0].shape[0]]
newRow = [filename, self.wellIndexLabelDict[wellID], len(self.active_channels)]
newRow.extend([self.util.total_number_of_binary(self.spikes_binary)])
newRow.extend([self.util.mean_firing_rate(self.spikes_binary, total_seconds = self.total_timesteps_signal//self.samplingFreq)])
newRow.extend([self.util.random_spikes_percentage_net(self.spikes_binary, default_net_bursts)])
newRow.extend([self.util.total_number_of_netBursts(default_net_bursts)])
newRow.extend([self.util.mean_netbursting_rate(default_net_bursts, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_netburst_duration(default_net_bursts)])
newRow.extend([self.util.mean_interNetBurstTrain_interval(default_net_bursts)])
newRow.extend([self.util.coeff_variance_interNetBurstTrain_interval(default_net_bursts)])
newRow.extend([self.util.mean_bursts_per_burstTrain(default_reverbs, default_bursts)])
newRow.extend([self.util.median_bursts_per_burstTrain(default_reverbs, default_bursts)])
newRow.extend([self.util.mean_bursts_per_burstTrain(default_net_reverbs, default_net_bursts, net = True)])
newRow.extend([self.util.median_bursts_per_burstTrain(default_net_reverbs, default_net_bursts, net = True)]) # net reverbs per net bursts
newRow.extend([self.util.total_number_of_netBursts(default_net_reverbs)])
newRow.extend([self.util.mean_netbursting_rate(default_net_reverbs, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_netburst_duration(default_net_reverbs)])
newRow.extend([self.util.mean_innetburst_frequency(self.spikes_binary, default_net_reverbs)])
# bursts stats
newRow.extend([self.util.random_spikes_percentage(self.spikes_binary, default_bursts)])
newRow.extend([self.util.total_number_of_bursts(default_bursts)])
newRow.extend([self.util.mean_bursting_rate(default_bursts, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_burst_duration(default_bursts)])
newRow.extend([self.util.total_number_of_bursts(default_bursts)])
newRow.extend([self.util.mean_bursting_rate(default_reverbs, total_minutes=self.total_timesteps_signal//self.samplingFreq//60)])
newRow.extend([self.util.mean_burst_duration(default_reverbs)])
newRow.extend([self.util.mean_inburst_frequency(self.spikes_binary, default_reverbs)])
newRow.extend([self.util.random_spikes_percentage(self.spikes_binary, default_reverbs)])
newRow.extend([self.util.mean_interBurstTrain_interval(default_bursts)])
newRow.extend([self.util.coeff_variance_interBurstTrain_interval(default_bursts)])
newRow.extend([self.util.mean_interBurstTrain_interval(default_reverbs)])
newRow.extend([self.util.coeff_variance_interBurstTrain_interval(default_reverbs)])
stats_data_def.append(newRow)
########################
### save files
########################
spikes_data_file = os.path.join(self.output_folder, f'{self.output_name}_SPIKES.csv')
if self.analysis_params['save_spikes']: pd.DataFrame(spikes_data[1:], columns = spikes_data[0]).to_csv(spikes_data_file, index = False)
reverbs_data_pred_file = os.path.join(self.output_folder, f'{self.output_name}_REVERBS_PREDICTED.csv')
bursts_data_pred_file = os.path.join(self.output_folder, f'{self.output_name}_BURSTS_PREDICTED.csv')
netBursts_data_pred_file = os.path.join(self.output_folder, f'{self.output_name}_NET_BURSTS_PREDICTED.csv')
stats_data_pred_file = os.path.join(self.output_folder, f'{self.output_name}_STATS_PREDICTED.csv')
if self.analysis_params['save_reverbs'] : pd.DataFrame(reverbs_data_pred[1:], columns = reverbs_data_pred[0]).to_csv(reverbs_data_pred_file, index = False)
if self.analysis_params['save_bursts'] : pd.DataFrame(bursts_data_pred[1:], columns = bursts_data_pred[0]).to_csv(bursts_data_pred_file, index = False)
if self.analysis_params['save_net_bursts'] : pd.DataFrame(netBursts_data_pred[1:], columns = netBursts_data_pred[0]).to_csv(netBursts_data_pred_file, index = False)
if self.analysis_params['save_stats'] : pd.DataFrame(stats_data_pred[1:], columns = stats_data_pred[0]).to_csv(stats_data_pred_file, index = False)
if save_default:
reverbs_data_def_file = os.path.join(self.output_folder, f'{self.output_name}_REVERBS_DEFAULT.csv')
bursts_data_def_file = os.path.join(self.output_folder, f'{self.output_name}_BURSTS_DEFAULT.csv')
netBursts_data_def_file = os.path.join(self.output_folder, f'{self.output_name}_NET_BURSTS_DEFAULT.csv')
stats_data_def_file = os.path.join(self.output_folder, f'{self.output_name}_STATS_DEFAULT.csv')
if self.analysis_params['save_reverbs'] : pd.DataFrame(reverbs_data_def[1:], columns = reverbs_data_def[0]).to_csv(reverbs_data_def_file, index = False)
if self.analysis_params['save_bursts'] : pd.DataFrame(bursts_data_def[1:], columns = bursts_data_def[0]).to_csv(bursts_data_def_file, index = False)
if self.analysis_params['save_net_bursts'] : pd.DataFrame(netBursts_data_def[1:], columns = netBursts_data_def[0]).to_csv(netBursts_data_def_file, index = False)
if self.analysis_params['save_stats'] : pd.DataFrame(stats_data_def[1:], columns = stats_data_def[0]).to_csv(stats_data_def_file, index = False)
print('\n--- Done! ---')
##########################################################################################################
##########################################################################################################
# plotting functions
########
########
[docs] def plot_window(self, signal, start_time=None, duration=None, threshold=None, spikes=None,
reverberations=None, net_reverberations=None, bursts=None, net_bursts=None,
save=None, show = True, figsize=(6, 6), yunits='a.u.', xunits='s'):
"""
Plot a window of the signal with detected spikes, reverberations, bursts, and network bursts.
Parameters
----------
signal : array_like
The input signal data.
start_time : float, optional
The start time of the window in seconds. Default is None (start from the beginning).
duration : float, optional
The duration of the window in seconds. Default is None (plot until the end of the signal).
threshold : float, optional
The threshold value for plotting. Default is None (no threshold line).
spikes : array_like, optional
The timestamps of detected spikes. Default is None (no spikes in the plot)
reverberations : array_like, optional
The timestamps data of detected reverberations. Default is None
bursts : array_like, optional
The timestamps data detected bursts. Default is None
net_bursts : array_like, optional
The timestamps data of detected network bursts. Default is None
save : str, optional
Name of the generated plot figure.
show : bool, optional
Whether to show or not the plot. Default is True.
figsize : tuple, optional
The size of the figure (width, height) in inches. Default is (6, 6).
yunits : str, optional
The units of the y-axis. Default is 'a.u.' (arbitrary units).
xunits : str, optional
The units of the x-axis. Default is 's' (seconds).
Notes
-----
- It normalizes the signal if yunits is 'a.u.' and converts it to millivolts (mV) if yunits is 'mV'.
- Detected spikes, reverberations, bursts, and network bursts are overlaid on the plot with different colors.
"""
if signal.ndim > 1:
print("More than one signal (channel) input!")
return
if start_time is None: start_time = 0
start_timestamp = int(start_time*self.samplingFreq)
if duration is None: duration = int(len(signal)/self.samplingFreq)
if start_timestamp > len(signal):
print("Start time (in seconds) bigger than recorded signal time.")
return
duration_ts = int(duration*self.samplingFreq)
if start_timestamp + duration_ts > len(signal):
print("Start time + duration (in seconds) bigger than recorded signal time.")
return
end_timestamp = start_timestamp + duration_ts
spikes_to_plot = []
if spikes:
if isinstance(spikes, np.ndarray) and signal.ndim == 1:
print('More than one spike channel input!')
return
elif self.util._has_list(spikes):
print('More than one spike channel input!')
return
spikes = np.array(spikes)
spikes_to_plot = spikes[(spikes >= start_timestamp) & (spikes <= end_timestamp)]
def select_bursts_to_plot(bursts, start_timestamp, end_timestamp):
bursts = np.array(bursts)
bursts_starts = bursts.T[0]
bursts_ends = bursts.T[1]
try:
plot_from = np.where(bursts_starts >= start_timestamp)[0][0] - 1
plot_to = np.where(bursts_ends <= end_timestamp)[0][-1] + 1
if plot_from <= 0: plot_from = 0
if plot_to >= len(bursts): plot_to = len(bursts)
return bursts[plot_from:plot_to]
except:
return np.array([])
reverbs_to_plot = []
if reverberations:
if isinstance(reverberations, np.ndarray):
if reverberations.ndim > 2:
print("More than one reverberations channel input!")
return
elif self.util._has_list(reverberations[0]):
print("More than one reverberations channel input!")
return
reverbs_to_plot = select_bursts_to_plot(reverberations, start_timestamp, end_timestamp)
bursts_to_plot = []
if bursts:
if isinstance(bursts, np.ndarray):
if bursts.ndim > 2:
print("More than one reverberations channel input!")
return
elif self.util._has_list(bursts[0]):
print("More than one reverberations channel input!")
return
bursts_to_plot = select_bursts_to_plot(bursts, start_timestamp, end_timestamp)
net_reverbs_to_plot = []
if net_reverberations:
if isinstance(net_reverberations, np.ndarray):
if net_reverberations.ndim > 2:
print("More than one reverberations channel input!")
return
elif self.util._has_list(net_reverberations[0]):
print("More than one reverberations channel input!")
return
net_reverbs_to_plot = select_bursts_to_plot(net_reverberations, start_timestamp, end_timestamp)
net_bursts_to_plot = []
if net_bursts:
if isinstance(net_bursts, np.ndarray):
if net_bursts.ndim > 2:
print("More than one reverberations channel input!")
return
elif self.util._has_list(net_bursts[0]):
print("More than one reverberations channel input!")
return
net_bursts_to_plot = select_bursts_to_plot(net_bursts, start_timestamp, end_timestamp)
if yunits.lower() == 'a.u.':
sig = self.normalize_signal(signal)
elif yunits.lower() == 'v':
sig = self.convert_signal(signal, self.adZero, self.conversionFactor, self.exponent)
fig = plt.figure(figsize=figsize)
for spike in spikes_to_plot:
plt.axvline(spike/self.samplingFreq,0,0.022, c = 'k', lw = 0.5)
if yunits.lower() == 'a.u.':
ylabel = yunits
position_reverb = -1.24
position_bursts = -1.164
position_net_bursts = -1.07
elif yunits.lower() == 'v':
ylabel = chr(956)+'V'
position_reverb = 1.25*(min(sig))
position_net_reverb = 1.175*(min(sig))
position_bursts = 1.1*(min(sig))
position_net_bursts = 1.025*(min(sig))
for reverb in reverbs_to_plot:
plt.hlines(position_reverb,reverb[0]/self.samplingFreq,reverb[-1]/self.samplingFreq, colors = mpl.cm.Set3(2/11), lw = 6.5)
for net_reverb in net_reverbs_to_plot:
plt.hlines(position_net_reverb,net_reverb[0]/self.samplingFreq,net_reverb[-1]/self.samplingFreq, colors = mpl.cm.Set3(3/11), lw = 6.5)
for burst in bursts_to_plot:
plt.hlines(position_bursts,burst[0]/self.samplingFreq,burst[-1]/self.samplingFreq, colors = mpl.cm.Set3(4/11), lw = 6.5)
for net_bursts in net_bursts_to_plot:
plt.hlines(position_net_bursts,net_bursts[0]/self.samplingFreq,net_bursts[-1]/self.samplingFreq, colors = mpl.cm.Set3(5/11), lw = 6.5)
plt.plot(np.arange(start_timestamp, end_timestamp)/self.samplingFreq, sig[start_timestamp:end_timestamp], c = 'k', lw = 1)
if threshold is not None:
if yunits.lower() == 'a.u.':
thresh = self.normalize_threshold(signal, threshold)
elif yunits.lower() == 'v':
thresh = self.convert_threshold(threshold, self.adZero, self.conversionFactor, self.exponent)
plt.axhline(thresh, color = 'k', lw = 2)
plt.axhline(-thresh, color = 'k', lw = 2)
plt.xlabel(f'Time [{xunits}]')
plt.ylabel(f'Signal [{ylabel}]')
if yunits.lower() == 'a.u.':
plt.ylim(-1.35,1.35)
elif yunits.lower() == 'v':
plt.ylim(1.35*min(sig), 1.35*max(sig))
plt.xlim(start_timestamp/self.samplingFreq, end_timestamp/self.samplingFreq)
plt.grid(ls = 'dotted')
if save is not None:
plt.savefig(save)
if show: plt.show()
plt.close()
[docs] def plot_raster(self, spikes, reverbs = None, bursts = None, net_reverbs = None, net_bursts = None,
start_time = None, end_time = None, save = None, show = True):
"""
Plot a raster plot of spikes with optional overlay of reverberations, bursts, and network bursts.
Parameters
----------
spikes : array_like
The timestamps of spikes. Can be a list of timestamps for multiple channels.
reverbs : array_like, optional
The timestamps data of detected reverberations.
bursts : array_like, optional
The timestamps data of detected bursts.
net_reverbs : array_like, optional
The timestamps data of detected network reverberations.
net_bursts : array_like, optional
The timestamps data of detected network bursts.
start_time : float, optional
Starting time of the plot, in seconds.
end_time : float, optional
Ending time of the plot, in seconds.
save : str, optional
Name of the generated plot figure.
show : bool, optional
Whether to show or not the plot. Default is True.
Returns
-------
None
Notes
-----
- Each spike is represented by a vertical line at its timestamp.
- Detected events are filled between their start and end timestamps.
"""
if self.util._has_list(spikes):
number_of_channels = len(spikes)
else:
number_of_channels = 1
spikes = [spikes]
def plot_channel_(bursts):
for channel, bursts_channel in enumerate(bursts):
for burst in bursts_channel:
y1 = (height*(number_of_channels-channel-1))+(0.9*height)
y2 = (height*(number_of_channels-channel))-(0.9*height)
plt.fill_between(np.array([burst[0]/self.samplingFreq, burst[1]/self.samplingFreq]),
y1, y2, color = 'r', alpha = 0.3, zorder = 1)
def plot_well_(nets):
for net in nets:
plt.fill_between(np.array([net[0]/self.samplingFreq, net[1]/self.samplingFreq]),
1, 0, color = '#DEE3E2', zorder = 1)
plt.figure(figsize = (10,5))
height = 1/number_of_channels
y_ticks = np.linspace(0,1,number_of_channels+1)[1:]-(height/2)
plt.hlines(y_ticks, 0, self.total_timesteps_signal/self.samplingFreq, color = 'k', lw = 1)
for channel, spikes_c in enumerate(spikes):
plt.vlines(np.array(spikes_c)/self.samplingFreq,(height*(number_of_channels-channel-1))+(height/4), (height*(number_of_channels-channel))-(height/4), color = 'k', lw = 1)
if reverbs: plot_channel_(reverbs)
if net_bursts: plot_well_(net_bursts)
plt.ylim(0,1)
if start_time is None: start_time = 0
if end_time is None: end_time = self.total_timesteps_signal/self.samplingFreq
plt.xlim(start_time, end_time)
plt.tick_params(axis = 'y',
which = 'both',
direction = 'in')
plt.yticks(y_ticks[::-1], labels = [f'e{channel+1}' for channel in range(number_of_channels)], fontsize = 20)
plt.xticks([])
if save is not None:
plt.savefig(save)
if show: plt.show()
plt.close()
[docs] def plot_raster_well(self, file:str, well, method = 'default', reverbs = False, bursts = False, net_reverbs = False,
net_bursts = False, start_time = None, end_time = None, save = None):
"""
Plot a raster plot for a specific well with optional overlay of events.
Parameters
----------
file : str
The filename or path of the data file.
well : str
The label or ID of the well to plot.
method : str, optional
The method used for detecting events. Default is 'default'.
reverbs : bool, optional
Whether to overlay detected reverberations on the raster plot. Default is False.
bursts : bool, optional
Whether to overlay detected bursts on the raster plot. Default is False.
net_reverbs : bool, optional
Whether to overlay detected network reverberations on the raster plot. Default is False.
net_bursts : bool, optional
Whether to overlay detected network bursts on the raster plot. Default is False.
start_time : float, optional
Starting time of the plot, in seconds.
end_time : float, optional
Ending time of the plot, in seconds.
save : str, optional
Name of the generated plot figure.
Returns
-------
None
"""
self.loadwell(file, well,
method = method, spikes = True,
reverbs = reverbs, bursts = bursts,
net_reverbs = net_reverbs, net_bursts = net_bursts)
reverbs_ = self.reverbs if reverbs else None
bursts_ = self.bursts if bursts else None
net_reverbs_ = self.net_reverbs if net_reverbs else None
net_bursts_ = self.net_bursts if net_bursts else None
self.plot_raster(self.spikes, reverbs = reverbs_, bursts = bursts_,
net_reverbs = net_reverbs_, net_bursts = net_bursts_,
start_time = start_time, end_time = end_time, save = save)