From 1f436cb10a412360933e9a15fee6ed8517bf391c Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Tue, 9 Mar 2021 13:05:57 -0800
Subject: [PATCH 1/3] optimization improvements
---
fooof/core/utils.py | 6 +++++-
fooof/objs/fit.py | 14 ++++++--------
fooof/objs/utils.py | 13 +++++++++----
3 files changed, 20 insertions(+), 13 deletions(-)
diff --git a/fooof/core/utils.py b/fooof/core/utils.py
index f0fcdb588..0f4c3e748 100644
--- a/fooof/core/utils.py
+++ b/fooof/core/utils.py
@@ -30,7 +30,11 @@ def group_three(vec):
if len(vec) % 3 != 0:
raise ValueError("Wrong size array to group by three.")
- return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
+ if isinstance(vec, np.ndarray):
+ # Reshaping is faster if already an array
+ return np.reshape(vec, (-1, 3))
+ else:
+ return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
def nearest_ind(array, value):
diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py
index fba745d27..ccdb9b16f 100644
--- a/fooof/objs/fit.py
+++ b/fooof/objs/fit.py
@@ -1007,18 +1007,16 @@ def _create_peak_params(self, gaus_params):
with `freqs`, `fooofed_spectrum_` and `_ap_fit` all required to be available.
"""
- peak_params = np.empty([0, 3])
+ peak_params = np.empty((len(gaus_params), 3))
for ii, peak in enumerate(gaus_params):
# Gets the index of the power_spectrum at the frequency closest to the CF of the peak
- ind = min(range(len(self.freqs)), key=lambda ii: abs(self.freqs[ii] - peak[0]))
+ ind = np.argmin(np.abs(self.freqs - peak[0]))
# Collect peak parameter data
- peak_params = np.vstack((peak_params,
- [peak[0],
- self.fooofed_spectrum_[ind] - self._ap_fit[ind],
- peak[2] * 2]))
+ peak_params[ii] = [peak[0], self.fooofed_spectrum_[ind] - self._ap_fit[ind],
+ peak[2] * 2]
return peak_params
@@ -1037,8 +1035,8 @@ def _drop_peak_cf(self, guess):
Guess parameters for gaussian peak fits. Shape: [n_peaks, 3].
"""
- cf_params = [item[0] for item in guess]
- bw_params = [item[2] * self._bw_std_edge for item in guess]
+ cf_params = guess[:, 0]
+ bw_params = guess[:, 2] * self._bw_std_edge
# Check if peaks within drop threshold from the edge of the frequency range
keep_peak = \
diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py
index 4ef05d97d..b4aab1043 100644
--- a/fooof/objs/utils.py
+++ b/fooof/objs/utils.py
@@ -219,9 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
>>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP
"""
- fgs = []
- for cond_spectra in power_spectra:
- fg.fit(freqs, cond_spectra, freq_range, n_jobs)
- fgs.append(fg.copy())
+ # Reshape 3d to 2d and fit
+ shape = np.shape(power_spectra)
+ powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2]))
+
+ fg.fit(freqs, powers_2d, freq_range, n_jobs)
+
+ # Reorganize 2d results
+ fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \
+ for dim_a in range(shape[0]) ]
return fgs
From cf721a2d1e66b87d6942f47f24058322b21bc780 Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Sun, 28 Mar 2021 16:00:37 -0400
Subject: [PATCH 2/3] small doc & code notes tweaks
---
fooof/core/utils.py | 10 +++++-----
fooof/objs/utils.py | 6 +++---
2 files changed, 8 insertions(+), 8 deletions(-)
diff --git a/fooof/core/utils.py b/fooof/core/utils.py
index 0f4c3e748..3b50e44c7 100644
--- a/fooof/core/utils.py
+++ b/fooof/core/utils.py
@@ -13,13 +13,13 @@ def group_three(vec):
Parameters
----------
- vec : 1d array
- Array of items to group by 3. Length of array must be divisible by three.
+ vec : list or 1d array
+ List or array of items to group by 3. Length of array must be divisible by three.
Returns
-------
- list of list
- List of lists, each with three items.
+ array or list of list
+ Array or list of lists, each with three items. Output type will match input type.
Raises
------
@@ -30,8 +30,8 @@ def group_three(vec):
if len(vec) % 3 != 0:
raise ValueError("Wrong size array to group by three.")
+ # Reshape, if an array, as it's faster, otherwise asssume lise
if isinstance(vec, np.ndarray):
- # Reshaping is faster if already an array
return np.reshape(vec, (-1, 3))
else:
return [list(vec[ii:ii+3]) for ii in range(0, len(vec), 3)]
diff --git a/fooof/objs/utils.py b/fooof/objs/utils.py
index b4aab1043..b7a6f052f 100644
--- a/fooof/objs/utils.py
+++ b/fooof/objs/utils.py
@@ -219,14 +219,14 @@ def fit_fooof_3d(fg, freqs, power_spectra, freq_range=None, n_jobs=1):
>>> fgs = fit_fooof_3d(fg, freqs, power_spectra, freq_range=[3, 30]) # doctest:+SKIP
"""
- # Reshape 3d to 2d and fit
+ # Reshape 3d data to 2d and fit, in order to fit with a single group model object
shape = np.shape(power_spectra)
powers_2d = np.reshape(power_spectra, (shape[0] * shape[1], shape[2]))
fg.fit(freqs, powers_2d, freq_range, n_jobs)
- # Reorganize 2d results
+ # Reorganize 2d results into a list of model group objects, to reflect original shape
fgs = [fg.get_group(range(dim_a * shape[1], (dim_a + 1) * shape[1])) \
- for dim_a in range(shape[0]) ]
+ for dim_a in range(shape[0])]
return fgs
From c2c8643a0985ae41426580c24d540bae3d9ef62d Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Sun, 28 Mar 2021 16:00:50 -0400
Subject: [PATCH 3/3] extend fit_fooof_3d test
---
fooof/tests/objs/test_utils.py | 10 +++++++---
1 file changed, 7 insertions(+), 3 deletions(-)
diff --git a/fooof/tests/objs/test_utils.py b/fooof/tests/objs/test_utils.py
index 6af08372a..6f9479141 100644
--- a/fooof/tests/objs/test_utils.py
+++ b/fooof/tests/objs/test_utils.py
@@ -120,13 +120,17 @@ def test_combine_errors(tfm, tfg):
def test_fit_fooof_3d(tfg):
- n_spectra = 2
+ n_groups = 2
+ n_spectra = 3
xs, ys = gen_group_power_spectra(n_spectra, *default_group_params())
- ys = np.stack([ys, ys], axis=0)
+ ys = np.stack([ys] * n_groups, axis=0)
+ spectra_shape = np.shape(ys)
tfg = FOOOFGroup()
fgs = fit_fooof_3d(tfg, xs, ys)
- assert len(fgs) == 2
+ assert len(fgs) == n_groups == spectra_shape[0]
for fg in fgs:
assert fg
+ assert len(fg) == n_spectra
+ assert fg.power_spectra.shape == spectra_shape[1:]