From 4395e14026912ca6b01abcfab5ed41fd9cd850e9 Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Mon, 29 Jun 2020 21:21:36 -0700
Subject: [PATCH 1/8] plot_style and savefig decorators added to plotting
---
fooof/objs/fit.py | 6 +-
fooof/objs/group.py | 4 +-
fooof/plts/annotate.py | 4 +-
fooof/plts/aperiodic.py | 8 +-
fooof/plts/error.py | 6 +-
fooof/plts/fg.py | 38 +++++---
fooof/plts/fm.py | 20 ++---
fooof/plts/periodic.py | 8 +-
fooof/plts/settings.py | 21 +++++
fooof/plts/spectra.py | 12 ++-
fooof/plts/style.py | 194 ++++++++++++++++++++++++++++++++++++++++
fooof/plts/utils.py | 22 +++++
12 files changed, 305 insertions(+), 38 deletions(-)
diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py
index e5b007b77..7b31fbac2 100644
--- a/fooof/objs/fit.py
+++ b/fooof/objs/fit.py
@@ -616,12 +616,12 @@ def get_results(self):
@copy_doc_func_to_method(plot_fm)
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
add_legend=True, save_fig=False, file_name=None, file_path=None,
- ax=None, plot_style=style_spectrum_plot,
- data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
+ ax=None, plot_style=style_spectrum_plot, data_kwargs=None, model_kwargs=None,
+ aperiodic_kwargs=None, peak_kwargs=None, **kwargs):
plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend,
save_fig, file_name, file_path, ax, plot_style,
- data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs)
+ data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs, **kwargs)
@copy_doc_func_to_method(save_report_fm)
diff --git a/fooof/objs/group.py b/fooof/objs/group.py
index 8bd293d3a..5c0e0b8d8 100644
--- a/fooof/objs/group.py
+++ b/fooof/objs/group.py
@@ -398,9 +398,9 @@ def get_params(self, name, col=None):
@copy_doc_func_to_method(plot_fg)
- def plot(self, save_fig=False, file_name=None, file_path=None):
+ def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs):
- plot_fg(self, save_fig, file_name, file_path)
+ plot_fg(self, save_fig, file_name, file_path, **kwargs)
@copy_doc_func_to_method(save_report_fg)
diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py
index 04faacf35..96161239f 100644
--- a/fooof/plts/annotate.py
+++ b/fooof/plts/annotate.py
@@ -7,7 +7,7 @@
from fooof.core.funcs import gaussian_function
from fooof.core.modutils import safe_import, check_dependency
from fooof.sim.gen import gen_aperiodic
-from fooof.plts.utils import check_ax
+from fooof.plts.utils import check_ax, savefig
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
from fooof.plts.style import check_n_style, style_spectrum_plot
@@ -20,6 +20,7 @@
###################################################################################################
###################################################################################################
+@savefig
@check_dependency(plt, 'matplotlib')
def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
"""Plot a series of plots illustrating the peak search from a flattened spectrum.
@@ -74,6 +75,7 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
check_n_style(plot_style, ax, False, True)
+@savefig
@check_dependency(plt, 'matplotlib')
def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True,
ax=None, plot_style=style_spectrum_plot):
diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py
index b20038b81..2fd8afc8a 100644
--- a/fooof/plts/aperiodic.py
+++ b/fooof/plts/aperiodic.py
@@ -7,14 +7,16 @@
from fooof.sim.gen import gen_freqs, gen_aperiodic
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_param_plot
-from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs
+from fooof.plts.style import check_n_style, style_param_plot, style_plot
+from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_params(aps, colors=None, labels=None,
ax=None, plot_style=style_param_plot, **plot_kwargs):
@@ -58,6 +60,8 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
check_n_style(plot_style, ax)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
log_freqs=False, colors=None, labels=None,
diff --git a/fooof/plts/error.py b/fooof/plts/error.py
index c870900bc..ea0f3b81f 100644
--- a/fooof/plts/error.py
+++ b/fooof/plts/error.py
@@ -5,14 +5,16 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_spectrum_plot
-from fooof.plts.utils import check_ax
+from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
+from fooof.plts.utils import check_ax, savefig
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py
index f4d121419..cd11d8d0e 100644
--- a/fooof/plts/fg.py
+++ b/fooof/plts/fg.py
@@ -10,6 +10,8 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
from fooof.plts.templates import plot_scatter_1, plot_scatter_2, plot_hist
+from fooof.plts.utils import savefig
+from fooof.plts.style import style_plot
plt = safe_import('.pyplot', 'matplotlib')
gridspec = safe_import('.gridspec', 'matplotlib')
@@ -17,8 +19,9 @@
###################################################################################################
###################################################################################################
+@savefig
@check_dependency(plt, 'matplotlib')
-def plot_fg(fg, save_fig=False, file_name=None, file_path=None):
+def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs):
"""Plot a figure with subplots visualizing the parameters from a FOOOFGroup object.
Parameters
@@ -44,26 +47,27 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None):
fig = plt.figure(figsize=PLT_FIGSIZES['group'])
gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2])
+ # Apply scatter kwargs to all subplots
+ scatter_kwargs = kwargs
+ scatter_kwargs['all_axes'] = True
+
# Aperiodic parameters plot
ax0 = plt.subplot(gs[0, 0])
- plot_fg_ap(fg, ax0)
+ plot_fg_ap(fg, ax0, **scatter_kwargs)
# Goodness of fit plot
ax1 = plt.subplot(gs[0, 1])
- plot_fg_gf(fg, ax1)
+ plot_fg_gf(fg, ax1, **scatter_kwargs)
# Center frequencies plot
ax2 = plt.subplot(gs[1, :])
- plot_fg_peak_cens(fg, ax2)
-
- if save_fig:
- if not file_name:
- raise ValueError("Input 'file_name' is required to save out the plot.")
- plt.savefig(fpath(file_path, fname(file_name, 'png')))
+ plot_fg_peak_cens(fg, ax2, **kwargs)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_ap(fg, ax=None):
+def plot_fg_ap(fg, ax=None, **kwargs):
"""Plot aperiodic fit parameters, in a scatter plot.
Parameters
@@ -72,6 +76,8 @@ def plot_fg_ap(fg, ax=None):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
+ **kwargs
+ Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
"""
if fg.aperiodic_mode == 'knee':
@@ -83,8 +89,10 @@ def plot_fg_ap(fg, ax=None):
'Aperiodic Fit', ax=ax)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_gf(fg, ax=None):
+def plot_fg_gf(fg, ax=None, **kwargs):
"""Plot goodness of fit results, in a scatter plot.
Parameters
@@ -93,14 +101,18 @@ def plot_fg_gf(fg, ax=None):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
+ **kwargs
+ Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
"""
plot_scatter_2(fg.get_params('error'), 'Error',
fg.get_params('r_squared'), 'R^2', 'Goodness of Fit', ax=ax)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_peak_cens(fg, ax=None):
+def plot_fg_peak_cens(fg, ax=None, **kwargs):
"""Plot peak center frequencies, in a histogram.
Parameters
@@ -109,6 +121,8 @@ def plot_fg_peak_cens(fg, ax=None):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
+ **kwargs
+ Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
"""
plot_hist(fg.get_params('peak_params', 0)[:, 0], 'Center Frequency',
diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py
index 6818e4840..2fb2fad7f 100644
--- a/fooof/plts/fm.py
+++ b/fooof/plts/fm.py
@@ -7,7 +7,6 @@
import numpy as np
-from fooof.core.io import fname, fpath
from fooof.core.utils import nearest_ind
from fooof.core.modutils import safe_import, check_dependency
from fooof.sim.gen import gen_periodic
@@ -15,19 +14,20 @@
from fooof.utils.params import compute_fwhm
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
-from fooof.plts.utils import check_ax, check_plot_kwargs
-from fooof.plts.style import check_n_style, style_spectrum_plot
+from fooof.plts.utils import check_ax, check_plot_kwargs, savefig
+from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True,
- save_fig=False, file_name=None, file_path=None,
- ax=None, plot_style=style_spectrum_plot,
- data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None):
+ save_fig=False, file_name=None, file_path=None, ax=None, plot_style=style_spectrum_plot,
+ data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **kwargs):
"""Plot the power spectrum and model fit results from a FOOOF object.
Parameters
@@ -55,6 +55,8 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
A function to call to apply styling & aesthetics to the plot.
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
+ **kwargs
+ Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
Notes
-----
@@ -99,12 +101,6 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
# Apply style to plot
check_n_style(plot_style, ax, log_freqs, True)
- # Save out figure, if requested
- if save_fig:
- if not file_name:
- raise ValueError("Input 'file_name' is required to save out the plot.")
- plt.savefig(fpath(file_path, fname(file_name, 'png')))
-
def _add_peaks(fm, approach, plt_log, ax, peak_kwargs):
"""Add peaks to a model plot.
diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py
index e654bfabc..0bd73e03d 100644
--- a/fooof/plts/periodic.py
+++ b/fooof/plts/periodic.py
@@ -8,14 +8,16 @@
from fooof.core.funcs import gaussian_function
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_param_plot
-from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs
+from fooof.plts.style import check_n_style, style_param_plot, style_plot
+from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
ax=None, plot_style=style_param_plot, **plot_kwargs):
@@ -69,6 +71,8 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
check_n_style(plot_style, ax)
+@savefig
+@style_plot
def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
ax=None, plot_style=style_param_plot, **plot_kwargs):
"""Plot reconstructions of model peak fits.
diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py
index 94d525054..b695ab40c 100644
--- a/fooof/plts/settings.py
+++ b/fooof/plts/settings.py
@@ -26,3 +26,24 @@
PLT_ALIASES = {'linewidth' : ['lw', 'linewidth'],
'markersize' : ['ms', 'markersize'],
'linestyle' : ['ls', 'linestyle']}
+
+# Plot style arguments are those that can be defined on an axis object
+AXIS_STYLE_ARGS = ['title', 'xlabel', 'ylabel', 'xlim', 'ylim']
+
+# Line style arguments are those that can be defined on a line object
+LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
+ 'marker', 'ms', 'markersize', 'color']
+
+# Custom style arguments are those that are custom-handled by the plot style function
+CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
+ 'legend_size', 'legend_loc']
+STYLERS = ['axis_styler', 'line_styler', 'custom_styler']
+STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS
+
+## Define default values for aesthetic
+# These are all custom style arguments
+TITLE_FONTSIZE = 20
+LABEL_SIZE = 16
+TICK_LABELSIZE = 16
+LEGEND_SIZE = 12
+LEGEND_LOC = 'best'
diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py
index eaeafdbe7..24e207770 100644
--- a/fooof/plts/spectra.py
+++ b/fooof/plts/spectra.py
@@ -11,14 +11,16 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_spectrum_plot
-from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs
+from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
+from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs, savefig
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
@@ -55,6 +57,8 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False,
check_n_style(plot_style, ax, log_freqs, log_powers)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels=None,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
@@ -93,6 +97,8 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels
check_n_style(plot_style, ax, log_freqs, log_powers)
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_center=False,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
@@ -129,6 +135,8 @@ def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_c
plot_kwargs.get('log_powers', False))
+@savefig
+@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_center=False,
ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
diff --git a/fooof/plts/style.py b/fooof/plts/style.py
index f9dbcfe80..f1353ba85 100644
--- a/fooof/plts/style.py
+++ b/fooof/plts/style.py
@@ -1,5 +1,15 @@
"""Style and aesthetics definitions for plots."""
+from itertools import cycle
+from functools import wraps
+import warnings
+
+import matplotlib.pyplot as plt
+
+from fooof.plts.settings import AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS
+from fooof.plts.settings import (LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC,
+ TICK_LABELSIZE, TITLE_FONTSIZE)
+
###################################################################################################
###################################################################################################
@@ -75,3 +85,187 @@ def style_param_plot(ax):
legend = ax.legend(prop={'size': 16})
for handle in legend.legendHandles:
handle._sizes = [100]
+
+
+# Additional plot style customization
+
+def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
+ """Apply axis plot style.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ Figure axes to apply style to.
+ style_args : list of str
+ A list of arguments to be sub-selected from `kwargs` and applied as axis styling.
+ **kwargs
+ Keyword arguments that define plot style to apply.
+ """
+
+ # Apply any provided axis style arguments
+ plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
+ ax.set(**plot_kwargs)
+
+
+def apply_plot_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
+ """Apply line/scatter/histogram plot style.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ Figure axes to apply style to.
+ style_args : list of str
+ A list of arguments to be sub-selected from `kwargs` and applied as styling.
+ **kwargs
+ Keyword arguments that define style to apply.
+ """
+
+ # Get the plot object related styling arguments from the keyword arguments
+ style_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
+
+ # For line plots
+ if len(ax.lines) > 0:
+ plot_objs = ax.lines
+
+ # For scatter plots
+ elif len(ax.collections) > 0:
+ plot_objs = ax.collections
+
+ # For histograms
+ elif len(ax.patches) > 0:
+ plot_objs = ax.patches
+
+ # There is no styling to apply
+ else:
+ return
+
+ plot_objs = [plot_objs] if not isinstance(plot_objs, list) else plot_objs
+
+ # Apply any provided plot object style arguments
+ for style, value in style_kwargs.items():
+
+ # Values should be either a single value, for all plot objects, or a list, one value per
+ # object. This line checks type, and makes a cycle-able / loop-able object out of the values
+ values = cycle([value] if isinstance(value, (int, float, str)) else value)
+
+ # For line plots
+ for plot_obj in plot_objs:
+ plot_obj.set(**{style : next(values)})
+
+
+def apply_custom_style(ax, **kwargs):
+ """Apply custom plot style.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ Figure axes to apply style to.
+ **kwargs
+ Keyword arguments that define custom style to apply.
+ """
+
+ # If a title was provided, update the size
+ if ax.get_title():
+ ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE))
+
+ # Settings for the axis labels
+ label_size = kwargs.pop('label_size', LABEL_SIZE)
+ ax.xaxis.label.set_size(label_size)
+ ax.yaxis.label.set_size(label_size)
+
+ # Settings for the axis ticks
+ ax.tick_params(axis='both', which='major',
+ labelsize=kwargs.pop('tick_labelsize', TICK_LABELSIZE))
+
+ # If labels were provided, add a legend
+ if ax.get_legend_handles_labels()[0]:
+ ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
+ loc=kwargs.pop('legend_loc', LEGEND_LOC))
+
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore')
+ plt.tight_layout()
+
+
+def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_plot_style,
+ custom_styler=apply_custom_style, **kwargs):
+ """Apply plot style to a figure axis.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ Figure axes to apply style to.
+ axis_styler, line_styler, custom_styler : callable, optional
+ Functions to apply style to aspects of the plot.
+ **kwargs
+ Keyword arguments that define style to apply.
+
+ Notes
+ -----
+ This function wraps sub-functions which apply style to different plot elements.
+ Each of these sub-functions can be replaced by passing in replacement callables.
+ """
+
+ axis_styler(ax, **kwargs)
+ line_styler(ax, **kwargs)
+ custom_styler(ax, **kwargs)
+
+
+def style_plot(func, *args, **kwargs):
+ """Decorator function to apply a plot style function, after plot generation.
+
+ Parameters
+ ----------
+ func : callable
+ The plotting function for creating a plot.
+ *args, **kwargs
+ Arguments & keyword arguments.
+ These should include any arguments for the plot, and those for applying plot style.
+
+ Notes
+ -----
+ This decorator works by:
+
+ - catching all inputs that relate to plot style
+ - creating a plot, using the passed in plotting function & passing in all non-style arguments
+ - passing the style related arguments into a `apply_style` function which applies plot styling
+
+ By default, this function applies styling with the `apply_style` function. Custom
+ functions for applying style can be passed in using `apply_style` as a keyword argument.
+
+ The `apply_style` function calls sub-functions for applying style different plot elements,
+ and these sub-functions can be overridden by passing in alternatives for `axis_styler`,
+ `line_styler`, and `custom_styler`.
+ """
+
+ @wraps(func)
+ def decorated(*args, **kwargs):
+
+ # Grab a custom style function, if provided, and grab any provided style arguments
+ style_func = kwargs.pop('apply_style', apply_style)
+ style_args = kwargs.pop('style_args', STYLE_ARGS)
+ style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs}
+
+ # Check how many lines are already on the plot, if it exists already
+ n_lines_pre = len(kwargs['ax'].lines) if 'ax' in kwargs and kwargs['ax'] is not None else 0
+
+ # Create the plot
+ func(*args, **kwargs)
+
+ # Get plot axis, if a specific one was provided, or if not, grab the current axis
+ cur_ax = kwargs['ax'] if 'ax' in kwargs and kwargs['ax'] is not None else plt.gca()
+
+ # Check how many lines were added to the plot, and make info available to plot styling
+ n_lines_apply = len(cur_ax.lines) - n_lines_pre
+ style_kwargs['n_lines_apply'] = n_lines_apply
+
+ # Determine if styling should be applied to all axes
+ all_axes = kwargs.pop('all_axes', False)
+ cur_ax = plt.gcf().get_axes() if all_axes is True else cur_ax
+ cur_ax = [cur_ax] if not isinstance(cur_ax, list) else cur_ax
+
+ # Apply the styling function
+ for ax in cur_ax:
+ style_func(ax, **style_kwargs)
+
+ return decorated
diff --git a/fooof/plts/utils.py b/fooof/plts/utils.py
index ef5b53901..0a970ee91 100644
--- a/fooof/plts/utils.py
+++ b/fooof/plts/utils.py
@@ -8,9 +8,11 @@
from itertools import repeat
from collections.abc import Iterator
+from functools import wraps
import numpy as np
+from fooof.core.io import fname, fpath
from fooof.core.modutils import safe_import
from fooof.core.utils import resolve_aliases
from fooof.plts.settings import PLT_ALPHA_LEVELS, PLT_ALIASES
@@ -171,3 +173,23 @@ def check_plot_kwargs(plot_kwargs, defaults):
plot_kwargs[key] = value
return plot_kwargs
+
+
+def savefig(func):
+ """Decorator function to save out figures."""
+
+ @wraps(func)
+ def decorated(*args, **kwargs):
+
+ save_fig = kwargs.pop('save_fig', False)
+ file_name = kwargs.pop('file_name', None)
+ file_path = kwargs.pop('file_path', None)
+
+ func(*args, **kwargs)
+
+ if save_fig:
+ if not file_name:
+ raise ValueError("Input 'file_name' is required to save out the plot.")
+ plt.savefig(fpath(file_path, fname(file_name, 'png')))
+
+ return decorated
From 871121578a54ffe7c4ab421ea06add09fb324187 Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Mon, 29 Jun 2020 21:21:53 -0700
Subject: [PATCH 2/8] tests updated
---
fooof/tests/plts/test_styles.py | 98 +++++++++++++++++++++++++++++++++
fooof/tests/plts/test_utils.py | 13 +++++
2 files changed, 111 insertions(+)
diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py
index 75b002d1b..caf09e3eb 100644
--- a/fooof/tests/plts/test_styles.py
+++ b/fooof/tests/plts/test_styles.py
@@ -1,5 +1,6 @@
"""Tests for fooof.plts.styles."""
+from fooof.tests.tutils import plot_test
from fooof.plts.style import *
###################################################################################################
@@ -27,3 +28,100 @@ def test_style_spectrum_plot(skip_if_no_mpl):
# Check that axis labels are added - use as proxy that it ran correctly
assert ax.xaxis.get_label().get_text()
assert ax.yaxis.get_label().get_text()
+
+
+def test_apply_axis_style():
+
+ _, ax = plt.subplots()
+
+ title = 'Ploty McPlotface'
+ xlim = (1.0, 10.0)
+ ylabel = 'Line Value'
+
+ apply_axis_style(ax, title=title, xlim=xlim, ylabel=ylabel)
+
+ assert ax.get_title() == title
+ assert ax.get_xlim() == xlim
+ assert ax.get_ylabel() == ylabel
+
+
+def test_apply_plot_style():
+
+ # Check applying style to one line
+ _, ax = plt.subplots()
+ ax.plot([1, 2], [3, 4])
+
+ lw = 4
+ apply_plot_style(ax, lw=lw)
+
+ assert ax.get_lines()[0].get_lw() == lw
+
+ # Check applying style across multiple lines
+ _, ax = plt.subplots()
+ ax.plot([1, 2], [[3, 4], [5, 6]])
+
+ alphas = [0.5, 0.75]
+ apply_plot_style(ax, alpha=alphas)
+
+ for line, alpha in zip(ax.get_lines(), alphas):
+ assert line.get_alpha() == alpha
+
+ # Check applying style to a scatter plot
+ _, ax = plt.subplots()
+ ax.scatter([1, 2], [2, 4])
+ apply_plot_style(ax, alpha=0.123)
+ assert ax.collections[0]._alpha == 0.123
+
+ # Check applying style to a histogram
+ _, ax = plt.subplots()
+ ax.hist([1, 2, 3])
+ apply_plot_style(ax, alpha=0.123)
+ assert ax.patches[0]._alpha == 0.123
+
+
+def test_apply_custom_style():
+
+ _, ax = plt.subplots()
+ ax.set_title('placeholder')
+
+ # Test simple application of custom plot style
+ apply_custom_style(ax)
+ assert ax.title.get_size() == TITLE_FONTSIZE
+
+ # Test adding input parameters to custom plot style
+ new_title_fontsize = 15.0
+ apply_custom_style(ax, title_fontsize=new_title_fontsize)
+ assert ax.title.get_size() == new_title_fontsize
+
+
+def test_apply_style():
+
+ _, ax = plt.subplots()
+
+ def my_custom_styler(ax, **kwargs):
+ ax.set_title('DATA!')
+
+ # Apply plot style using all defaults
+ apply_style(ax)
+
+ # Apply plot style passing in a styler
+ apply_style(ax, custom_styler=my_custom_styler)
+
+
+@plot_test
+def test_style_plot():
+
+ @style_plot
+ def example_plot():
+ plt.plot([1, 2], [3, 4])
+
+ def my_plot_style(ax, **kwargs):
+ ax.set_title('Custom!')
+
+ # Test with applying default custom styling
+ lw = 5
+ title = 'Science.'
+ example_plot(title=title, lw=lw)
+
+ # Test with passing in own plot_style function
+ example_plot(apply_style=my_plot_style)
diff --git a/fooof/tests/plts/test_utils.py b/fooof/tests/plts/test_utils.py
index 51a2117f4..89956717f 100644
--- a/fooof/tests/plts/test_utils.py
+++ b/fooof/tests/plts/test_utils.py
@@ -1,5 +1,8 @@
"""Tests for fooof.plts.utils."""
+import os
+import tempfile
+
from fooof.tests.tutils import plot_test
from fooof.core.modutils import safe_import
@@ -69,3 +72,13 @@ def test_check_plot_kwargs(skip_if_no_mpl):
assert len(plot_kwargs) == 2
assert plot_kwargs['alpha'] == 0.5
assert plot_kwargs['linewidth'] == 2
+
+def test_savefig():
+
+ @savefig
+ def example_plot():
+ plt.plot([1, 2], [3, 4])
+
+ with tempfile.NamedTemporaryFile(mode='w+') as file:
+ example_plot(save_fig=True, file_name=file.name)
+ assert os.path.exists(file.name)
From 8dc0b0892dab97cd964d6aa2807c21bd5df26290 Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Wed, 5 Aug 2020 17:25:24 -0700
Subject: [PATCH 3/8] consolidated plotting approach
---
fooof/objs/fit.py | 9 +++--
fooof/objs/group.py | 2 +-
fooof/plts/annotate.py | 38 ++++++++----------
fooof/plts/aperiodic.py | 56 +++++++++++++-------------
fooof/plts/error.py | 13 +++---
fooof/plts/fg.py | 24 ++++++------
fooof/plts/fm.py | 87 +++++++++++++++++++----------------------
fooof/plts/periodic.py | 48 ++++++++++++-----------
fooof/plts/settings.py | 2 +-
fooof/plts/spectra.py | 79 ++++++++++++++++++-------------------
fooof/plts/style.py | 80 ++++++++++---------------------------
11 files changed, 192 insertions(+), 246 deletions(-)
diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py
index 7b31fbac2..fb274118f 100644
--- a/fooof/objs/fit.py
+++ b/fooof/objs/fit.py
@@ -616,12 +616,13 @@ def get_results(self):
@copy_doc_func_to_method(plot_fm)
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
add_legend=True, save_fig=False, file_name=None, file_path=None,
- ax=None, plot_style=style_spectrum_plot, data_kwargs=None, model_kwargs=None,
+ ax=None, data_kwargs=None, model_kwargs=None,
aperiodic_kwargs=None, peak_kwargs=None, **kwargs):
- plot_fm(self, plot_peaks, plot_aperiodic, plt_log, add_legend,
- save_fig, file_name, file_path, ax, plot_style,
- data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs, **kwargs)
+ plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
+ add_legend=add_legend, save_fig=save_fig, file_name=file_name,
+ file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
+ aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **kwargs)
@copy_doc_func_to_method(save_report_fm)
diff --git a/fooof/objs/group.py b/fooof/objs/group.py
index 5c0e0b8d8..c3f7bdc2d 100644
--- a/fooof/objs/group.py
+++ b/fooof/objs/group.py
@@ -400,7 +400,7 @@ def get_params(self, name, col=None):
@copy_doc_func_to_method(plot_fg)
def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs):
- plot_fg(self, save_fig, file_name, file_path, **kwargs)
+ plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **kwargs)
@copy_doc_func_to_method(save_report_fg)
diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py
index 96161239f..d67befb24 100644
--- a/fooof/plts/annotate.py
+++ b/fooof/plts/annotate.py
@@ -10,7 +10,7 @@
from fooof.plts.utils import check_ax, savefig
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
-from fooof.plts.style import check_n_style, style_spectrum_plot
+from fooof.plts.style import style_spectrum_plot
from fooof.analysis.periodic import get_band_peak_fm
from fooof.utils.params import compute_knee_frequency, compute_fwhm
@@ -22,15 +22,13 @@
@savefig
@check_dependency(plt, 'matplotlib')
-def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
+def plot_annotated_peak_search(fm):
"""Plot a series of plots illustrating the peak search from a flattened spectrum.
Parameters
----------
fm : FOOOF
FOOOF object, with model fit, data and settings available.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plots.
"""
# Recalculate the initial aperiodic fit and flattened spectrum that
@@ -47,14 +45,12 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
# This forces the creation of a new plotting axes per iteration
ax = check_ax(None, PLT_FIGSIZES['spectral'])
- plot_spectrum(fm.freqs, flatspec, ax=ax, plot_style=None,
- label='Flattened Spectrum', color=PLT_COLORS['data'], linewidth=2.5)
- plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs),
- ax=ax, plot_style=None, label='Relative Threshold',
- color='orange', linewidth=2.5, linestyle='dashed')
- plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs),
- ax=ax, plot_style=None, label='Absolute Threshold',
- color='red', linewidth=2.5, linestyle='dashed')
+ plot_spectrum(fm.freqs, flatspec, ax=ax, linewidth=2.5,
+ label='Flattened Spectrum', color=PLT_COLORS['data'])
+ plot_spectrum(fm.freqs, [fm.peak_threshold * np.std(flatspec)]*len(fm.freqs), ax=ax,
+ label='Relative Threshold', color='orange', linewidth=2.5, linestyle='dashed')
+ plot_spectrum(fm.freqs, [fm.min_peak_height]*len(fm.freqs), ax=ax,
+ label='Absolute Threshold', color='red', linewidth=2.5, linestyle='dashed')
maxi = np.argmax(flatspec)
ax.plot(fm.freqs[maxi], flatspec[maxi], '.',
@@ -66,19 +62,18 @@ def plot_annotated_peak_search(fm, plot_style=style_spectrum_plot):
if ind < fm.n_peaks_:
gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
- plot_spectrum(fm.freqs, gauss, ax=ax, plot_style=None,
- label='Gaussian Fit', color=PLT_COLORS['periodic'],
- linestyle=':', linewidth=3.0)
+ plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
+ color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0)
flatspec = flatspec - gauss
- check_n_style(plot_style, ax, False, True)
+ style_spectrum_plot(ax, False, True)
@savefig
@check_dependency(plt, 'matplotlib')
-def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperiodic=True,
- ax=None, plot_style=style_spectrum_plot):
+def plot_annotated_model(fm, plt_log=False, annotate_peaks=True,
+ annotate_aperiodic=True, ax=None):
"""Plot a an annotated power spectrum and model, from a FOOOF object.
Parameters
@@ -89,8 +84,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
Whether to plot the frequency values in log10 spacing.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plots.
+
Raises
------
@@ -110,7 +104,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
# Create the baseline figure
ax = check_ax(ax, PLT_FIGSIZES['spectral'])
- fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax, plot_style=None,
+ fm.plot(plot_peaks='dot-shade-width', plt_log=plt_log, ax=ax,
data_kwargs={'lw' : lw1, 'alpha' : 0.6},
aperiodic_kwargs={'lw' : lw1, 'zorder' : 10},
model_kwargs={'lw' : lw1, 'alpha' : 0.5},
@@ -217,7 +211,7 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True, annotate_aperio
color=PLT_COLORS['aperiodic'], fontsize=fontsize)
# Apply style to plot & tune grid styling
- check_n_style(plot_style, ax, plt_log, True)
+ style_spectrum_plot(ax, plt_log, True)
ax.grid(True, alpha=0.5)
# Add labels to plot in the legend
diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py
index 2fd8afc8a..f69bd36bb 100644
--- a/fooof/plts/aperiodic.py
+++ b/fooof/plts/aperiodic.py
@@ -3,12 +3,13 @@
from itertools import cycle
import numpy as np
+import matplotlib.pyplot as plt
from fooof.sim.gen import gen_freqs, gen_aperiodic
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_param_plot, style_plot
-from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig
+from fooof.plts.style import style_param_plot, style_plot
+from fooof.plts.utils import check_ax, recursive_plot, savefig
plt = safe_import('.pyplot', 'matplotlib')
@@ -18,8 +19,7 @@
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_aperiodic_params(aps, colors=None, labels=None,
- ax=None, plot_style=style_param_plot, **plot_kwargs):
+def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot aperiodic parameters as dots representing offset and exponent value.
Parameters
@@ -32,17 +32,14 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_param_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to pass into the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
if isinstance(aps, list):
- recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels,
- plot_style=plot_style, **plot_kwargs)
+ recursive_plot(aps, plot_aperiodic_params, ax, colors=colors, labels=labels)
else:
@@ -50,14 +47,26 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
xs, ys = aps[:, 0], aps[:, -1]
sizes = plot_kwargs.pop('s', 150)
- plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
- ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
+ colors = 'C0' if colors is None else colors
+ colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors)
+ labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels)
+
+ for xi, yi in zip(xs, ys):
+
+ # Prevent duplicate labels when recursively plotting
+ _, cur_labels = plt.gca().get_legend_handles_labels()
+
+ label = next(labels)
+ if label not in cur_labels:
+ ax.scatter(xi, yi, s=sizes, color=next(colors), label=label, alpha=0.7)
+ else:
+ ax.scatter(xi, yi, s=sizes, color=next(colors), alpha=0.7)
# Add axis labels
ax.set_xlabel('Offset')
ax.set_ylabel('Exponent')
- check_n_style(plot_style, ax)
+ style_param_plot(ax)
@savefig
@@ -65,7 +74,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None,
@check_dependency(plt, 'matplotlib')
def plot_aperiodic_fits(aps, freq_range, control_offset=False,
log_freqs=False, colors=None, labels=None,
- ax=None, plot_style=style_param_plot, **plot_kwargs):
+ ax=None, **plot_kwargs):
"""Plot reconstructions of model aperiodic fits.
Parameters
@@ -84,10 +93,8 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_param_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to pass into the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
@@ -97,11 +104,9 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
if not colors:
colors = cycle(plt.rcParams['axes.prop_cycle'].by_key()['color'])
- recursive_plot(aps, plot_function=plot_aperiodic_fits, ax=ax,
- freq_range=tuple(freq_range),
- control_offset=control_offset,
- log_freqs=log_freqs, colors=colors, labels=labels,
- plot_style=plot_style, **plot_kwargs)
+ recursive_plot(aps, plot_aperiodic_fits, ax=ax, freq_range=tuple(freq_range),
+ control_offset=control_offset, log_freqs=log_freqs, colors=colors,
+ labels=labels, **plot_kwargs)
else:
freqs = gen_freqs(freq_range, 0.1)
@@ -122,8 +127,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
# Recreate & plot the aperiodic component from parameters
ap_vals = gen_aperiodic(freqs, ap_params)
- plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25})
- ax.plot(plt_freqs, ap_vals, color=colors, **plot_kwargs)
+ ax.plot(plt_freqs, ap_vals, color=colors, alpha=0.35, linewidth=1.25)
# Collect a running average across components
avg_vals = np.nansum(np.vstack([avg_vals, ap_vals]), axis=0)
@@ -131,8 +135,7 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
# Plot the average component
avg = avg_vals / aps.shape[0]
avg_color = 'black' if not colors else colors
- ax.plot(plt_freqs, avg, linewidth=plot_kwargs.get('linewidth')*3,
- color=avg_color, label=labels)
+ ax.plot(plt_freqs, avg, linewidth=3.75, color=avg_color, label=labels)
# Add axis labels
ax.set_xlabel('log(Frequency)' if log_freqs else 'Frequency')
@@ -141,5 +144,4 @@ def plot_aperiodic_fits(aps, freq_range, control_offset=False,
# Set plot limit
ax.set_xlim(np.log10(freq_range) if log_freqs else freq_range)
- # Apply plot style
- check_n_style(plot_style, ax)
+ style_param_plot(ax)
diff --git a/fooof/plts/error.py b/fooof/plts/error.py
index ea0f3b81f..f7cbfdf7b 100644
--- a/fooof/plts/error.py
+++ b/fooof/plts/error.py
@@ -5,7 +5,7 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
+from fooof.plts.style import style_spectrum_plot, style_plot
from fooof.plts.utils import check_ax, savefig
plt = safe_import('.pyplot', 'matplotlib')
@@ -16,8 +16,7 @@
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
- ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
+def plot_spectral_error(freqs, error, shade=None, log_freqs=False, ax=None, **plot_kwargs):
"""Plot frequency by frequency error values.
Parameters
@@ -33,17 +32,15 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
Whether to plot the frequency axis in log spacing.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to be passed to `plot_spectra` or to the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
plt_freqs = np.log10(freqs) if log_freqs else freqs
- plot_spectrum(plt_freqs, error, plot_style=None, ax=ax, linewidth=3, **plot_kwargs)
+ plot_spectrum(plt_freqs, error, ax=ax, linewidth=3)
if np.any(shade):
ax.fill_between(plt_freqs, error-shade, error+shade, alpha=0.25)
@@ -53,5 +50,5 @@ def plot_spectral_error(freqs, error, shade=None, log_freqs=False,
ax.set_ylim([0, ymax])
ax.set_xlim(plt_freqs.min(), plt_freqs.max())
- check_n_style(plot_style, ax, log_freqs, True)
+ style_spectrum_plot(ax, log_freqs, True)
ax.set_ylabel('Absolute Error')
diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py
index cd11d8d0e..342065c64 100644
--- a/fooof/plts/fg.py
+++ b/fooof/plts/fg.py
@@ -21,7 +21,7 @@
@savefig
@check_dependency(plt, 'matplotlib')
-def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs):
+def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
"""Plot a figure with subplots visualizing the parameters from a FOOOFGroup object.
Parameters
@@ -48,7 +48,7 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs):
gs = gridspec.GridSpec(2, 2, wspace=0.4, hspace=0.25, height_ratios=[1, 1.2])
# Apply scatter kwargs to all subplots
- scatter_kwargs = kwargs
+ scatter_kwargs = plot_kwargs
scatter_kwargs['all_axes'] = True
# Aperiodic parameters plot
@@ -61,13 +61,13 @@ def plot_fg(fg, save_fig=False, file_name=None, file_path=None, **kwargs):
# Center frequencies plot
ax2 = plt.subplot(gs[1, :])
- plot_fg_peak_cens(fg, ax2, **kwargs)
+ plot_fg_peak_cens(fg, ax2, **plot_kwargs)
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_ap(fg, ax=None, **kwargs):
+def plot_fg_ap(fg, ax=None, **plot_kwargs):
"""Plot aperiodic fit parameters, in a scatter plot.
Parameters
@@ -76,8 +76,8 @@ def plot_fg_ap(fg, ax=None, **kwargs):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- **kwargs
- Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
+ **plot_kwargs
+ Keyword arguments to pass into the ``style_plot``.
"""
if fg.aperiodic_mode == 'knee':
@@ -92,7 +92,7 @@ def plot_fg_ap(fg, ax=None, **kwargs):
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_gf(fg, ax=None, **kwargs):
+def plot_fg_gf(fg, ax=None, **plot_kwargs):
"""Plot goodness of fit results, in a scatter plot.
Parameters
@@ -101,8 +101,8 @@ def plot_fg_gf(fg, ax=None, **kwargs):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- **kwargs
- Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
+ **plot_kwargs
+ Keyword arguments to pass into the ``style_plot``.
"""
plot_scatter_2(fg.get_params('error'), 'Error',
@@ -112,7 +112,7 @@ def plot_fg_gf(fg, ax=None, **kwargs):
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_fg_peak_cens(fg, ax=None, **kwargs):
+def plot_fg_peak_cens(fg, ax=None, **plot_kwargs):
"""Plot peak center frequencies, in a histogram.
Parameters
@@ -121,8 +121,8 @@ def plot_fg_peak_cens(fg, ax=None, **kwargs):
Object to plot data from.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- **kwargs
- Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
+ **plot_kwargs
+ Keyword arguments to pass into the ``style_plot``.
"""
plot_hist(fg.get_params('peak_params', 0)[:, 0], 'Center Frequency',
diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py
index 2fb2fad7f..54a76a8a2 100644
--- a/fooof/plts/fm.py
+++ b/fooof/plts/fm.py
@@ -15,7 +15,7 @@
from fooof.plts.spectra import plot_spectrum
from fooof.plts.settings import PLT_FIGSIZES, PLT_COLORS
from fooof.plts.utils import check_ax, check_plot_kwargs, savefig
-from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
+from fooof.plts.style import style_spectrum_plot, style_plot
plt = safe_import('.pyplot', 'matplotlib')
@@ -26,8 +26,8 @@
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=True,
- save_fig=False, file_name=None, file_path=None, ax=None, plot_style=style_spectrum_plot,
- data_kwargs=None, model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **kwargs):
+ save_fig=False, file_name=None, file_path=None, ax=None, data_kwargs=None,
+ model_kwargs=None, aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
"""Plot the power spectrum and model fit results from a FOOOF object.
Parameters
@@ -51,12 +51,10 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
Path to directory to save to. If None, saves to current directory.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
data_kwargs, model_kwargs, aperiodic_kwargs, peak_kwargs : None or dict, optional
Keyword arguments to pass into the plot call for each plot element.
- **kwargs
- Keyword arguments for customizing the plot, passed to the 'style_plot' decorator.
+ **plot_kwargs
+ Keyword arguments to pass into the ``style_plot``.
Notes
-----
@@ -72,34 +70,29 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
# Plot the data, if available
if fm.has_data:
- data_kwargs = check_plot_kwargs(data_kwargs, \
- {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
- 'label' : 'Original Spectrum' if add_legend else None})
- plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers,
- ax=ax, plot_style=None, **data_kwargs)
+ data_kwargs = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
+ 'label' : 'Original Spectrum' if add_legend else None}
+ plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs)
# Add the full model fit, and components (if requested)
if fm.has_model:
- model_kwargs = check_plot_kwargs(model_kwargs, \
- {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
- 'label' : 'Full Model Fit' if add_legend else None})
- plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers,
- ax=ax, plot_style=None, **model_kwargs)
+ model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
+ 'label' : 'Full Model Fit' if add_legend else None}
+ plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs)
# Plot the aperiodic component of the model fit
if plot_aperiodic:
- aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, \
- {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0, 'alpha' : 0.5,
- 'linestyle' : 'dashed', 'label' : 'Aperiodic Fit' if add_legend else None})
- plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers,
- ax=ax, plot_style=None, **aperiodic_kwargs)
+ aperiodic_kwargs = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0,
+ 'alpha' : 0.5, 'linestyle' : 'dashed',
+ 'label' : 'Aperiodic Fit' if add_legend else None}
+ plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs)
# Plot the periodic components of the model fit
if plot_peaks:
- _add_peaks(fm, plot_peaks, plt_log, ax=ax, peak_kwargs=peak_kwargs)
+ _add_peaks(fm, plot_peaks, plt_log, ax, peak_kwargs)
- # Apply style to plot
- check_n_style(plot_style, ax, log_freqs, True)
+ # Apply default style to plot
+ style_spectrum_plot(ax, log_freqs, True)
def _add_peaks(fm, approach, plt_log, ax, peak_kwargs):
@@ -158,18 +151,18 @@ def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs):
ax : matplotlib.Axes
Figure axes upon which to plot.
**plot_kwargs
- Keyword arguments to pass into the plot call.
+ Keyword arguments to pass into the ``fill_between``.
"""
- kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25})
+ plot_kwargs = check_plot_kwargs(plot_kwargs,
+ {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25})
for peak in fm.get_params('gaussian_params'):
peak_freqs = np.log10(fm.freqs) if plt_log else fm.freqs
peak_line = fm._ap_fit + gen_periodic(fm.freqs, peak)
- ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **kwargs)
+ ax.fill_between(peak_freqs, peak_line, fm._ap_fit, **plot_kwargs)
def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs):
@@ -187,9 +180,9 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
+ plot_kwargs = check_plot_kwargs(plot_kwargs,
+ {'color' : PLT_COLORS['periodic'],
+ 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
for peak in fm.get_params('peak_params'):
@@ -197,10 +190,10 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs):
freq_point = np.log10(peak[0]) if plt_log else peak[0]
# Add the line from the aperiodic fit up the tip of the peak
- ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **kwargs)
+ ax.plot([freq_point, freq_point], [ap_point, ap_point + peak[1]], **plot_kwargs)
# Add an extra dot at the tip of the peak
- ax.plot(freq_point, ap_point + peak[1], marker='o', **kwargs)
+ ax.plot(freq_point, ap_point + peak[1], marker='o', **plot_kwargs)
def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs):
@@ -218,9 +211,9 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.7, 'lw' : 1.5})
+ plot_kwargs = check_plot_kwargs(plot_kwargs,
+ {'color' : PLT_COLORS['periodic'],
+ 'alpha' : 0.7, 'lw' : 1.5})
for peak in fm.get_params('gaussian_params'):
@@ -233,7 +226,7 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs):
# Plot the peak outline
peak_freqs = np.log10(peak_freqs) if plt_log else peak_freqs
- ax.plot(peak_freqs, peak_line, **kwargs)
+ ax.plot(peak_freqs, peak_line, **plot_kwargs)
def _add_peaks_line(fm, plt_log, ax, **plot_kwargs):
@@ -251,16 +244,16 @@ def _add_peaks_line(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10})
+ plot_kwargs = check_plot_kwargs(plot_kwargs,
+ {'color' : PLT_COLORS['periodic'],
+ 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10})
ylims = ax.get_ylim()
for peak in fm.get_params('peak_params'):
freq_point = np.log10(peak[0]) if plt_log else peak[0]
- ax.plot([freq_point, freq_point], ylims, '-', **kwargs)
- ax.plot(freq_point, ylims[1], 'v', **kwargs)
+ ax.plot([freq_point, freq_point], ylims, '-', **plot_kwargs)
+ ax.plot(freq_point, ylims[1], 'v', **plot_kwargs)
def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
@@ -283,9 +276,9 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
the peak, though what is literally plotted is the full-width half-max.
"""
- kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
+ plot_kwargs = check_plot_kwargs(plot_kwargs,
+ {'color' : PLT_COLORS['periodic'],
+ 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
for peak in fm.gaussian_params_:
@@ -296,7 +289,7 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
if plt_log:
bw_freqs = np.log10(bw_freqs)
- ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **kwargs)
+ ax.plot(bw_freqs, [peak_top-(0.5*peak[1]), peak_top-(0.5*peak[1])], **plot_kwargs)
# Collect all the possible `add_peak_*` functions together
diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py
index 0bd73e03d..55d147a99 100644
--- a/fooof/plts/periodic.py
+++ b/fooof/plts/periodic.py
@@ -8,8 +8,8 @@
from fooof.core.funcs import gaussian_function
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_param_plot, style_plot
-from fooof.plts.utils import check_ax, recursive_plot, check_plot_kwargs, savefig
+from fooof.plts.style import style_param_plot, style_plot
+from fooof.plts.utils import check_ax, recursive_plot, savefig
plt = safe_import('.pyplot', 'matplotlib')
@@ -19,8 +19,7 @@
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
- ax=None, plot_style=style_param_plot, **plot_kwargs):
+def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot peak parameters as dots representing center frequency, power and bandwidth.
Parameters
@@ -35,18 +34,15 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_param_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to pass into the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['params']))
# If there is a list, use recurse function to loop across arrays of data and plot them
if isinstance(peaks, list):
- recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels,
- plot_style=plot_style, **plot_kwargs)
+ recursive_plot(peaks, plot_peak_params, ax, colors=colors, labels=labels)
# Otherwise, plot the array of data
else:
@@ -55,9 +51,20 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
xs, ys = peaks[:, 0], peaks[:, 1]
sizes = peaks[:, 2] * plot_kwargs.pop('s', 150)
- # Create the plot
- plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
- ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
+ colors = 'C0' if colors is None else colors
+ colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors)
+ labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels)
+
+ for xi, yi, size in zip(xs, ys, sizes):
+
+ # Prevent duplicate labels when recursively plotting
+ _, cur_labels = plt.gca().get_legend_handles_labels()
+
+ label = next(labels)
+ if label not in cur_labels:
+ ax.scatter(xi, yi, s=size, color=next(colors), label=label, alpha=0.7)
+ else:
+ ax.scatter(xi, yi, s=size, color=next(colors), alpha=0.7)
# Add axis labels
ax.set_xlabel('Center Frequency')
@@ -68,13 +75,12 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None,
ax.set_xlim(freq_range)
ax.set_ylim([0, ax.get_ylim()[1]])
- check_n_style(plot_style, ax)
+ style_param_plot(ax)
@savefig
@style_plot
-def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
- ax=None, plot_style=style_param_plot, **plot_kwargs):
+def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot reconstructions of model peak fits.
Parameters
@@ -90,8 +96,6 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
Label(s) for plotted data, to be added in a legend.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_param_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
Keyword arguments to pass into the plot call.
"""
@@ -105,8 +109,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
recursive_plot(peaks, plot_function=plot_peak_fits, ax=ax,
freq_range=tuple(freq_range) if freq_range else freq_range,
- colors=colors, labels=labels,
- plot_style=plot_style, **plot_kwargs)
+ colors=colors, labels=labels, **plot_kwargs)
else:
@@ -132,8 +135,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
# Create & plot the peak model from parameters
peak_vals = gaussian_function(freqs, *peak_params)
- plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.35, 'linewidth' : 1.25})
- ax.plot(freqs, peak_vals, color=colors, **plot_kwargs)
+ ax.plot(freqs, peak_vals, color=colors, alpha=0.35, linewidth=1.25)
# Collect a running average average peaks
avg_vals = np.nansum(np.vstack([avg_vals, peak_vals]), axis=0)
@@ -141,7 +143,7 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
# Plot the average across all components
avg = avg_vals / peaks.shape[0]
avg_color = 'black' if not colors else colors
- ax.plot(freqs, avg, color=avg_color, linewidth=plot_kwargs.get('linewidth')*3, label=labels)
+ ax.plot(freqs, avg, color=avg_color, linewidth=3.75, label=labels)
# Add axis labels
ax.set_xlabel('Frequency')
@@ -152,4 +154,4 @@ def plot_peak_fits(peaks, freq_range=None, colors=None, labels=None,
ax.set_ylim([0, ax.get_ylim()[1]])
# Apply plot style
- check_n_style(plot_style, ax)
+ style_param_plot(ax)
diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py
index b695ab40c..a68043af4 100644
--- a/fooof/plts/settings.py
+++ b/fooof/plts/settings.py
@@ -32,7 +32,7 @@
# Line style arguments are those that can be defined on a line object
LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
- 'marker', 'ms', 'markersize', 'color']
+ 'marker', 'ms', 'markersize']
# Custom style arguments are those that are custom-handled by the plot style function
CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
diff --git a/fooof/plts/spectra.py b/fooof/plts/spectra.py
index 24e207770..9c9f4ab53 100644
--- a/fooof/plts/spectra.py
+++ b/fooof/plts/spectra.py
@@ -5,14 +5,14 @@
This file contains functions for plotting power spectra, that take in data directly.
"""
-from itertools import repeat
+from itertools import repeat, cycle
import numpy as np
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
-from fooof.plts.style import check_n_style, style_spectrum_plot, style_plot
-from fooof.plts.utils import check_ax, add_shades, check_plot_kwargs, savefig
+from fooof.plts.style import style_spectrum_plot, style_plot
+from fooof.plts.utils import check_ax, add_shades, savefig
plt = safe_import('.pyplot', 'matplotlib')
@@ -23,7 +23,7 @@
@style_plot
@check_dependency(plt, 'matplotlib')
def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False,
- ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
+ color=None, label=None, ax=None, **plot_kwargs):
"""Plot a power spectrum.
Parameters
@@ -36,12 +36,14 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False,
Whether to plot the frequency axis in log spacing.
log_powers : bool, optional, default: False
Whether to plot the power axis in log spacing.
+ label : str, optional, default: None
+ Legend label for the spectrum.
+ color : str, optional, default: None
+ Line color of the spectrum.
ax : matplotlib.Axes, optional
Figure axis upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to be passed to the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
@@ -51,17 +53,16 @@ def plot_spectrum(freqs, power_spectrum, log_freqs=False, log_powers=False,
plt_powers = np.log10(power_spectrum) if log_powers else power_spectrum
# Create the plot
- plot_kwargs = check_plot_kwargs(plot_kwargs, {'linewidth' : 2.0})
- ax.plot(plt_freqs, plt_powers, **plot_kwargs)
+ ax.plot(plt_freqs, plt_powers, linewidth=2.0, color=color, label=label)
- check_n_style(plot_style, ax, log_freqs, log_powers)
+ style_spectrum_plot(ax, log_freqs, log_powers)
@savefig
@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels=None,
- ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
+def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False,
+ colors=None, labels=None, ax=None, **plot_kwargs):
"""Plot multiple power spectra on the same plot.
Parameters
@@ -74,34 +75,35 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, labels
Whether to plot the frequency axis in log spacing.
log_powers : bool, optional, default: False
Whether to plot the power axis in log spacing.
- labels : list of str, optional
- Legend labels, for each power spectrum.
+ labels : list of str, optional, default: None
+ Legend labels for the spectra.
+ colors : list of str, optional, default: None
+ Line colors of the spectra.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to be passed to the plot call.
+ Keyword arguments to pass into the ``style_plot``.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
# Make inputs iterable if need to be passed multiple times to plot each spectrum
freqs = repeat(freqs) if isinstance(freqs, np.ndarray) and freqs.ndim == 1 else freqs
- labels = repeat(labels) if not isinstance(labels, list) else labels
- for freq, power_spectrum, label in zip(freqs, power_spectra, labels):
- plot_spectrum(freq, power_spectrum, log_freqs, log_powers, label=label,
- plot_style=None, ax=ax, **plot_kwargs)
+ colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)
+ labels = repeat(labels) if not isinstance(labels, list) else cycle(labels)
- check_n_style(plot_style, ax, log_freqs, log_powers)
+ for freq, power_spectrum, color, label in zip(freqs, power_spectra, colors, labels):
+ plot_spectrum(freq, power_spectrum, log_freqs, log_powers,
+ color=color, label=label, ax=ax)
+
+ style_spectrum_plot(ax, log_freqs, log_powers)
@savefig
-@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_center=False,
- ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
+def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r',
+ add_center=False, ax=None, **plot_kwargs):
"""Plot a power spectrum with a shaded frequency region (or regions).
Parameters
@@ -118,28 +120,24 @@ def plot_spectrum_shading(freqs, power_spectrum, shades, shade_colors='r', add_c
Whether to add a line at the center point of the shaded regions.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to be passed to the plot call.
+ Keyword arguments to pass into :func:`~.plot_spectrum`.
"""
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
- plot_spectrum(freqs, power_spectrum, plot_style=None, ax=ax, **plot_kwargs)
+ plot_spectrum(freqs, power_spectrum, ax=ax, **plot_kwargs)
add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False))
- check_n_style(plot_style, ax,
- plot_kwargs.get('log_freqs', False),
- plot_kwargs.get('log_powers', False))
+ style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False),
+ plot_kwargs.get('log_powers', False))
@savefig
-@style_plot
@check_dependency(plt, 'matplotlib')
-def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_center=False,
- ax=None, plot_style=style_spectrum_plot, **plot_kwargs):
+def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
+ add_center=False, ax=None, **plot_kwargs):
"""Plot a group of power spectra with a shaded frequency region (or regions).
Parameters
@@ -156,10 +154,8 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen
Whether to add a line at the center point of the shaded regions.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
- plot_style : callable, optional, default: style_spectrum_plot
- A function to call to apply styling & aesthetics to the plot.
**plot_kwargs
- Keyword arguments to be passed to `plot_spectra` or to the plot call.
+ Keyword arguments to pass into :func:`~.plot_spectra`.
Notes
-----
@@ -170,10 +166,9 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r', add_cen
ax = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['spectral']))
- plot_spectra(freqs, power_spectra, ax=ax, plot_style=None, **plot_kwargs)
+ plot_spectra(freqs, power_spectra, ax=ax, **plot_kwargs)
add_shades(ax, shades, shade_colors, add_center, plot_kwargs.get('log_freqs', False))
- check_n_style(plot_style, ax,
- plot_kwargs.get('log_freqs', False),
- plot_kwargs.get('log_powers', False))
+ style_spectrum_plot(ax, plot_kwargs.get('log_freqs', False),
+ plot_kwargs.get('log_powers', False))
diff --git a/fooof/plts/style.py b/fooof/plts/style.py
index f1353ba85..3596e71b5 100644
--- a/fooof/plts/style.py
+++ b/fooof/plts/style.py
@@ -13,20 +13,7 @@
###################################################################################################
###################################################################################################
-def check_n_style(style_func, *args):
- """"Check if a style function has been passed, and apply it to a plot if so.
-
- Parameters
- ----------
- style_func : callable or None
- Function to apply styling to a plot axis.
- *args
- Inputs to the style plot.
- """
-
- if style_func:
- style_func(*args)
-
+# Default plot styling
def style_spectrum_plot(ax, log_freqs, log_powers):
"""Apply style and aesthetics to a power spectrum plot.
@@ -86,8 +73,7 @@ def style_param_plot(ax):
for handle in legend.legendHandles:
handle._sizes = [100]
-
-# Additional plot style customization
+# Custom plot styling
def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
"""Apply axis plot style.
@@ -107,50 +93,34 @@ def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
ax.set(**plot_kwargs)
-def apply_plot_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
- """Apply line/scatter/histogram plot style.
+def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
+ """Apply line plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
- A list of arguments to be sub-selected from `kwargs` and applied as styling.
+ A list of arguments to be sub-selected from `kwargs` and applied as line styling.
**kwargs
- Keyword arguments that define style to apply.
+ Keyword arguments that define line style to apply.
"""
- # Get the plot object related styling arguments from the keyword arguments
- style_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
-
- # For line plots
- if len(ax.lines) > 0:
- plot_objs = ax.lines
-
- # For scatter plots
- elif len(ax.collections) > 0:
- plot_objs = ax.collections
-
- # For histograms
- elif len(ax.patches) > 0:
- plot_objs = ax.patches
+ # Check how many lines are from the current plot call, to apply style to
+ # If available, this indicates the apply styling to the last 'n' lines
+ n_lines_apply = kwargs.pop('n_lines_apply', 0)
- # There is no styling to apply
- else:
- return
+ # Get the line related styling arguments from the keyword arguments
+ line_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
- plot_objs = [plot_objs] if not isinstance(plot_objs, list) else plot_objs
+ # Apply any provided line style arguments
+ for style, value in line_kwargs.items():
- # Apply any provided plot object style arguments
- for style, value in style_kwargs.items():
-
- # Values should be either a single value, for all plot objects, or a list, one value per
- # object. This line checks type, and makes a cycle-able / loop-able object out of the values
+ # Values should be either a single value, for all lines, or a list, of a value per line
+ # This line checks type, and makes a cycle-able / loop-able object out of the values
values = cycle([value] if isinstance(value, (int, float, str)) else value)
-
- # For line plots
- for plot_obj in plot_objs:
- plot_obj.set(**{style : next(values)})
+ for line in ax.lines[-n_lines_apply:]:
+ line.set(**{style : next(values)})
def apply_custom_style(ax, **kwargs):
@@ -182,12 +152,10 @@ def apply_custom_style(ax, **kwargs):
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))
- with warnings.catch_warnings():
- warnings.simplefilter('ignore')
- plt.tight_layout()
+ plt.tight_layout()
-def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_plot_style,
+def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
custom_styler=apply_custom_style, **kwargs):
"""Apply plot style to a figure axis.
@@ -242,7 +210,7 @@ def style_plot(func, *args, **kwargs):
def decorated(*args, **kwargs):
# Grab a custom style function, if provided, and grab any provided style arguments
- style_func = kwargs.pop('apply_style', apply_style)
+ style_func = kwargs.pop('plot_style', apply_style)
style_args = kwargs.pop('style_args', STYLE_ARGS)
style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs}
@@ -259,13 +227,7 @@ def decorated(*args, **kwargs):
n_lines_apply = len(cur_ax.lines) - n_lines_pre
style_kwargs['n_lines_apply'] = n_lines_apply
- # Determine if styling should be applied to all axes
- all_axes = kwargs.pop('all_axes', False)
- cur_ax = plt.gcf().get_axes() if all_axes is True else cur_ax
- cur_ax = [cur_ax] if not isinstance(cur_ax, list) else cur_ax
-
# Apply the styling function
- for ax in cur_ax:
- style_func(ax, **style_kwargs)
+ style_func(cur_ax, **style_kwargs)
return decorated
From 9420e623739e8853583af838c19f0f0abfb9ec45 Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Wed, 5 Aug 2020 17:48:40 -0700
Subject: [PATCH 4/8] style testing updated
---
fooof/tests/plts/test_styles.py | 30 ++++--------------------------
1 file changed, 4 insertions(+), 26 deletions(-)
diff --git a/fooof/tests/plts/test_styles.py b/fooof/tests/plts/test_styles.py
index caf09e3eb..1ddca9ffc 100644
--- a/fooof/tests/plts/test_styles.py
+++ b/fooof/tests/plts/test_styles.py
@@ -6,16 +6,6 @@
###################################################################################################
###################################################################################################
-def test_check_n_style(skip_if_no_mpl):
-
- # Check can pass None and do nothing
- check_n_style(None)
- assert True
-
- # Check can pass a callable
- def checker(*args):
- return True
- check_n_style(checker)
def test_style_spectrum_plot(skip_if_no_mpl):
@@ -45,14 +35,14 @@ def test_apply_axis_style():
assert ax.get_ylabel() == ylabel
-def test_apply_plot_style():
+def test_apply_line_style():
# Check applying style to one line
_, ax = plt.subplots()
ax.plot([1, 2], [3, 4])
lw = 4
- apply_plot_style(ax, lw=lw)
+ apply_line_style(ax, lw=lw)
assert ax.get_lines()[0].get_lw() == lw
@@ -61,23 +51,11 @@ def test_apply_plot_style():
ax.plot([1, 2], [[3, 4], [5, 6]])
alphas = [0.5, 0.75]
- apply_plot_style(ax, alpha=alphas)
+ apply_line_style(ax, alpha=alphas)
for line, alpha in zip(ax.get_lines(), alphas):
assert line.get_alpha() == alpha
- # Check applying style to a scatter plot
- _, ax = plt.subplots()
- ax.scatter([1, 2], [2, 4])
- apply_plot_style(ax, alpha=0.123)
- assert ax.collections[0]._alpha == 0.123
-
- # Check applying style to a histogram
- _, ax = plt.subplots()
- ax.hist([1, 2, 3])
- apply_plot_style(ax, alpha=0.123)
- assert ax.patches[0]._alpha == 0.123
-
def test_apply_custom_style():
@@ -124,4 +102,4 @@ def my_plot_style(ax, **kwargs):
example_plot(title=title, lw=lw)
# Test with passing in own plot_style function
- example_plot(apply_style=my_plot_style)
+ example_plot(plot_style=my_plot_style)
From e249ab56ef2aa758baaed3dec1286c22742c09d2 Mon Sep 17 00:00:00 2001
From: ryanhammonds
Date: Mon, 17 Aug 2020 12:19:56 -0700
Subject: [PATCH 5/8] save out test plots
---
fooof/tests/conftest.py | 4 +++-
fooof/tests/plts/test_annotate.py | 7 +++++--
fooof/tests/plts/test_aperiodic.py | 7 +++++--
fooof/tests/plts/test_error.py | 4 +++-
fooof/tests/plts/test_fg.py | 13 +++++++++----
fooof/tests/plts/test_fm.py | 14 +++++++-------
fooof/tests/plts/test_periodic.py | 7 +++++--
fooof/tests/plts/test_spectra.py | 27 +++++++++++++++++++--------
fooof/tests/settings.py | 1 +
9 files changed, 57 insertions(+), 27 deletions(-)
diff --git a/fooof/tests/conftest.py b/fooof/tests/conftest.py
index 65a8b00a9..b943456ce 100644
--- a/fooof/tests/conftest.py
+++ b/fooof/tests/conftest.py
@@ -9,7 +9,8 @@
from fooof.core.modutils import safe_import
from fooof.tests.tutils import get_tfm, get_tfg, get_tbands
-from fooof.tests.settings import BASE_TEST_FILE_PATH, TEST_DATA_PATH, TEST_REPORTS_PATH
+from fooof.tests.settings import (BASE_TEST_FILE_PATH, TEST_DATA_PATH,
+ TEST_REPORTS_PATH, TEST_PLOTS_PATH)
plt = safe_import('.pyplot', 'matplotlib')
@@ -33,6 +34,7 @@ def check_dir():
os.mkdir(BASE_TEST_FILE_PATH)
os.mkdir(TEST_DATA_PATH)
os.mkdir(TEST_REPORTS_PATH)
+ os.mkdir(TEST_PLOTS_PATH)
@pytest.fixture(scope='session')
def tfm():
diff --git a/fooof/tests/plts/test_annotate.py b/fooof/tests/plts/test_annotate.py
index c096612cc..84f3848df 100644
--- a/fooof/tests/plts/test_annotate.py
+++ b/fooof/tests/plts/test_annotate.py
@@ -3,6 +3,7 @@
import numpy as np
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.annotate import *
@@ -12,11 +13,13 @@
@plot_test
def test_plot_annotated_peak_search(tfm, skip_if_no_mpl):
- plot_annotated_peak_search(tfm)
+ plot_annotated_peak_search(tfm, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_annotated_peak_search.png')
@plot_test
def test_plot_annotated_model(tfm, skip_if_no_mpl):
# Make sure model has been fit & then plot annotated model
tfm.fit()
- plot_annotated_model(tfm)
+ plot_annotated_model(tfm, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_annotated_model.png')
diff --git a/fooof/tests/plts/test_aperiodic.py b/fooof/tests/plts/test_aperiodic.py
index 167907846..477b22053 100644
--- a/fooof/tests/plts/test_aperiodic.py
+++ b/fooof/tests/plts/test_aperiodic.py
@@ -3,6 +3,7 @@
import numpy as np
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.aperiodic import *
@@ -21,7 +22,8 @@ def test_plot_aperiodic_params(skip_if_no_mpl):
# Test for 'knee' mode: offset, knee exponent
aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]])
- plot_aperiodic_params(aps)
+ plot_aperiodic_params(aps, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_aperiodic_params.png')
@plot_test
def test_plot_aperiodic_fits(skip_if_no_mpl):
@@ -36,4 +38,5 @@ def test_plot_aperiodic_fits(skip_if_no_mpl):
# Test for 'knee' mode: offset, knee exponent
aps = np.array([[1, 100, 1], [0.5, 150, 0.5], [2, 200, 2]])
- plot_aperiodic_fits(aps, [1, 50])
+ plot_aperiodic_fits(aps, [1, 50], save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_aperiodic_fits.png')
diff --git a/fooof/tests/plts/test_error.py b/fooof/tests/plts/test_error.py
index 2bffbc7a2..3e8b817bd 100644
--- a/fooof/tests/plts/test_error.py
+++ b/fooof/tests/plts/test_error.py
@@ -3,6 +3,7 @@
import numpy as np
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.error import *
@@ -15,4 +16,5 @@ def test_plot_spectral_error(skip_if_no_mpl):
fs = np.arange(3, 41, 1)
errs = np.ones(len(fs))
- plot_spectral_error(fs, errs)
+ plot_spectral_error(fs, errs, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_spectral_error.png')
diff --git a/fooof/tests/plts/test_fg.py b/fooof/tests/plts/test_fg.py
index 8841c9fba..24103a916 100644
--- a/fooof/tests/plts/test_fg.py
+++ b/fooof/tests/plts/test_fg.py
@@ -6,6 +6,7 @@
from fooof.core.errors import NoModelError
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.fg import *
@@ -15,7 +16,8 @@
@plot_test
def test_plot_fg(tfg, skip_if_no_mpl):
- plot_fg(tfg)
+ plot_fg(tfg, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_fg.png')
# Test error if no data available to plot
tfg = FOOOFGroup()
@@ -25,14 +27,17 @@ def test_plot_fg(tfg, skip_if_no_mpl):
@plot_test
def test_plot_fg_ap(tfg, skip_if_no_mpl):
- plot_fg_ap(tfg)
+ plot_fg_ap(tfg, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_fg_ap.png')
@plot_test
def test_plot_fg_gf(tfg, skip_if_no_mpl):
- plot_fg_gf(tfg)
+ plot_fg_gf(tfg, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_fg_gf.png')
@plot_test
def test_plot_fg_peak_cens(tfg, skip_if_no_mpl):
- plot_fg_peak_cens(tfg)
+ plot_fg_peak_cens(tfg, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_fg_peak_cens.png')
diff --git a/fooof/tests/plts/test_fm.py b/fooof/tests/plts/test_fm.py
index 7fce65ad5..6d7a0f02f 100644
--- a/fooof/tests/plts/test_fm.py
+++ b/fooof/tests/plts/test_fm.py
@@ -1,6 +1,7 @@
"""Tests for fooof.plts.fm."""
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.fm import *
@@ -13,7 +14,8 @@ def test_plot_fm(tfm, skip_if_no_mpl):
# Make sure model has been fit
tfm.fit()
- plot_fm(tfm)
+ plot_fm(tfm, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_fm.png')
@plot_test
def test_plot_fm_add_peaks(tfm, skip_if_no_mpl):
@@ -22,9 +24,7 @@ def test_plot_fm_add_peaks(tfm, skip_if_no_mpl):
tfm.fit()
# Test run each of the add peak approaches
- for add_peak in ['shade', 'dot', 'outline', 'line']:
- plot_fm(tfm, plot_peaks=add_peak)
-
- # Test run some combinations
- for add_peak in ['shade-dot', 'outline-line']:
- plot_fm(tfm, plot_peaks=add_peak)
+ for add_peak in ['shade', 'dot', 'outline', 'line', 'shade-dot', 'outline-line']:
+ file_name = 'test_plot_fm_add_peaks_' + add_peak + '.png'
+ plot_fm(tfm, plot_peaks=add_peak, save_fig=True,
+ file_path=TEST_PLOTS_PATH, file_name=file_name)
diff --git a/fooof/tests/plts/test_periodic.py b/fooof/tests/plts/test_periodic.py
index 647c967bd..83e77daf9 100644
--- a/fooof/tests/plts/test_periodic.py
+++ b/fooof/tests/plts/test_periodic.py
@@ -3,6 +3,7 @@
import numpy as np
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.periodic import *
@@ -18,7 +19,8 @@ def test_plot_peak_params(skip_if_no_mpl):
plot_peak_params(peaks)
# Test with multiple set of params
- plot_peak_params([peaks, peaks])
+ plot_peak_params([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_peak_params.png')
@plot_test
def test_plot_peak_fits(skip_if_no_mpl):
@@ -29,4 +31,5 @@ def test_plot_peak_fits(skip_if_no_mpl):
plot_peak_fits(peaks)
# Test with multiple set of params
- plot_peak_fits([peaks, peaks])
+ plot_peak_fits([peaks, peaks], save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_peak_fits.png')
diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py
index 9abb95b2a..03b137975 100644
--- a/fooof/tests/plts/test_spectra.py
+++ b/fooof/tests/plts/test_spectra.py
@@ -3,6 +3,7 @@
import numpy as np
from fooof.tests.tutils import plot_test
+from fooof.tests.settings import TEST_PLOTS_PATH
from fooof.plts.spectra import *
@@ -15,36 +16,46 @@ def test_plot_spectrum(tfm, skip_if_no_mpl):
plot_spectrum(tfm.freqs, tfm.power_spectrum)
# Test with logging both axes
- plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True)
+ plot_spectrum(tfm.freqs, tfm.power_spectrum, True, True, save_fig=True,
+ file_path=TEST_PLOTS_PATH, file_name='test_plot_spectrum.png')
@plot_test
def test_plot_spectra(tfg, skip_if_no_mpl):
# Test with 1d inputs - 1d freq array and list of 1d power spectra
- plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]])
+ plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
+ save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_1d.png')
# Test with multiple freq inputs - list of 1d freq array and list of 1d power spectra
- plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]])
+ plot_spectra([tfg.freqs, tfg.freqs], [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
+ save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_spectra_list_of_1d.png')
# Test with 2d array inputs
plot_spectra(np.vstack([tfg.freqs, tfg.freqs]),
- np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]))
+ np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]),
+ save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png')
# Test with labels
- plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'])
+ plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'],
+ save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_labels.png')
@plot_test
def test_plot_spectrum_shading(tfm, skip_if_no_mpl):
- plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True)
+ plot_spectrum_shading(tfm.freqs, tfm.power_spectrum, shades=[8, 12], add_center=True,
+ save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_spectrum_shading.png')
@plot_test
def test_plot_spectra_shading(tfg, skip_if_no_mpl):
plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
- shades=[8, 12], add_center=True)
+ shades=[8, 12], add_center=True, save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_spectra_shading.png')
# Test with **kwargs that pass into plot_spectra
plot_spectra_shading(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]],
shades=[8, 12], add_center=True, log_freqs=True, log_powers=True,
- labels=['A', 'B'])
+ labels=['A', 'B'], save_fig=True, file_path=TEST_PLOTS_PATH,
+ file_name='test_plot_spectra_shading_kwargs.png')
diff --git a/fooof/tests/settings.py b/fooof/tests/settings.py
index 9beae1afa..856c532f6 100644
--- a/fooof/tests/settings.py
+++ b/fooof/tests/settings.py
@@ -10,3 +10,4 @@
BASE_TEST_FILE_PATH = pkg.resource_filename(__name__, 'test_files')
TEST_DATA_PATH = os.path.join(BASE_TEST_FILE_PATH, 'data')
TEST_REPORTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'reports')
+TEST_PLOTS_PATH = os.path.join(BASE_TEST_FILE_PATH, 'plots')
From b4b53ff576d78c751721c555537a8b640fb666e3 Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Sun, 11 Apr 2021 21:46:34 -0400
Subject: [PATCH 6/8] plot lints
---
fooof/objs/fit.py | 7 +++----
fooof/objs/group.py | 4 ++--
fooof/plts/annotate.py | 3 +--
fooof/plts/fg.py | 1 -
fooof/plts/fm.py | 28 ++++++++++++----------------
fooof/plts/settings.py | 2 +-
fooof/plts/style.py | 4 ----
fooof/tests/plts/test_spectra.py | 2 +-
8 files changed, 20 insertions(+), 31 deletions(-)
diff --git a/fooof/objs/fit.py b/fooof/objs/fit.py
index 415fc895f..7e5333d50 100644
--- a/fooof/objs/fit.py
+++ b/fooof/objs/fit.py
@@ -68,7 +68,6 @@
gen_issue_str, gen_width_warning_str)
from fooof.plts.fm import plot_fm
-from fooof.plts.style import style_spectrum_plot
from fooof.utils.data import trim_spectrum
from fooof.utils.params import compute_gauss_std
from fooof.data import FOOOFResults, FOOOFSettings, FOOOFMetaData
@@ -618,12 +617,12 @@ def get_results(self):
def plot(self, plot_peaks=None, plot_aperiodic=True, plt_log=False,
add_legend=True, save_fig=False, file_name=None, file_path=None,
ax=None, data_kwargs=None, model_kwargs=None,
- aperiodic_kwargs=None, peak_kwargs=None, **kwargs):
+ aperiodic_kwargs=None, peak_kwargs=None, **plot_kwargs):
plot_fm(self, plot_peaks=plot_peaks, plot_aperiodic=plot_aperiodic, plt_log=plt_log,
add_legend=add_legend, save_fig=save_fig, file_name=file_name,
- file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
- aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **kwargs)
+ file_path=file_path, ax=ax, data_kwargs=data_kwargs, model_kwargs=model_kwargs,
+ aperiodic_kwargs=aperiodic_kwargs, peak_kwargs=peak_kwargs, **plot_kwargs)
@copy_doc_func_to_method(save_report_fm)
diff --git a/fooof/objs/group.py b/fooof/objs/group.py
index ed42bb0e2..064a4bb8c 100644
--- a/fooof/objs/group.py
+++ b/fooof/objs/group.py
@@ -398,9 +398,9 @@ def get_params(self, name, col=None):
@copy_doc_func_to_method(plot_fg)
- def plot(self, save_fig=False, file_name=None, file_path=None, **kwargs):
+ def plot(self, save_fig=False, file_name=None, file_path=None, **plot_kwargs):
- plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **kwargs)
+ plot_fg(self, save_fig=save_fig, file_name=file_name, file_path=file_path, **plot_kwargs)
@copy_doc_func_to_method(save_report_fg)
diff --git a/fooof/plts/annotate.py b/fooof/plts/annotate.py
index d67befb24..093306d42 100644
--- a/fooof/plts/annotate.py
+++ b/fooof/plts/annotate.py
@@ -62,7 +62,7 @@ def plot_annotated_peak_search(fm):
if ind < fm.n_peaks_:
gauss = gaussian_function(fm.freqs, *fm.gaussian_params_[ind, :])
- plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
+ plot_spectrum(fm.freqs, gauss, ax=ax, label='Gaussian Fit',
color=PLT_COLORS['periodic'], linestyle=':', linewidth=3.0)
flatspec = flatspec - gauss
@@ -85,7 +85,6 @@ def plot_annotated_model(fm, plt_log=False, annotate_peaks=True,
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
-
Raises
------
NoModelError
diff --git a/fooof/plts/fg.py b/fooof/plts/fg.py
index 342065c64..d2f2cc476 100644
--- a/fooof/plts/fg.py
+++ b/fooof/plts/fg.py
@@ -5,7 +5,6 @@
This file contains plotting functions that take as input a FOOOFGroup object.
"""
-from fooof.core.io import fname, fpath
from fooof.core.errors import NoModelError
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py
index 575b8b97b..40ff683b6 100644
--- a/fooof/plts/fm.py
+++ b/fooof/plts/fm.py
@@ -76,8 +76,8 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
# Add the full model fit, and components (if requested)
if fm.has_model:
- model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
- 'label' : 'Full Model Fit' if add_legend else None}
+ model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
+ 'label' : 'Full Model Fit' if add_legend else None}
plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs)
# Plot the aperiodic component of the model fit
@@ -154,8 +154,8 @@ def _add_peaks_shade(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the ``fill_between``.
"""
- plot_kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25})
+ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.25}
+ plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in fm.get_params('gaussian_params'):
@@ -180,9 +180,8 @@ def _add_peaks_dot(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- plot_kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
+ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}
+ plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in fm.get_params('peak_params'):
@@ -211,9 +210,8 @@ def _add_peaks_outline(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- plot_kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.7, 'lw' : 1.5})
+ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.5}
+ plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in fm.get_params('gaussian_params'):
@@ -244,9 +242,8 @@ def _add_peaks_line(fm, plt_log, ax, **plot_kwargs):
Keyword arguments to pass into the plot call.
"""
- plot_kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10})
+ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.7, 'lw' : 1.4, 'ms' : 10}
+ plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
ylims = ax.get_ylim()
for peak in fm.get_params('peak_params'):
@@ -276,9 +273,8 @@ def _add_peaks_width(fm, plt_log, ax, **plot_kwargs):
the peak, though what is literally plotted is the full-width half-max.
"""
- plot_kwargs = check_plot_kwargs(plot_kwargs,
- {'color' : PLT_COLORS['periodic'],
- 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6})
+ defaults = {'color' : PLT_COLORS['periodic'], 'alpha' : 0.6, 'lw' : 2.5, 'ms' : 6}
+ plot_kwargs = check_plot_kwargs(plot_kwargs, defaults)
for peak in fm.gaussian_params_:
diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py
index a68043af4..ff579d2e3 100644
--- a/fooof/plts/settings.py
+++ b/fooof/plts/settings.py
@@ -40,7 +40,7 @@
STYLERS = ['axis_styler', 'line_styler', 'custom_styler']
STYLE_ARGS = AXIS_STYLE_ARGS + LINE_STYLE_ARGS + CUSTOM_STYLE_ARGS + STYLERS
-## Define default values for aesthetic
+## Define default values for plot aesthetics
# These are all custom style arguments
TITLE_FONTSIZE = 20
LABEL_SIZE = 16
diff --git a/fooof/plts/style.py b/fooof/plts/style.py
index 3596e71b5..0c3acb484 100644
--- a/fooof/plts/style.py
+++ b/fooof/plts/style.py
@@ -2,7 +2,6 @@
from itertools import cycle
from functools import wraps
-import warnings
import matplotlib.pyplot as plt
@@ -13,8 +12,6 @@
###################################################################################################
###################################################################################################
-# Default plot styling
-
def style_spectrum_plot(ax, log_freqs, log_powers):
"""Apply style and aesthetics to a power spectrum plot.
@@ -73,7 +70,6 @@ def style_param_plot(ax):
for handle in legend.legendHandles:
handle._sizes = [100]
-# Custom plot styling
def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
"""Apply axis plot style.
diff --git a/fooof/tests/plts/test_spectra.py b/fooof/tests/plts/test_spectra.py
index 03b137975..0b85e2d9e 100644
--- a/fooof/tests/plts/test_spectra.py
+++ b/fooof/tests/plts/test_spectra.py
@@ -34,7 +34,7 @@ def test_plot_spectra(tfg, skip_if_no_mpl):
# Test with 2d array inputs
plot_spectra(np.vstack([tfg.freqs, tfg.freqs]),
np.vstack([tfg.power_spectra[0, :], tfg.power_spectra[1, :]]),
- save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png')
+ save_fig=True, file_path=TEST_PLOTS_PATH, file_name='test_plot_spectra_2d.png')
# Test with labels
plot_spectra(tfg.freqs, [tfg.power_spectra[0, :], tfg.power_spectra[1, :]], labels=['A', 'B'],
From 96ac42a64c5075fce6945b88d7a47872770d8bae Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Sun, 11 Apr 2021 23:26:36 -0400
Subject: [PATCH 7/8] small plot updates
---
fooof/plts/aperiodic.py | 19 ++++---------------
fooof/plts/fm.py | 17 ++++++++++-------
fooof/plts/periodic.py | 19 ++++---------------
3 files changed, 18 insertions(+), 37 deletions(-)
diff --git a/fooof/plts/aperiodic.py b/fooof/plts/aperiodic.py
index c3636ce22..905b95145 100644
--- a/fooof/plts/aperiodic.py
+++ b/fooof/plts/aperiodic.py
@@ -9,7 +9,7 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
from fooof.plts.style import style_param_plot, style_plot
-from fooof.plts.utils import check_ax, recursive_plot, savefig
+from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
plt = safe_import('.pyplot', 'matplotlib')
@@ -47,20 +47,9 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
xs, ys = aps[:, 0], aps[:, -1]
sizes = plot_kwargs.pop('s', 150)
- colors = 'C0' if colors is None else colors
- colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors)
- labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels)
-
- for xi, yi in zip(xs, ys):
-
- # Prevent duplicate labels when recursively plotting
- _, cur_labels = plt.gca().get_legend_handles_labels()
-
- label = next(labels)
- if label not in cur_labels:
- ax.scatter(xi, yi, s=sizes, color=next(colors), label=label, alpha=0.7)
- else:
- ax.scatter(xi, yi, s=sizes, color=next(colors), alpha=0.7)
+ # Create the plot
+ plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
+ ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
# Add axis labels
ax.set_xlabel('Offset')
diff --git a/fooof/plts/fm.py b/fooof/plts/fm.py
index 40ff683b6..6674848a3 100644
--- a/fooof/plts/fm.py
+++ b/fooof/plts/fm.py
@@ -70,21 +70,24 @@ def plot_fm(fm, plot_peaks=None, plot_aperiodic=True, plt_log=False, add_legend=
# Plot the data, if available
if fm.has_data:
- data_kwargs = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
- 'label' : 'Original Spectrum' if add_legend else None}
+ data_defaults = {'color' : PLT_COLORS['data'], 'linewidth' : 2.0,
+ 'label' : 'Original Spectrum' if add_legend else None}
+ data_kwargs = check_plot_kwargs(data_kwargs, data_defaults)
plot_spectrum(fm.freqs, fm.power_spectrum, log_freqs, log_powers, ax=ax, **data_kwargs)
# Add the full model fit, and components (if requested)
if fm.has_model:
- model_kwargs = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
- 'label' : 'Full Model Fit' if add_legend else None}
+ model_defaults = {'color' : PLT_COLORS['model'], 'linewidth' : 3.0, 'alpha' : 0.5,
+ 'label' : 'Full Model Fit' if add_legend else None}
+ model_kwargs = check_plot_kwargs(model_kwargs, model_defaults)
plot_spectrum(fm.freqs, fm.fooofed_spectrum_, log_freqs, log_powers, ax=ax, **model_kwargs)
# Plot the aperiodic component of the model fit
if plot_aperiodic:
- aperiodic_kwargs = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0,
- 'alpha' : 0.5, 'linestyle' : 'dashed',
- 'label' : 'Aperiodic Fit' if add_legend else None}
+ aperiodic_defaults = {'color' : PLT_COLORS['aperiodic'], 'linewidth' : 3.0,
+ 'alpha' : 0.5, 'linestyle' : 'dashed',
+ 'label' : 'Aperiodic Fit' if add_legend else None}
+ aperiodic_kwargs = check_plot_kwargs(aperiodic_kwargs, aperiodic_defaults)
plot_spectrum(fm.freqs, fm._ap_fit, log_freqs, log_powers, ax=ax, **aperiodic_kwargs)
# Plot the periodic components of the model fit
diff --git a/fooof/plts/periodic.py b/fooof/plts/periodic.py
index 55d147a99..17e66f1b9 100644
--- a/fooof/plts/periodic.py
+++ b/fooof/plts/periodic.py
@@ -9,7 +9,7 @@
from fooof.core.modutils import safe_import, check_dependency
from fooof.plts.settings import PLT_FIGSIZES
from fooof.plts.style import style_param_plot, style_plot
-from fooof.plts.utils import check_ax, recursive_plot, savefig
+from fooof.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
plt = safe_import('.pyplot', 'matplotlib')
@@ -51,20 +51,9 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,
xs, ys = peaks[:, 0], peaks[:, 1]
sizes = peaks[:, 2] * plot_kwargs.pop('s', 150)
- colors = 'C0' if colors is None else colors
- colors = cycle([colors]) if not isinstance(colors, list) else cycle(colors)
- labels = cycle([labels]) if not isinstance(labels, list) else cycle(labels)
-
- for xi, yi, size in zip(xs, ys, sizes):
-
- # Prevent duplicate labels when recursively plotting
- _, cur_labels = plt.gca().get_legend_handles_labels()
-
- label = next(labels)
- if label not in cur_labels:
- ax.scatter(xi, yi, s=size, color=next(colors), label=label, alpha=0.7)
- else:
- ax.scatter(xi, yi, s=size, color=next(colors), alpha=0.7)
+ # Create the plot
+ plot_kwargs = check_plot_kwargs(plot_kwargs, {'alpha' : 0.7})
+ ax.scatter(xs, ys, sizes, c=colors, label=labels, **plot_kwargs)
# Add axis labels
ax.set_xlabel('Center Frequency')
From f5e448eba4ba03d0eea26dd7fc861353bcc02f57 Mon Sep 17 00:00:00 2001
From: Tom Donoghue
Date: Sun, 11 Apr 2021 23:47:02 -0400
Subject: [PATCH 8/8] add collection styler
---
fooof/plts/settings.py | 3 +++
fooof/plts/style.py | 33 ++++++++++++++++++++++++++++-----
2 files changed, 31 insertions(+), 5 deletions(-)
diff --git a/fooof/plts/settings.py b/fooof/plts/settings.py
index ff579d2e3..4b7f1050e 100644
--- a/fooof/plts/settings.py
+++ b/fooof/plts/settings.py
@@ -34,6 +34,9 @@
LINE_STYLE_ARGS = ['alpha', 'lw', 'linewidth', 'ls', 'linestyle',
'marker', 'ms', 'markersize']
+# Collection style arguments are those that can be defined on a collections object
+COLLECTION_STYLE_ARGS = ['alpha', 'edgecolor']
+
# Custom style arguments are those that are custom-handled by the plot style function
CUSTOM_STYLE_ARGS = ['title_fontsize', 'label_size', 'tick_labelsize',
'legend_size', 'legend_loc']
diff --git a/fooof/plts/style.py b/fooof/plts/style.py
index 0c3acb484..dc72a1142 100644
--- a/fooof/plts/style.py
+++ b/fooof/plts/style.py
@@ -5,9 +5,9 @@
import matplotlib.pyplot as plt
-from fooof.plts.settings import AXIS_STYLE_ARGS, LINE_STYLE_ARGS, CUSTOM_STYLE_ARGS, STYLE_ARGS
-from fooof.plts.settings import (LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC,
- TICK_LABELSIZE, TITLE_FONTSIZE)
+from fooof.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS,
+ CUSTOM_STYLE_ARGS, STYLE_ARGS, LABEL_SIZE, LEGEND_SIZE,
+ LEGEND_LOC, TICK_LABELSIZE, TITLE_FONTSIZE)
###################################################################################################
###################################################################################################
@@ -119,6 +119,27 @@ def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
line.set(**{style : next(values)})
+def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs):
+ """Apply collection plot style.
+
+ Parameters
+ ----------
+ ax : matplotlib.Axes
+ Figure axes to apply style to.
+ style_args : list of str
+ A list of arguments to be sub-selected from `kwargs` and applied as collection styling.
+ **kwargs
+ Keyword arguments that define collection style to apply.
+ """
+
+ # Get the collection related styling arguments from the keyword arguments
+ collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
+
+ # Apply any provided collection style arguments
+ for collection in ax.collections:
+ collection.set(**collection_kwargs)
+
+
def apply_custom_style(ax, **kwargs):
"""Apply custom plot style.
@@ -152,14 +173,15 @@ def apply_custom_style(ax, **kwargs):
def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
- custom_styler=apply_custom_style, **kwargs):
+ collection_styler=apply_collection_style, custom_styler=apply_custom_style,
+ **kwargs):
"""Apply plot style to a figure axis.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
- axis_styler, line_styler, custom_styler : callable, optional
+ axis_styler, line_styler, collection_style, custom_styler : callable, optional
Functions to apply style to aspects of the plot.
**kwargs
Keyword arguments that define style to apply.
@@ -172,6 +194,7 @@ def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
axis_styler(ax, **kwargs)
line_styler(ax, **kwargs)
+ collection_styler(ax, **kwargs)
custom_styler(ax, **kwargs)