-
Notifications
You must be signed in to change notification settings - Fork 112
Expand file tree
/
Copy pathutils.py
More file actions
173 lines (128 loc) · 4.98 KB
/
utils.py
File metadata and controls
173 lines (128 loc) · 4.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""Utility functions for plotting.
Notes
-----
These utility functions should be considered private.
They are not expected to be called directly by the user.
"""
from itertools import repeat
from collections.abc import Iterator
import numpy as np
from fooof.core.modutils import safe_import
from fooof.core.utils import resolve_aliases
from fooof.plts.settings import PLT_ALPHA_LEVELS, PLT_ALIASES
plt = safe_import('.pyplot', 'matplotlib')
###################################################################################################
###################################################################################################
def check_ax(ax, figsize=None):
"""Check whether a figure axes object is defined, and define if not.
Parameters
----------
ax : matplotlib.Axes or None
Object to check if is already an axes object.
figsize : tuple of float, optional
Size to create the figure, if not already created.
Returns
-------
ax : matplotlib.Axes
Figure axes object to use.
"""
if not ax:
_, ax = plt.subplots(figsize=figsize)
return ax
def set_alpha(n_points):
"""Set an alpha value for plotting that is scaled by the number of points.
Parameters
----------
n_points : int
Number of points that will be in the plot.
Returns
-------
alpha : float
Value for alpha to use for plotting.
"""
for key, val in PLT_ALPHA_LEVELS.items():
if n_points > key:
alpha = val
return alpha
def add_shades(ax, shades, colors='r', add_center=False, logged=False):
"""Add shaded regions to a plot.
Parameters
----------
ax : matplotlib.Axes
Figure axes upon which to plot.
shades : list of [float, float] or list of list of [float, float]
Shaded region(s) to add to plot, defined as [lower_bound, upper_bound].
colors : str or list of string
Color(s) to plot shades.
add_center : boolean, default: False
Whether to add a line at the center point of the shaded regions.
logged : boolean, default: False
Whether the shade values should be logged before applying to plot axes.
"""
# If only one shade region is specified, this embeds in a list, so that the loop works
if not isinstance(shades[0], list):
shades = [shades]
colors = repeat(colors) if not isinstance(colors, list) else colors
for shade, color in zip(shades, colors):
shade = np.log10(shade) if logged else shade
ax.axvspan(shade[0], shade[1], color=color, alpha=0.2, lw=0)
if add_center:
center = sum(shade) / 2
ax.axvspan(center, center, color='k', alpha=0.6)
def recursive_plot(data, plot_function, ax, **kwargs):
"""A utility to recursively plot sets of data.
Parameters
----------
data : list
List of datasets to iteratively add to the plot.
plot_function : callable
Plot function to call to plot the data.
ax : matplotlib.Axes
Figure axes upon which to plot.
**kwargs
Keyword arguments to pass into the plot function.
Inputs can be organized as:
- a list of values corresponding to length of data, one for each plot call
- a single value, such as an int, str or None, to be applied to all plot calls
Notes
-----
The `plot_function` argument must accept the `ax` parameter to specify a plot axis.
"""
# Check and update all inputs to be iterators
for key, val in kwargs.items():
# If an input is already an iterator, we can leave as is
if isinstance(val, Iterator):
kwargs[key] = val
# If an input is a list, assume one element per call, and make iterator to do so
elif isinstance(val, list):
kwargs[key] = iter(val)
# Otherwise, assume is a single value to pass to all plot calls, and make repeat to do so
else:
kwargs[key] = repeat(val)
# Pass each array of data recursively into plot function
# Each element of data is added to the same plot axis
for cur_data in data:
cur_kwargs = {key: next(val) for key, val in kwargs.items()}
plot_function(cur_data, ax=ax, **cur_kwargs)
def check_plot_kwargs(plot_kwargs, defaults):
"""Check plot keyword arguments, using default values for any missing parameters.
Parameters
----------
plot_kwargs : dict or None
Keyword arguments for a plot call.
defaults : dict
Any arguments, and their default values, to check and update in 'plot_kwargs'.
Returns
-------
plot_kwargs : dict
Keyword arguments for a plot call.
Notes
-----
If the input for `plot_kwargs` is None, then `defaults` is returned as `plot_kwargs`.
"""
if not plot_kwargs:
return defaults
for key, value in resolve_aliases(defaults, PLT_ALIASES).items():
if key not in resolve_aliases(plot_kwargs, PLT_ALIASES):
plot_kwargs[key] = value
return plot_kwargs