Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions specparam/plts/aperiodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import AP_MODES
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.settings import ITERABLES, PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
Expand All @@ -28,7 +28,7 @@ def plot_aperiodic_params(aps, colors=None, labels=None, ax=None, **plot_kwargs)
----------
aps : 2d array or list of 2d array
Aperiodic parameters. Each row is a parameter set, as [Off, Exp] or [Off, Knee, Exp].
colors : str or list of str, optional
colors : str or iterable, optional
Color(s) to plot data.
labels : list of str, optional
Label(s) for plotted data, to be added in a legend.
Expand Down Expand Up @@ -89,7 +89,7 @@ def plot_aperiodic_fits(aps, freq_range, aperiodic_mode, control_offset=False, a
Whether to control for the offset, by setting it to zero.
log_freqs : boolean, optional, default: False
Whether to plot the x-axis in log space.
colors : str or list of str, optional
colors : str or iterable, optional
Color(s) to plot data.
labels : list of str, optional
Label(s) for plotted data, to be added in a legend.
Expand Down Expand Up @@ -117,7 +117,7 @@ def plot_aperiodic_fits(aps, freq_range, aperiodic_mode, control_offset=False, a
freqs = gen_freqs(freq_range, 0.1)
plt_freqs = np.log10(freqs) if log_freqs else freqs

colors = colors[0] if isinstance(colors, list) else colors
colors = colors[0] if isinstance(colors, ITERABLES) else colors

all_ap_vals = np.zeros(shape=(len(aps), len(freqs)))
for ind, ap_params in enumerate(aps):
Expand Down
8 changes: 4 additions & 4 deletions specparam/plts/periodic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.modes.modes import check_mode_definition
from specparam.modes.definitions import PE_MODES
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.settings import ITERABLES, PLT_FIGSIZES
from specparam.plts.templates import plot_yshade
from specparam.plts.style import style_param_plot, style_plot
from specparam.plts.utils import check_ax, recursive_plot, savefig, check_plot_kwargs
Expand All @@ -30,7 +30,7 @@ def plot_peak_params(peaks, freq_range=None, colors=None, labels=None, ax=None,
Peak data. Each row is a peak, as [CF, PW, BW].
freq_range : list of [float, float] , optional
The frequency range to plot the peak parameters across, as [f_min, f_max].
colors : str or list of str, optional
colors : str or iterable, optional
Color(s) to plot data.
labels : list of str, optional
Label(s) for plotted data, to be added in a legend.
Expand Down Expand Up @@ -93,7 +93,7 @@ def plot_peak_fits(peaks, periodic_mode, freq_range=None, average='mean', shade=
plot_individual : bool, optional, default: True
Whether to plot individual component reconstructions.
If False, only the average component reconstruction is plotted.
colors : str or list of str, optional
colors : str or iterable, optional
Color(s) to plot data.
labels : list of str, optional
Label(s) for plotted data, to be added in a legend.
Expand Down Expand Up @@ -133,7 +133,7 @@ def plot_peak_fits(peaks, periodic_mode, freq_range=None, average='mean', shade=
# Create the frequency axis, which will be the plot x-axis
freqs = gen_freqs(freq_range, 0.1)

colors = colors[0] if isinstance(colors, list) else colors
colors = colors[0] if isinstance(colors, ITERABLES) else colors

all_peak_vals = np.zeros(shape=(len(peaks), len(freqs)))
for ind, peak_params in enumerate(peaks):
Expand Down
5 changes: 5 additions & 0 deletions specparam/plts/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,18 @@

from collections import OrderedDict

import numpy as np

from specparam.modutils.dependencies import safe_import

plt = safe_import('.pyplot', 'matplotlib')

###################################################################################################
###################################################################################################

# Define list of iterables to check against
ITERABLES = (list, tuple, np.ndarray)

# Define list of default plot colors
DEFAULT_COLORS = plt.rcParams['axes.prop_cycle'].by_key()['color'] if plt else None

Expand Down
12 changes: 6 additions & 6 deletions specparam/plts/spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.utils.select import dict_extract_keys
from specparam.plts.templates import plot_yshade
from specparam.plts.settings import PLT_FIGSIZES
from specparam.plts.settings import ITERABLES, PLT_FIGSIZES
from specparam.plts.style import style_spectrum_plot, style_plot
from specparam.plts.utils import check_ax, add_shades, savefig, check_plot_kwargs

Expand Down Expand Up @@ -41,7 +41,7 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r
Whether to plot the power axis in log spacing.
freq_range : list of [float, float], optional
Frequency range to plot, defined in linear space.
colors : list of str, optional, default: None
colors : str or iterable, optional, default: None
Line colors of the spectra.
labels : list of str, optional, default: None
Legend labels for the spectra.
Expand Down Expand Up @@ -69,16 +69,16 @@ def plot_spectra(freqs, power_spectra, log_freqs=False, log_powers=False, freq_r
# Set labels
labels = plot_kwargs.pop('label') \
if 'label' in plot_kwargs.keys() and labels is None else labels
labels = repeat(labels) if not isinstance(labels, list) else cycle(labels)
colors = repeat(colors) if not isinstance(colors, list) else cycle(colors)
labels = repeat(labels) if not isinstance(labels, ITERABLES) else cycle(labels)
colors = repeat(colors) if not isinstance(colors, ITERABLES) else cycle(colors)

# Plot power spectra, looping across all spectra to plot
for freqs, powers, color, label in zip(plt_freqs, plt_powers, colors, labels):

# Set plot data, logging if requested, and collect color, if absent
freqs = np.log10(freqs) if log_freqs else freqs
powers = np.log10(powers) if log_powers else powers
if color:
if color is not None:
plot_kwargs['color'] = color

ax.plot(freqs, powers, label=label, **plot_kwargs)
Expand Down Expand Up @@ -106,7 +106,7 @@ def plot_spectra_shading(freqs, power_spectra, shades, shade_colors='r',
Power values, to be plotted on the y-axis.
shades : list of [float, float] or list of list of [float, float]
Shaded region(s) to add to plot, defined as [lower_bound, upper_bound].
shade_colors : str or list of string
shade_colors : str or iterable, default: 'red'
Color(s) to plot shades.
add_center : bool, optional, default: False
Whether to add a line at the center point of the shaded regions.
Expand Down
12 changes: 6 additions & 6 deletions specparam/plts/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from specparam.modutils.dependencies import safe_import, check_dependency
from specparam.measures.properties import compute_average, compute_dispersion
from specparam.plts.utils import check_ax, set_alpha
from specparam.plts.settings import (PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT,
from specparam.plts.settings import (ITERABLES, PLT_FIGSIZES, DEFAULT_COLORS, PLT_TEXT_FONT,
TITLE_FONTSIZE, LABEL_SIZE, TICK_LABELSIZE)

plt = safe_import('.pyplot', 'matplotlib')
Expand Down Expand Up @@ -84,7 +84,7 @@ def plot_scatter_2(data_0, label_0, data_1, label_1,
Label for the data on the second axis, to be set as the axis label.
title : str, optional
Title for the plot.
colors : list of str, optional
colors : iterable, optional
Color(s) to plot data.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
Expand All @@ -99,7 +99,7 @@ def plot_scatter_2(data_0, label_0, data_1, label_1,
ax = check_ax(ax)
ax1 = ax.twinx()

colors = iter(colors) if isinstance(colors, list) else repeat(colors)
colors = iter(colors) if isinstance(colors, ITERABLES) else repeat(colors)

plot_scatter_1(data_0, label_0, color=next(colors), ax=ax, **plot_kwargs)
plot_scatter_1(data_1, label_1, x_val=1, color=next(colors), ax=ax1, **plot_kwargs)
Expand Down Expand Up @@ -277,16 +277,16 @@ def plot_params_over_time(times, params, labels=None, title=None, colors=None,
Label(s) for the data, to be set as the y-axis label(s).
title : str, optional
Title for the plot.
colors : list of str
colors : iterable, optional
Color(s) to plot data.
ax : matplotlib.Axes, optional
Figure axes upon which to plot.
**plot_kwargs
Additional keyword arguments for the plot call.
"""

labels = repeat(labels) if not isinstance(labels, list) else cycle(labels)
colors = cycle(DEFAULT_COLORS) if not isinstance(colors, list) else cycle(colors)
labels = repeat(labels) if not isinstance(labels, ITERABLES) else cycle(labels)
colors = cycle(DEFAULT_COLORS) if not isinstance(colors, ITERABLES) else cycle(colors)

ax0 = check_ax(ax, plot_kwargs.pop('figsize', PLT_FIGSIZES['time']))

Expand Down
8 changes: 4 additions & 4 deletions specparam/plts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from specparam.io.utils import create_file_path
from specparam.modutils.dependencies import safe_import
from specparam.modutils.functions import resolve_aliases
from specparam.plts.settings import PLT_ALPHA_LEVELS, PLT_ALIASES
from specparam.plts.settings import ITERABLES, PLT_ALPHA_LEVELS, PLT_ALIASES

plt = safe_import('.pyplot', 'matplotlib')

Expand Down Expand Up @@ -75,7 +75,7 @@ def add_shades(ax, shades, colors='r', shade_alpha=0.2,
Figure axes upon which to plot.
shades : list of [float, float] or list of list of [float, float]
Shaded region(s) to add to plot, defined as [lower_bound, upper_bound].
colors : str or list of string
colors : str or iterable, default: 'red'
Color(s) to plot shades.
shade_alpha : float or list of float, optional, default: 0.2
The alpha level to add the shade regions with.
Expand All @@ -92,8 +92,8 @@ def add_shades(ax, shades, colors='r', shade_alpha=0.2,
if not isinstance(shades[0], (tuple, list)):
shades = [shades]

colors = repeat(colors) if not isinstance(colors, list) else colors
shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, list) else shade_alpha
colors = repeat(colors) if not isinstance(colors, ITERABLES) else colors
shade_alphas = repeat(shade_alpha) if not isinstance(shade_alpha, ITERABLES) else shade_alpha

for shade, color, alpha in zip(shades, colors, shade_alphas):

Expand Down
Loading