Skip to content
Merged
10 changes: 8 additions & 2 deletions specparam/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(self, name, description, public_settings, private_settings=None,
self.set_debug(debug)


def _fit_prechecks(self):
def _fit_prechecks(self, verbose):
"""Pre-checks to run before the fit function - if are some, overload this function."""


Expand Down Expand Up @@ -195,11 +195,17 @@ def _initialize_bounds(self, mode):
[high_bound_param1, high_bound_param2])
"""

n_params = getattr(self.modes, mode).n_params
# If modes defined, get number of params - otherwise set stores as empty
if self.modes is not None:
n_params = getattr(self.modes, mode).n_params
else:
n_params = 0

bounds = (np.array([-np.inf] * n_params), np.array([np.inf] * n_params))

return bounds


def _initialize_guess(self, mode):
"""Initialize a guess definition.

Expand Down
16 changes: 16 additions & 0 deletions specparam/algorithms/definitions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
"""Define collection of fitting algorithms."""

from functools import partial

from specparam.utils.checks import check_selection
from specparam.algorithms.algorithm import Algorithm
from specparam.algorithms.spectral_fit import SpectralFitAlgorithm

###################################################################################################
Expand All @@ -9,3 +13,15 @@
ALGORITHMS = {
'spectral_fit' : SpectralFitAlgorithm,
}


def check_algorithms():
"""Check the set of available fit algorithms."""

print('Available algorithms:')
for algorithm in ALGORITHMS.values():
algorithm = algorithm()
print(' {:12s} : {:s}'.format(algorithm.name, algorithm.description))


check_algorithm_definition = partial(check_selection, definition=Algorithm)
2 changes: 1 addition & 1 deletion specparam/algorithms/spectral_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h

# Initialize base algorithm object with algorithm metadata
super().__init__(
name='spectral fit',
name='spectral_fit',
description='Original parameterizing neural power spectra algorithm.',
public_settings=SPECTRAL_FIT_SETTINGS_DEF,
private_settings=SPECTRAL_FIT_PRIVATE_SETTINGS_DEF,
Expand Down
6 changes: 6 additions & 0 deletions specparam/metrics/definitions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
"""Collect together library of available built in metrics."""

from functools import partial

from specparam.metrics.metrics import Metric
from specparam.metrics.error import (compute_mean_abs_error, compute_mean_squared_error,
compute_root_mean_squared_error, compute_median_abs_error)
from specparam.metrics.gof import compute_r_squared, compute_adj_r_squared
from specparam.utils.checks import check_selection

###################################################################################################
## ERROR METRICS
Expand Down Expand Up @@ -79,3 +82,6 @@ def check_metrics():
print('Available metrics:')
for metric in METRICS.values():
print(' {:8s} {:12s} : {:s}'.format(metric.category, metric.measure, metric.description))


check_metric_definition = partial(check_selection, definition=Metric)
3 changes: 3 additions & 0 deletions specparam/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from specparam.metrics.metric import Metric
from specparam.metrics.definitions import METRICS, check_metric_definition

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -61,6 +62,8 @@ def add_metric(self, metric):
if isinstance(metric, dict):
metric = Metric(**metric)

metric = check_metric_definition(metric, METRICS)

self.metrics.append(deepcopy(metric))


Expand Down
4 changes: 2 additions & 2 deletions specparam/models/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def __init__(self, *args, **kwargs):
self.data = Data3D()

self.results = Results3D(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
4 changes: 2 additions & 2 deletions specparam/models/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,8 @@ def __init__(self, *args, **kwargs):
self.data = Data2D()

self.results = Results2D(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
30 changes: 17 additions & 13 deletions specparam/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from specparam.data.data import Data
from specparam.data.conversions import model_to_dataframe
from specparam.results.results import Results

from specparam.algorithms.spectral_fit import SpectralFitAlgorithm, SPECTRAL_FIT_SETTINGS_DEF
from specparam.algorithms.definitions import ALGORITHMS, check_algorithm_definition

from specparam.reports.save import save_model_report
from specparam.reports.strings import gen_model_results_str
from specparam.modutils.errors import NoDataError, FitError
Expand All @@ -34,10 +37,14 @@ class SpectralModel(BaseModel):
Parameters
----------
% copied in from Spectral Fit Algorithm Settings
aperiodic_mode : {'fixed', 'knee'}
aperiodic_mode : {'fixed', 'knee'} or Mode
Which approach to take for fitting the aperiodic component.
periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'}
periodic_mode : {'gaussian', 'skewed_gaussian', 'cauchy'} or Mode
Which approach to take for fitting the periodic component.
algorithm : {'spectral_fit'} or Algorithm
The fitting algorithm to use.
algorithm_settings : dict
Setting for the algorithm.
metrics : Metrics or list of Metric or list or str
Metrics definition(s) to use to evaluate the model.
bands : Bands or dict or int or None, optional
Expand All @@ -49,6 +56,7 @@ class SpectralModel(BaseModel):
Verbosity mode. If True, prints out warnings and general status updates.
**model_kwargs
Additional model fitting related keyword arguments.
These are passed into the algorithm object.

Attributes
----------
Expand All @@ -71,25 +79,21 @@ class SpectralModel(BaseModel):
as this will give better model fits.
"""

def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0,
peak_threshold=2.0, aperiodic_mode='fixed', periodic_mode='gaussian',
def __init__(self, aperiodic_mode='fixed', periodic_mode='gaussian',
algorithm='spectral_fit', algorithm_settings=None,
metrics=None, bands=None, debug=False, verbose=True, **model_kwargs):
"""Initialize model object."""

BaseModel.__init__(self,
aperiodic_mode=aperiodic_mode,
periodic_mode=periodic_mode,
verbose=verbose)
BaseModel.__init__(self, aperiodic_mode, periodic_mode, verbose)

self.data = Data()

self.results = Results(modes=self.modes, metrics=metrics, bands=bands)

self.algorithm = SpectralFitAlgorithm(
peak_width_limits=peak_width_limits, max_n_peaks=max_n_peaks,
min_peak_height=min_peak_height, peak_threshold=peak_threshold,
modes=self.modes, data=self.data, results=self.results,
debug=debug, **model_kwargs)
algorithm_settings = {} if algorithm_settings is None else algorithm_settings
self.algorithm = check_algorithm_definition(algorithm, ALGORITHMS)(
**algorithm_settings, modes=self.modes, data=self.data,
results=self.results, debug=debug, **model_kwargs)


@replace_docstring_sections([docs_get_section(Data.add_data.__doc__, 'Parameters'),
Expand Down
4 changes: 2 additions & 2 deletions specparam/models/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ def __init__(self, *args, **kwargs):
self.data = Data2DT()

self.results = Results2DT(modes=self.modes,
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))
metrics=kwargs.pop('metrics', None),
bands=kwargs.pop('bands', None))

self.algorithm._reset_subobjects(data=self.data, results=self.results)

Expand Down
3 changes: 2 additions & 1 deletion specparam/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ def combine_model_objs(model_objs):
"or meta data, and so cannot be combined.")

# Initialize group model object, with settings derived from input objects
group = SpectralGroupModel(*model_objs[0].algorithm.get_settings(),
group = SpectralGroupModel(**model_objs[0].modes.get_modes()._asdict(),
**model_objs[0].algorithm.get_settings()._asdict(),
verbose=model_objs[0].verbose)

# Use a temporary store to collect spectra, as we'll only add it if it is consistently present
Expand Down
5 changes: 5 additions & 0 deletions specparam/modes/definitions.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Define fitting modes."""

from functools import partial
from collections import OrderedDict

from specparam.modes.mode import Mode
from specparam.modes.params import ParamDefinition
from specparam.modes.funcs import (expo_function, expo_nk_function, double_expo_function,
gaussian_function, skewed_gaussian_function, cauchy_function)
from specparam.modes.jacobians import jacobian_gauss
from specparam.utils.checks import check_selection

###################################################################################################
## APERIODIC MODES
Expand Down Expand Up @@ -184,3 +186,6 @@ def check_modes(component, check_params=False):
print('\n{:s}'.format(mode.name))
print(' {:s}'.format(mode.description))
mode.check_params()


check_mode_definition = partial(check_selection, definition=Mode)
37 changes: 2 additions & 35 deletions specparam/modes/modes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Modes object."""

from specparam.data import ModelModes
from specparam.modes.mode import Mode, VALID_COMPONENTS
from specparam.modes.definitions import AP_MODES, PE_MODES
from specparam.modes.mode import VALID_COMPONENTS
from specparam.modes.definitions import check_mode_definition, AP_MODES, PE_MODES

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -49,36 +49,3 @@ def get_modes(self):

return ModelModes(aperiodic_mode=self.aperiodic.name if self.aperiodic else None,
periodic_mode=self.periodic.name if self.periodic else None)


def check_mode_definition(mode, options):
"""Check a mode specification.

Parameters
----------
mode : str or None or Mode
Fit mode. If str, should be a label corresponding to an entry in `options`.
options : dict
Available modes.

Returns
-------
mode : Mode or None
Mode object, if defined, or None if not defined.

Raises
------
ValueError
If the mode definition is not found / understood.
"""

if isinstance(mode, str):
assert mode in list(options.keys()), 'Specific Mode not found.'
mode = options[mode]

if mode is None:
mode = None
elif not isinstance(mode, Mode):
raise ValueError('Mode input not understood.')

return mode
11 changes: 6 additions & 5 deletions specparam/results/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,17 +125,18 @@ def add_metrics(self, metrics):

Parameters
----------
metrics : Metrics or list of Metric or list or str
metrics : Metrics or list of Metric or list of str or None
Metrics definition(s) to add to object.
If None, initialized with default metrics.
"""

if metrics is None:
metrics = DEFAULT_METRICS

if isinstance(metrics, Metrics):
self.metrics = deepcopy(metrics)
elif isinstance(metrics, list):
self.metrics = Metrics(\
[METRICS[metric] if isinstance(metric, str) else metric for metric in metrics])
else:
self.metrics = Metrics([METRICS[metric] for metric in DEFAULT_METRICS])
self.metrics = Metrics(metrics)


def add_results(self, results):
Expand Down
25 changes: 25 additions & 0 deletions specparam/tests/algorithms/test_definitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Tests for specparam.algorithms.definitions."""

from specparam.algorithms.algorithm import Algorithm

from specparam.algorithms.definitions import *

###################################################################################################
###################################################################################################

def test_algorithms_library():

for key, algorithm in ALGORITHMS.items():
algorithm = algorithm()
assert isinstance(algorithm, Algorithm)
assert algorithm.name == key

def test_check_algorithms():

check_algorithms()

def test_check_algorithm_definition():

for algorithm in ALGORITHMS.keys():
algorithm = check_algorithm_definition(algorithm, ALGORITHMS)
assert issubclass(algorithm, Algorithm)
9 changes: 7 additions & 2 deletions specparam/tests/metrics/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,10 @@

def test_metrics_library():

for key in METRICS:
assert isinstance(METRICS[key], Metric)
for key, metric in METRICS.items():
assert isinstance(metric, Metric)
assert metric.label == key

def test_check_metrics():

check_metrics()
19 changes: 19 additions & 0 deletions specparam/tests/modes/test_definitions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,30 @@
"""Tests for specparam.modes.definitions."""

from specparam.modes.mode import Mode, VALID_COMPONENTS

from specparam.modes.definitions import *

###################################################################################################
###################################################################################################

def test_modes_library():

for component in VALID_COMPONENTS:
for key, mode in MODES[component].items():
assert isinstance(mode, Mode)
assert mode.name == key

def test_check_modes():

check_modes('aperiodic')
check_modes('periodic')

def test_check_mode_definition():

for ap_mode in AP_MODES.keys():
mode = check_mode_definition(ap_mode, AP_MODES)
assert isinstance(mode, Mode)

for pe_mode in PE_MODES.keys():
mode = check_mode_definition(pe_mode, PE_MODES)
assert isinstance(mode, Mode)
11 changes: 1 addition & 10 deletions specparam/tests/modes/test_modes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for specparam.modes.modes."""

from specparam.data import ModelModes
from specparam.modes.mode import Mode
from specparam.modes.definitions import AP_MODES, PE_MODES

from specparam.modes.modes import *
Expand All @@ -26,13 +27,3 @@ def test_modes_get_modes():
assert isinstance(mode_names, ModelModes)
assert mode_names.aperiodic_mode == ap_mode_name
assert mode_names.periodic_mode == pe_mode_name

def test_check_mode_definition():

for ap_mode in AP_MODES.keys():
mode = check_mode_definition(ap_mode, AP_MODES)
assert isinstance(mode, Mode)

for pe_mode in PE_MODES.keys():
mode = check_mode_definition(pe_mode, PE_MODES)
assert isinstance(mode, Mode)
Loading