Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ Metric Object
:toctree: generated/

Metric
check_metrics

Metric Functions (Error)
~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions specparam/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
"""Metrics sub-module."""

from .metric import Metric
from .check import check_metrics
22 changes: 22 additions & 0 deletions specparam/metrics/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Functionality to check available metrics."""

###################################################################################################
###################################################################################################

def check_metrics(category='all'):
"""Check the set of available metrics.

Parameters
----------
category : {'all', 'error', 'gof'}
Which category of metrics to check.
"""

from specparam.metrics.definitions import METRICS

category = ['error', 'gof'] if category == 'all' else [category]

for cat in category:
print('Available {} metrics:'.format(cat))
for metric in METRICS[cat].values():
print(' {:15s} {:s}'.format(metric.measure, metric.description))
68 changes: 45 additions & 23 deletions specparam/metrics/definitions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Collect together library of available built in metrics."""

from functools import partial

from specparam.metrics.metrics import Metric
from specparam.metrics.error import (compute_mean_abs_error, compute_mean_squared_error,
compute_root_mean_squared_error, compute_median_abs_error)
Expand Down Expand Up @@ -39,6 +37,14 @@
func=compute_median_abs_error,
)

# Collect available error metrics
ERROR_METRICS = {
'mae' : error_mae,
'mse' : error_mse,
'rmse' : error_rmse,
'medae' : error_medae,
}

###################################################################################################
## GOF

Expand All @@ -58,30 +64,46 @@
results.params.periodic.params.size + results.params.aperiodic.params.size},
)

# Collect available error metrics
GOF_METRICS = {
'rsquared' : gof_rsquared,
'adjrsquared' : gof_adjrsquared,
}

###################################################################################################
## COLLECT ALL METRICS TOGETHER

# Collect a store of all available metrics
METRICS = {

# Available error metrics
'error_mae' : error_mae,
'error_mse' : error_mse,
'error_rmse' : error_rmse,
'error_medae' : error_medae,

# Available GOF / r-squared metrics
'gof_rsquared' : gof_rsquared,
'gof_adjrsquared' : gof_adjrsquared,

'error' : ERROR_METRICS,
'gof' : GOF_METRICS,
}


def check_metrics():
"""Check the set of available metrics."""

print('Available metrics:')
for metric in METRICS.values():
print(' {:8s} {:12s} : {:s}'.format(metric.category, metric.measure, metric.description))


check_metric_definition = partial(check_selection, options=METRICS, definition=Metric)
###################################################################################################
## CHECKER FUNCTION

def check_metric_definition(metric):
"""Check a metric definition.

Parameters
----------
metric : Metric or dict or str
Definition for a metric to check.
If dict, should have keys corresponding to a metric definition.
If str, should be the label corresponding to a defined metric (see `check_metrics`).

Returns
-------
Metric
Metric definition.
"""

if isinstance(metric, dict):
metric = Metric(**metric)
elif isinstance(metric, str):
category, label = metric.split('_')
metric = check_selection(label, METRICS[category], Metric)
else:
metric = check_selection(metric, [], Metric)

return metric
14 changes: 5 additions & 9 deletions specparam/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class Metrics():

Parameters
----------
metrics : list of Metric or list of dict
metrics : list of Metric or list of dict or list of str
Metric(s) to add to the object.
"""

Expand Down Expand Up @@ -55,25 +55,21 @@ def add_metric(self, metric):

Parameters
----------
metric : Metric or dict
metric : Metric or dict or str
Metric to add to the object.
If dict, should have keys corresponding to a metric definition.
If str, should be the label corresponding to a defined metric (see `check_metrics`).
"""

if isinstance(metric, dict):
metric = Metric(**metric)

metric = check_metric_definition(metric)

self.metrics.append(deepcopy(metric))
self.metrics.append(deepcopy(check_metric_definition(metric)))


def add_metrics(self, metrics):
"""Add metric(s) to object

Parameters
----------
metrics : list of Metric or list of dict
metrics : list of Metric or list of dict or list of str
Metric(s) to add to the object.
"""

Expand Down
2 changes: 1 addition & 1 deletion specparam/modes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

from .mode import Mode
from .params import ParamDefinition
from .definitions import check_modes
from .check import check_modes
29 changes: 29 additions & 0 deletions specparam/modes/check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Functionality to check available modes."""

###################################################################################################
###################################################################################################

def check_modes(component='all', check_params=False):
"""Check the set of modes that are available.

Parameters
----------
component : {'all', 'aperiodic', 'periodic'}
Which component to check available modes for.
check_params : bool, optional, default: False
Whether to print out information on the parameters of each mode.
"""

from specparam.modes.definitions import MODES

component = ['aperiodic', 'periodic'] if component == 'all' else [component]

for comp in component:
print('Available {:s} modes:'.format(comp))
for mode in MODES[comp].values():
if not check_params:
print(' {:15s} {:s}'.format(mode.name, mode.description))
else:
print('\n{:s}'.format(mode.name))
print(' {:s}'.format(mode.description))
mode.check_params()
29 changes: 12 additions & 17 deletions specparam/modes/definitions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
"""Define fitting modes."""

from functools import partial
from collections import OrderedDict

from specparam.modes.mode import Mode
Expand Down Expand Up @@ -223,25 +222,21 @@
###################################################################################################
## CHECKER FUNCTION

def check_modes(component, check_params=False):
"""Check the set of modes that are available.
def check_mode_definition(mode, component):
"""Check a mode definition.

Parameters
----------
mode : Mode or str
Definition for a mode to check.
If str, should be the label corresponding to a defined mode (see `check_modes`).
component : {'aperiodic', 'periodic'}
Which component to check available modes for.
check_params : bool, optional, default: False
Whether to print out information on the parameters of each mode.
"""

print('Available {:s} modes:'.format(component))
for mode in MODES[component].values():
if not check_params:
print(' {:10s} {:s}'.format(mode.name, mode.description))
else:
print('\n{:s}'.format(mode.name))
print(' {:s}'.format(mode.description))
mode.check_params()
Which component the mode corresponds to.

Returns
-------
Mode
Mode definition.
"""

check_mode_definition = partial(check_selection, definition=Mode)
return check_selection(mode, MODES[component], definition=Mode)
6 changes: 3 additions & 3 deletions specparam/modes/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from specparam.data import ModelModes
from specparam.modes.mode import VALID_COMPONENTS
from specparam.modes.definitions import check_mode_definition, AP_MODES, PE_MODES
from specparam.modes.definitions import check_mode_definition
from specparam.reports.strings import gen_modes_str

###################################################################################################
Expand All @@ -28,8 +28,8 @@ def __init__(self, aperiodic, periodic, model=None):
self.components = VALID_COMPONENTS

# Add mode definitions for each component
self.aperiodic = check_mode_definition(aperiodic, AP_MODES)
self.periodic = check_mode_definition(periodic, PE_MODES)
self.aperiodic = check_mode_definition(aperiodic, 'aperiodic')
self.periodic = check_mode_definition(periodic, 'periodic')

self.model = model

Expand Down
3 changes: 1 addition & 2 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from specparam.sim.gen import gen_freqs
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import AP_MODES
from specparam.plts.settings import ITERABLES, PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
Expand Down Expand Up @@ -99,7 +98,7 @@ def plot_aperiodic_fits(aps, freq_range, aperiodic_mode, control_offset=False, a
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

aperiodic_mode = check_mode_definition(aperiodic_mode, AP_MODES)
aperiodic_mode = check_mode_definition(aperiodic_mode, 'aperiodic')

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

Expand Down
3 changes: 1 addition & 2 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from specparam.sim import gen_freqs
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import PE_MODES
from specparam.plts.settings import ITERABLES, PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
Expand Down Expand Up @@ -103,7 +102,7 @@ def plot_peak_fits(peaks, periodic_mode, freq_range=None, average='mean', shade=
Additional plot related keyword arguments, with styling options managed by ``style_plot``.
"""

periodic_mode = check_mode_definition(periodic_mode, PE_MODES)
periodic_mode = check_mode_definition(periodic_mode, 'periodic')

ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))

Expand Down
5 changes: 2 additions & 3 deletions specparam/sim/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

from specparam.utils.checks import check_flat
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import AP_MODES, PE_MODES

from specparam.sim.transform import rotate_spectrum

Expand Down Expand Up @@ -59,7 +58,7 @@ def gen_aperiodic(freqs, aperiodic_mode, aperiodic_params):
Aperiodic values, in log10 spacing.
"""

ap_mode = check_mode_definition(aperiodic_mode, AP_MODES)
ap_mode = check_mode_definition(aperiodic_mode, 'aperiodic')

ap_vals = ap_mode.func(freqs, *aperiodic_params)

Expand All @@ -84,7 +83,7 @@ def gen_periodic(freqs, periodic_mode, periodic_params):
Peak values, in log10 spacing.
"""

pe_mode = check_mode_definition(periodic_mode, PE_MODES)
pe_mode = check_mode_definition(periodic_mode, 'periodic')

pe_vals = pe_mode.func(freqs, *check_flat(periodic_params))

Expand Down
3 changes: 1 addition & 2 deletions specparam/sim/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from specparam.data import SimParams
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import AP_MODES
from specparam.utils.checks import check_flat
from specparam.modutils.errors import InconsistentDataError

Expand Down Expand Up @@ -74,7 +73,7 @@ def update_sim_ap_params(sim_params, delta, field=None):
# If labels are given, update deltas according to their labels
else:

aperiodic_mode = check_mode_definition(ap_mode, AP_MODES)
aperiodic_mode = check_mode_definition(ap_mode, 'aperiodic')

# This loop checks & casts to list, to work for single or multiple passed in values
for cur_field, cur_delta in zip(list([field]) if not isinstance(field, list) else field,
Expand Down
10 changes: 10 additions & 0 deletions specparam/tests/metrics/test_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Test functions for specparam.metrics.check."""

from specparam.metrics.check import *

###################################################################################################
###################################################################################################

def test_check_metrics():

check_metrics()
21 changes: 16 additions & 5 deletions specparam/tests/metrics/test_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,21 @@

def test_metrics_library():

for key, metric in METRICS.items():
assert isinstance(metric, Metric)
assert metric.label == key
for category, collection in METRICS.items():
for key, metric in collection.items():
assert isinstance(metric, Metric)
assert metric.label == category + '_' + key

def test_check_metrics():
def test_check_metric_definition():

check_metrics()
mdict = {'category' : 'test', 'measure' : 'test',
'description' : 'test', 'func' : lambda x: x}

m1 = check_metric_definition(Metric(**mdict))
assert isinstance(m1, Metric)

m2 = check_metric_definition(mdict)
assert isinstance(m2, Metric)

m3 = check_metric_definition('error_mae')
assert isinstance(m3, Metric)
Loading
Loading