Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions specparam/objs/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import numpy as np

from specparam.core.utils import unlog
from specparam.core.items import OBJ_DESC
from specparam.core.funcs import infer_ap_func
from specparam.core.utils import check_inds, check_array_dim
from specparam.modutils.errors import NoModelError
from specparam.modutils.dependencies import safe_import
from specparam.data import FitResults, ModelSettings
from specparam.data.conversions import group_to_dict, event_group_to_dict
from specparam.data.utils import get_group_params, get_results_by_ind, get_results_by_row
from specparam.core.items import OBJ_DESC
from specparam.modutils.dependencies import safe_import
from specparam.utils.gof import compute_r_squared, compute_error

###################################################################################################
###################################################################################################
Expand Down Expand Up @@ -292,8 +293,7 @@ def _reset_results(self, clear_results=False):
def _calc_r_squared(self):
"""Calculate the r-squared goodness of fit of the model, compared to the original data."""

r_val = np.corrcoef(self.power_spectrum, self.modeled_spectrum_)
self.r_squared_ = r_val[0][1] ** 2
self.r_squared_ = compute_r_squared(self.power_spectrum, self.modeled_spectrum_)


def _calc_error(self, metric=None):
Expand All @@ -317,21 +317,8 @@ def _calc_error(self, metric=None):
Which measure is applied is by default controlled by the `_error_metric` attribute.
"""

# If metric is not specified, use the default approach
metric = self._error_metric if not metric else metric

if metric == 'MAE':
self.error_ = np.abs(self.power_spectrum - self.modeled_spectrum_).mean()

elif metric == 'MSE':
self.error_ = ((self.power_spectrum - self.modeled_spectrum_) ** 2).mean()

elif metric == 'RMSE':
self.error_ = np.sqrt(((self.power_spectrum - self.modeled_spectrum_) ** 2).mean())

else:
error_msg = "Error metric '{}' not understood or not implemented.".format(metric)
raise ValueError(error_msg)
self.error_ = compute_error(self.power_spectrum, self.modeled_spectrum_,
self._error_metric if not metric else metric)


class BaseResults2D(BaseResults):
Expand Down
2 changes: 0 additions & 2 deletions specparam/tests/objs/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ def test_fit_measures():
assert np.isclose(tfm.error_, 0.8)
tfm._calc_error(metric='RMSE')
assert np.isclose(tfm.error_, np.sqrt(0.8))
with raises(ValueError):
tfm._calc_error(metric='BAD')

def test_checks():
"""Test various checks, errors and edge cases for model fitting.
Expand Down
43 changes: 43 additions & 0 deletions specparam/tests/utils/test_gof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test functions for specparam.utils.gof."""

import numpy as np

from specparam.utils.gof import *

###################################################################################################
###################################################################################################

## GOF FUNCTIONS

def test_compute_r_squared(tfm):

r_squared = compute_r_squared(tfm.power_spectrum, tfm.modeled_spectrum_)
assert isinstance(r_squared, float)

def test_compute_adj_r_squared(tfm):

r_squared = compute_adj_r_squared(tfm.power_spectrum, tfm.modeled_spectrum_, 5)
assert isinstance(r_squared, float)

## ERROR FUNCTIONS

def test_compute_mean_abs_error(tfm):

error = compute_mean_abs_error(tfm.power_spectrum, tfm.modeled_spectrum_)
assert isinstance(error, float)

def test_compute_mean_squared_error(tfm):

error = compute_mean_squared_error(tfm.power_spectrum, tfm.modeled_spectrum_)
assert isinstance(error, float)

def test_compute_root_mean_squared_error(tfm):

error = compute_root_mean_squared_error(tfm.power_spectrum, tfm.modeled_spectrum_)
assert isinstance(error, float)

def test_compute_error(tfm):

for metric in ['mae', 'mse', 'rmse']:
error = compute_error(tfm.power_spectrum, tfm.modeled_spectrum_)
assert isinstance(error, float)
159 changes: 159 additions & 0 deletions specparam/utils/gof.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
"""Goodness of fit related functions & utilities."""

import numpy as np

###################################################################################################
###################################################################################################

## Goodness of fit measures

def compute_r_squared(power_spectrum, modeled_spectrum):
"""Calculate the r-squared of the model, compared to the original data.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.

Returns
-------
r_squared : float
R-squared of the model fit.
"""

corrcoefs = np.corrcoef(power_spectrum, modeled_spectrum)
r_squared = corrcoefs[0][1] ** 2

return r_squared


def compute_adj_r_squared(power_spectrum, modeled_spectrum, n_params):
"""Calculate the adjusted r-squared of the model, compared to the original data.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.

Returns
-------
adj_r_squared : float
Adjusted R-squared of the model fit.
"""

n_points = len(power_spectrum)
r_squared = compute_r_squared(power_spectrum, modeled_spectrum)
adj_r_squared = 1 - (1 - r_squared) * (n_points - 1) / (n_points - n_params - 1)

return adj_r_squared


# Collect available error functions together
RSQUARED_FUNCS = {
'r_squared' : compute_r_squared,
'adj_r_squared' : compute_adj_r_squared,
}


## ERROR FUNCTIONS

def compute_mean_abs_error(power_spectrum, modeled_spectrum):
"""Compute mean absolute error.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.

Returns
-------
error : float
Computed MAE.
"""

error = np.abs(power_spectrum - modeled_spectrum).mean()

return error


def compute_mean_squared_error(power_spectrum, modeled_spectrum):
"""Compute mean squared error.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.

Returns
-------
error : float
Computed MSE.

"""

error = ((power_spectrum - modeled_spectrum) ** 2).mean()

return error


def compute_root_mean_squared_error(power_spectrum, modeled_spectrum):
"""Compute root mean squared error.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.

Returns
-------
error : float
Computed rMSE.
"""

error = np.sqrt(((power_spectrum - modeled_spectrum) ** 2).mean())

return error


# Collect available error functions together
ERROR_FUNCS = {
'mae' : compute_mean_abs_error,
'mse' : compute_mean_squared_error,
'rmse' : compute_root_mean_squared_error,
}


def compute_error(power_spectrum, modeled_spectrum, error_metric='mae'):
"""Compute error between a model and a power spectrum.

Parameters
----------
power_spectrum : 1d array
Real data power spectrum.
modeled_spectrum : 1d array
Modelled power spectrum.
error_metric : {'mae', 'mse', 'rsme'} or callable
Which approach to take to compute the error.

Returns
-------
error : float
Computed error.
"""

if isinstance(error_metric, str):
error = ERROR_FUNCS[error_metric.lower()](power_spectrum, modeled_spectrum)
elif isfunction(error_metric):
error = error_metric(data)

return error