 |
- from itertools import product
- import warnings
- from textwrap import dedent
- from distutils.version import LooseVersion
- import numpy as np
- import pandas as pd
- from scipy import stats
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- from . import utils
- from .palettes import color_palette, blend_palette
- from .distributions import distplot, kdeplot, _freedman_diaconis_bins
- __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
- class Grid(object):
- """Base class for grids of subplots."""
- _margin_titles = False
- _legend_out = True
- def set(self, **kwargs):
- """Set attributes on each subplot Axes."""
- for ax in self.axes.flat:
- ax.set(**kwargs)
- return self
- def savefig(self, *args, **kwargs):
- """Save the figure."""
- kwargs = kwargs.copy()
- kwargs.setdefault("bbox_inches", "tight")
- self.fig.savefig(*args, **kwargs)
- def add_legend(self, legend_data=None, title=None, label_order=None,
- **kwargs):
- """Draw a legend, maybe placing it outside axes and resizing the figure.
- Parameters
- ----------
- legend_data : dict, optional
- Dictionary mapping label names (or two-element tuples where the
- second element is a label name) to matplotlib artist handles. The
- default reads from ``self._legend_data``.
- title : string, optional
- Title for the legend. The default reads from ``self._hue_var``.
- label_order : list of labels, optional
- The order that the legend entries should appear in. The default
- reads from ``self.hue_names``.
- kwargs : key, value pairings
- Other keyword arguments are passed to the underlying legend methods
- on the Figure or Axes object.
- Returns
- -------
- self : Grid instance
- Returns self for easy chaining.
- """
- # Find the data for the legend
- if legend_data is None:
- legend_data = self._legend_data
- if label_order is None:
- if self.hue_names is None:
- label_order = list(legend_data.keys())
- else:
- label_order = list(map(utils.to_utf8, self.hue_names))
- blank_handle = mpl.patches.Patch(alpha=0, linewidth=0)
- handles = [legend_data.get(l, blank_handle) for l in label_order]
- title = self._hue_var if title is None else title
- if LooseVersion(mpl.__version__) < LooseVersion("3.0"):
- try:
- title_size = mpl.rcParams["axes.labelsize"] * .85
- except TypeError: # labelsize is something like "large"
- title_size = mpl.rcParams["axes.labelsize"]
- else:
- title_size = mpl.rcParams["legend.title_fontsize"]
- # Unpack nested labels from a hierarchical legend
- labels = []
- for entry in label_order:
- if isinstance(entry, tuple):
- _, label = entry
- else:
- label = entry
- labels.append(label)
- # Set default legend kwargs
- kwargs.setdefault("scatterpoints", 1)
- if self._legend_out:
- kwargs.setdefault("frameon", False)
- kwargs.setdefault("loc", "center right")
- # Draw a full-figure legend outside the grid
- figlegend = self.fig.legend(handles, labels, **kwargs)
- self._legend = figlegend
- figlegend.set_title(title, prop={"size": title_size})
- # Draw the plot to set the bounding boxes correctly
- if hasattr(self.fig.canvas, "get_renderer"):
- self.fig.draw(self.fig.canvas.get_renderer())
- # Calculate and set the new width of the figure so the legend fits
- legend_width = figlegend.get_window_extent().width / self.fig.dpi
- fig_width, fig_height = self.fig.get_size_inches()
- self.fig.set_size_inches(fig_width + legend_width, fig_height)
- # Draw the plot again to get the new transformations
- if hasattr(self.fig.canvas, "get_renderer"):
- self.fig.draw(self.fig.canvas.get_renderer())
- # Now calculate how much space we need on the right side
- legend_width = figlegend.get_window_extent().width / self.fig.dpi
- space_needed = legend_width / (fig_width + legend_width)
- margin = .04 if self._margin_titles else .01
- self._space_needed = margin + space_needed
- right = 1 - self._space_needed
- # Place the subplot axes to give space for the legend
- self.fig.subplots_adjust(right=right)
- else:
- # Draw a legend in the first axis
- ax = self.axes.flat[0]
- kwargs.setdefault("loc", "best")
- leg = ax.legend(handles, labels, **kwargs)
- leg.set_title(title, prop={"size": title_size})
- self._legend = leg
- return self
- def _clean_axis(self, ax):
- """Turn off axis labels and legend."""
- ax.set_xlabel("")
- ax.set_ylabel("")
- ax.legend_ = None
- return self
- def _update_legend_data(self, ax):
- """Extract the legend data from an axes object and save it."""
- handles, labels = ax.get_legend_handles_labels()
- data = {l: h for h, l in zip(handles, labels)}
- self._legend_data.update(data)
- def _get_palette(self, data, hue, hue_order, palette):
- """Get a list of colors for the hue variable."""
- if hue is None:
- palette = color_palette(n_colors=1)
- else:
- hue_names = utils.categorical_order(data[hue], hue_order)
- n_colors = len(hue_names)
- # By default use either the current color palette or HUSL
- if palette is None:
- current_palette = utils.get_color_cycle()
- if n_colors > len(current_palette):
- colors = color_palette("husl", n_colors)
- else:
- colors = color_palette(n_colors=n_colors)
- # Allow for palette to map from hue variable names
- elif isinstance(palette, dict):
- color_names = [palette[h] for h in hue_names]
- colors = color_palette(color_names, n_colors)
- # Otherwise act as if we just got a list of colors
- else:
- colors = color_palette(palette, n_colors)
- palette = color_palette(colors, n_colors)
- return palette
- _facet_docs = dict(
- data=dedent("""\
- data : DataFrame
- Tidy ("long-form") dataframe where each column is a variable and each
- row is an observation.\
- """),
- col_wrap=dedent("""\
- col_wrap : int, optional
- "Wrap" the column variable at this width, so that the column facets
- span multiple rows. Incompatible with a ``row`` facet.\
- """),
- share_xy=dedent("""\
- share{x,y} : bool, 'col', or 'row' optional
- If true, the facets will share y axes across columns and/or x axes
- across rows.\
- """),
- height=dedent("""\
- height : scalar, optional
- Height (in inches) of each facet. See also: ``aspect``.\
- """),
- aspect=dedent("""\
- aspect : scalar, optional
- Aspect ratio of each facet, so that ``aspect * height`` gives the width
- of each facet in inches.\
- """),
- palette=dedent("""\
- palette : palette name, list, or dict, optional
- Colors to use for the different levels of the ``hue`` variable. Should
- be something that can be interpreted by :func:`color_palette`, or a
- dictionary mapping hue levels to matplotlib colors.\
- """),
- legend_out=dedent("""\
- legend_out : bool, optional
- If ``True``, the figure size will be extended, and the legend will be
- drawn outside the plot on the center right.\
- """),
- margin_titles=dedent("""\
- margin_titles : bool, optional
- If ``True``, the titles for the row variable are drawn to the right of
- the last column. This option is experimental and may not work in all
- cases.\
- """),
- )
- class FacetGrid(Grid):
- """Multi-plot grid for plotting conditional relationships."""
- def __init__(self, data, row=None, col=None, hue=None, col_wrap=None,
- sharex=True, sharey=True, height=3, aspect=1, palette=None,
- row_order=None, col_order=None, hue_order=None, hue_kws=None,
- dropna=True, legend_out=True, despine=True,
- margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
- gridspec_kws=None, size=None):
- # Handle deprecations
- if size is not None:
- height = size
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(msg, UserWarning)
- # Determine the hue facet layer information
- hue_var = hue
- if hue is None:
- hue_names = None
- else:
- hue_names = utils.categorical_order(data[hue], hue_order)
- colors = self._get_palette(data, hue, hue_order, palette)
- # Set up the lists of names for the row and column facet variables
- if row is None:
- row_names = []
- else:
- row_names = utils.categorical_order(data[row], row_order)
- if col is None:
- col_names = []
- else:
- col_names = utils.categorical_order(data[col], col_order)
- # Additional dict of kwarg -> list of values for mapping the hue var
- hue_kws = hue_kws if hue_kws is not None else {}
- # Make a boolean mask that is True anywhere there is an NA
- # value in one of the faceting variables, but only if dropna is True
- none_na = np.zeros(len(data), np.bool)
- if dropna:
- row_na = none_na if row is None else data[row].isnull()
- col_na = none_na if col is None else data[col].isnull()
- hue_na = none_na if hue is None else data[hue].isnull()
- not_na = ~(row_na | col_na | hue_na)
- else:
- not_na = ~none_na
- # Compute the grid shape
- ncol = 1 if col is None else len(col_names)
- nrow = 1 if row is None else len(row_names)
- self._n_facets = ncol * nrow
- self._col_wrap = col_wrap
- if col_wrap is not None:
- if row is not None:
- err = "Cannot use `row` and `col_wrap` together."
- raise ValueError(err)
- ncol = col_wrap
- nrow = int(np.ceil(len(col_names) / col_wrap))
- self._ncol = ncol
- self._nrow = nrow
- # Calculate the base figure size
- # This can get stretched later by a legend
- # TODO this doesn't account for axis labels
- figsize = (ncol * height * aspect, nrow * height)
- # Validate some inputs
- if col_wrap is not None:
- margin_titles = False
- # Build the subplot keyword dictionary
- subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
- gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
- if xlim is not None:
- subplot_kws["xlim"] = xlim
- if ylim is not None:
- subplot_kws["ylim"] = ylim
- # Initialize the subplot grid
- if col_wrap is None:
- kwargs = dict(figsize=figsize, squeeze=False,
- sharex=sharex, sharey=sharey,
- subplot_kw=subplot_kws,
- gridspec_kw=gridspec_kws)
- fig, axes = plt.subplots(nrow, ncol, **kwargs)
- self.axes = axes
- else:
- # If wrapping the col variable we need to make the grid ourselves
- if gridspec_kws:
- warnings.warn("`gridspec_kws` ignored when using `col_wrap`")
- n_axes = len(col_names)
- fig = plt.figure(figsize=figsize)
- axes = np.empty(n_axes, object)
- axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
- if sharex:
- subplot_kws["sharex"] = axes[0]
- if sharey:
- subplot_kws["sharey"] = axes[0]
- for i in range(1, n_axes):
- axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)
- self.axes = axes
- # Now we turn off labels on the inner axes
- if sharex:
- for ax in self._not_bottom_axes:
- for label in ax.get_xticklabels():
- label.set_visible(False)
- ax.xaxis.offsetText.set_visible(False)
- if sharey:
- for ax in self._not_left_axes:
- for label in ax.get_yticklabels():
- label.set_visible(False)
- ax.yaxis.offsetText.set_visible(False)
- # Set up the class attributes
- # ---------------------------
- # First the public API
- self.data = data
- self.fig = fig
- self.axes = axes
- self.row_names = row_names
- self.col_names = col_names
- self.hue_names = hue_names
- self.hue_kws = hue_kws
- # Next the private variables
- self._nrow = nrow
- self._row_var = row
- self._ncol = ncol
- self._col_var = col
- self._margin_titles = margin_titles
- self._col_wrap = col_wrap
- self._hue_var = hue_var
- self._colors = colors
- self._legend_out = legend_out
- self._legend = None
- self._legend_data = {}
- self._x_var = None
- self._y_var = None
- self._dropna = dropna
- self._not_na = not_na
- # Make the axes look good
- fig.tight_layout()
- if despine:
- self.despine()
- __init__.__doc__ = dedent("""\
- Initialize the matplotlib figure and FacetGrid object.
- This class maps a dataset onto multiple axes arrayed in a grid of rows
- and columns that correspond to *levels* of variables in the dataset.
- The plots it produces are often called "lattice", "trellis", or
- "small-multiple" graphics.
- It can also represent levels of a third variable with the ``hue``
- parameter, which plots different subsets of data in different colors.
- This uses color to resolve elements on a third dimension, but only
- draws subsets on top of each other and will not tailor the ``hue``
- parameter for the specific visualization the way that axes-level
- functions that accept ``hue`` will.
- When using seaborn functions that infer semantic mappings from a
- dataset, care must be taken to synchronize those mappings across
- facets (e.g., by defing the ``hue`` mapping with a palette dict or
- setting the data type of the variables to ``category``). In most cases,
- it will be better to use a figure-level function (e.g. :func:`relplot`
- or :func:`catplot`) than to use :class:`FacetGrid` directly.
- The basic workflow is to initialize the :class:`FacetGrid` object with
- the dataset and the variables that are used to structure the grid. Then
- one or more plotting functions can be applied to each subset by calling
- :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the
- plot can be tweaked with other methods to do things like change the
- axis labels, use different ticks, or add a legend. See the detailed
- code examples below for more information.
- See the :ref:`tutorial <grid_tutorial>` for more information.
- Parameters
- ----------
- {data}
- row, col, hue : strings
- Variables that define subsets of the data, which will be drawn on
- separate facets in the grid. See the ``*_order`` parameters to
- control the order of levels of this variable.
- {col_wrap}
- {share_xy}
- {height}
- {aspect}
- {palette}
- {{row,col,hue}}_order : lists, optional
- Order for the levels of the faceting variables. By default, this
- will be the order that the levels appear in ``data`` or, if the
- variables are pandas categoricals, the category order.
- hue_kws : dictionary of param -> list of values mapping
- Other keyword arguments to insert into the plotting call to let
- other plot attributes vary across levels of the hue variable (e.g.
- the markers in a scatterplot).
- {legend_out}
- despine : boolean, optional
- Remove the top and right spines from the plots.
- {margin_titles}
- {{x, y}}lim: tuples, optional
- Limits for each of the axes on each facet (only relevant when
- share{{x, y}} is True).
- subplot_kws : dict, optional
- Dictionary of keyword arguments passed to matplotlib subplot(s)
- methods.
- gridspec_kws : dict, optional
- Dictionary of keyword arguments passed to matplotlib's ``gridspec``
- module (via ``plt.subplots``). Ignored if ``col_wrap`` is not
- ``None``.
- See Also
- --------
- PairGrid : Subplot grid for plotting pairwise relationships.
- relplot : Combine a relational plot and a :class:`FacetGrid`.
- catplot : Combine a categorical plot and a :class:`FacetGrid`.
- lmplot : Combine a regression plot and a :class:`FacetGrid`.
- Examples
- --------
- Initialize a 2x2 grid of facets using the tips dataset:
- .. plot::
- :context: close-figs
- >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
- >>> tips = sns.load_dataset("tips")
- >>> g = sns.FacetGrid(tips, col="time", row="smoker")
- Draw a univariate plot on each facet:
- .. plot::
- :context: close-figs
- >>> import matplotlib.pyplot as plt
- >>> g = sns.FacetGrid(tips, col="time", row="smoker")
- >>> g = g.map(plt.hist, "total_bill")
- (Note that it's not necessary to re-catch the returned variable; it's
- the same object, but doing so in the examples makes dealing with the
- doctests somewhat less annoying).
- Pass additional keyword arguments to the mapped function:
- .. plot::
- :context: close-figs
- >>> import numpy as np
- >>> bins = np.arange(0, 65, 5)
- >>> g = sns.FacetGrid(tips, col="time", row="smoker")
- >>> g = g.map(plt.hist, "total_bill", bins=bins, color="r")
- Plot a bivariate function on each facet:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="time", row="smoker")
- >>> g = g.map(plt.scatter, "total_bill", "tip", edgecolor="w")
- Assign one of the variables to the color of the plot elements:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="time", hue="smoker")
- >>> g = (g.map(plt.scatter, "total_bill", "tip", edgecolor="w")
- ... .add_legend())
- Change the height and aspect ratio of each facet:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="day", height=4, aspect=.5)
- >>> g = g.map(plt.hist, "total_bill", bins=bins)
- Specify the order for plot elements:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="smoker", col_order=["Yes", "No"])
- >>> g = g.map(plt.hist, "total_bill", bins=bins, color="m")
- Use a different color palette:
- .. plot::
- :context: close-figs
- >>> kws = dict(s=50, linewidth=.5, edgecolor="w")
- >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette="Set1",
- ... hue_order=["Dinner", "Lunch"])
- >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
- ... .add_legend())
- Use a dictionary mapping hue levels to colors:
- .. plot::
- :context: close-figs
- >>> pal = dict(Lunch="seagreen", Dinner="gray")
- >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal,
- ... hue_order=["Dinner", "Lunch"])
- >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
- ... .add_legend())
- Additionally use a different marker for the hue levels:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal,
- ... hue_order=["Dinner", "Lunch"],
- ... hue_kws=dict(marker=["^", "v"]))
- >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
- ... .add_legend())
- "Wrap" a column variable with many levels into the rows:
- .. plot::
- :context: close-figs
- >>> att = sns.load_dataset("attention")
- >>> g = sns.FacetGrid(att, col="subject", col_wrap=5, height=1.5)
- >>> g = g.map(plt.plot, "solutions", "score", marker=".")
- Define a custom bivariate function to map onto the grid:
- .. plot::
- :context: close-figs
- >>> from scipy import stats
- >>> def qqplot(x, y, **kwargs):
- ... _, xr = stats.probplot(x, fit=False)
- ... _, yr = stats.probplot(y, fit=False)
- ... sns.scatterplot(xr, yr, **kwargs)
- >>> g = sns.FacetGrid(tips, col="smoker", hue="sex")
- >>> g = (g.map(qqplot, "total_bill", "tip", **kws)
- ... .add_legend())
- Define a custom function that uses a ``DataFrame`` object and accepts
- column names as positional variables:
- .. plot::
- :context: close-figs
- >>> import pandas as pd
- >>> df = pd.DataFrame(
- ... data=np.random.randn(90, 4),
- ... columns=pd.Series(list("ABCD"), name="walk"),
- ... index=pd.date_range("2015-01-01", "2015-03-31",
- ... name="date"))
- >>> df = df.cumsum(axis=0).stack().reset_index(name="val")
- >>> def dateplot(x, y, **kwargs):
- ... ax = plt.gca()
- ... data = kwargs.pop("data")
- ... data.plot(x=x, y=y, ax=ax, grid=False, **kwargs)
- >>> g = sns.FacetGrid(df, col="walk", col_wrap=2, height=3.5)
- >>> g = g.map_dataframe(dateplot, "date", "val")
- Use different axes labels after plotting:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="smoker", row="sex")
- >>> g = (g.map(plt.scatter, "total_bill", "tip", color="g", **kws)
- ... .set_axis_labels("Total bill (US Dollars)", "Tip"))
- Set other attributes that are shared across the facetes:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="smoker", row="sex")
- >>> g = (g.map(plt.scatter, "total_bill", "tip", color="r", **kws)
- ... .set(xlim=(0, 60), ylim=(0, 12),
- ... xticks=[10, 30, 50], yticks=[2, 6, 10]))
- Use a different template for the facet titles:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="size", col_wrap=3)
- >>> g = (g.map(plt.hist, "tip", bins=np.arange(0, 13), color="c")
- ... .set_titles("{{col_name}} diners"))
- Tighten the facets:
- .. plot::
- :context: close-figs
- >>> g = sns.FacetGrid(tips, col="smoker", row="sex",
- ... margin_titles=True)
- >>> g = (g.map(plt.scatter, "total_bill", "tip", color="m", **kws)
- ... .set(xlim=(0, 60), ylim=(0, 12),
- ... xticks=[10, 30, 50], yticks=[2, 6, 10])
- ... .fig.subplots_adjust(wspace=.05, hspace=.05))
- """).format(**_facet_docs)
- def facet_data(self):
- """Generator for name indices and data subsets for each facet.
- Yields
- ------
- (i, j, k), data_ijk : tuple of ints, DataFrame
- The ints provide an index into the {row, col, hue}_names attribute,
- and the dataframe contains a subset of the full data corresponding
- to each facet. The generator yields subsets that correspond with
- the self.axes.flat iterator, or self.axes[i, j] when `col_wrap`
- is None.
- """
- data = self.data
- # Construct masks for the row variable
- if self.row_names:
- row_masks = [data[self._row_var] == n for n in self.row_names]
- else:
- row_masks = [np.repeat(True, len(self.data))]
- # Construct masks for the column variable
- if self.col_names:
- col_masks = [data[self._col_var] == n for n in self.col_names]
- else:
- col_masks = [np.repeat(True, len(self.data))]
- # Construct masks for the hue variable
- if self.hue_names:
- hue_masks = [data[self._hue_var] == n for n in self.hue_names]
- else:
- hue_masks = [np.repeat(True, len(self.data))]
- # Here is the main generator loop
- for (i, row), (j, col), (k, hue) in product(enumerate(row_masks),
- enumerate(col_masks),
- enumerate(hue_masks)):
- data_ijk = data[row & col & hue & self._not_na]
- yield (i, j, k), data_ijk
- def map(self, func, *args, **kwargs):
- """Apply a plotting function to each facet's subset of the data.
- Parameters
- ----------
- func : callable
- A plotting function that takes data and keyword arguments. It
- must plot to the currently active matplotlib Axes and take a
- `color` keyword argument. If faceting on the `hue` dimension,
- it must also take a `label` keyword argument.
- args : strings
- Column names in self.data that identify variables with data to
- plot. The data for each variable is passed to `func` in the
- order the variables are specified in the call.
- kwargs : keyword arguments
- All keyword arguments are passed to the plotting function.
- Returns
- -------
- self : object
- Returns self.
- """
- # If color was a keyword argument, grab it here
- kw_color = kwargs.pop("color", None)
- if hasattr(func, "__module__"):
- func_module = str(func.__module__)
- else:
- func_module = ""
- # Check for categorical plots without order information
- if func_module == "seaborn.categorical":
- if "order" not in kwargs:
- warning = ("Using the {} function without specifying "
- "`order` is likely to produce an incorrect "
- "plot.".format(func.__name__))
- warnings.warn(warning)
- if len(args) == 3 and "hue_order" not in kwargs:
- warning = ("Using the {} function without specifying "
- "`hue_order` is likely to produce an incorrect "
- "plot.".format(func.__name__))
- warnings.warn(warning)
- # Iterate over the data subsets
- for (row_i, col_j, hue_k), data_ijk in self.facet_data():
- # If this subset is null, move on
- if not data_ijk.values.size:
- continue
- # Get the current axis
- ax = self.facet_axis(row_i, col_j)
- # Decide what color to plot with
- kwargs["color"] = self._facet_color(hue_k, kw_color)
- # Insert the other hue aesthetics if appropriate
- for kw, val_list in self.hue_kws.items():
- kwargs[kw] = val_list[hue_k]
- # Insert a label in the keyword arguments for the legend
- if self._hue_var is not None:
- kwargs["label"] = utils.to_utf8(self.hue_names[hue_k])
- # Get the actual data we are going to plot with
- plot_data = data_ijk[list(args)]
- if self._dropna:
- plot_data = plot_data.dropna()
- plot_args = [v for k, v in plot_data.iteritems()]
- # Some matplotlib functions don't handle pandas objects correctly
- if func_module.startswith("matplotlib"):
- plot_args = [v.values for v in plot_args]
- # Draw the plot
- self._facet_plot(func, ax, plot_args, kwargs)
- # Finalize the annotations and layout
- self._finalize_grid(args[:2])
- return self
- def map_dataframe(self, func, *args, **kwargs):
- """Like ``.map`` but passes args as strings and inserts data in kwargs.
- This method is suitable for plotting with functions that accept a
- long-form DataFrame as a `data` keyword argument and access the
- data in that DataFrame using string variable names.
- Parameters
- ----------
- func : callable
- A plotting function that takes data and keyword arguments. Unlike
- the `map` method, a function used here must "understand" Pandas
- objects. It also must plot to the currently active matplotlib Axes
- and take a `color` keyword argument. If faceting on the `hue`
- dimension, it must also take a `label` keyword argument.
- args : strings
- Column names in self.data that identify variables with data to
- plot. The data for each variable is passed to `func` in the
- order the variables are specified in the call.
- kwargs : keyword arguments
- All keyword arguments are passed to the plotting function.
- Returns
- -------
- self : object
- Returns self.
- """
- # If color was a keyword argument, grab it here
- kw_color = kwargs.pop("color", None)
- # Iterate over the data subsets
- for (row_i, col_j, hue_k), data_ijk in self.facet_data():
- # If this subset is null, move on
- if not data_ijk.values.size:
- continue
- # Get the current axis
- ax = self.facet_axis(row_i, col_j)
- # Decide what color to plot with
- kwargs["color"] = self._facet_color(hue_k, kw_color)
- # Insert the other hue aesthetics if appropriate
- for kw, val_list in self.hue_kws.items():
- kwargs[kw] = val_list[hue_k]
- # Insert a label in the keyword arguments for the legend
- if self._hue_var is not None:
- kwargs["label"] = self.hue_names[hue_k]
- # Stick the facet dataframe into the kwargs
- if self._dropna:
- data_ijk = data_ijk.dropna()
- kwargs["data"] = data_ijk
- # Draw the plot
- self._facet_plot(func, ax, args, kwargs)
- # Finalize the annotations and layout
- self._finalize_grid(args[:2])
- return self
- def _facet_color(self, hue_index, kw_color):
- color = self._colors[hue_index]
- if kw_color is not None:
- return kw_color
- elif color is not None:
- return color
- def _facet_plot(self, func, ax, plot_args, plot_kwargs):
- # Draw the plot
- func(*plot_args, **plot_kwargs)
- # Sort out the supporting information
- self._update_legend_data(ax)
- self._clean_axis(ax)
- def _finalize_grid(self, axlabels):
- """Finalize the annotations and layout."""
- self.set_axis_labels(*axlabels)
- self.set_titles()
- self.fig.tight_layout()
- def facet_axis(self, row_i, col_j):
- """Make the axis identified by these indices active and return it."""
- # Calculate the actual indices of the axes to plot on
- if self._col_wrap is not None:
- ax = self.axes.flat[col_j]
- else:
- ax = self.axes[row_i, col_j]
- # Get a reference to the axes object we want, and make it active
- plt.sca(ax)
- return ax
- def despine(self, **kwargs):
- """Remove axis spines from the facets."""
- utils.despine(self.fig, **kwargs)
- return self
- def set_axis_labels(self, x_var=None, y_var=None):
- """Set axis labels on the left column and bottom row of the grid."""
- if x_var is not None:
- self._x_var = x_var
- self.set_xlabels(x_var)
- if y_var is not None:
- self._y_var = y_var
- self.set_ylabels(y_var)
- return self
- def set_xlabels(self, label=None, **kwargs):
- """Label the x axis on the bottom row of the grid."""
- if label is None:
- label = self._x_var
- for ax in self._bottom_axes:
- ax.set_xlabel(label, **kwargs)
- return self
- def set_ylabels(self, label=None, **kwargs):
- """Label the y axis on the left column of the grid."""
- if label is None:
- label = self._y_var
- for ax in self._left_axes:
- ax.set_ylabel(label, **kwargs)
- return self
- def set_xticklabels(self, labels=None, step=None, **kwargs):
- """Set x axis tick labels of the grid."""
- for ax in self.axes.flat:
- if labels is None:
- curr_labels = [l.get_text() for l in ax.get_xticklabels()]
- if step is not None:
- xticks = ax.get_xticks()[::step]
- curr_labels = curr_labels[::step]
- ax.set_xticks(xticks)
- ax.set_xticklabels(curr_labels, **kwargs)
- else:
- ax.set_xticklabels(labels, **kwargs)
- return self
- def set_yticklabels(self, labels=None, **kwargs):
- """Set y axis tick labels on the left column of the grid."""
- for ax in self.axes.flat:
- if labels is None:
- curr_labels = [l.get_text() for l in ax.get_yticklabels()]
- ax.set_yticklabels(curr_labels, **kwargs)
- else:
- ax.set_yticklabels(labels, **kwargs)
- return self
- def set_titles(self, template=None, row_template=None, col_template=None,
- **kwargs):
- """Draw titles either above each facet or on the grid margins.
- Parameters
- ----------
- template : string
- Template for all titles with the formatting keys {col_var} and
- {col_name} (if using a `col` faceting variable) and/or {row_var}
- and {row_name} (if using a `row` faceting variable).
- row_template:
- Template for the row variable when titles are drawn on the grid
- margins. Must have {row_var} and {row_name} formatting keys.
- col_template:
- Template for the row variable when titles are drawn on the grid
- margins. Must have {col_var} and {col_name} formatting keys.
- Returns
- -------
- self: object
- Returns self.
- """
- args = dict(row_var=self._row_var, col_var=self._col_var)
- kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])
- # Establish default templates
- if row_template is None:
- row_template = "{row_var} = {row_name}"
- if col_template is None:
- col_template = "{col_var} = {col_name}"
- if template is None:
- if self._row_var is None:
- template = col_template
- elif self._col_var is None:
- template = row_template
- else:
- template = " | ".join([row_template, col_template])
- row_template = utils.to_utf8(row_template)
- col_template = utils.to_utf8(col_template)
- template = utils.to_utf8(template)
- if self._margin_titles:
- if self.row_names is not None:
- # Draw the row titles on the right edge of the grid
- for i, row_name in enumerate(self.row_names):
- ax = self.axes[i, -1]
- args.update(dict(row_name=row_name))
- title = row_template.format(**args)
- bgcolor = self.fig.get_facecolor()
- ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction",
- rotation=270, ha="left", va="center",
- backgroundcolor=bgcolor, **kwargs)
- if self.col_names is not None:
- # Draw the column titles as normal titles
- for j, col_name in enumerate(self.col_names):
- args.update(dict(col_name=col_name))
- title = col_template.format(**args)
- self.axes[0, j].set_title(title, **kwargs)
- return self
- # Otherwise title each facet with all the necessary information
- if (self._row_var is not None) and (self._col_var is not None):
- for i, row_name in enumerate(self.row_names):
- for j, col_name in enumerate(self.col_names):
- args.update(dict(row_name=row_name, col_name=col_name))
- title = template.format(**args)
- self.axes[i, j].set_title(title, **kwargs)
- elif self.row_names is not None and len(self.row_names):
- for i, row_name in enumerate(self.row_names):
- args.update(dict(row_name=row_name))
- title = template.format(**args)
- self.axes[i, 0].set_title(title, **kwargs)
- elif self.col_names is not None and len(self.col_names):
- for i, col_name in enumerate(self.col_names):
- args.update(dict(col_name=col_name))
- title = template.format(**args)
- # Index the flat array so col_wrap works
- self.axes.flat[i].set_title(title, **kwargs)
- return self
- @property
- def ax(self):
- """Easy access to single axes."""
- if self.axes.shape == (1, 1):
- return self.axes[0, 0]
- else:
- err = ("You must use the `.axes` attribute (an array) when "
- "there is more than one plot.")
- raise AttributeError(err)
- @property
- def _inner_axes(self):
- """Return a flat array of the inner axes."""
- if self._col_wrap is None:
- return self.axes[:-1, 1:].flat
- else:
- axes = []
- n_empty = self._nrow * self._ncol - self._n_facets
- for i, ax in enumerate(self.axes):
- append = (i % self._ncol and
- i < (self._ncol * (self._nrow - 1)) and
- i < (self._ncol * (self._nrow - 1) - n_empty))
- if append:
- axes.append(ax)
- return np.array(axes, object).flat
- @property
- def _left_axes(self):
- """Return a flat array of the left column of axes."""
- if self._col_wrap is None:
- return self.axes[:, 0].flat
- else:
- axes = []
- for i, ax in enumerate(self.axes):
- if not i % self._ncol:
- axes.append(ax)
- return np.array(axes, object).flat
- @property
- def _not_left_axes(self):
- """Return a flat array of axes that aren't on the left column."""
- if self._col_wrap is None:
- return self.axes[:, 1:].flat
- else:
- axes = []
- for i, ax in enumerate(self.axes):
- if i % self._ncol:
- axes.append(ax)
- return np.array(axes, object).flat
- @property
- def _bottom_axes(self):
- """Return a flat array of the bottom row of axes."""
- if self._col_wrap is None:
- return self.axes[-1, :].flat
- else:
- axes = []
- n_empty = self._nrow * self._ncol - self._n_facets
- for i, ax in enumerate(self.axes):
- append = (i >= (self._ncol * (self._nrow - 1)) or
- i >= (self._ncol * (self._nrow - 1) - n_empty))
- if append:
- axes.append(ax)
- return np.array(axes, object).flat
- @property
- def _not_bottom_axes(self):
- """Return a flat array of axes that aren't on the bottom row."""
- if self._col_wrap is None:
- return self.axes[:-1, :].flat
- else:
- axes = []
- n_empty = self._nrow * self._ncol - self._n_facets
- for i, ax in enumerate(self.axes):
- append = (i < (self._ncol * (self._nrow - 1)) and
- i < (self._ncol * (self._nrow - 1) - n_empty))
- if append:
- axes.append(ax)
- return np.array(axes, object).flat
- class PairGrid(Grid):
- """Subplot grid for plotting pairwise relationships in a dataset.
- This class maps each variable in a dataset onto a column and row in a
- grid of multiple axes. Different axes-level plotting functions can be
- used to draw bivariate plots in the upper and lower triangles, and the
- the marginal distribution of each variable can be shown on the diagonal.
- It can also represent an additional level of conditionalization with the
- ``hue`` parameter, which plots different subsets of data in different
- colors. This uses color to resolve elements on a third dimension, but
- only draws subsets on top of each other and will not tailor the ``hue``
- parameter for the specific visualization the way that axes-level functions
- that accept ``hue`` will.
- See the :ref:`tutorial <grid_tutorial>` for more information.
- """
- def __init__(self, data, hue=None, hue_order=None, palette=None,
- hue_kws=None, vars=None, x_vars=None, y_vars=None,
- corner=False, diag_sharey=True, height=2.5, aspect=1,
- layout_pad=0, despine=True, dropna=True, size=None):
- """Initialize the plot figure and PairGrid object.
- Parameters
- ----------
- data : DataFrame
- Tidy (long-form) dataframe where each column is a variable and
- each row is an observation.
- hue : string (variable name), optional
- Variable in ``data`` to map plot aspects to different colors. This
- variable will be excluded from the default x and y variables.
- hue_order : list of strings
- Order for the levels of the hue variable in the palette
- palette : dict or seaborn color palette
- Set of colors for mapping the ``hue`` variable. If a dict, keys
- should be values in the ``hue`` variable.
- hue_kws : dictionary of param -> list of values mapping
- Other keyword arguments to insert into the plotting call to let
- other plot attributes vary across levels of the hue variable (e.g.
- the markers in a scatterplot).
- vars : list of variable names, optional
- Variables within ``data`` to use, otherwise use every column with
- a numeric datatype.
- {x, y}_vars : lists of variable names, optional
- Variables within ``data`` to use separately for the rows and
- columns of the figure; i.e. to make a non-square plot.
- corner : bool, optional
- If True, don't add axes to the upper (off-diagonal) triangle of the
- grid, making this a "corner" plot.
- height : scalar, optional
- Height (in inches) of each facet.
- aspect : scalar, optional
- Aspect * height gives the width (in inches) of each facet.
- layout_pad : scalar, optional
- Padding between axes; passed to ``fig.tight_layout``.
- despine : boolean, optional
- Remove the top and right spines from the plots.
- dropna : boolean, optional
- Drop missing values from the data before plotting.
- See Also
- --------
- pairplot : Easily drawing common uses of :class:`PairGrid`.
- FacetGrid : Subplot grid for plotting conditional relationships.
- Examples
- --------
- Draw a scatterplot for each pairwise relationship:
- .. plot::
- :context: close-figs
- >>> import matplotlib.pyplot as plt
- >>> import seaborn as sns; sns.set()
- >>> iris = sns.load_dataset("iris")
- >>> g = sns.PairGrid(iris)
- >>> g = g.map(plt.scatter)
- Show a univariate distribution on the diagonal:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris)
- >>> g = g.map_diag(plt.hist)
- >>> g = g.map_offdiag(plt.scatter)
- (It's not actually necessary to catch the return value every time,
- as it is the same object, but it makes it easier to deal with the
- doctests).
- Color the points using a categorical variable:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris, hue="species")
- >>> g = g.map_diag(plt.hist)
- >>> g = g.map_offdiag(plt.scatter)
- >>> g = g.add_legend()
- Use a different style to show multiple histograms:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris, hue="species")
- >>> g = g.map_diag(plt.hist, histtype="step", linewidth=3)
- >>> g = g.map_offdiag(plt.scatter)
- >>> g = g.add_legend()
- Plot a subset of variables
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris, vars=["sepal_length", "sepal_width"])
- >>> g = g.map(plt.scatter)
- Pass additional keyword arguments to the functions
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris)
- >>> g = g.map_diag(plt.hist, edgecolor="w")
- >>> g = g.map_offdiag(plt.scatter, edgecolor="w", s=40)
- Use different variables for the rows and columns:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris,
- ... x_vars=["sepal_length", "sepal_width"],
- ... y_vars=["petal_length", "petal_width"])
- >>> g = g.map(plt.scatter)
- Use different functions on the upper and lower triangles:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris)
- >>> g = g.map_upper(sns.scatterplot)
- >>> g = g.map_lower(sns.kdeplot, colors="C0")
- >>> g = g.map_diag(sns.kdeplot, lw=2)
- Use different colors and markers for each categorical level:
- .. plot::
- :context: close-figs
- >>> g = sns.PairGrid(iris, hue="species", palette="Set2",
- ... hue_kws={"marker": ["o", "s", "D"]})
- >>> g = g.map(sns.scatterplot, linewidths=1, edgecolor="w", s=40)
- >>> g = g.add_legend()
- """
- # Handle deprecations
- if size is not None:
- height = size
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(UserWarning(msg))
- # Sort out the variables that define the grid
- if vars is not None:
- x_vars = list(vars)
- y_vars = list(vars)
- elif (x_vars is not None) or (y_vars is not None):
- if (x_vars is None) or (y_vars is None):
- raise ValueError("Must specify `x_vars` and `y_vars`")
- else:
- numeric_cols = self._find_numeric_cols(data)
- if hue in numeric_cols:
- numeric_cols.remove(hue)
- x_vars = numeric_cols
- y_vars = numeric_cols
- if np.isscalar(x_vars):
- x_vars = [x_vars]
- if np.isscalar(y_vars):
- y_vars = [y_vars]
- self.x_vars = list(x_vars)
- self.y_vars = list(y_vars)
- self.square_grid = self.x_vars == self.y_vars
- # Create the figure and the array of subplots
- figsize = len(x_vars) * height * aspect, len(y_vars) * height
- fig, axes = plt.subplots(len(y_vars), len(x_vars),
- figsize=figsize,
- sharex="col", sharey="row",
- squeeze=False)
- # Possibly remove upper axes to make a corner grid
- # Note: setting up the axes is usually the most time-intensive part
- # of using the PairGrid. We are foregoing the speed improvement that
- # we would get by just not setting up the hidden axes so that we can
- # avoid implementing plt.subplots ourselves. But worth thinking about.
- self._corner = corner
- if corner:
- hide_indices = np.triu_indices_from(axes, 1)
- for i, j in zip(*hide_indices):
- axes[i, j].remove()
- axes[i, j] = None
- self.fig = fig
- self.axes = axes
- self.data = data
- # Save what we are going to do with the diagonal
- self.diag_sharey = diag_sharey
- self.diag_vars = None
- self.diag_axes = None
- self._dropna = dropna
- # Label the axes
- self._add_axis_labels()
- # Sort out the hue variable
- self._hue_var = hue
- if hue is None:
- self.hue_names = ["_nolegend_"]
- self.hue_vals = pd.Series(["_nolegend_"] * len(data),
- index=data.index)
- else:
- hue_names = utils.categorical_order(data[hue], hue_order)
- if dropna:
- # Filter NA from the list of unique hue names
- hue_names = list(filter(pd.notnull, hue_names))
- self.hue_names = hue_names
- self.hue_vals = data[hue]
- # Additional dict of kwarg -> list of values for mapping the hue var
- self.hue_kws = hue_kws if hue_kws is not None else {}
- self.palette = self._get_palette(data, hue, hue_order, palette)
- self._legend_data = {}
- # Make the plot look nice
- if despine:
- self._despine = True
- utils.despine(fig=fig)
- fig.tight_layout(pad=layout_pad)
- def map(self, func, **kwargs):
- """Plot with the same function in every subplot.
- Parameters
- ----------
- func : callable plotting function
- Must take x, y arrays as positional arguments and draw onto the
- "currently active" matplotlib Axes. Also needs to accept kwargs
- called ``color`` and ``label``.
- """
- row_indices, col_indices = np.indices(self.axes.shape)
- indices = zip(row_indices.flat, col_indices.flat)
- self._map_bivariate(func, indices, **kwargs)
- return self
- def map_lower(self, func, **kwargs):
- """Plot with a bivariate function on the lower diagonal subplots.
- Parameters
- ----------
- func : callable plotting function
- Must take x, y arrays as positional arguments and draw onto the
- "currently active" matplotlib Axes. Also needs to accept kwargs
- called ``color`` and ``label``.
- """
- indices = zip(*np.tril_indices_from(self.axes, -1))
- self._map_bivariate(func, indices, **kwargs)
- return self
- def map_upper(self, func, **kwargs):
- """Plot with a bivariate function on the upper diagonal subplots.
- Parameters
- ----------
- func : callable plotting function
- Must take x, y arrays as positional arguments and draw onto the
- "currently active" matplotlib Axes. Also needs to accept kwargs
- called ``color`` and ``label``.
- """
- indices = zip(*np.triu_indices_from(self.axes, 1))
- self._map_bivariate(func, indices, **kwargs)
- return self
- def map_offdiag(self, func, **kwargs):
- """Plot with a bivariate function on the off-diagonal subplots.
- Parameters
- ----------
- func : callable plotting function
- Must take x, y arrays as positional arguments and draw onto the
- "currently active" matplotlib Axes. Also needs to accept kwargs
- called ``color`` and ``label``.
- """
- self.map_lower(func, **kwargs)
- if not self._corner:
- self.map_upper(func, **kwargs)
- return self
- def map_diag(self, func, **kwargs):
- """Plot with a univariate function on each diagonal subplot.
- Parameters
- ----------
- func : callable plotting function
- Must take an x array as a positional argument and draw onto the
- "currently active" matplotlib Axes. Also needs to accept kwargs
- called ``color`` and ``label``.
- """
- # Add special diagonal axes for the univariate plot
- if self.diag_axes is None:
- diag_vars = []
- diag_axes = []
- for i, y_var in enumerate(self.y_vars):
- for j, x_var in enumerate(self.x_vars):
- if x_var == y_var:
- # Make the density axes
- diag_vars.append(x_var)
- ax = self.axes[i, j]
- diag_ax = ax.twinx()
- diag_ax.set_axis_off()
- diag_axes.append(diag_ax)
- # Work around matplotlib bug
- # https://github.com/matplotlib/matplotlib/issues/15188
- if not plt.rcParams.get("ytick.left", True):
- for tick in ax.yaxis.majorTicks:
- tick.tick1line.set_visible(False)
- # Remove main y axis from density axes in a corner plot
- if self._corner:
- ax.yaxis.set_visible(False)
- if self._despine:
- utils.despine(ax=ax, left=True)
- # TODO add optional density ticks (on the right)
- # when drawing a corner plot?
- if self.diag_sharey:
- # This may change in future matplotlibs
- # See https://github.com/matplotlib/matplotlib/pull/9923
- group = diag_axes[0].get_shared_y_axes()
- for ax in diag_axes[1:]:
- group.join(ax, diag_axes[0])
- self.diag_vars = np.array(diag_vars, np.object)
- self.diag_axes = np.array(diag_axes, np.object)
- # Plot on each of the diagonal axes
- fixed_color = kwargs.pop("color", None)
- for var, ax in zip(self.diag_vars, self.diag_axes):
- hue_grouped = self.data[var].groupby(self.hue_vals)
- plt.sca(ax)
- for k, label_k in enumerate(self.hue_names):
- # Attempt to get data for this level, allowing for empty
- try:
- # TODO newer matplotlib(?) doesn't need array for hist
- data_k = np.asarray(hue_grouped.get_group(label_k))
- except KeyError:
- data_k = np.array([])
- if fixed_color is None:
- color = self.palette[k]
- else:
- color = fixed_color
- if self._dropna:
- data_k = utils.remove_na(data_k)
- func(data_k, label=label_k, color=color, **kwargs)
- self._clean_axis(ax)
- self._add_axis_labels()
- return self
- def _map_bivariate(self, func, indices, **kwargs):
- """Draw a bivariate plot on the indicated axes."""
- kws = kwargs.copy() # Use copy as we insert other kwargs
- kw_color = kws.pop("color", None)
- for i, j in indices:
- x_var = self.x_vars[j]
- y_var = self.y_vars[i]
- ax = self.axes[i, j]
- self._plot_bivariate(x_var, y_var, ax, func, kw_color, **kws)
- self._add_axis_labels()
- def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs):
- """Draw a bivariate plot on the specified axes."""
- plt.sca(ax)
- if x_var == y_var:
- axes_vars = [x_var]
- else:
- axes_vars = [x_var, y_var]
- hue_grouped = self.data.groupby(self.hue_vals)
- for k, label_k in enumerate(self.hue_names):
- # Attempt to get data for this level, allowing for empty
- try:
- data_k = hue_grouped.get_group(label_k)
- except KeyError:
- data_k = pd.DataFrame(columns=axes_vars,
- dtype=np.float)
- if self._dropna:
- data_k = data_k[axes_vars].dropna()
- x = data_k[x_var]
- y = data_k[y_var]
- for kw, val_list in self.hue_kws.items():
- kwargs[kw] = val_list[k]
- color = self.palette[k] if kw_color is None else kw_color
- func(x, y, label=label_k, color=color, **kwargs)
- self._clean_axis(ax)
- self._update_legend_data(ax)
- def _add_axis_labels(self):
- """Add labels to the left and bottom Axes."""
- for ax, label in zip(self.axes[-1, :], self.x_vars):
- ax.set_xlabel(label)
- for ax, label in zip(self.axes[:, 0], self.y_vars):
- ax.set_ylabel(label)
- if self._corner:
- self.axes[0, 0].set_ylabel("")
- def _find_numeric_cols(self, data):
- """Find which variables in a DataFrame are numeric."""
- # This can't be the best way to do this, but I do not
- # know what the best way might be, so this seems ok
- numeric_cols = []
- for col in data:
- try:
- data[col].astype(np.float)
- numeric_cols.append(col)
- except (ValueError, TypeError):
- pass
- return numeric_cols
- class JointGrid(object):
- """Grid for drawing a bivariate plot with marginal univariate plots."""
- def __init__(self, x, y, data=None, height=6, ratio=5, space=.2,
- dropna=True, xlim=None, ylim=None, size=None):
- """Set up the grid of subplots.
- Parameters
- ----------
- x, y : strings or vectors
- Data or names of variables in ``data``.
- data : DataFrame, optional
- DataFrame when ``x`` and ``y`` are variable names.
- height : numeric
- Size of each side of the figure in inches (it will be square).
- ratio : numeric
- Ratio of joint axes size to marginal axes height.
- space : numeric, optional
- Space between the joint and marginal axes
- dropna : bool, optional
- If True, remove observations that are missing from `x` and `y`.
- {x, y}lim : two-tuples, optional
- Axis limits to set before plotting.
- See Also
- --------
- jointplot : High-level interface for drawing bivariate plots with
- several different default plot kinds.
- Examples
- --------
- Initialize the figure but don't draw any plots onto it:
- .. plot::
- :context: close-figs
- >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
- >>> tips = sns.load_dataset("tips")
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
- Add plots using default parameters:
- .. plot::
- :context: close-figs
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
- >>> g = g.plot(sns.regplot, sns.distplot)
- Draw the join and marginal plots separately, which allows finer-level
- control other parameters:
- .. plot::
- :context: close-figs
- >>> import matplotlib.pyplot as plt
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
- >>> g = g.plot_joint(sns.scatterplot, color=".5")
- >>> g = g.plot_marginals(sns.distplot, kde=False, color=".5")
- Draw the two marginal plots separately:
- .. plot::
- :context: close-figs
- >>> import numpy as np
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
- >>> g = g.plot_joint(sns.scatterplot, color="m")
- >>> _ = g.ax_marg_x.hist(tips["total_bill"], color="b", alpha=.6,
- ... bins=np.arange(0, 60, 5))
- >>> _ = g.ax_marg_y.hist(tips["tip"], color="r", alpha=.6,
- ... orientation="horizontal",
- ... bins=np.arange(0, 12, 1))
- Remove the space between the joint and marginal axes:
- .. plot::
- :context: close-figs
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips, space=0)
- >>> g = g.plot_joint(sns.kdeplot, cmap="Blues_d")
- >>> g = g.plot_marginals(sns.kdeplot, shade=True)
- Draw a smaller plot with relatively larger marginal axes:
- .. plot::
- :context: close-figs
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips,
- ... height=5, ratio=2)
- >>> g = g.plot_joint(sns.kdeplot, cmap="Reds_d")
- >>> g = g.plot_marginals(sns.kdeplot, color="r", shade=True)
- Set limits on the axes:
- .. plot::
- :context: close-figs
- >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips,
- ... xlim=(0, 50), ylim=(0, 8))
- >>> g = g.plot_joint(sns.kdeplot, cmap="Purples_d")
- >>> g = g.plot_marginals(sns.kdeplot, color="m", shade=True)
- """
- # Handle deprecations
- if size is not None:
- height = size
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(msg, UserWarning)
- # Set up the subplot grid
- f = plt.figure(figsize=(height, height))
- gs = plt.GridSpec(ratio + 1, ratio + 1)
- ax_joint = f.add_subplot(gs[1:, :-1])
- ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
- ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
- self.fig = f
- self.ax_joint = ax_joint
- self.ax_marg_x = ax_marg_x
- self.ax_marg_y = ax_marg_y
- # Turn off tick visibility for the measure axis on the marginal plots
- plt.setp(ax_marg_x.get_xticklabels(), visible=False)
- plt.setp(ax_marg_y.get_yticklabels(), visible=False)
- # Turn off the ticks on the density axis for the marginal plots
- plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
- plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
- plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
- plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
- plt.setp(ax_marg_x.get_yticklabels(), visible=False)
- plt.setp(ax_marg_y.get_xticklabels(), visible=False)
- ax_marg_x.yaxis.grid(False)
- ax_marg_y.xaxis.grid(False)
- # Possibly extract the variables from a DataFrame
- if data is not None:
- x = data.get(x, x)
- y = data.get(y, y)
- for var in [x, y]:
- if isinstance(var, str):
- err = "Could not interpret input '{}'".format(var)
- raise ValueError(err)
- # Find the names of the variables
- if hasattr(x, "name"):
- xlabel = x.name
- ax_joint.set_xlabel(xlabel)
- if hasattr(y, "name"):
- ylabel = y.name
- ax_joint.set_ylabel(ylabel)
- # Convert the x and y data to arrays for indexing and plotting
- x_array = np.asarray(x)
- y_array = np.asarray(y)
- # Possibly drop NA
- if dropna:
- not_na = pd.notnull(x_array) & pd.notnull(y_array)
- x_array = x_array[not_na]
- y_array = y_array[not_na]
- self.x = x_array
- self.y = y_array
- if xlim is not None:
- ax_joint.set_xlim(xlim)
- if ylim is not None:
- ax_joint.set_ylim(ylim)
- # Make the grid look nice
- utils.despine(f)
- utils.despine(ax=ax_marg_x, left=True)
- utils.despine(ax=ax_marg_y, bottom=True)
- f.tight_layout()
- f.subplots_adjust(hspace=space, wspace=space)
- def plot(self, joint_func, marginal_func, annot_func=None):
- """Shortcut to draw the full plot.
- Use `plot_joint` and `plot_marginals` directly for more control.
- Parameters
- ----------
- joint_func, marginal_func: callables
- Functions to draw the bivariate and univariate plots.
- Returns
- -------
- self : JointGrid instance
- Returns `self`.
- """
- self.plot_marginals(marginal_func)
- self.plot_joint(joint_func)
- if annot_func is not None:
- self.annotate(annot_func)
- return self
- def plot_joint(self, func, **kwargs):
- """Draw a bivariate plot of `x` and `y`.
- Parameters
- ----------
- func : plotting callable
- This must take two 1d arrays of data as the first two
- positional arguments, and it must plot on the "current" axes.
- kwargs : key, value mappings
- Keyword argument are passed to the plotting function.
- Returns
- -------
- self : JointGrid instance
- Returns `self`.
- """
- plt.sca(self.ax_joint)
- func(self.x, self.y, **kwargs)
- return self
- def plot_marginals(self, func, **kwargs):
- """Draw univariate plots for `x` and `y` separately.
- Parameters
- ----------
- func : plotting callable
- This must take a 1d array of data as the first positional
- argument, it must plot on the "current" axes, and it must
- accept a "vertical" keyword argument to orient the measure
- dimension of the plot vertically.
- kwargs : key, value mappings
- Keyword argument are passed to the plotting function.
- Returns
- -------
- self : JointGrid instance
- Returns `self`.
- """
- kwargs["vertical"] = False
- plt.sca(self.ax_marg_x)
- func(self.x, **kwargs)
- kwargs["vertical"] = True
- plt.sca(self.ax_marg_y)
- func(self.y, **kwargs)
- return self
- def annotate(self, func, template=None, stat=None, loc="best", **kwargs):
- """Annotate the plot with a statistic about the relationship.
- *Deprecated and will be removed in a future version*.
- Parameters
- ----------
- func : callable
- Statistical function that maps the x, y vectors either to (val, p)
- or to val.
- template : string format template, optional
- The template must have the format keys "stat" and "val";
- if `func` returns a p value, it should also have the key "p".
- stat : string, optional
- Name to use for the statistic in the annotation, by default it
- uses the name of `func`.
- loc : string or int, optional
- Matplotlib legend location code; used to place the annotation.
- kwargs : key, value mappings
- Other keyword arguments are passed to `ax.legend`, which formats
- the annotation.
- Returns
- -------
- self : JointGrid instance.
- Returns `self`.
- """
- msg = ("JointGrid annotation is deprecated and will be removed "
- "in a future release.")
- warnings.warn(UserWarning(msg))
- default_template = "{stat} = {val:.2g}; p = {p:.2g}"
- # Call the function and determine the form of the return value(s)
- out = func(self.x, self.y)
- try:
- val, p = out
- except TypeError:
- val, p = out, None
- default_template, _ = default_template.split(";")
- # Set the default template
- if template is None:
- template = default_template
- # Default to name of the function
- if stat is None:
- stat = func.__name__
- # Format the annotation
- if p is None:
- annotation = template.format(stat=stat, val=val)
- else:
- annotation = template.format(stat=stat, val=val, p=p)
- # Draw an invisible plot and use the legend to draw the annotation
- # This is a bit of a hack, but `loc=best` works nicely and is not
- # easily abstracted.
- phantom, = self.ax_joint.plot(self.x, self.y, linestyle="", alpha=0)
- self.ax_joint.legend([phantom], [annotation], loc=loc, **kwargs)
- phantom.remove()
- return self
- def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
- """Set the axis labels on the bivariate axes.
- Parameters
- ----------
- xlabel, ylabel : strings
- Label names for the x and y variables.
- kwargs : key, value mappings
- Other keyword arguments are passed to the set_xlabel or
- set_ylabel.
- Returns
- -------
- self : JointGrid instance
- returns `self`
- """
- self.ax_joint.set_xlabel(xlabel, **kwargs)
- self.ax_joint.set_ylabel(ylabel, **kwargs)
- return self
- def savefig(self, *args, **kwargs):
- """Wrap figure.savefig defaulting to tight bounding box."""
- kwargs.setdefault("bbox_inches", "tight")
- self.fig.savefig(*args, **kwargs)
- def pairplot(data, hue=None, hue_order=None, palette=None,
- vars=None, x_vars=None, y_vars=None,
- kind="scatter", diag_kind="auto", markers=None,
- height=2.5, aspect=1, corner=False, dropna=True,
- plot_kws=None, diag_kws=None, grid_kws=None, size=None):
- """Plot pairwise relationships in a dataset.
- By default, this function will create a grid of Axes such that each numeric
- variable in ``data`` will by shared in the y-axis across a single row and
- in the x-axis across a single column. The diagonal Axes are treated
- differently, drawing a plot to show the univariate distribution of the data
- for the variable in that column.
- It is also possible to show a subset of variables or plot different
- variables on the rows and columns.
- This is a high-level interface for :class:`PairGrid` that is intended to
- make it easy to draw a few common styles. You should use :class:`PairGrid`
- directly if you need more flexibility.
- Parameters
- ----------
- data : DataFrame
- Tidy (long-form) dataframe where each column is a variable and
- each row is an observation.
- hue : string (variable name), optional
- Variable in ``data`` to map plot aspects to different colors.
- hue_order : list of strings
- Order for the levels of the hue variable in the palette
- palette : dict or seaborn color palette
- Set of colors for mapping the ``hue`` variable. If a dict, keys
- should be values in the ``hue`` variable.
- vars : list of variable names, optional
- Variables within ``data`` to use, otherwise use every column with
- a numeric datatype.
- {x, y}_vars : lists of variable names, optional
- Variables within ``data`` to use separately for the rows and
- columns of the figure; i.e. to make a non-square plot.
- kind : {'scatter', 'reg'}, optional
- Kind of plot for the non-identity relationships.
- diag_kind : {'auto', 'hist', 'kde', None}, optional
- Kind of plot for the diagonal subplots. The default depends on whether
- ``"hue"`` is used or not.
- markers : single matplotlib marker code or list, optional
- Either the marker to use for all datapoints or a list of markers with
- a length the same as the number of levels in the hue variable so that
- differently colored points will also have different scatterplot
- markers.
- height : scalar, optional
- Height (in inches) of each facet.
- aspect : scalar, optional
- Aspect * height gives the width (in inches) of each facet.
- corner : bool, optional
- If True, don't add axes to the upper (off-diagonal) triangle of the
- grid, making this a "corner" plot.
- dropna : boolean, optional
- Drop missing values from the data before plotting.
- {plot, diag, grid}_kws : dicts, optional
- Dictionaries of keyword arguments. ``plot_kws`` are passed to the
- bivariate plotting function, ``diag_kws`` are passed to the univariate
- plotting function, and ``grid_kws`` are passed to the :class:`PairGrid`
- constructor.
- Returns
- -------
- grid : :class:`PairGrid`
- Returns the underlying :class:`PairGrid` instance for further tweaking.
- See Also
- --------
- PairGrid : Subplot grid for more flexible plotting of pairwise
- relationships.
- Examples
- --------
- Draw scatterplots for joint relationships and histograms for univariate
- distributions:
- .. plot::
- :context: close-figs
- >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
- >>> iris = sns.load_dataset("iris")
- >>> g = sns.pairplot(iris)
- Show different levels of a categorical variable by the color of plot
- elements:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, hue="species")
- Use a different color palette:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, hue="species", palette="husl")
- Use different markers for each level of the hue variable:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, hue="species", markers=["o", "s", "D"])
- Plot a subset of variables:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, vars=["sepal_width", "sepal_length"])
- Draw larger plots:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, height=3,
- ... vars=["sepal_width", "sepal_length"])
- Plot different variables in the rows and columns:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris,
- ... x_vars=["sepal_width", "sepal_length"],
- ... y_vars=["petal_width", "petal_length"])
- Plot only the lower triangle of bivariate axes:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, corner=True)
- Use kernel density estimates for univariate plots:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, diag_kind="kde")
- Fit linear regression models to the scatter plots:
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, kind="reg")
- Pass keyword arguments down to the underlying functions (it may be easier
- to use :class:`PairGrid` directly):
- .. plot::
- :context: close-figs
- >>> g = sns.pairplot(iris, diag_kind="kde", markers="+",
- ... plot_kws=dict(s=50, edgecolor="b", linewidth=1),
- ... diag_kws=dict(shade=True))
- """
- # Handle deprecations
- if size is not None:
- height = size
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(msg, UserWarning)
- if not isinstance(data, pd.DataFrame):
- raise TypeError(
- "'data' must be pandas DataFrame object, not: {typefound}".format(
- typefound=type(data)))
- plot_kws = {} if plot_kws is None else plot_kws.copy()
- diag_kws = {} if diag_kws is None else diag_kws.copy()
- grid_kws = {} if grid_kws is None else grid_kws.copy()
- # Set up the PairGrid
- grid_kws.setdefault("diag_sharey", diag_kind == "hist")
- grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue,
- hue_order=hue_order, palette=palette, corner=corner,
- height=height, aspect=aspect, dropna=dropna, **grid_kws)
- # Add the markers here as PairGrid has figured out how many levels of the
- # hue variable are needed and we don't want to duplicate that process
- if markers is not None:
- if grid.hue_names is None:
- n_markers = 1
- else:
- n_markers = len(grid.hue_names)
- if not isinstance(markers, list):
- markers = [markers] * n_markers
- if len(markers) != n_markers:
- raise ValueError(("markers must be a singleton or a list of "
- "markers for each level of the hue variable"))
- grid.hue_kws = {"marker": markers}
- # Maybe plot on the diagonal
- if diag_kind == "auto":
- diag_kind = "hist" if hue is None else "kde"
- diag_kws = diag_kws.copy()
- if grid.square_grid:
- if diag_kind == "hist":
- grid.map_diag(plt.hist, **diag_kws)
- elif diag_kind == "kde":
- diag_kws.setdefault("shade", True)
- diag_kws["legend"] = False
- grid.map_diag(kdeplot, **diag_kws)
- # Maybe plot on the off-diagonals
- if grid.square_grid and diag_kind is not None:
- plotter = grid.map_offdiag
- else:
- plotter = grid.map
- if kind == "scatter":
- from .relational import scatterplot # Avoid circular import
- plotter(scatterplot, **plot_kws)
- elif kind == "reg":
- from .regression import regplot # Avoid circular import
- plotter(regplot, **plot_kws)
- # Add a legend
- if hue is not None:
- grid.add_legend()
- return grid
- def jointplot(x, y, data=None, kind="scatter", stat_func=None,
- color=None, height=6, ratio=5, space=.2,
- dropna=True, xlim=None, ylim=None,
- joint_kws=None, marginal_kws=None, annot_kws=None, **kwargs):
- """Draw a plot of two variables with bivariate and univariate graphs.
- This function provides a convenient interface to the :class:`JointGrid`
- class, with several canned plot kinds. This is intended to be a fairly
- lightweight wrapper; if you need more flexibility, you should use
- :class:`JointGrid` directly.
- Parameters
- ----------
- x, y : strings or vectors
- Data or names of variables in ``data``.
- data : DataFrame, optional
- DataFrame when ``x`` and ``y`` are variable names.
- kind : { "scatter" | "reg" | "resid" | "kde" | "hex" }, optional
- Kind of plot to draw.
- stat_func : callable or None, optional
- *Deprecated*
- color : matplotlib color, optional
- Color used for the plot elements.
- height : numeric, optional
- Size of the figure (it will be square).
- ratio : numeric, optional
- Ratio of joint axes height to marginal axes height.
- space : numeric, optional
- Space between the joint and marginal axes
- dropna : bool, optional
- If True, remove observations that are missing from ``x`` and ``y``.
- {x, y}lim : two-tuples, optional
- Axis limits to set before plotting.
- {joint, marginal, annot}_kws : dicts, optional
- Additional keyword arguments for the plot components.
- kwargs : key, value pairings
- Additional keyword arguments are passed to the function used to
- draw the plot on the joint Axes, superseding items in the
- ``joint_kws`` dictionary.
- Returns
- -------
- grid : :class:`JointGrid`
- :class:`JointGrid` object with the plot on it.
- See Also
- --------
- JointGrid : The Grid class used for drawing this plot. Use it directly if
- you need more flexibility.
- Examples
- --------
- Draw a scatterplot with marginal histograms:
- .. plot::
- :context: close-figs
- >>> import numpy as np, pandas as pd; np.random.seed(0)
- >>> import seaborn as sns; sns.set(style="white", color_codes=True)
- >>> tips = sns.load_dataset("tips")
- >>> g = sns.jointplot(x="total_bill", y="tip", data=tips)
- Add regression and kernel density fits:
- .. plot::
- :context: close-figs
- >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="reg")
- Replace the scatterplot with a joint histogram using hexagonal bins:
- .. plot::
- :context: close-figs
- >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="hex")
- Replace the scatterplots and histograms with density estimates and align
- the marginal Axes tightly with the joint Axes:
- .. plot::
- :context: close-figs
- >>> iris = sns.load_dataset("iris")
- >>> g = sns.jointplot("sepal_width", "petal_length", data=iris,
- ... kind="kde", space=0, color="g")
- Draw a scatterplot, then add a joint density estimate:
- .. plot::
- :context: close-figs
- >>> g = (sns.jointplot("sepal_length", "sepal_width",
- ... data=iris, color="k")
- ... .plot_joint(sns.kdeplot, zorder=0, n_levels=6))
- Pass vectors in directly without using Pandas, then name the axes:
- .. plot::
- :context: close-figs
- >>> x, y = np.random.randn(2, 300)
- >>> g = (sns.jointplot(x, y, kind="hex")
- ... .set_axis_labels("x", "y"))
- Draw a smaller figure with more space devoted to the marginal plots:
- .. plot::
- :context: close-figs
- >>> g = sns.jointplot("total_bill", "tip", data=tips,
- ... height=5, ratio=3, color="g")
- Pass keyword arguments down to the underlying plots:
- .. plot::
- :context: close-figs
- >>> g = sns.jointplot("petal_length", "sepal_length", data=iris,
- ... marginal_kws=dict(bins=15, rug=True),
- ... annot_kws=dict(stat="r"),
- ... s=40, edgecolor="w", linewidth=1)
- """
- # Handle deprecations
- if "size" in kwargs:
- height = kwargs.pop("size")
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(msg, UserWarning)
- # Set up empty default kwarg dicts
- joint_kws = {} if joint_kws is None else joint_kws.copy()
- joint_kws.update(kwargs)
- marginal_kws = {} if marginal_kws is None else marginal_kws.copy()
- annot_kws = {} if annot_kws is None else annot_kws.copy()
- # Make a colormap based off the plot color
- if color is None:
- color = color_palette()[0]
- color_rgb = mpl.colors.colorConverter.to_rgb(color)
- colors = [utils.set_hls_values(color_rgb, l=l) # noqa
- for l in np.linspace(1, 0, 12)]
- cmap = blend_palette(colors, as_cmap=True)
- # Initialize the JointGrid object
- grid = JointGrid(x, y, data, dropna=dropna,
- height=height, ratio=ratio, space=space,
- xlim=xlim, ylim=ylim)
- # Plot the data using the grid
- if kind == "scatter":
- joint_kws.setdefault("color", color)
- grid.plot_joint(plt.scatter, **joint_kws)
- marginal_kws.setdefault("kde", False)
- marginal_kws.setdefault("color", color)
- grid.plot_marginals(distplot, **marginal_kws)
- elif kind.startswith("hex"):
- x_bins = min(_freedman_diaconis_bins(grid.x), 50)
- y_bins = min(_freedman_diaconis_bins(grid.y), 50)
- gridsize = int(np.mean([x_bins, y_bins]))
- joint_kws.setdefault("gridsize", gridsize)
- joint_kws.setdefault("cmap", cmap)
- grid.plot_joint(plt.hexbin, **joint_kws)
- marginal_kws.setdefault("kde", False)
- marginal_kws.setdefault("color", color)
- grid.plot_marginals(distplot, **marginal_kws)
- elif kind.startswith("kde"):
- joint_kws.setdefault("shade", True)
- joint_kws.setdefault("cmap", cmap)
- grid.plot_joint(kdeplot, **joint_kws)
- marginal_kws.setdefault("shade", True)
- marginal_kws.setdefault("color", color)
- grid.plot_marginals(kdeplot, **marginal_kws)
- elif kind.startswith("reg"):
- from .regression import regplot
- marginal_kws.setdefault("color", color)
- grid.plot_marginals(distplot, **marginal_kws)
- joint_kws.setdefault("color", color)
- grid.plot_joint(regplot, **joint_kws)
- elif kind.startswith("resid"):
- from .regression import residplot
- joint_kws.setdefault("color", color)
- grid.plot_joint(residplot, **joint_kws)
- x, y = grid.ax_joint.collections[0].get_offsets().T
- marginal_kws.setdefault("color", color)
- marginal_kws.setdefault("kde", False)
- distplot(x, ax=grid.ax_marg_x, **marginal_kws)
- distplot(y, vertical=True, fit=stats.norm, ax=grid.ax_marg_y,
- **marginal_kws)
- stat_func = None
- else:
- msg = "kind must be either 'scatter', 'reg', 'resid', 'kde', or 'hex'"
- raise ValueError(msg)
- if stat_func is not None:
- grid.annotate(stat_func, **annot_kws)
- return grid
|