Source code for biobss.imutools.acc_features

from typing import Callable

import numpy as np

from biobss.imutools.acc_correlation import *
from biobss.imutools.acc_freqdomain import *
from biobss.imutools.acc_statistical import *


[docs]def get_domain_function(domain: str) -> Callable: if domain == "Corr": return acc_corr_features elif domain == "Freq": return acc_freq_features elif domain == "Stat": return acc_stat_features else: raise ValueError("Unknown domain:", domain)
[docs]def get_acc_features( signals: list, signal_names: list, sampling_rate: float, feature_types: list = ["Freq", "Stat", "Corr"], magnitude: bool = False, ) -> dict: """Calculates ACC features. Args: signals (list): List of ACC signals for different axes. signal_names (list): List of ACC signal names for different axes. It must have the same order with the signals. sampling_rate (float): Sampling rate of the ACC signals (Hz). feature_types (list, optional): Feature types to be calculated. It can be a list of 'Freq', 'Stat' and 'Corr'. Defaults to ['Freq','Stat','Corr']. magnitude (bool, optional): If True, features are also calculated for magnitude signal. Defaults to False. Raises: ValueError: If sampling rate is not greater than zero. ValueError: If feature type is invalid. Returns: dict: Dictionary of ACC features. """ if np.ndim(signals) == 1: signals = [signals] if isinstance(signal_names, str): signal_names = [signal_names] data = dict(zip(signal_names, signals)) if sampling_rate <= 0: raise ValueError("Sampling rate must be greater than 0.") feature_types = [x.capitalize() for x in feature_types] valid_types = ["Freq", "Stat", "Corr"] features = {} for domain in feature_types: if domain not in valid_types: raise ValueError("invalid feature type: " + domain) else: domain_function = get_domain_function(domain) features.update(domain_function(signals=signals, signal_names=signal_names, sampling_rate=sampling_rate)) if domain in ["Freq", "Stat"]: if magnitude: sum = 0 for sig in data.values(): sum += np.square(sig) magn = np.sqrt(sum) features.update(domain_function(signals=[magn], signal_names=["magn"], sampling_rate=sampling_rate)) return features