From 4395e14026912ca6b01abcfab5ed41fd9cd850e9 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 29 Jun 2020 21:21:36 -0700 Subject: [PATCH 1/8] plot_style and savefig decorators added to plotting --- fooof/objs/fit.py | 6 +- fooof/objs/group.py | 4 +- fooof/plts/annotate.py | 4 +- fooof/plts/aperiodic.py | 8 +- fooof/plts/error.py | 6 +- fooof/plts/fg.py | 38 +++++--- fooof/plts/fm.py | 20 ++--- fooof/plts/periodic.py | 8 +- fooof/plts/settings.py | 21 +++++ fooof/plts/spectra.py | 12 ++- fooof/plts/style.py | 194 ++++++++++++++++++++++++++++++++++++++++ fooof/plts/utils.py | 22 +++++ 12 files changed, 305 insertions(+), 38 deletions(-) diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index e5b007b77..7b31fbac2 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -616,12 +616,12 @@ def get_results(self): @copy_doc_func_to_method(plot_fm) def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, save_fig=False, file_name=None, file_path=None, - ax=None, plot_style=style_spectrum_plot, - data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None): + ax=None, plot_style=style_spectrum_plot, data_kwargs=None, model_kwargs=None, + aperiodic_kwargs=None, peak_kwargs=None, **kwargs): plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend, save_fig, file_name, file_path, ax, plot_style, - data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs) + data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs, **kwargs) @copy_doc_func_to_method(save_report_fm) diff --git a/fooof/objs/group.py b/fooof/objs/group.py index 8bd293d3a..5c0e0b8d8 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -398,9 +398,9 @@ def get_params(self, name, col=None): @copy_doc_func_to_method(plot_fg) - def plot(self, save_fig=False, file_name=None, file_path=None): + def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs): - plot_fg(self, save_fig, file_name, file_path) + plot_fg(self, save_fig, file_name, file_path, **kwargs) @copy_doc_func_to_method(save_report_fg) diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py index 04faacf35..96161239f 100644 --- a/fooof/plts/annotate.py +++ b/fooof/plts/annotate.py @@ -7,7 +7,7 @@ from fooof.core.funcs import gaussian_function from fooof.core.modutils import safe_import, check_dependency from fooof.sim.gen import gen_aperiodic -from fooof.plts.utils import check_ax +from fooof.plts.utils import check_ax, savefig from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS from fooof.plts.style import check_n_style, style_spectrum_plot @@ -20,6 +20,7 @@ ################################################################################################### ################################################################################################### +@savefig @check_dependency(plt, 'matplotlib') def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): """Plot a series of plots illustrating the peak search from a flattened spectrum. @@ -74,6 +75,7 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): check_n_style(plot_style, ax, False, True) +@savefig @check_dependency(plt, 'matplotlib') def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True, ax=None, plot_style=style_spectrum_plot): diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py index b20038b81..2fd8afc8a 100644 --- a/fooof/plts/aperiodic.py +++ b/fooof/plts/aperiodic.py @@ -7,14 +7,16 @@ from fooof.sim.gen import gen_freqs, gen_aperiodic from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs +from fooof.plts.style import check_n_style, style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, plot_style=style_param_plot, **plot_kwargs): @@ -58,6 +60,8 @@ def plot_aperiodic_params(aps, colors=None, labels=None, check_n_style(plot_style, ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, log_freqs=False, colors=None, labels=None, diff --git a/fooof/plts/error.py b/fooof/plts/error.py index c870900bc..ea0f3b81f 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -5,14 +5,16 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot -from fooof.plts.utils import check_ax +from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot +from fooof.plts.utils import check_ax, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py index f4d121419..cd11d8d0e 100644 --- a/fooof/plts/fg.py +++ b/fooof/plts/fg.py @@ -10,6 +10,8 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES from fooof.plts.templates import plot_scatter_1, plot_scatter_2, plot_hist +from fooof.plts.utils import savefig +from fooof.plts.style import style_plot plt = safe_import('.pyplot', 'matplotlib') gridspec = safe_import('.gridspec', 'matplotlib') @@ -17,8 +19,9 @@ ################################################################################################### ################################################################################################### +@savefig @check_dependency(plt, 'matplotlib') -def plot_fg(fg, save_fig=False, file_name=None, file_path=None): +def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs): """Plot a figure with subplots visualizing the parameters from a FOOOFGroup object. Parameters @@ -44,26 +47,27 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None): fig = plt.figure(figsize=PLT_FIGSIZES['group']) gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2]) + # Apply scatter kwargs to all subplots + scatter_kwargs = kwargs + scatter_kwargs['all_axes'] = True + # Aperiodic parameters plot ax0 = plt.subplot(gs[0, 0]) - plot_fg_ap(fg, ax0) + plot_fg_ap(fg, ax0, **scatter_kwargs) # Goodness of fit plot ax1 = plt.subplot(gs[0, 1]) - plot_fg_gf(fg, ax1) + plot_fg_gf(fg, ax1, **scatter_kwargs) # Center frequencies plot ax2 = plt.subplot(gs[1, :]) - plot_fg_peak_cens(fg, ax2) - - if save_fig: - if not file_name: - raise ValueError("Input 'file_name' is required to save out the plot.") - plt.savefig(fpath(file_path, fname(file_name, 'png'))) + plot_fg_peak_cens(fg, ax2, **kwargs) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_ap(fg, ax=None): +def plot_fg_ap(fg, ax=None, **kwargs): """Plot aperiodic fit parameters, in a scatter plot. Parameters @@ -72,6 +76,8 @@ def plot_fg_ap(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **kwargs + Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. """ if fg.aperiodic_mode == 'knee': @@ -83,8 +89,10 @@ def plot_fg_ap(fg, ax=None): 'Aperiodic Fit', ax=ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_gf(fg, ax=None): +def plot_fg_gf(fg, ax=None, **kwargs): """Plot goodness of fit results, in a scatter plot. Parameters @@ -93,14 +101,18 @@ def plot_fg_gf(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **kwargs + Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. """ plot_scatter_2(fg.get_params('error'), 'Error', fg.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_peak_cens(fg, ax=None): +def plot_fg_peak_cens(fg, ax=None, **kwargs): """Plot peak center frequencies, in a histogram. Parameters @@ -109,6 +121,8 @@ def plot_fg_peak_cens(fg, ax=None): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. + **kwargs + Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. """ plot_hist(fg.get_params('peak_params', 0)[:, 0], 'Center Frequency', diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 6818e4840..2fb2fad7f 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -7,7 +7,6 @@ import numpy as np -from fooof.core.io import fname, fpath from fooof.core.utils import nearest_ind from fooof.core.modutils import safe_import, check_dependency from fooof.sim.gen import gen_periodic @@ -15,19 +14,20 @@ from fooof.utils.params import compute_fwhm from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS -from fooof.plts.utils import check_ax, check_plot_kwargs -from fooof.plts.style import check_n_style, style_spectrum_plot +from fooof.plts.utils import check_ax, check_plot_kwargs, savefig +from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, - save_fig=False, file_name=None, file_path=None, - ax=None, plot_style=style_spectrum_plot, - data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None): + save_fig=False, file_name=None, file_path=None, ax=None, plot_style=style_spectrum_plot, + data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **kwargs): """Plot the power spectrum and model fit results from a FOOOF object. Parameters @@ -55,6 +55,8 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= A function to call to apply styling & aesthetics to the plot. data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional Keyword arguments to pass into the plot call for each plot element. + **kwargs + Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. Notes ----- @@ -99,12 +101,6 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= # Apply style to plot check_n_style(plot_style, ax, log_freqs, True) - # Save out figure, if requested - if save_fig: - if not file_name: - raise ValueError("Input 'file_name' is required to save out the plot.") - plt.savefig(fpath(file_path, fname(file_name, 'png'))) - def _add_peaks(fm, approach, plt_log, ax, peak_kwargs): """Add peaks to a model plot. diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py index e654bfabc..0bd73e03d 100644 --- a/fooof/plts/periodic.py +++ b/fooof/plts/periodic.py @@ -8,14 +8,16 @@ from fooof.core.funcs import gaussian_function from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs +from fooof.plts.style import check_n_style, style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, plot_style=style_param_plot, **plot_kwargs): @@ -69,6 +71,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, check_n_style(plot_style, ax) +@savefig +@style_plot def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, plot_style=style_param_plot, **plot_kwargs): """Plot reconstructions of model peak fits. diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index 94d525054..b695ab40c 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -26,3 +26,24 @@ PLT_ALIASES = {'linewidth' : ['lw', 'linewidth'], 'markersize' : ['ms', 'markersize'], 'linestyle' : ['ls', 'linestyle']} + +# Plot style arguments are those that can be defined on an axis object +AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim'] + +# Line style arguments are those that can be defined on a line object +LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle', + 'marker', 'ms', 'markersize', 'color'] + +# Custom style arguments are those that are custom-handled by the plot style function +CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize', + 'legend_size', 'legend_loc'] +STYLERS = ['axis_styler', 'line_styler', 'custom_styler'] +STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS + +## Define default values for aesthetic +# These are all custom style arguments +TITLE_FONTSIZE = 20 +LABEL_SIZE = 16 +TICK_LABELSIZE = 16 +LEGEND_SIZE = 12 +LEGEND_LOC = 'best' diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index eaeafdbe7..24e207770 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -11,14 +11,16 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot -from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs +from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot +from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs, savefig plt = safe_import('.pyplot', 'matplotlib') ################################################################################################### ################################################################################################### +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): @@ -55,6 +57,8 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, check_n_style(plot_style, ax, log_freqs, log_powers) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels=None, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): @@ -93,6 +97,8 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels check_n_style(plot_style, ax, log_freqs, log_powers) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_center=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): @@ -129,6 +135,8 @@ def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_c plot_kwargs.get('log_powers', False)) +@savefig +@style_plot @check_dependency(plt, 'matplotlib') def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_center=False, ax=None, plot_style=style_spectrum_plot, **plot_kwargs): diff --git a/fooof/plts/style.py b/fooof/plts/style.py index f9dbcfe80..f1353ba85 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -1,5 +1,15 @@ """Style and aesthetics definitions for plots.""" +from itertools import cycle +from functools import wraps +import warnings + +import matplotlib.pyplot as plt + +from fooof.plts.settings import AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS +from fooof.plts.settings import (LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC, + TICK_LABELSIZE, TITLE_FONTSIZE) + ################################################################################################### ################################################################################################### @@ -75,3 +85,187 @@ def style_param_plot(ax): legend = ax.legend(prop={'size': 16}) for handle in legend.legendHandles: handle._sizes = [100] + + +# Additional plot style customization + +def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs): + """Apply axis plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as axis styling. + **kwargs + Keyword arguments that define plot style to apply. + """ + + # Apply any provided axis style arguments + plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + ax.set(**plot_kwargs) + + +def apply_plot_style(ax, style_args=LINE_STYLE_ARGS, **kwargs): + """Apply line/scatter/histogram plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as styling. + **kwargs + Keyword arguments that define style to apply. + """ + + # Get the plot object related styling arguments from the keyword arguments + style_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + + # For line plots + if len(ax.lines) > 0: + plot_objs = ax.lines + + # For scatter plots + elif len(ax.collections) > 0: + plot_objs = ax.collections + + # For histograms + elif len(ax.patches) > 0: + plot_objs = ax.patches + + # There is no styling to apply + else: + return + + plot_objs = [plot_objs] if not isinstance(plot_objs, list) else plot_objs + + # Apply any provided plot object style arguments + for style, value in style_kwargs.items(): + + # Values should be either a single value, for all plot objects, or a list, one value per + # object. This line checks type, and makes a cycle-able / loop-able object out of the values + values = cycle([value] if isinstance(value, (int, float, str)) else value) + + # For line plots + for plot_obj in plot_objs: + plot_obj.set(**{style : next(values)}) + + +def apply_custom_style(ax, **kwargs): + """Apply custom plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + **kwargs + Keyword arguments that define custom style to apply. + """ + + # If a title was provided, update the size + if ax.get_title(): + ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE)) + + # Settings for the axis labels + label_size = kwargs.pop('label_size', LABEL_SIZE) + ax.xaxis.label.set_size(label_size) + ax.yaxis.label.set_size(label_size) + + # Settings for the axis ticks + ax.tick_params(axis='both', which='major', + labelsize=kwargs.pop('tick_labelsize', TICK_LABELSIZE)) + + # If labels were provided, add a legend + if ax.get_legend_handles_labels()[0]: + ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)}, + loc=kwargs.pop('legend_loc', LEGEND_LOC)) + + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + plt.tight_layout() + + +def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_plot_style, + custom_styler=apply_custom_style, **kwargs): + """Apply plot style to a figure axis. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + axis_styler, line_styler, custom_styler : callable, optional + Functions to apply style to aspects of the plot. + **kwargs + Keyword arguments that define style to apply. + + Notes + ----- + This function wraps sub-functions which apply style to different plot elements. + Each of these sub-functions can be replaced by passing in replacement callables. + """ + + axis_styler(ax, **kwargs) + line_styler(ax, **kwargs) + custom_styler(ax, **kwargs) + + +def style_plot(func, *args, **kwargs): + """Decorator function to apply a plot style function, after plot generation. + + Parameters + ---------- + func : callable + The plotting function for creating a plot. + *args, **kwargs + Arguments & keyword arguments. + These should include any arguments for the plot, and those for applying plot style. + + Notes + ----- + This decorator works by: + + - catching all inputs that relate to plot style + - creating a plot, using the passed in plotting function & passing in all non-style arguments + - passing the style related arguments into a `apply_style` function which applies plot styling + + By default, this function applies styling with the `apply_style` function. Custom + functions for applying style can be passed in using `apply_style` as a keyword argument. + + The `apply_style` function calls sub-functions for applying style different plot elements, + and these sub-functions can be overridden by passing in alternatives for `axis_styler`, + `line_styler`, and `custom_styler`. + """ + + @wraps(func) + def decorated(*args, **kwargs): + + # Grab a custom style function, if provided, and grab any provided style arguments + style_func = kwargs.pop('apply_style', apply_style) + style_args = kwargs.pop('style_args', STYLE_ARGS) + style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs} + + # Check how many lines are already on the plot, if it exists already + n_lines_pre = len(kwargs['ax'].lines) if 'ax' in kwargs and kwargs['ax'] is not None else 0 + + # Create the plot + func(*args, **kwargs) + + # Get plot axis, if a specific one was provided, or if not, grab the current axis + cur_ax = kwargs['ax'] if 'ax' in kwargs and kwargs['ax'] is not None else plt.gca() + + # Check how many lines were added to the plot, and make info available to plot styling + n_lines_apply = len(cur_ax.lines) - n_lines_pre + style_kwargs['n_lines_apply'] = n_lines_apply + + # Determine if styling should be applied to all axes + all_axes = kwargs.pop('all_axes', False) + cur_ax = plt.gcf().get_axes() if all_axes is True else cur_ax + cur_ax = [cur_ax] if not isinstance(cur_ax, list) else cur_ax + + # Apply the styling function + for ax in cur_ax: + style_func(ax, **style_kwargs) + + return decorated diff --git a/fooof/plts/utils.py b/fooof/plts/utils.py index ef5b53901..0a970ee91 100644 --- a/fooof/plts/utils.py +++ b/fooof/plts/utils.py @@ -8,9 +8,11 @@ from itertools import repeat from collections.abc import Iterator +from functools import wraps import numpy as np +from fooof.core.io import fname, fpath from fooof.core.modutils import safe_import from fooof.core.utils import resolve_aliases from fooof.plts.settings import PLT_ALPHA_LEVELS, PLT_ALIASES @@ -171,3 +173,23 @@ def check_plot_kwargs(plot_kwargs, defaults): plot_kwargs[key] = value return plot_kwargs + + +def savefig(func): + """Decorator function to save out figures.""" + + @wraps(func) + def decorated(*args, **kwargs): + + save_fig = kwargs.pop('save_fig', False) + file_name = kwargs.pop('file_name', None) + file_path = kwargs.pop('file_path', None) + + func(*args, **kwargs) + + if save_fig: + if not file_name: + raise ValueError("Input 'file_name' is required to save out the plot.") + plt.savefig(fpath(file_path, fname(file_name, 'png'))) + + return decorated From 871121578a54ffe7c4ab421ea06add09fb324187 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 29 Jun 2020 21:21:53 -0700 Subject: [PATCH 2/8] tests updated --- fooof/tests/plts/test_styles.py | 98 +++++++++++++++++++++++++++++++++ fooof/tests/plts/test_utils.py | 13 +++++ 2 files changed, 111 insertions(+) diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py index 75b002d1b..caf09e3eb 100644 --- a/fooof/tests/plts/test_styles.py +++ b/fooof/tests/plts/test_styles.py @@ -1,5 +1,6 @@ """Tests for fooof.plts.styles.""" +from fooof.tests.tutils import plot_test from fooof.plts.style import * ################################################################################################### @@ -27,3 +28,100 @@ def test_style_spectrum_plot(skip_if_no_mpl): # Check that axis labels are added - use as proxy that it ran correctly assert ax.xaxis.get_label().get_text() assert ax.yaxis.get_label().get_text() + + +def test_apply_axis_style(): + + _, ax = plt.subplots() + + title = 'Ploty McPlotface' + xlim = (1.0, 10.0) + ylabel = 'Line Value' + + apply_axis_style(ax, title=title, xlim=xlim, ylabel=ylabel) + + assert ax.get_title() == title + assert ax.get_xlim() == xlim + assert ax.get_ylabel() == ylabel + + +def test_apply_plot_style(): + + # Check applying style to one line + _, ax = plt.subplots() + ax.plot([1, 2], [3, 4]) + + lw = 4 + apply_plot_style(ax, lw=lw) + + assert ax.get_lines()[0].get_lw() == lw + + # Check applying style across multiple lines + _, ax = plt.subplots() + ax.plot([1, 2], [[3, 4], [5, 6]]) + + alphas = [0.5, 0.75] + apply_plot_style(ax, alpha=alphas) + + for line, alpha in zip(ax.get_lines(), alphas): + assert line.get_alpha() == alpha + + # Check applying style to a scatter plot + _, ax = plt.subplots() + ax.scatter([1, 2], [2, 4]) + apply_plot_style(ax, alpha=0.123) + assert ax.collections[0]._alpha == 0.123 + + # Check applying style to a histogram + _, ax = plt.subplots() + ax.hist([1, 2, 3]) + apply_plot_style(ax, alpha=0.123) + assert ax.patches[0]._alpha == 0.123 + + +def test_apply_custom_style(): + + _, ax = plt.subplots() + ax.set_title('placeholder') + + # Test simple application of custom plot style + apply_custom_style(ax) + assert ax.title.get_size() == TITLE_FONTSIZE + + # Test adding input parameters to custom plot style + new_title_fontsize = 15.0 + apply_custom_style(ax, title_fontsize=new_title_fontsize) + assert ax.title.get_size() == new_title_fontsize + + +def test_apply_style(): + + _, ax = plt.subplots() + + def my_custom_styler(ax, **kwargs): + ax.set_title('DATA!') + + # Apply plot style using all defaults + apply_style(ax) + + # Apply plot style passing in a styler + apply_style(ax, custom_styler=my_custom_styler) + + +@plot_test +def test_style_plot(): + + @style_plot + def example_plot(): + plt.plot([1, 2], [3, 4]) + + def my_plot_style(ax, **kwargs): + ax.set_title('Custom!') + + # Test with applying default custom styling + lw = 5 + title = 'Science.' + example_plot(title=title, lw=lw) + + # Test with passing in own plot_style function + example_plot(apply_style=my_plot_style) diff --git a/fooof/tests/plts/test_utils.py b/fooof/tests/plts/test_utils.py index 51a2117f4..89956717f 100644 --- a/fooof/tests/plts/test_utils.py +++ b/fooof/tests/plts/test_utils.py @@ -1,5 +1,8 @@ """Tests for fooof.plts.utils.""" +import os +import tempfile + from fooof.tests.tutils import plot_test from fooof.core.modutils import safe_import @@ -69,3 +72,13 @@ def test_check_plot_kwargs(skip_if_no_mpl): assert len(plot_kwargs) == 2 assert plot_kwargs['alpha'] == 0.5 assert plot_kwargs['linewidth'] == 2 + +def test_savefig(): + + @savefig + def example_plot(): + plt.plot([1, 2], [3, 4]) + + with tempfile.NamedTemporaryFile(mode='w+') as file: + example_plot(save_fig=True, file_name=file.name) + assert os.path.exists(file.name) From 8dc0b0892dab97cd964d6aa2807c21bd5df26290 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Wed, 5 Aug 2020 17:25:24 -0700 Subject: [PATCH 3/8] consolidated plotting approach --- fooof/objs/fit.py | 9 +++-- fooof/objs/group.py | 2 +- fooof/plts/annotate.py | 38 ++++++++---------- fooof/plts/aperiodic.py | 56 +++++++++++++------------- fooof/plts/error.py | 13 +++--- fooof/plts/fg.py | 24 ++++++------ fooof/plts/fm.py | 87 +++++++++++++++++++---------------------- fooof/plts/periodic.py | 48 ++++++++++++----------- fooof/plts/settings.py | 2 +- fooof/plts/spectra.py | 79 ++++++++++++++++++------------------- fooof/plts/style.py | 80 ++++++++++--------------------------- 11 files changed, 192 insertions(+), 246 deletions(-) diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index 7b31fbac2..fb274118f 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -616,12 +616,13 @@ def get_results(self): @copy_doc_func_to_method(plot_fm) def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, save_fig=False, file_name=None, file_path=None, - ax=None, plot_style=style_spectrum_plot, data_kwargs=None, model_kwargs=None, + ax=None, data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **kwargs): - plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend, - save_fig, file_name, file_path, ax, plot_style, - data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs, **kwargs) + plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log, + add_legend=add_legend, save_fig=save_fig, file_name=file_name, + file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, + aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **kwargs) @copy_doc_func_to_method(save_report_fm) diff --git a/fooof/objs/group.py b/fooof/objs/group.py index 5c0e0b8d8..c3f7bdc2d 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -400,7 +400,7 @@ def get_params(self, name, col=None): @copy_doc_func_to_method(plot_fg) def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs): - plot_fg(self, save_fig, file_name, file_path, **kwargs) + plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **kwargs) @copy_doc_func_to_method(save_report_fg) diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py index 96161239f..d67befb24 100644 --- a/fooof/plts/annotate.py +++ b/fooof/plts/annotate.py @@ -10,7 +10,7 @@ from fooof.plts.utils import check_ax, savefig from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS -from fooof.plts.style import check_n_style, style_spectrum_plot +from fooof.plts.style import style_spectrum_plot from fooof.analysis.periodic import get_band_peak_fm from fooof.utils.params import compute_knee_frequency, compute_fwhm @@ -22,15 +22,13 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): +def plot_annotated_peak_search(fm): """Plot a series of plots illustrating the peak search from a flattened spectrum. Parameters ---------- fm : FOOOF FOOOF object, with model fit, data and settings available. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plots. """ # Recalculate the initial aperiodic fit and flattened spectrum that @@ -47,14 +45,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): # This forces the creation of a new plotting axes per iteration ax = check_ax(None, PLT_FIGSIZES['spectral']) - plot_spectrum(fm.freqs, flatspec, ax=ax, plot_style=None, - label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5) - plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), - ax=ax, plot_style=None, label='Relative Threshold', - color='orange', linewidth=2.5, linestyle='dashed') - plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), - ax=ax, plot_style=None, label='Absolute Threshold', - color='red', linewidth=2.5, linestyle='dashed') + plot_spectrum(fm.freqs, flatspec, ax=ax, linewidth=2.5, + label='Flattened Spectrum', color=PLT_COLORS['data']) + plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax, + label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed') + plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax, + label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed') maxi = np.argmax(flatspec) ax.plot(fm.freqs[maxi], flatspec[maxi], '.', @@ -66,19 +62,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot): if ind < fm.n_peaks_: gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :]) - plot_spectrum(fm.freqs, gauss, ax=ax, plot_style=None, - label='Gaussian Fit', color=PLT_COLORS['periodic'], - linestyle=':', linewidth=3.0) + plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit', + color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) flatspec = flatspec - gauss - check_n_style(plot_style, ax, False, True) + style_spectrum_plot(ax, False, True) @savefig @check_dependency(plt, 'matplotlib') -def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True, - ax=None, plot_style=style_spectrum_plot): +def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, + annotate_aperiodic=True, ax=None): """Plot a an annotated power spectrum and model, from a FOOOF object. Parameters @@ -89,8 +84,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio Whether to plot the frequency values in log10 spacing. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plots. + Raises ------ @@ -110,7 +104,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio # Create the baseline figure ax = check_ax(ax, PLT_FIGSIZES['spectral']) - fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None, + fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, data_kwargs={'lw' : lw1, 'alpha' : 0.6}, aperiodic_kwargs={'lw' : lw1, 'zorder' : 10}, model_kwargs={'lw' : lw1, 'alpha' : 0.5}, @@ -217,7 +211,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio color=PLT_COLORS['aperiodic'], fontsize=fontsize) # Apply style to plot & tune grid styling - check_n_style(plot_style, ax, plt_log, True) + style_spectrum_plot(ax, plt_log, True) ax.grid(True, alpha=0.5) # Add labels to plot in the legend diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py index 2fd8afc8a..f69bd36bb 100644 --- a/fooof/plts/aperiodic.py +++ b/fooof/plts/aperiodic.py @@ -3,12 +3,13 @@ from itertools import cycle import numpy as np +import matplotlib.pyplot as plt from fooof.sim.gen import gen_freqs, gen_aperiodic from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot, style_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig +from fooof.plts.style import style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, savefig plt = safe_import('.pyplot', 'matplotlib') @@ -18,8 +19,7 @@ @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_aperiodic_params(aps, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs): """Plot aperiodic parameters as dots representing offset and exponent value. Parameters @@ -32,17 +32,14 @@ def plot_aperiodic_params(aps, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) if isinstance(aps, list): - recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels) else: @@ -50,14 +47,26 @@ def plot_aperiodic_params(aps, colors=None, labels=None, xs, ys = aps[:, 0], aps[:, -1] sizes = plot_kwargs.pop('s', 150) - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7}) - ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs) + colors = 'C0' if colors is None else colors + colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors) + labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels) + + for xi, yi in zip(xs, ys): + + # Prevent duplicate labels when recursively plotting + _, cur_labels = plt.gca().get_legend_handles_labels() + + label = next(labels) + if label not in cur_labels: + ax.scatter(xi, yi, s=sizes, color=next(colors), label=label, alpha=0.7) + else: + ax.scatter(xi, yi, s=sizes, color=next(colors), alpha=0.7) # Add axis labels ax.set_xlabel('Offset') ax.set_ylabel('Exponent') - check_n_style(plot_style, ax) + style_param_plot(ax) @savefig @@ -65,7 +74,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, @check_dependency(plt, 'matplotlib') def plot_aperiodic_fits(aps, freq_range, control_offset=False, log_freqs=False, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): + ax=None, **plot_kwargs): """Plot reconstructions of model aperiodic fits. Parameters @@ -84,10 +93,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) @@ -97,11 +104,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, if not colors: colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color']) - recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax, - freq_range=tuple(freq_range), - control_offset=control_offset, - log_freqs=log_freqs, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range), + control_offset=control_offset, log_freqs=log_freqs, colors=colors, + labels=labels, **plot_kwargs) else: freqs = gen_freqs(freq_range, 0.1) @@ -122,8 +127,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Recreate & plot the aperiodic component from parameters ap_vals = gen_aperiodic(freqs, ap_params) - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25}) - ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs) + ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25) # Collect a running average across components avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0) @@ -131,8 +135,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Plot the average component avg = avg_vals / aps.shape[0] avg_color = 'black' if not colors else colors - ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3, - color=avg_color, label=labels) + ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels) # Add axis labels ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency') @@ -141,5 +144,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False, # Set plot limit ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range) - # Apply plot style - check_n_style(plot_style, ax) + style_param_plot(ax) diff --git a/fooof/plts/error.py b/fooof/plts/error.py index ea0f3b81f..f7cbfdf7b 100644 --- a/fooof/plts/error.py +++ b/fooof/plts/error.py @@ -5,7 +5,7 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot +from fooof.plts.style import style_spectrum_plot, style_plot from fooof.plts.utils import check_ax, savefig plt = safe_import('.pyplot', 'matplotlib') @@ -16,8 +16,7 @@ @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectral_error(freqs, error, shade=None, log_freqs=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **plot_kwargs): """Plot frequency by frequency error values. Parameters @@ -33,17 +32,15 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, Whether to plot the frequency axis in log spacing. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to `plot_spectra` or to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) plt_freqs = np.log10(freqs) if log_freqs else freqs - plot_spectrum(plt_freqs, error, plot_style=None, ax=ax, linewidth=3, **plot_kwargs) + plot_spectrum(plt_freqs, error, ax=ax, linewidth=3) if np.any(shade): ax.fill_between(plt_freqs, error-shade, error+shade, alpha=0.25) @@ -53,5 +50,5 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax.set_ylim([0, ymax]) ax.set_xlim(plt_freqs.min(), plt_freqs.max()) - check_n_style(plot_style, ax, log_freqs, True) + style_spectrum_plot(ax, log_freqs, True) ax.set_ylabel('Absolute Error') diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py index cd11d8d0e..342065c64 100644 --- a/fooof/plts/fg.py +++ b/fooof/plts/fg.py @@ -21,7 +21,7 @@ @savefig @check_dependency(plt, 'matplotlib') -def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs): +def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs): """Plot a figure with subplots visualizing the parameters from a FOOOFGroup object. Parameters @@ -48,7 +48,7 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs): gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2]) # Apply scatter kwargs to all subplots - scatter_kwargs = kwargs + scatter_kwargs = plot_kwargs scatter_kwargs['all_axes'] = True # Aperiodic parameters plot @@ -61,13 +61,13 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs): # Center frequencies plot ax2 = plt.subplot(gs[1, :]) - plot_fg_peak_cens(fg, ax2, **kwargs) + plot_fg_peak_cens(fg, ax2, **plot_kwargs) @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_ap(fg, ax=None, **kwargs): +def plot_fg_ap(fg, ax=None, **plot_kwargs): """Plot aperiodic fit parameters, in a scatter plot. Parameters @@ -76,8 +76,8 @@ def plot_fg_ap(fg, ax=None, **kwargs): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. - **kwargs - Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ if fg.aperiodic_mode == 'knee': @@ -92,7 +92,7 @@ def plot_fg_ap(fg, ax=None, **kwargs): @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_gf(fg, ax=None, **kwargs): +def plot_fg_gf(fg, ax=None, **plot_kwargs): """Plot goodness of fit results, in a scatter plot. Parameters @@ -101,8 +101,8 @@ def plot_fg_gf(fg, ax=None, **kwargs): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. - **kwargs - Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ plot_scatter_2(fg.get_params('error'), 'Error', @@ -112,7 +112,7 @@ def plot_fg_gf(fg, ax=None, **kwargs): @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_fg_peak_cens(fg, ax=None, **kwargs): +def plot_fg_peak_cens(fg, ax=None, **plot_kwargs): """Plot peak center frequencies, in a histogram. Parameters @@ -121,8 +121,8 @@ def plot_fg_peak_cens(fg, ax=None, **kwargs): Object to plot data from. ax : matplotlib.Axes, optional Figure axes upon which to plot. - **kwargs - Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. """ plot_hist(fg.get_params('peak_params', 0)[:, 0], 'Center Frequency', diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 2fb2fad7f..54a76a8a2 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -15,7 +15,7 @@ from fooof.plts.spectra import plot_spectrum from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS from fooof.plts.utils import check_ax, check_plot_kwargs, savefig -from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot +from fooof.plts.style import style_spectrum_plot, style_plot plt = safe_import('.pyplot', 'matplotlib') @@ -26,8 +26,8 @@ @style_plot @check_dependency(plt, 'matplotlib') def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, - save_fig=False, file_name=None, file_path=None, ax=None, plot_style=style_spectrum_plot, - data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **kwargs): + save_fig=False, file_name=None, file_path=None, ax=None, data_kwargs=None, + model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): """Plot the power spectrum and model fit results from a FOOOF object. Parameters @@ -51,12 +51,10 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= Path to directory to save to. If None, saves to current directory. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional Keyword arguments to pass into the plot call for each plot element. - **kwargs - Keyword arguments for customizing the plot, passed to the 'style_plot' decorator. + **plot_kwargs + Keyword arguments to pass into the ``style_plot``. Notes ----- @@ -72,34 +70,29 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= # Plot the data, if available if fm.has_data: - data_kwargs = check_plot_kwargs(data_kwargs, \ - {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, - 'label' : 'Original Spectrum' if add_legend else None}) - plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, - ax=ax, plot_style=None, **data_kwargs) + data_kwargs = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, + 'label' : 'Original Spectrum' if add_legend else None} + plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs) # Add the full model fit, and components (if requested) if fm.has_model: - model_kwargs = check_plot_kwargs(model_kwargs, \ - {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'label' : 'Full Model Fit' if add_legend else None}) - plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, - ax=ax, plot_style=None, **model_kwargs) + model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, + 'label' : 'Full Model Fit' if add_legend else None} + plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit if plot_aperiodic: - aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, \ - {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None}) - plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, - ax=ax, plot_style=None, **aperiodic_kwargs) + aperiodic_kwargs = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, + 'alpha' : 0.5, 'linestyle' : 'dashed', + 'label' : 'Aperiodic Fit' if add_legend else None} + plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit if plot_peaks: - _add_peaks(fm, plot_peaks, plt_log, ax=ax, peak_kwargs=peak_kwargs) + _add_peaks(fm, plot_peaks, plt_log, ax, peak_kwargs) - # Apply style to plot - check_n_style(plot_style, ax, log_freqs, True) + # Apply default style to plot + style_spectrum_plot(ax, log_freqs, True) def _add_peaks(fm, approach, plt_log, ax, peak_kwargs): @@ -158,18 +151,18 @@ def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs): ax : matplotlib.Axes Figure axes upon which to plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``fill_between``. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}) + plot_kwargs = check_plot_kwargs(plot_kwargs, + {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}) for peak in fm.get_params('gaussian_params'): peak_freqs = np.log10(fm.freqs) if plt_log else fm.freqs peak_line = fm._ap_fit + gen_periodic(fm.freqs, peak) - ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **kwargs) + ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **plot_kwargs) def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): @@ -187,9 +180,9 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + plot_kwargs = check_plot_kwargs(plot_kwargs, + {'color' : PLT_COLORS['periodic'], + 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) for peak in fm.get_params('peak_params'): @@ -197,10 +190,10 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): freq_point = np.log10(peak[0]) if plt_log else peak[0] # Add the line from the aperiodic fit up the tip of the peak - ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **kwargs) + ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **plot_kwargs) # Add an extra dot at the tip of the peak - ax.plot(freq_point, ap_point + peak[1], marker='o', **kwargs) + ax.plot(freq_point, ap_point + peak[1], marker='o', **plot_kwargs) def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): @@ -218,9 +211,9 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.5}) + plot_kwargs = check_plot_kwargs(plot_kwargs, + {'color' : PLT_COLORS['periodic'], + 'alpha' : 0.7, 'lw' : 1.5}) for peak in fm.get_params('gaussian_params'): @@ -233,7 +226,7 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): # Plot the peak outline peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs - ax.plot(peak_freqs, peak_line, **kwargs) + ax.plot(peak_freqs, peak_line, **plot_kwargs) def _add_peaks_line(fm, plt_log, ax, **plot_kwargs): @@ -251,16 +244,16 @@ def _add_peaks_line(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}) + plot_kwargs = check_plot_kwargs(plot_kwargs, + {'color' : PLT_COLORS['periodic'], + 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}) ylims = ax.get_ylim() for peak in fm.get_params('peak_params'): freq_point = np.log10(peak[0]) if plt_log else peak[0] - ax.plot([freq_point, freq_point], ylims, '-', **kwargs) - ax.plot(freq_point, ylims[1], 'v', **kwargs) + ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs) + ax.plot(freq_point, ylims[1], 'v', **plot_kwargs) def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): @@ -283,9 +276,9 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): the peak, though what is literally plotted is the full-width half-max. """ - kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + plot_kwargs = check_plot_kwargs(plot_kwargs, + {'color' : PLT_COLORS['periodic'], + 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) for peak in fm.gaussian_params_: @@ -296,7 +289,7 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): if plt_log: bw_freqs = np.log10(bw_freqs) - ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **kwargs) + ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **plot_kwargs) # Collect all the possible `add_peak_*` functions together diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py index 0bd73e03d..55d147a99 100644 --- a/fooof/plts/periodic.py +++ b/fooof/plts/periodic.py @@ -8,8 +8,8 @@ from fooof.core.funcs import gaussian_function from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_param_plot, style_plot -from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig +from fooof.plts.style import style_param_plot, style_plot +from fooof.plts.utils import check_ax, recursive_plot, savefig plt = safe_import('.pyplot', 'matplotlib') @@ -19,8 +19,7 @@ @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): """Plot peak parameters as dots representing center frequency, power and bandwidth. Parameters @@ -35,18 +34,15 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to pass into the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params'])) # If there is a list, use recurse function to loop across arrays of data and plot them if isinstance(peaks, list): - recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels) # Otherwise, plot the array of data else: @@ -55,9 +51,20 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, xs, ys = peaks[:, 0], peaks[:, 1] sizes = peaks[:, 2] * plot_kwargs.pop('s', 150) - # Create the plot - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7}) - ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs) + colors = 'C0' if colors is None else colors + colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors) + labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels) + + for xi, yi, size in zip(xs, ys, sizes): + + # Prevent duplicate labels when recursively plotting + _, cur_labels = plt.gca().get_legend_handles_labels() + + label = next(labels) + if label not in cur_labels: + ax.scatter(xi, yi, s=size, color=next(colors), label=label, alpha=0.7) + else: + ax.scatter(xi, yi, s=size, color=next(colors), alpha=0.7) # Add axis labels ax.set_xlabel('Center Frequency') @@ -68,13 +75,12 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax.set_xlim(freq_range) ax.set_ylim([0, ax.get_ylim()[1]]) - check_n_style(plot_style, ax) + style_param_plot(ax) @savefig @style_plot -def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, - ax=None, plot_style=style_param_plot, **plot_kwargs): +def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs): """Plot reconstructions of model peak fits. Parameters @@ -90,8 +96,6 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, Label(s) for plotted data, to be added in a legend. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_param_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs Keyword arguments to pass into the plot call. """ @@ -105,8 +109,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, recursive_plot(peaks, plot_function=plot_peak_fits, ax=ax, freq_range=tuple(freq_range) if freq_range else freq_range, - colors=colors, labels=labels, - plot_style=plot_style, **plot_kwargs) + colors=colors, labels=labels, **plot_kwargs) else: @@ -132,8 +135,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, # Create & plot the peak model from parameters peak_vals = gaussian_function(freqs, *peak_params) - plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25}) - ax.plot(freqs, peak_vals, color=colors, **plot_kwargs) + ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25) # Collect a running average average peaks avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0) @@ -141,7 +143,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, # Plot the average across all components avg = avg_vals / peaks.shape[0] avg_color = 'black' if not colors else colors - ax.plot(freqs, avg, color=avg_color, linewidth=plot_kwargs.get('linewidth')*3, label=labels) + ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels) # Add axis labels ax.set_xlabel('Frequency') @@ -152,4 +154,4 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax.set_ylim([0, ax.get_ylim()[1]]) # Apply plot style - check_n_style(plot_style, ax) + style_param_plot(ax) diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index b695ab40c..a68043af4 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -32,7 +32,7 @@ # Line style arguments are those that can be defined on a line object LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle', - 'marker', 'ms', 'markersize', 'color'] + 'marker', 'ms', 'markersize'] # Custom style arguments are those that are custom-handled by the plot style function CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize', diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py index 24e207770..9c9f4ab53 100644 --- a/fooof/plts/spectra.py +++ b/fooof/plts/spectra.py @@ -5,14 +5,14 @@ This file contains functions for plotting power spectra, that take in data directly. """ -from itertools import repeat +from itertools import repeat, cycle import numpy as np from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES -from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot -from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs, savefig +from fooof.plts.style import style_spectrum_plot, style_plot +from fooof.plts.utils import check_ax, add_shades, savefig plt = safe_import('.pyplot', 'matplotlib') @@ -23,7 +23,7 @@ @style_plot @check_dependency(plt, 'matplotlib') def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): + color=None, label=None, ax=None, **plot_kwargs): """Plot a power spectrum. Parameters @@ -36,12 +36,14 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False Whether to plot the power axis in log spacing. + label : str, optional, default: None + Legend label for the spectrum. + color : str, optional, default: None + Line color of the spectrum. ax : matplotlib.Axes, optional Figure axis upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) @@ -51,17 +53,16 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False, plt_powers = np.log10(power_spectrum) if log_powers else power_spectrum # Create the plot - plot_kwargs = check_plot_kwargs(plot_kwargs, {'linewidth' : 2.0}) - ax.plot(plt_freqs, plt_powers, **plot_kwargs) + ax.plot(plt_freqs, plt_powers, linewidth=2.0, color=color, label=label) - check_n_style(plot_style, ax, log_freqs, log_powers) + style_spectrum_plot(ax, log_freqs, log_powers) @savefig @style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels=None, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, + colors=None, labels=None, ax=None, **plot_kwargs): """Plot multiple power spectra on the same plot. Parameters @@ -74,34 +75,35 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels Whether to plot the frequency axis in log spacing. log_powers : bool, optional, default: False Whether to plot the power axis in log spacing. - labels : list of str, optional - Legend labels, for each power spectrum. + labels : list of str, optional, default: None + Legend labels for the spectra. + colors : list of str, optional, default: None + Line colors of the spectra. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into the ``style_plot``. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) # Make inputs iterable if need to be passed multiple times to plot each spectrum freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs - labels = repeat(labels) if not isinstance(labels, list) else labels - for freq, power_spectrum, label in zip(freqs, power_spectra, labels): - plot_spectrum(freq, power_spectrum, log_freqs, log_powers, label=label, - plot_style=None, ax=ax, **plot_kwargs) + colors = repeat(colors) if not isinstance(colors, list) else cycle(colors) + labels = repeat(labels) if not isinstance(labels, list) else cycle(labels) - check_n_style(plot_style, ax, log_freqs, log_powers) + for freq, power_spectrum, color, label in zip(freqs, power_spectra, colors, labels): + plot_spectrum(freq, power_spectrum, log_freqs, log_powers, + color=color, label=label, ax=ax) + + style_spectrum_plot(ax, log_freqs, log_powers) @savefig -@style_plot @check_dependency(plt, 'matplotlib') -def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_center=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', + add_center=False, ax=None, **plot_kwargs): """Plot a power spectrum with a shaded frequency region (or regions). Parameters @@ -118,28 +120,24 @@ def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_c Whether to add a line at the center point of the shaded regions. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to the plot call. + Keyword arguments to pass into :func:`~.plot_spectrum`. """ ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - plot_spectrum(freqs, power_spectrum, plot_style=None, ax=ax, **plot_kwargs) + plot_spectrum(freqs, power_spectrum, ax=ax, **plot_kwargs) add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False)) - check_n_style(plot_style, ax, - plot_kwargs.get('log_freqs', False), - plot_kwargs.get('log_powers', False)) + style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), + plot_kwargs.get('log_powers', False)) @savefig -@style_plot @check_dependency(plt, 'matplotlib') -def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_center=False, - ax=None, plot_style=style_spectrum_plot, **plot_kwargs): +def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', + add_center=False, ax=None, **plot_kwargs): """Plot a group of power spectra with a shaded frequency region (or regions). Parameters @@ -156,10 +154,8 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen Whether to add a line at the center point of the shaded regions. ax : matplotlib.Axes, optional Figure axes upon which to plot. - plot_style : callable, optional, default: style_spectrum_plot - A function to call to apply styling & aesthetics to the plot. **plot_kwargs - Keyword arguments to be passed to `plot_spectra` or to the plot call. + Keyword arguments to pass into :func:`~.plot_spectra`. Notes ----- @@ -170,10 +166,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral'])) - plot_spectra(freqs, power_spectra, ax=ax, plot_style=None, **plot_kwargs) + plot_spectra(freqs, power_spectra, ax=ax, **plot_kwargs) add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False)) - check_n_style(plot_style, ax, - plot_kwargs.get('log_freqs', False), - plot_kwargs.get('log_powers', False)) + style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False), + plot_kwargs.get('log_powers', False)) diff --git a/fooof/plts/style.py b/fooof/plts/style.py index f1353ba85..3596e71b5 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -13,20 +13,7 @@ ################################################################################################### ################################################################################################### -def check_n_style(style_func, *args): - """"Check if a style function has been passed, and apply it to a plot if so. - - Parameters - ---------- - style_func : callable or None - Function to apply styling to a plot axis. - *args - Inputs to the style plot. - """ - - if style_func: - style_func(*args) - +# Default plot styling def style_spectrum_plot(ax, log_freqs, log_powers): """Apply style and aesthetics to a power spectrum plot. @@ -86,8 +73,7 @@ def style_param_plot(ax): for handle in legend.legendHandles: handle._sizes = [100] - -# Additional plot style customization +# Custom plot styling def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs): """Apply axis plot style. @@ -107,50 +93,34 @@ def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs): ax.set(**plot_kwargs) -def apply_plot_style(ax, style_args=LINE_STYLE_ARGS, **kwargs): - """Apply line/scatter/histogram plot style. +def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs): + """Apply line plot style. Parameters ---------- ax : matplotlib.Axes Figure axes to apply style to. style_args : list of str - A list of arguments to be sub-selected from `kwargs` and applied as styling. + A list of arguments to be sub-selected from `kwargs` and applied as line styling. **kwargs - Keyword arguments that define style to apply. + Keyword arguments that define line style to apply. """ - # Get the plot object related styling arguments from the keyword arguments - style_kwargs = {key : val for key, val in kwargs.items() if key in style_args} - - # For line plots - if len(ax.lines) > 0: - plot_objs = ax.lines - - # For scatter plots - elif len(ax.collections) > 0: - plot_objs = ax.collections - - # For histograms - elif len(ax.patches) > 0: - plot_objs = ax.patches + # Check how many lines are from the current plot call, to apply style to + # If available, this indicates the apply styling to the last 'n' lines + n_lines_apply = kwargs.pop('n_lines_apply', 0) - # There is no styling to apply - else: - return + # Get the line related styling arguments from the keyword arguments + line_kwargs = {key : val for key, val in kwargs.items() if key in style_args} - plot_objs = [plot_objs] if not isinstance(plot_objs, list) else plot_objs + # Apply any provided line style arguments + for style, value in line_kwargs.items(): - # Apply any provided plot object style arguments - for style, value in style_kwargs.items(): - - # Values should be either a single value, for all plot objects, or a list, one value per - # object. This line checks type, and makes a cycle-able / loop-able object out of the values + # Values should be either a single value, for all lines, or a list, of a value per line + # This line checks type, and makes a cycle-able / loop-able object out of the values values = cycle([value] if isinstance(value, (int, float, str)) else value) - - # For line plots - for plot_obj in plot_objs: - plot_obj.set(**{style : next(values)}) + for line in ax.lines[-n_lines_apply:]: + line.set(**{style : next(values)}) def apply_custom_style(ax, **kwargs): @@ -182,12 +152,10 @@ def apply_custom_style(ax, **kwargs): ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)}, loc=kwargs.pop('legend_loc', LEGEND_LOC)) - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - plt.tight_layout() + plt.tight_layout() -def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_plot_style, +def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, custom_styler=apply_custom_style, **kwargs): """Apply plot style to a figure axis. @@ -242,7 +210,7 @@ def style_plot(func, *args, **kwargs): def decorated(*args, **kwargs): # Grab a custom style function, if provided, and grab any provided style arguments - style_func = kwargs.pop('apply_style', apply_style) + style_func = kwargs.pop('plot_style', apply_style) style_args = kwargs.pop('style_args', STYLE_ARGS) style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs} @@ -259,13 +227,7 @@ def decorated(*args, **kwargs): n_lines_apply = len(cur_ax.lines) - n_lines_pre style_kwargs['n_lines_apply'] = n_lines_apply - # Determine if styling should be applied to all axes - all_axes = kwargs.pop('all_axes', False) - cur_ax = plt.gcf().get_axes() if all_axes is True else cur_ax - cur_ax = [cur_ax] if not isinstance(cur_ax, list) else cur_ax - # Apply the styling function - for ax in cur_ax: - style_func(ax, **style_kwargs) + style_func(cur_ax, **style_kwargs) return decorated From 9420e623739e8853583af838c19f0f0abfb9ec45 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Wed, 5 Aug 2020 17:48:40 -0700 Subject: [PATCH 4/8] style testing updated --- fooof/tests/plts/test_styles.py | 30 ++++-------------------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py index caf09e3eb..1ddca9ffc 100644 --- a/fooof/tests/plts/test_styles.py +++ b/fooof/tests/plts/test_styles.py @@ -6,16 +6,6 @@ ################################################################################################### ################################################################################################### -def test_check_n_style(skip_if_no_mpl): - - # Check can pass None and do nothing - check_n_style(None) - assert True - - # Check can pass a callable - def checker(*args): - return True - check_n_style(checker) def test_style_spectrum_plot(skip_if_no_mpl): @@ -45,14 +35,14 @@ def test_apply_axis_style(): assert ax.get_ylabel() == ylabel -def test_apply_plot_style(): +def test_apply_line_style(): # Check applying style to one line _, ax = plt.subplots() ax.plot([1, 2], [3, 4]) lw = 4 - apply_plot_style(ax, lw=lw) + apply_line_style(ax, lw=lw) assert ax.get_lines()[0].get_lw() == lw @@ -61,23 +51,11 @@ def test_apply_plot_style(): ax.plot([1, 2], [[3, 4], [5, 6]]) alphas = [0.5, 0.75] - apply_plot_style(ax, alpha=alphas) + apply_line_style(ax, alpha=alphas) for line, alpha in zip(ax.get_lines(), alphas): assert line.get_alpha() == alpha - # Check applying style to a scatter plot - _, ax = plt.subplots() - ax.scatter([1, 2], [2, 4]) - apply_plot_style(ax, alpha=0.123) - assert ax.collections[0]._alpha == 0.123 - - # Check applying style to a histogram - _, ax = plt.subplots() - ax.hist([1, 2, 3]) - apply_plot_style(ax, alpha=0.123) - assert ax.patches[0]._alpha == 0.123 - def test_apply_custom_style(): @@ -124,4 +102,4 @@ def my_plot_style(ax, **kwargs): example_plot(title=title, lw=lw) # Test with passing in own plot_style function - example_plot(apply_style=my_plot_style) + example_plot(plot_style=my_plot_style) From e249ab56ef2aa758baaed3dec1286c22742c09d2 Mon Sep 17 00:00:00 2001 From: ryanhammonds Date: Mon, 17 Aug 2020 12:19:56 -0700 Subject: [PATCH 5/8] save out test plots --- fooof/tests/conftest.py | 4 +++- fooof/tests/plts/test_annotate.py | 7 +++++-- fooof/tests/plts/test_aperiodic.py | 7 +++++-- fooof/tests/plts/test_error.py | 4 +++- fooof/tests/plts/test_fg.py | 13 +++++++++---- fooof/tests/plts/test_fm.py | 14 +++++++------- fooof/tests/plts/test_periodic.py | 7 +++++-- fooof/tests/plts/test_spectra.py | 27 +++++++++++++++++++-------- fooof/tests/settings.py | 1 + 9 files changed, 57 insertions(+), 27 deletions(-) diff --git a/fooof/tests/conftest.py b/fooof/tests/conftest.py index 65a8b00a9..b943456ce 100644 --- a/fooof/tests/conftest.py +++ b/fooof/tests/conftest.py @@ -9,7 +9,8 @@ from fooof.core.modutils import safe_import from fooof.tests.tutils import get_tfm, get_tfg, get_tbands -from fooof.tests.settings import BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH +from fooof.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH, + TEST_REPORTS_PATH, TEST_PLOTS_PATH) plt = safe_import('.pyplot', 'matplotlib') @@ -33,6 +34,7 @@ def check_dir(): os.mkdir(BASE_TEST_FILE_PATH) os.mkdir(TEST_DATA_PATH) os.mkdir(TEST_REPORTS_PATH) + os.mkdir(TEST_PLOTS_PATH) @pytest.fixture(scope='session') def tfm(): diff --git a/fooof/tests/plts/test_annotate.py b/fooof/tests/plts/test_annotate.py index c096612cc..84f3848df 100644 --- a/fooof/tests/plts/test_annotate.py +++ b/fooof/tests/plts/test_annotate.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.annotate import * @@ -12,11 +13,13 @@ @plot_test def test_plot_annotated_peak_search(tfm, skip_if_no_mpl): - plot_annotated_peak_search(tfm) + plot_annotated_peak_search(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_annotated_peak_search.png') @plot_test def test_plot_annotated_model(tfm, skip_if_no_mpl): # Make sure model has been fit & then plot annotated model tfm.fit() - plot_annotated_model(tfm) + plot_annotated_model(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_annotated_model.png') diff --git a/fooof/tests/plts/test_aperiodic.py b/fooof/tests/plts/test_aperiodic.py index 167907846..477b22053 100644 --- a/fooof/tests/plts/test_aperiodic.py +++ b/fooof/tests/plts/test_aperiodic.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.aperiodic import * @@ -21,7 +22,8 @@ def test_plot_aperiodic_params(skip_if_no_mpl): # Test for 'knee' mode: offset, knee exponent aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]]) - plot_aperiodic_params(aps) + plot_aperiodic_params(aps, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_aperiodic_params.png') @plot_test def test_plot_aperiodic_fits(skip_if_no_mpl): @@ -36,4 +38,5 @@ def test_plot_aperiodic_fits(skip_if_no_mpl): # Test for 'knee' mode: offset, knee exponent aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]]) - plot_aperiodic_fits(aps, [1, 50]) + plot_aperiodic_fits(aps, [1, 50], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_aperiodic_fits.png') diff --git a/fooof/tests/plts/test_error.py b/fooof/tests/plts/test_error.py index 2bffbc7a2..3e8b817bd 100644 --- a/fooof/tests/plts/test_error.py +++ b/fooof/tests/plts/test_error.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.error import * @@ -15,4 +16,5 @@ def test_plot_spectral_error(skip_if_no_mpl): fs = np.arange(3, 41, 1) errs = np.ones(len(fs)) - plot_spectral_error(fs, errs) + plot_spectral_error(fs, errs, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectral_error.png') diff --git a/fooof/tests/plts/test_fg.py b/fooof/tests/plts/test_fg.py index 8841c9fba..24103a916 100644 --- a/fooof/tests/plts/test_fg.py +++ b/fooof/tests/plts/test_fg.py @@ -6,6 +6,7 @@ from fooof.core.errors import NoModelError from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.fg import * @@ -15,7 +16,8 @@ @plot_test def test_plot_fg(tfg, skip_if_no_mpl): - plot_fg(tfg) + plot_fg(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg.png') # Test error if no data available to plot tfg = FOOOFGroup() @@ -25,14 +27,17 @@ def test_plot_fg(tfg, skip_if_no_mpl): @plot_test def test_plot_fg_ap(tfg, skip_if_no_mpl): - plot_fg_ap(tfg) + plot_fg_ap(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_ap.png') @plot_test def test_plot_fg_gf(tfg, skip_if_no_mpl): - plot_fg_gf(tfg) + plot_fg_gf(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_gf.png') @plot_test def test_plot_fg_peak_cens(tfg, skip_if_no_mpl): - plot_fg_peak_cens(tfg) + plot_fg_peak_cens(tfg, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fg_peak_cens.png') diff --git a/fooof/tests/plts/test_fm.py b/fooof/tests/plts/test_fm.py index 7fce65ad5..6d7a0f02f 100644 --- a/fooof/tests/plts/test_fm.py +++ b/fooof/tests/plts/test_fm.py @@ -1,6 +1,7 @@ """Tests for fooof.plts.fm.""" from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.fm import * @@ -13,7 +14,8 @@ def test_plot_fm(tfm, skip_if_no_mpl): # Make sure model has been fit tfm.fit() - plot_fm(tfm) + plot_fm(tfm, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_fm.png') @plot_test def test_plot_fm_add_peaks(tfm, skip_if_no_mpl): @@ -22,9 +24,7 @@ def test_plot_fm_add_peaks(tfm, skip_if_no_mpl): tfm.fit() # Test run each of the add peak approaches - for add_peak in ['shade', 'dot', 'outline', 'line']: - plot_fm(tfm, plot_peaks=add_peak) - - # Test run some combinations - for add_peak in ['shade-dot', 'outline-line']: - plot_fm(tfm, plot_peaks=add_peak) + for add_peak in ['shade', 'dot', 'outline', 'line', 'shade-dot', 'outline-line']: + file_name = 'test_plot_fm_add_peaks_' + add_peak + '.png' + plot_fm(tfm, plot_peaks=add_peak, save_fig=True, + file_path=TEST_PLOTS_PATH, file_name=file_name) diff --git a/fooof/tests/plts/test_periodic.py b/fooof/tests/plts/test_periodic.py index 647c967bd..83e77daf9 100644 --- a/fooof/tests/plts/test_periodic.py +++ b/fooof/tests/plts/test_periodic.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.periodic import * @@ -18,7 +19,8 @@ def test_plot_peak_params(skip_if_no_mpl): plot_peak_params(peaks) # Test with multiple set of params - plot_peak_params([peaks, peaks]) + plot_peak_params([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_peak_params.png') @plot_test def test_plot_peak_fits(skip_if_no_mpl): @@ -29,4 +31,5 @@ def test_plot_peak_fits(skip_if_no_mpl): plot_peak_fits(peaks) # Test with multiple set of params - plot_peak_fits([peaks, peaks]) + plot_peak_fits([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_peak_fits.png') diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 9abb95b2a..03b137975 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -3,6 +3,7 @@ import numpy as np from fooof.tests.tutils import plot_test +from fooof.tests.settings import TEST_PLOTS_PATH from fooof.plts.spectra import * @@ -15,36 +16,46 @@ def test_plot_spectrum(tfm, skip_if_no_mpl): plot_spectrum(tfm.freqs, tfm.power_spectrum) # Test with logging both axes - plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True) + plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True, save_fig=True, + file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrum.png') @plot_test def test_plot_spectra(tfg, skip_if_no_mpl): # Test with 1d inputs - 1d freq array and list of 1d power spectra - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]]) + plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png') # Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra - plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]]) + plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_list_of_1d.png') # Test with 2d array inputs plot_spectra(np.vstack([tfg.freqs, tfg.freqs]), - np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]])) + np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]), + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png') # Test with labels - plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B']) + plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'], + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_labels.png') @plot_test def test_plot_spectrum_shading(tfm, skip_if_no_mpl): - plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True) + plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True, + save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectrum_shading.png') @plot_test def test_plot_spectra_shading(tfg, skip_if_no_mpl): plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], - shades=[8, 12], add_center=True) + shades=[8, 12], add_center=True, save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_shading.png') # Test with **kwargs that pass into plot_spectra plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], shades=[8, 12], add_center=True, log_freqs=True, log_powers=True, - labels=['A', 'B']) + labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH, + file_name='test_plot_spectra_shading_kwargs.png') diff --git a/fooof/tests/settings.py b/fooof/tests/settings.py index 9beae1afa..856c532f6 100644 --- a/fooof/tests/settings.py +++ b/fooof/tests/settings.py @@ -10,3 +10,4 @@ BASE_TEST_FILE_PATH = pkg.resource_filename(__name__, 'test_files') TEST_DATA_PATH = os.path.join(BASE_TEST_FILE_PATH, 'data') TEST_REPORTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'reports') +TEST_PLOTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'plots') From b4b53ff576d78c751721c555537a8b640fb666e3 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 11 Apr 2021 21:46:34 -0400 Subject: [PATCH 6/8] plot lints --- fooof/objs/fit.py | 7 +++---- fooof/objs/group.py | 4 ++-- fooof/plts/annotate.py | 3 +-- fooof/plts/fg.py | 1 - fooof/plts/fm.py | 28 ++++++++++++---------------- fooof/plts/settings.py | 2 +- fooof/plts/style.py | 4 ---- fooof/tests/plts/test_spectra.py | 2 +- 8 files changed, 20 insertions(+), 31 deletions(-) diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py index 415fc895f..7e5333d50 100644 --- a/fooof/objs/fit.py +++ b/fooof/objs/fit.py @@ -68,7 +68,6 @@ gen_issue_str, gen_width_warning_str) from fooof.plts.fm import plot_fm -from fooof.plts.style import style_spectrum_plot from fooof.utils.data import trim_spectrum from fooof.utils.params import compute_gauss_std from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData @@ -618,12 +617,12 @@ def get_results(self): def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True, save_fig=False, file_name=None, file_path=None, ax=None, data_kwargs=None, model_kwargs=None, - aperiodic_kwargs=None, peak_kwargs=None, **kwargs): + aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs): plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log, add_legend=add_legend, save_fig=save_fig, file_name=file_name, - file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, - aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **kwargs) + file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs, + aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs) @copy_doc_func_to_method(save_report_fm) diff --git a/fooof/objs/group.py b/fooof/objs/group.py index ed42bb0e2..064a4bb8c 100644 --- a/fooof/objs/group.py +++ b/fooof/objs/group.py @@ -398,9 +398,9 @@ def get_params(self, name, col=None): @copy_doc_func_to_method(plot_fg) - def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs): + def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs): - plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **kwargs) + plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs) @copy_doc_func_to_method(save_report_fg) diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py index d67befb24..093306d42 100644 --- a/fooof/plts/annotate.py +++ b/fooof/plts/annotate.py @@ -62,7 +62,7 @@ def plot_annotated_peak_search(fm): if ind < fm.n_peaks_: gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :]) - plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit', + plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit', color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0) flatspec = flatspec - gauss @@ -85,7 +85,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, ax : matplotlib.Axes, optional Figure axes upon which to plot. - Raises ------ NoModelError diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py index 342065c64..d2f2cc476 100644 --- a/fooof/plts/fg.py +++ b/fooof/plts/fg.py @@ -5,7 +5,6 @@ This file contains plotting functions that take as input a FOOOFGroup object. """ -from fooof.core.io import fname, fpath from fooof.core.errors import NoModelError from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 575b8b97b..40ff683b6 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -76,8 +76,8 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= # Add the full model fit, and components (if requested) if fm.has_model: - model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'label' : 'Full Model Fit' if add_legend else None} + model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, + 'label' : 'Full Model Fit' if add_legend else None} plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit @@ -154,8 +154,8 @@ def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the ``fill_between``. """ - plot_kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('gaussian_params'): @@ -180,9 +180,8 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - plot_kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('peak_params'): @@ -211,9 +210,8 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - plot_kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.5}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.get_params('gaussian_params'): @@ -244,9 +242,8 @@ def _add_peaks_line(fm, plt_log, ax, **plot_kwargs): Keyword arguments to pass into the plot call. """ - plot_kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) ylims = ax.get_ylim() for peak in fm.get_params('peak_params'): @@ -276,9 +273,8 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs): the peak, though what is literally plotted is the full-width half-max. """ - plot_kwargs = check_plot_kwargs(plot_kwargs, - {'color' : PLT_COLORS['periodic'], - 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}) + defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6} + plot_kwargs = check_plot_kwargs(plot_kwargs, defaults) for peak in fm.gaussian_params_: diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index a68043af4..ff579d2e3 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -40,7 +40,7 @@ STYLERS = ['axis_styler', 'line_styler', 'custom_styler'] STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS -## Define default values for aesthetic +## Define default values for plot aesthetics # These are all custom style arguments TITLE_FONTSIZE = 20 LABEL_SIZE = 16 diff --git a/fooof/plts/style.py b/fooof/plts/style.py index 3596e71b5..0c3acb484 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -2,7 +2,6 @@ from itertools import cycle from functools import wraps -import warnings import matplotlib.pyplot as plt @@ -13,8 +12,6 @@ ################################################################################################### ################################################################################################### -# Default plot styling - def style_spectrum_plot(ax, log_freqs, log_powers): """Apply style and aesthetics to a power spectrum plot. @@ -73,7 +70,6 @@ def style_param_plot(ax): for handle in legend.legendHandles: handle._sizes = [100] -# Custom plot styling def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs): """Apply axis plot style. diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py index 03b137975..0b85e2d9e 100644 --- a/fooof/tests/plts/test_spectra.py +++ b/fooof/tests/plts/test_spectra.py @@ -34,7 +34,7 @@ def test_plot_spectra(tfg, skip_if_no_mpl): # Test with 2d array inputs plot_spectra(np.vstack([tfg.freqs, tfg.freqs]), np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]), - save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png') + save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png') # Test with labels plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'], From 96ac42a64c5075fce6945b88d7a47872770d8bae Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 11 Apr 2021 23:26:36 -0400 Subject: [PATCH 7/8] small plot updates --- fooof/plts/aperiodic.py | 19 ++++--------------- fooof/plts/fm.py | 17 ++++++++++------- fooof/plts/periodic.py | 19 ++++--------------- 3 files changed, 18 insertions(+), 37 deletions(-) diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py index c3636ce22..905b95145 100644 --- a/fooof/plts/aperiodic.py +++ b/fooof/plts/aperiodic.py @@ -9,7 +9,7 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES from fooof.plts.style import style_param_plot, style_plot -from fooof.plts.utils import check_ax, recursive_plot, savefig +from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs plt = safe_import('.pyplot', 'matplotlib') @@ -47,20 +47,9 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs) xs, ys = aps[:, 0], aps[:, -1] sizes = plot_kwargs.pop('s', 150) - colors = 'C0' if colors is None else colors - colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors) - labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels) - - for xi, yi in zip(xs, ys): - - # Prevent duplicate labels when recursively plotting - _, cur_labels = plt.gca().get_legend_handles_labels() - - label = next(labels) - if label not in cur_labels: - ax.scatter(xi, yi, s=sizes, color=next(colors), label=label, alpha=0.7) - else: - ax.scatter(xi, yi, s=sizes, color=next(colors), alpha=0.7) + # Create the plot + plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7}) + ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs) # Add axis labels ax.set_xlabel('Offset') diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py index 40ff683b6..6674848a3 100644 --- a/fooof/plts/fm.py +++ b/fooof/plts/fm.py @@ -70,21 +70,24 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend= # Plot the data, if available if fm.has_data: - data_kwargs = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, - 'label' : 'Original Spectrum' if add_legend else None} + data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0, + 'label' : 'Original Spectrum' if add_legend else None} + data_kwargs = check_plot_kwargs(data_kwargs, data_defaults) plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs) # Add the full model fit, and components (if requested) if fm.has_model: - model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, - 'label' : 'Full Model Fit' if add_legend else None} + model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5, + 'label' : 'Full Model Fit' if add_legend else None} + model_kwargs = check_plot_kwargs(model_kwargs, model_defaults) plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs) # Plot the aperiodic component of the model fit if plot_aperiodic: - aperiodic_kwargs = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, - 'alpha' : 0.5, 'linestyle' : 'dashed', - 'label' : 'Aperiodic Fit' if add_legend else None} + aperiodic_defaults = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, + 'alpha' : 0.5, 'linestyle' : 'dashed', + 'label' : 'Aperiodic Fit' if add_legend else None} + aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults) plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs) # Plot the periodic components of the model fit diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py index 55d147a99..17e66f1b9 100644 --- a/fooof/plts/periodic.py +++ b/fooof/plts/periodic.py @@ -9,7 +9,7 @@ from fooof.core.modutils import safe_import, check_dependency from fooof.plts.settings import PLT_FIGSIZES from fooof.plts.style import style_param_plot, style_plot -from fooof.plts.utils import check_ax, recursive_plot, savefig +from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs plt = safe_import('.pyplot', 'matplotlib') @@ -51,20 +51,9 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, xs, ys = peaks[:, 0], peaks[:, 1] sizes = peaks[:, 2] * plot_kwargs.pop('s', 150) - colors = 'C0' if colors is None else colors - colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors) - labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels) - - for xi, yi, size in zip(xs, ys, sizes): - - # Prevent duplicate labels when recursively plotting - _, cur_labels = plt.gca().get_legend_handles_labels() - - label = next(labels) - if label not in cur_labels: - ax.scatter(xi, yi, s=size, color=next(colors), label=label, alpha=0.7) - else: - ax.scatter(xi, yi, s=size, color=next(colors), alpha=0.7) + # Create the plot + plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7}) + ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs) # Add axis labels ax.set_xlabel('Center Frequency') From f5e448eba4ba03d0eea26dd7fc861353bcc02f57 Mon Sep 17 00:00:00 2001 From: Tom Donoghue Date: Sun, 11 Apr 2021 23:47:02 -0400 Subject: [PATCH 8/8] add collection styler --- fooof/plts/settings.py | 3 +++ fooof/plts/style.py | 33 ++++++++++++++++++++++++++++----- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py index ff579d2e3..4b7f1050e 100644 --- a/fooof/plts/settings.py +++ b/fooof/plts/settings.py @@ -34,6 +34,9 @@ LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle', 'marker', 'ms', 'markersize'] +# Collection style arguments are those that can be defined on a collections object +COLLECTION_STYLE_ARGS = ['alpha', 'edgecolor'] + # Custom style arguments are those that are custom-handled by the plot style function CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize', 'legend_size', 'legend_loc'] diff --git a/fooof/plts/style.py b/fooof/plts/style.py index 0c3acb484..dc72a1142 100644 --- a/fooof/plts/style.py +++ b/fooof/plts/style.py @@ -5,9 +5,9 @@ import matplotlib.pyplot as plt -from fooof.plts.settings import AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS -from fooof.plts.settings import (LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC, - TICK_LABELSIZE, TITLE_FONTSIZE) +from fooof.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS, + CUSTOM_STYLE_ARGS, STYLE_ARGS, LABEL_SIZE, LEGEND_SIZE, + LEGEND_LOC, TICK_LABELSIZE, TITLE_FONTSIZE) ################################################################################################### ################################################################################################### @@ -119,6 +119,27 @@ def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs): line.set(**{style : next(values)}) +def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs): + """Apply collection plot style. + + Parameters + ---------- + ax : matplotlib.Axes + Figure axes to apply style to. + style_args : list of str + A list of arguments to be sub-selected from `kwargs` and applied as collection styling. + **kwargs + Keyword arguments that define collection style to apply. + """ + + # Get the collection related styling arguments from the keyword arguments + collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args} + + # Apply any provided collection style arguments + for collection in ax.collections: + collection.set(**collection_kwargs) + + def apply_custom_style(ax, **kwargs): """Apply custom plot style. @@ -152,14 +173,15 @@ def apply_custom_style(ax, **kwargs): def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, - custom_styler=apply_custom_style, **kwargs): + collection_styler=apply_collection_style, custom_styler=apply_custom_style, + **kwargs): """Apply plot style to a figure axis. Parameters ---------- ax : matplotlib.Axes Figure axes to apply style to. - axis_styler, line_styler, custom_styler : callable, optional + axis_styler, line_styler, collection_style, custom_styler : callable, optional Functions to apply style to aspects of the plot. **kwargs Keyword arguments that define style to apply. @@ -172,6 +194,7 @@ def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style, axis_styler(ax, **kwargs) line_styler(ax, **kwargs) + collection_styler(ax, **kwargs) custom_styler(ax, **kwargs)