"""Plot utilities."""
import math
from functools import wraps
from os.path import join as pjoin
import matplotlib.pyplot as plt
from matplotlib import gridspec
[docs]def check_ax(ax, figsize=None, return_current=False):
"""Check whether a figure axes object is defined, and define or return current axis if not.
ax : matplotlib.Axes or None
Axes object to check if is defined.
figsize : tuple of float, optional
Size to make the axis.
return_current : bool, optional, default: False
Whether to return the current axis, if axis is not defined.
If False, creates a new plot axis instead.
ax : matplotlib.Axes
Figure axes object to use.
if not ax:
if return_current:
ax = plt.gca()
_, ax = plt.subplots(figsize=figsize)
return ax
def get_kwargs(kwargs, select):
"""Get keyword arguments.
kwargs : dict
Keyword arguments to extract from.
select : list of str
The arguments to extract.
setters : dict
Selected keyword arguments.
setters = {arg : kwargs.pop(arg, None) for arg in select}
setters = {arg : value for arg, value in setters.items() if value is not None}
return setters
def get_attr_kwargs(kwargs, attr):
"""Get keyword arguments related to a particular attribute.
kwargs : dict
Plotting related keyword arguments.
attr : str
The attribute to select related arguments.
attr_kwargs : dict
Selected keyword arguments, related to the given attribute.
labels = [key for key in kwargs.keys() if attr in key]
attr_kwargs = {label.split('_')[1] : kwargs.pop(label) for label in labels}
return attr_kwargs
def savefig(func):
"""Decorator function to save out figures."""
def decorated(*args, **kwargs):
# Grab file name and path arguments, if they are in kwargs
file_name = kwargs.pop('file_name', None)
file_path = kwargs.pop('file_path', None)
# Check for an explicit argument for whether to save figure or not
# Defaults to saving when file name given (since bool(str)->True; bool(None)->False)
save_fig = kwargs.pop('save_fig', bool(file_name))
# Check and collect any other plot keywords
save_kwargs = kwargs.pop('save_kwargs', {})
save_kwargs.setdefault('bbox_inches', 'tight')
# Check and collect whether to close the plot
close = kwargs.pop('close', False)
func(*args, **kwargs)
if save_fig:
save_figure(file_name, file_path, close, **save_kwargs)
return decorated
[docs]def make_axes(n_axes, n_cols=5, figsize=None, row_size=4, col_size=3.6,
wspace=None, hspace=None, title=None, **plt_kwargs):
"""Make a subplot with multiple axes.
n_axes : int
The total number of axes to create in the figure.
n_cols : int, optional, default: 5
The number of columns in the figure.
figsize : tuple of float, optional
Size to make the overall figure.
If not given, is estimated from the number of axes.
row_size, col_size : float, optional
The size to use per row / column.
Only used if `figsize` is None.
wspace, hspace : float, optional
Spacing parameters for between subplots.
These get passed into `plt.subplots_adjust`.
title : str, optional
A title to add to the figure.
Extra arguments to pass to `plt.subplots`.
axes : 1d array of AxesSubplot
Collection of axes objects.
n_rows = math.ceil(n_axes / n_cols)
if not figsize:
figsize = (n_cols * col_size, n_rows * row_size)
title_kwargs = get_attr_kwargs(plt_kwargs, 'title')
_, axes = plt.subplots(n_rows, n_cols, figsize=figsize, **plt_kwargs)
if wspace or hspace:
plt.subplots_adjust(wspace=wspace, hspace=hspace)
# Turn off axes for any extra subplots in last row
_ = [ax.axis('off') for ax in axes.ravel()[n_axes:]]
if title:
fontsize=title_kwargs.pop('title_fontsize', 24),
return axes.flatten()
[docs]def make_grid(nrows, ncols, title=None, **plt_kwargs):
"""Create a plot grid.
nrows, ncols : int
The number of rows and columns to add to the data.
title : str, optional
A title to add to the figure.
Additional arguments to pass into the plot function.
title_kwargs = get_attr_kwargs(plt_kwargs, 'title')
_ = plt.figure(figsize=plt_kwargs.pop('figsize', None))
grid = gridspec.GridSpec(nrows, ncols, **plt_kwargs)
if title:
fontsize=title_kwargs.pop('title_fontsize', 24),
y=title_kwargs.pop('title_y', 0.95),
return grid
[docs]def get_grid_subplot(grid, row, col, **plt_kwargs):
"""Get a subplot section from a grid layout.
grid : matplotlib.gridspec.GridSpec
A predefined plot grid layout.
row, col : int or slice
The row(s) and column(s) in which to place the subplot.
Additional arguments to pass into the plot function.
Subplot axis.
return plt.subplot(grid[row, col], **plt_kwargs)