-
Notifications
You must be signed in to change notification settings - Fork 112
Expand file tree
/
Copy pathstyle.py
More file actions
286 lines (219 loc) · 10.4 KB
/
style.py
File metadata and controls
286 lines (219 loc) · 10.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
"""Style and aesthetics definitions for plots."""
from itertools import cycle
from functools import wraps
import matplotlib.pyplot as plt
from fooof.plts.settings import (AXIS_STYLE_ARGS, LINE_STYLE_ARGS, COLLECTION_STYLE_ARGS,
CUSTOM_STYLE_ARGS, STYLE_ARGS, TICK_LABELSIZE, TITLE_FONTSIZE,
LABEL_SIZE, LEGEND_SIZE, LEGEND_LOC)
###################################################################################################
###################################################################################################
def check_style_options():
"""Check the list of valid style arguments that can be passed into plot functions."""
print('Valid style arguments:')
for label, options in zip(['Axis', 'Line', 'Collection', 'Custom'],
[AXIS_STYLE_ARGS, LINE_STYLE_ARGS,
COLLECTION_STYLE_ARGS, CUSTOM_STYLE_ARGS]):
print(' {:10s} {}'.format(label, ', '.join(options)))
def style_spectrum_plot(ax, log_freqs, log_powers, grid=True):
"""Apply style and aesthetics to a power spectrum plot.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply styling to.
log_freqs : bool
Whether the frequency axis is plotted in log space.
log_powers : bool
Whether the power axis is plotted in log space.
grid : bool, optional, default: True
Whether to add grid lines to the plot.
"""
# Get labels, based on log status
xlabel = 'Frequency' if not log_freqs else 'log(Frequency)'
ylabel = 'Power' if not log_powers else 'log(Power)'
# Aesthetics and axis labels
ax.set_xlabel(xlabel, fontsize=20)
ax.set_ylabel(ylabel, fontsize=20)
ax.tick_params(axis='both', which='major', labelsize=16)
ax.grid(grid)
# If labels were provided, add a legend
if ax.get_legend_handles_labels()[0]:
ax.legend(prop={'size': 16}, loc='upper right')
def style_param_plot(ax):
"""Apply style and aesthetics to a peaks plot.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply styling to.
"""
# Set the top and right side frame & ticks off
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('left')
# Set linewidth of remaining spines
ax.spines['left'].set_linewidth(1.5)
ax.spines['bottom'].set_linewidth(1.5)
# Aesthetics and axis labels
for item in ([ax.xaxis.label, ax.yaxis.label]):
item.set_fontsize(20)
ax.tick_params(axis='both', which='major', labelsize=16)
# If labels were provided, add a legend and standardize the dot size
if ax.get_legend_handles_labels()[0]:
legend = ax.legend(prop={'size': 16})
for handle in legend.legendHandles:
handle._sizes = [100]
def apply_axis_style(ax, style_args=AXIS_STYLE_ARGS, **kwargs):
"""Apply axis plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as axis styling.
**kwargs
Keyword arguments that define plot style to apply.
"""
# Apply any provided axis style arguments
plot_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
ax.set(**plot_kwargs)
def apply_line_style(ax, style_args=LINE_STYLE_ARGS, **kwargs):
"""Apply line plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as line styling.
**kwargs
Keyword arguments that define line style to apply.
"""
# Check how many lines are from the current plot call, to apply style to
# If available, this indicates the apply styling to the last 'n' lines
n_lines_apply = kwargs.pop('n_lines_apply', 0)
# Get the line related styling arguments from the keyword arguments
line_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
# Apply any provided line style arguments
for style, value in line_kwargs.items():
# Values should be either a single value, for all lines, or a list, of a value per line
# This line checks type, and makes a cycle-able / loop-able object out of the values
values = cycle([value] if isinstance(value, (int, float, str)) else value)
for line in ax.lines[-n_lines_apply:]:
line.set(**{style : next(values)})
def apply_collection_style(ax, style_args=COLLECTION_STYLE_ARGS, **kwargs):
"""Apply collection plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
style_args : list of str
A list of arguments to be sub-selected from `kwargs` and applied as collection styling.
**kwargs
Keyword arguments that define collection style to apply.
"""
# Get the collection related styling arguments from the keyword arguments
collection_kwargs = {key : val for key, val in kwargs.items() if key in style_args}
# Apply any provided collection style arguments
for collection in ax.collections:
collection.set(**collection_kwargs)
def apply_custom_style(ax, **kwargs):
"""Apply custom plot style.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
**kwargs
Keyword arguments that define custom style to apply.
"""
# If a title was provided, update the size
if ax.get_title():
ax.title.set_size(kwargs.pop('title_fontsize', TITLE_FONTSIZE))
# Settings for the axis labels
label_size = kwargs.pop('label_size', LABEL_SIZE)
ax.xaxis.label.set_size(label_size)
ax.yaxis.label.set_size(label_size)
# Settings for the axis ticks
ax.tick_params(axis='both', which='major',
labelsize=kwargs.pop('tick_labelsize', TICK_LABELSIZE))
# If labels were provided, add a legend
if ax.get_legend_handles_labels()[0]:
ax.legend(prop={'size': kwargs.pop('legend_size', LEGEND_SIZE)},
loc=kwargs.pop('legend_loc', LEGEND_LOC))
# Apply tight layout to the figure object, if matplotlib is new enough
# If available, `.set_layout_engine` should be equivalent to
# `plt.tight_layout()`, but seems to raise fewer warnings...
try:
fig = plt.gcf()
fig.set_layout_engine('tight')
except:
plt.tight_layout()
def apply_style(ax, axis_styler=apply_axis_style, line_styler=apply_line_style,
collection_styler=apply_collection_style, custom_styler=apply_custom_style,
**kwargs):
"""Apply plot style to a figure axis.
Parameters
----------
ax : matplotlib.Axes
Figure axes to apply style to.
axis_styler, line_styler, collection_style, custom_styler : callable, optional
Functions to apply style to aspects of the plot.
**kwargs
Keyword arguments that define style to apply.
Notes
-----
This function wraps sub-functions which apply style to different plot elements.
Each of these sub-functions can be replaced by passing in replacement callables.
"""
axis_styler(ax, **kwargs) if axis_styler is not None else None
line_styler(ax, **kwargs) if line_styler is not None else None
collection_styler(ax, **kwargs) if collection_styler is not None else None
custom_styler(ax, **kwargs) if custom_styler is not None else None
def style_plot(func, *args, **kwargs):
"""Decorator function to apply a plot style function, after plot generation.
Parameters
----------
func : callable
The plotting function for creating a plot.
*args, **kwargs
Arguments & keyword arguments.
These should include any arguments for the plot, and those for applying plot style.
Notes
-----
This decorator works by:
- catching all inputs that relate to plot style
- creating a plot, using the passed in plotting function & passing in all non-style arguments
- passing the style related arguments into a `apply_style` function which applies plot styling
By default, this function applies styling with the `apply_style` function. Custom
functions for applying style can be passed in using `apply_style` as a keyword argument.
The `apply_style` function calls sub-functions for applying different plot elements, including:
- `axis_styler`: apply style options to an axis
- `line_styler`: applies style options to lines objects in a plot
- `collection_styler`: applies style options to collections objects in a plot
- `custom_style`: applies custom style options
Each of these sub-functions can be overridden by passing in alternatives.
To see the full set of style arguments that are supported, run the following code:
>>> from fooof.plts.style import check_style_options
>>> check_style_options()
Valid style arguments:
Axis title, xlabel, ylabel, xlim, ylim, xticks, yticks, xticklabels, yticklabels
Line alpha, lw, linewidth, ls, linestyle, marker, ms, markersize
Collection alpha, edgecolor
Custom title_fontsize, label_size, tick_labelsize, legend_size, legend_loc
"""
@wraps(func)
def decorated(*args, **kwargs):
# Grab a custom style function, if provided, and grab any provided style arguments
style_func = kwargs.pop('plot_style', apply_style)
style_args = kwargs.pop('style_args', STYLE_ARGS)
style_kwargs = {key : kwargs.pop(key) for key in style_args if key in kwargs}
# Check how many lines are already on the plot, if it exists already
n_lines_pre = len(kwargs['ax'].lines) if 'ax' in kwargs and kwargs['ax'] is not None else 0
# Create the plot
func(*args, **kwargs)
# Get plot axis, if a specific one was provided, or if not, grab the current axis
cur_ax = kwargs['ax'] if 'ax' in kwargs and kwargs['ax'] is not None else plt.gca()
# Check how many lines were added to the plot, and make info available to plot styling
n_lines_apply = len(cur_ax.lines) - n_lines_pre
style_kwargs['n_lines_apply'] = n_lines_apply
# Apply the styling function
style_func(cur_ax, **style_kwargs)
return decorated