blob: 63d8749f70b09e0a8cfc318d9553088a047df0e5 [file] [log] [blame]
# Copyright 2015 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
"""Base classes for rendering figures based on matplotlib."""
import warnings
from safetynet import Any, Dict, List, Optional, Tuple, TypecheckMeta
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg
import matplotlib.figure
from . import _styles
class Figure(object):
"""Base class for a figure that can be rendered into an image."""
__metaclass__ = TypecheckMeta
DPI = 64.0
"""Dots per inches used for font-rendering."""
LEFT_PADDING = 48.0
"""Left padding in pixels."""
RIGHT_PADDING = 16.0
"""Right padding in pixels."""
TOP_PADDING = 12.0
"""Top padding in pixels."""
BOTTOM_PADDING = 32.0
"""Bottom padding in pixels."""
AXIS_COLOR = "#444444"
"""Color of all axes and labels."""
def __init__(self, width, height):
"""Create new figure image of specified size.
:param int width: Width of image in pixels.
:param int height: Height of image in pixels.
"""
self.figure = matplotlib.figure.Figure(facecolor=[1, 1, 1], dpi=self.DPI)
self.axes_list = []
self.width = float(width)
self.height = float(height)
if self.width and self.height:
self.figure.set_size_inches(self.width / self.DPI, self.height / self.DPI)
self.left = self.LEFT_PADDING / self.width
self.right = 1.0 - (self.RIGHT_PADDING / self.width)
self.bottom = self.BOTTOM_PADDING / self.height
self.top = 1.0 - (self.TOP_PADDING / self.height)
def Save(self, filename, show_legend):
"""Save figure to image.
:param str filename: Path to file to store image to.
:param bool show_legend: True if the image should contain a legend.
"""
canvas = self._CreateCanvas(show_legend)
canvas.print_figure(filename, dpi=self.DPI)
def AsImage(self, show_legend):
"""Same as Save, but return image as a numpy array instead.
:param bool show_legend: True if the image should contain a legend.
:rtype np.ndarray
"""
canvas = self._CreateCanvas(show_legend)
(buffer, (width, height)) = canvas.print_to_buffer()
return np.reshape(buffer, (height, width, 4))
def _CreateAxes(self, left, bottom, width, height):
"""Create new axes instance with specified border.
left, bottom, width and height are specified in the matplotlib 0..1 relative
coordinate system.
:type left: float
:type bottom: float
:type width: float
:type height: float
:rtype matplotlib.axes.Axes
"""
axes = self.figure.add_axes([left, bottom, width, height])
for axis in ['top', 'bottom', 'left', 'right']:
axes.spines[axis].set_color(self.AXIS_COLOR)
axes.tick_params(colors=self.AXIS_COLOR)
axes.xaxis.label.set_color(self.AXIS_COLOR)
axes.yaxis.label.set_color(self.AXIS_COLOR)
self.axes_list.append(axes)
return axes
def _CreateCanvas(self, show_legend):
"""Create matplotlib canvas, optionally including legend.
:param bool show_legend: True if the canvas should show a legend.
:rtype FigureCanvasAgg
"""
for axes in self.axes_list:
with warnings.catch_warnings():
# legend() will issue a warning if there are no labels in this plot.
# Which happens, and we are ok with.
warnings.simplefilter("ignore", UserWarning)
legend = axes.legend()
if legend:
legend.set_visible(show_legend)
return FigureCanvasAgg(self.figure)
class SingleAxesFigure(Figure):
"""A figure that contains only one pair of Axes."""
def __init__(self, width, height, x_label, y_label):
"""
:param int width: Width of image in pixels.
:param int height: Height of image in pixels.
:param Optional[str] x_label: Label of x-axis
:param Optional[str] y_label: Label of y-axis
"""
super(SingleAxesFigure, self).__init__(width, height)
self.axes = self._CreateAxes(self.left, self.bottom, self.right - self.left,
self.top - self.bottom)
if x_label:
self.axes.set_xlabel(x_label, labelpad=1)
if y_label:
self.axes.set_ylabel(y_label, labelpad=1)
class PrimarySecondaryAxesFigure(Figure):
"""Figure that is split horizontally.
The figure shows a primary pair of Axes above a secondary pair. Both share
the same x-axis, but have separate y-axis labels.
"""
def __init__(self, width, prim_height, second_height, x_label, prim_label,
second_label):
"""
:param int width: Width of image in pixels.
:param int prim_height: Height of primary axes in pixels.
:param int second_height: Height of secondary axes in pixels.
:param Optional[str] x_label: Label of shared x-axis
:param Optional[str] prim_label: Label of primary y-axis
:param Optional[str] second_label: Label of secondary y-axis
"""
overall_height = prim_height + second_height
super(PrimarySecondaryAxesFigure, self).__init__(width, overall_height)
axes_width = self.right - self.left
axes_height = self.top - self.bottom
# Add secondary axes at the bottom
second_height_ratio = float(second_height) / float(overall_height)
second_height = second_height_ratio * axes_height
self.second_axes = self._CreateAxes(self.left, self.bottom, axes_width,
second_height)
if x_label:
self.second_axes.set_xlabel(x_label, labelpad=4)
if second_label:
self.second_axes.set_ylabel(second_label, labelpad=8)
# Add primary axes above
prim_bottom = self.bottom + second_height
prim_height = axes_height * (1 - second_height_ratio)
self.prim_axes = self._CreateAxes(self.left, prim_bottom, axes_width,
prim_height)
if prim_label:
self.prim_axes.set_ylabel(prim_label, labelpad=8)
self.prim_axes.tick_params(axis="x", labelbottom="off")
class TimeSeriesAxes(object):
"""Wrapper around matplotlib.axes.Axes for visualizing time-series data.
Provides methods to plot time-series data and limit the view of data to a
certain time range. It also takes care of converting time values from high
speed camera frames into milliseconds.
All style parameters are a dictionary of strings mapping to strings that
are passed directly as keyword parameters to the matplotlib plotting
functions.
"""
__metaclass__ = TypecheckMeta
X_AXIS_PADDING = 0.01
"""X axis padding as a factor of the full x axis range."""
Y_AXIS_PADDING = 0.05
"""Y axis padding as a factor of the full y axis range."""
def __init__(self, mpl_axes, ms_per_frame):
"""
:param matplotlib.axes.Axes mpl_axes: matplotlib axes to wrap.
:param float ms_per_frame: milliseconds per high speed camera frame.
"""
self.mpl_axes = mpl_axes
self.ms_per_frame = ms_per_frame
self.begin_frame = None
self.end_frame = None
self.no_data_label = None
def AddVerticalLine(self, frame_index, style):
"""Draw vertical line at specified time.
:param int frame_index: time in camera frames.
:param Dict[str, Any] style: Dictionary of style parameters.
"""
time = self._InMS(frame_index)
self.mpl_axes.axvline(time, **style)
def AddCenteredText(self, text, style):
"""Draw text in the center of the axes.
:param str text: text to draw
:param Dict[str, Any] style: Dictionary of style parameters.
"""
self.mpl_axes.text(0.5, 0.5, text, ha="center", va="center",
transform=self.mpl_axes.transAxes, **style)
def AddText(self, frame_index, y, text, style):
"""Draw text at specified location
:param int frame_index: time in camera frames.
:param float y: y-location of text.
:param str text: text to draw.
:param Dict[str, Any] style: Dictionary of style parameters.
"""
time = self._InMS(frame_index)
self.mpl_axes.text(time, y, text, ha="center", va="center", **style)
def SetLimits(self, begin_frame, end_frame):
"""Specify start and end time of axes.
:param int begin_frame: begin time in camera frames.
:param int end_frame: end time in camera frames.
"""
self.begin_frame = begin_frame
self.end_frame = end_frame
def UpdateAxes(self):
"""Update axes format.
This method is to be called after all plotting and setting of limits is
done.
"""
self._UpdateLimits()
if self.no_data_label:
self.no_data_label.remove()
self.no_data_label = None
if not self.mpl_axes.get_lines():
self.no_data_label = self.mpl_axes.text(0.5, 0.5, "no data",
ha="center", va="center",
transform=self.mpl_axes.transAxes,
**_styles.MESSAGE_STYLE)
def _UpdateLimits(self):
"""Update limits of axes to show requested timeframe.
The y-axis is automatically scaled to show all plot points.
"""
if self.begin_frame and self.end_frame:
begin_time = self._InMS(self.begin_frame)
end_time = self._InMS(self.end_frame)
padding = (end_time - begin_time) * self.X_AXIS_PADDING
self.mpl_axes.set_xlim(begin_time - padding, end_time + padding)
(min_value, max_value) = self._GetMinMaxValue()
padding = (max_value - min_value) * self.Y_AXIS_PADDING
if padding > 0:
self.mpl_axes.set_ylim(min_value - padding, max_value + padding)
def _GetMinMaxValue(self):
"""Returns min and max value of all plots of this axes.
:rtype Tuple[float, float]
"""
def _NonNaN(array):
return array[~np.isnan(array)]
values = []
for lines in self.mpl_axes.get_lines():
ydata = lines.get_ydata()
begin = self.begin_frame or 0
end = self.end_frame or len(ydata)
values.extend(_NonNaN(ydata[begin:end]))
(min_value, max_value) = (0, 1)
if len(values) > 0:
min_value = np.min(values)
max_value = np.max(values)
max_value = max([max_value, 1])
return (min_value, max_value)
def _InMS(self, frame_index):
"""Convert high speed camera frame index into milliseconds."""
return np.asarray(frame_index) * self.ms_per_frame