From b09f9db955c0b734b119f10124a2429bf45c1e39 Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Tue, 18 Jul 2023 16:30:17 -0400
Subject: [PATCH 1/3] add average_reconstructions helper function
---
doc/api.rst | 1 +
fooof/objs/utils.py | 38 ++++++++++++++++++++++++++++++++++
fooof/tests/objs/test_utils.py | 7 +++++++
3 files changed, 46 insertions(+)
diff --git a/doc/api.rst b/doc/api.rst
index 23f06e5ee..1338100df 100644
--- a/doc/api.rst
+++ b/doc/api.rst
@@ -52,6 +52,7 @@ Functions to manipulate, examine and analyze FOOOF objects, and related utilitie
compare_info
average_fg
+ average_reconstructions
combine_fooofs
.. currentmodule:: fooof
diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py
index fe1b740cd..599b9302b 100644
--- a/fooof/objs/utils.py
+++ b/fooof/objs/utils.py
@@ -116,6 +116,44 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
return fm
+def average_reconstructions(fg, avg_method='mean'):
+ """Average across model reconstructions for a group of power spectra.
+
+ Parameters
+ ----------
+ fg : FOOOFGroup
+ Object with model fit results to average across.
+ avg : {'mean', 'median'}
+ Averaging function to use.
+
+ Returns
+ -------
+ freqs : 1d array
+ Frequency values for the average model reconstruction.
+ avg_model : 1d array
+ Power values for the average model reconstruction.
+ Note that power values are in log10 space.
+ """
+
+ if avg_method not in ['mean', 'median']:
+ raise ValueError("Requested average method not understood.")
+ if not fg.has_model:
+ raise NoModelError("No model fit results are available, can not proceed.")
+
+ if avg_method == 'mean':
+ avg_func = np.nanmean
+ elif avg_method == 'median':
+ avg_func = np.nanmedian
+
+ models = np.zeros(shape=fg.power_spectra.shape)
+ for ind in range(len(fg)):
+ models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_
+
+ avg_model = avg_func(models, 0)
+
+ return fg.freqs, avg_model
+
+
def combine_fooofs(fooofs):
"""Combine a group of FOOOF and/or FOOOFGroup objects into a single FOOOFGroup object.
diff --git a/fooof/tests/objs/test_utils.py b/fooof/tests/objs/test_utils.py
index b496c1bc9..28a4c87eb 100644
--- a/fooof/tests/objs/test_utils.py
+++ b/fooof/tests/objs/test_utils.py
@@ -45,6 +45,13 @@ def test_average_fg(tfg, tbands):
with raises(NoModelError):
average_fg(ntfg, tbands)
+def test_average_reconstructions(tfg):
+
+ freqs, avg_model = average_reconstructions(tfg)
+ assert isinstance(freqs, np.ndarray)
+ assert isinstance(avg_model, np.ndarray)
+ assert freqs.shape == avg_model.shape
+
def test_combine_fooofs(tfm, tfg):
tfm2 = tfm.copy()
From c2ba4da195d2758ae9780aee24e9c5bf373db29f Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Tue, 18 Jul 2023 16:34:55 -0400
Subject: [PATCH 2/3] refact use of avg_func in objs utils
---
fooof/objs/utils.py | 30 ++++++++++++------------------
1 file changed, 12 insertions(+), 18 deletions(-)
diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py
index 599b9302b..a9ac7934d 100644
--- a/fooof/objs/utils.py
+++ b/fooof/objs/utils.py
@@ -65,18 +65,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
If there are no model fit results available to average across.
"""
- if avg_method not in ['mean', 'median']:
- raise ValueError("Requested average method not understood.")
if not fg.has_model:
raise NoModelError("No model fit results are available, can not proceed.")
- if avg_method == 'mean':
- avg_func = np.nanmean
- elif avg_method == 'median':
- avg_func = np.nanmedian
+ avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
+ if avg_method not in avg_funcs.keys():
+ raise ValueError("Requested average method not understood.")
# Aperiodic parameters: extract & average
- ap_params = avg_func(fg.get_params('aperiodic_params'), 0)
+ ap_params = avg_funcs[avg_method](fg.get_params('aperiodic_params'), 0)
# Periodic parameters: extract & average
peak_params = []
@@ -90,15 +87,15 @@ def average_fg(fg, bands, avg_method='mean', regenerate=True):
# Check if there are any extracted peaks - if not, don't add
# Note that we only check peaks, but gauss should be the same
if not np.all(np.isnan(peaks)):
- peak_params.append(avg_func(peaks, 0))
- gauss_params.append(avg_func(gauss, 0))
+ peak_params.append(avg_funcs[avg_method](peaks, 0))
+ gauss_params.append(avg_funcs[avg_method](gauss, 0))
peak_params = np.array(peak_params)
gauss_params = np.array(gauss_params)
# Goodness of fit measures: extract & average
- r2 = avg_func(fg.get_params('r_squared'))
- error = avg_func(fg.get_params('error'))
+ r2 = avg_funcs[avg_method](fg.get_params('r_squared'))
+ error = avg_funcs[avg_method](fg.get_params('error'))
# Collect all results together, to be added to FOOOF object
results = FOOOFResults(ap_params, peak_params, r2, error, gauss_params)
@@ -135,21 +132,18 @@ def average_reconstructions(fg, avg_method='mean'):
Note that power values are in log10 space.
"""
- if avg_method not in ['mean', 'median']:
- raise ValueError("Requested average method not understood.")
if not fg.has_model:
raise NoModelError("No model fit results are available, can not proceed.")
- if avg_method == 'mean':
- avg_func = np.nanmean
- elif avg_method == 'median':
- avg_func = np.nanmedian
+ avg_funcs = {'mean' : np.nanmean, 'median' : np.nanmedian}
+ if avg_method not in avg_funcs.keys():
+ raise ValueError("Requested average method not understood.")
models = np.zeros(shape=fg.power_spectra.shape)
for ind in range(len(fg)):
models[ind, :] = fg.get_fooof(ind, regenerate=True).fooofed_spectrum_
- avg_model = avg_func(models, 0)
+ avg_model = avg_funcs[avg_method](models, 0)
return fg.freqs, avg_model
From bb96d8ebb6715e1dc0acdb0f968018d73a06f44c Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Tue, 18 Jul 2023 16:37:04 -0400
Subject: [PATCH 3/3] update aliasing of plot_spectrum
---
fooof/plts/__init__.py | 3 +--
fooof/plts/spectra.py | 4 ++++
2 files changed, 5 insertions(+), 2 deletions(-)
diff --git a/fooof/plts/__init__.py b/fooof/plts/__init__.py
index 95e05f403..981ba12b4 100644
--- a/fooof/plts/__init__.py
+++ b/fooof/plts/__init__.py
@@ -1,4 +1,3 @@
"""Plots sub-module for FOOOF."""
-from .spectra import plot_spectra
-from .spectra import plot_spectra as plot_spectrum
+from .spectra import plot_spectrum, plot_spectra
diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py
index c68acc698..8141b7cde 100644
--- a/fooof/plts/spectra.py
+++ b/fooof/plts/spectra.py
@@ -77,6 +77,10 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False,
style_spectrum_plot(ax, log_freqs, log_powers)
+# Alias `plot_spectrum` to `plot_spectra` for backwards compatibility
+plot_spectrum = plot_spectra
+
+
@savefig
@check_dependency(plt, 'matplotlib')
def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',