Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions fooof/core/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
###################################################################################################

@check_dependency(plt, 'matplotlib')
def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=True):
def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=True, **plot_kwargs):
"""Generate and save out a PDF report for a power spectrum model fit.

Parameters
Expand All @@ -37,6 +37,8 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=Tr
Whether or not to plot the frequency axis in log space.
add_settings : bool, optional, default: True
Whether to add a print out of the model settings to the end of the report.
plot_kwargs : keyword arguments
Keyword arguments to pass into the plot method.
"""

# Define grid settings based on what is to be plotted
Expand All @@ -56,7 +58,7 @@ def save_report_fm(fm, file_name, file_path=None, plt_log=False, add_settings=Tr

# Second - data plot
ax1 = plt.subplot(grid[1])
fm.plot(plt_log=plt_log, ax=ax1)
fm.plot(plt_log=plt_log, ax=ax1, **plot_kwargs)

# Third - FOOOF settings
if add_settings:
Expand Down
11 changes: 7 additions & 4 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def add_results(self, fooof_result):
self._check_loaded_results(fooof_result._asdict())


def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False):
def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False, **plot_kwargs):
"""Run model fit, and display a report, which includes a plot, and printed results.

Parameters
Expand All @@ -392,14 +392,16 @@ def report(self, freqs=None, power_spectrum=None, freq_range=None, plt_log=False
If not provided, fits across the entire given range.
plt_log : bool, optional, default: False
Whether or not to plot the frequency axis in log space.
**plot_kwargs
Keyword arguments to pass into the plot method.

Notes
-----
Data is optional, if data has already been added to the object.
"""

self.fit(freqs, power_spectrum, freq_range)
self.plot(plt_log=plt_log)
self.plot(plt_log=plt_log, **plot_kwargs)
self.print_results(concise=False)


Expand Down Expand Up @@ -648,9 +650,10 @@ def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,


@copy_doc_func_to_method(save_report_fm)
def save_report(self, file_name, file_path=None, plt_log=False, add_settings=True):
def save_report(self, file_name, file_path=None, plt_log=False,
add_settings=True, **plot_kwargs):

save_report_fm(self, file_name, file_path, plt_log, add_settings)
save_report_fm(self, file_name, file_path, plt_log, add_settings, **plot_kwargs)


@copy_doc_func_to_method(save_fm)
Expand Down
24 changes: 24 additions & 0 deletions fooof/objs/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,30 @@ def print_results(self, concise=False):
print(gen_results_fg_str(self, concise))


def save_model_report(self, index, file_name, file_path=None, plt_log=False,
add_settings=True, **plot_kwargs):
""""Save out an individual model report for a specified model fit.

Parameters
----------
index : int
Index of the model fit to save out.
file_name : str
Name to give the saved out file.
file_path : str, optional
Path to directory to save to. If None, saves to current directory.
plt_log : bool, optional, default: False
Whether or not to plot the frequency axis in log space.
add_settings : bool, optional, default: True
Whether to add a print out of the model settings to the end of the report.
plot_kwargs : keyword arguments
Keyword arguments to pass into the plot method.
"""

self.get_fooof(ind=index, regenerate=True).save_report(\
file_name, file_path, plt_log, **plot_kwargs)


def to_df(self, peak_org):
"""Convert and extract the model results as a pandas object.

Expand Down
11 changes: 10 additions & 1 deletion fooof/tests/objs/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
They serve rather as 'smoke tests', for if anything fails completely.
"""

import os

import numpy as np
from numpy.testing import assert_equal

Expand All @@ -17,7 +19,7 @@

pd = safe_import('pandas')

from fooof.tests.settings import TEST_DATA_PATH
from fooof.tests.settings import TEST_DATA_PATH, TEST_REPORTS_PATH
from fooof.tests.tutils import default_group_params, plot_test

from fooof.objs.group import *
Expand Down Expand Up @@ -212,6 +214,13 @@ def test_fg_print(tfg):
tfg.print_results()
assert True

def test_save_model_report(tfg):

file_name = 'test_group_model_report'
tfg.save_model_report(0, file_name, TEST_REPORTS_PATH)

assert os.path.exists(os.path.join(TEST_REPORTS_PATH, file_name + '.pdf'))

def test_get_results(tfg):
"""Check get results method."""

Expand Down