axisgrid.py 84 KB


  1. from itertools import product
  2. import warnings
  3. from textwrap import dedent
  4. from distutils.version import LooseVersion
  5. import numpy as np
  6. import pandas as pd
  7. from scipy import stats
  8. import matplotlib as mpl
  9. import matplotlib.pyplot as plt
  10. from . import utils
  11. from .palettes import color_palette, blend_palette
  12. from .distributions import distplot, kdeplot, _freedman_diaconis_bins
  13. __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
  14. class Grid(object):
  15. """Base class for grids of subplots."""
  16. _margin_titles = False
  17. _legend_out = True
  18. def set(self, **kwargs):
  19. """Set attributes on each subplot Axes."""
  20. for ax in self.axes.flat:
  21. ax.set(**kwargs)
  22. return self
  23. def savefig(self, *args, **kwargs):
  24. """Save the figure."""
  25. kwargs = kwargs.copy()
  26. kwargs.setdefault("bbox_inches", "tight")
  27. self.fig.savefig(*args, **kwargs)
  28. def add_legend(self, legend_data=None, title=None, label_order=None,
  29. **kwargs):
  30. """Draw a legend, maybe placing it outside axes and resizing the figure.
  31. Parameters
  32. ----------
  33. legend_data : dict, optional
  34. Dictionary mapping label names (or two-element tuples where the
  35. second element is a label name) to matplotlib artist handles. The
  36. default reads from ``self._legend_data``.
  37. title : string, optional
  38. Title for the legend. The default reads from ``self._hue_var``.
  39. label_order : list of labels, optional
  40. The order that the legend entries should appear in. The default
  41. reads from ``self.hue_names``.
  42. kwargs : key, value pairings
  43. Other keyword arguments are passed to the underlying legend methods
  44. on the Figure or Axes object.
  45. Returns
  46. -------
  47. self : Grid instance
  48. Returns self for easy chaining.
  49. """
  50. # Find the data for the legend
  51. if legend_data is None:
  52. legend_data = self._legend_data
  53. if label_order is None:
  54. if self.hue_names is None:
  55. label_order = list(legend_data.keys())
  56. else:
  57. label_order = list(map(utils.to_utf8, self.hue_names))
  58. blank_handle = mpl.patches.Patch(alpha=0, linewidth=0)
  59. handles = [legend_data.get(l, blank_handle) for l in label_order]
  60. title = self._hue_var if title is None else title
  61. if LooseVersion(mpl.__version__) < LooseVersion("3.0"):
  62. try:
  63. title_size = mpl.rcParams["axes.labelsize"] * .85
  64. except TypeError: # labelsize is something like "large"
  65. title_size = mpl.rcParams["axes.labelsize"]
  66. else:
  67. title_size = mpl.rcParams["legend.title_fontsize"]
  68. # Unpack nested labels from a hierarchical legend
  69. labels = []
  70. for entry in label_order:
  71. if isinstance(entry, tuple):
  72. _, label = entry
  73. else:
  74. label = entry
  75. labels.append(label)
  76. # Set default legend kwargs
  77. kwargs.setdefault("scatterpoints", 1)
  78. if self._legend_out:
  79. kwargs.setdefault("frameon", False)
  80. kwargs.setdefault("loc", "center right")
  81. # Draw a full-figure legend outside the grid
  82. figlegend = self.fig.legend(handles, labels, **kwargs)
  83. self._legend = figlegend
  84. figlegend.set_title(title, prop={"size": title_size})
  85. # Draw the plot to set the bounding boxes correctly
  86. if hasattr(self.fig.canvas, "get_renderer"):
  87. self.fig.draw(self.fig.canvas.get_renderer())
  88. # Calculate and set the new width of the figure so the legend fits
  89. legend_width = figlegend.get_window_extent().width / self.fig.dpi
  90. fig_width, fig_height = self.fig.get_size_inches()
  91. self.fig.set_size_inches(fig_width + legend_width, fig_height)
  92. # Draw the plot again to get the new transformations
  93. if hasattr(self.fig.canvas, "get_renderer"):
  94. self.fig.draw(self.fig.canvas.get_renderer())
  95. # Now calculate how much space we need on the right side
  96. legend_width = figlegend.get_window_extent().width / self.fig.dpi
  97. space_needed = legend_width / (fig_width + legend_width)
  98. margin = .04 if self._margin_titles else .01
  99. self._space_needed = margin + space_needed
  100. right = 1 - self._space_needed
  101. # Place the subplot axes to give space for the legend
  102. self.fig.subplots_adjust(right=right)
  103. else:
  104. # Draw a legend in the first axis
  105. ax = self.axes.flat[0]
  106. kwargs.setdefault("loc", "best")
  107. leg = ax.legend(handles, labels, **kwargs)
  108. leg.set_title(title, prop={"size": title_size})
  109. self._legend = leg
  110. return self
  111. def _clean_axis(self, ax):
  112. """Turn off axis labels and legend."""
  113. ax.set_xlabel("")
  114. ax.set_ylabel("")
  115. ax.legend_ = None
  116. return self
  117. def _update_legend_data(self, ax):
  118. """Extract the legend data from an axes object and save it."""
  119. handles, labels = ax.get_legend_handles_labels()
  120. data = {l: h for h, l in zip(handles, labels)}
  121. self._legend_data.update(data)
  122. def _get_palette(self, data, hue, hue_order, palette):
  123. """Get a list of colors for the hue variable."""
  124. if hue is None:
  125. palette = color_palette(n_colors=1)
  126. else:
  127. hue_names = utils.categorical_order(data[hue], hue_order)
  128. n_colors = len(hue_names)
  129. # By default use either the current color palette or HUSL
  130. if palette is None:
  131. current_palette = utils.get_color_cycle()
  132. if n_colors > len(current_palette):
  133. colors = color_palette("husl", n_colors)
  134. else:
  135. colors = color_palette(n_colors=n_colors)
  136. # Allow for palette to map from hue variable names
  137. elif isinstance(palette, dict):
  138. color_names = [palette[h] for h in hue_names]
  139. colors = color_palette(color_names, n_colors)
  140. # Otherwise act as if we just got a list of colors
  141. else:
  142. colors = color_palette(palette, n_colors)
  143. palette = color_palette(colors, n_colors)
  144. return palette
  145. _facet_docs = dict(
  146. data=dedent("""\
  147. data : DataFrame
  148. Tidy ("long-form") dataframe where each column is a variable and each
  149. row is an observation.\
  150. """),
  151. col_wrap=dedent("""\
  152. col_wrap : int, optional
  153. "Wrap" the column variable at this width, so that the column facets
  154. span multiple rows. Incompatible with a ``row`` facet.\
  155. """),
  156. share_xy=dedent("""\
  157. share{x,y} : bool, 'col', or 'row' optional
  158. If true, the facets will share y axes across columns and/or x axes
  159. across rows.\
  160. """),
  161. height=dedent("""\
  162. height : scalar, optional
  163. Height (in inches) of each facet. See also: ``aspect``.\
  164. """),
  165. aspect=dedent("""\
  166. aspect : scalar, optional
  167. Aspect ratio of each facet, so that ``aspect * height`` gives the width
  168. of each facet in inches.\
  169. """),
  170. palette=dedent("""\
  171. palette : palette name, list, or dict, optional
  172. Colors to use for the different levels of the ``hue`` variable. Should
  173. be something that can be interpreted by :func:`color_palette`, or a
  174. dictionary mapping hue levels to matplotlib colors.\
  175. """),
  176. legend_out=dedent("""\
  177. legend_out : bool, optional
  178. If ``True``, the figure size will be extended, and the legend will be
  179. drawn outside the plot on the center right.\
  180. """),
  181. margin_titles=dedent("""\
  182. margin_titles : bool, optional
  183. If ``True``, the titles for the row variable are drawn to the right of
  184. the last column. This option is experimental and may not work in all
  185. cases.\
  186. """),
  187. )
  188. class FacetGrid(Grid):
  189. """Multi-plot grid for plotting conditional relationships."""
  190. def __init__(self, data, row=None, col=None, hue=None, col_wrap=None,
  191. sharex=True, sharey=True, height=3, aspect=1, palette=None,
  192. row_order=None, col_order=None, hue_order=None, hue_kws=None,
  193. dropna=True, legend_out=True, despine=True,
  194. margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
  195. gridspec_kws=None, size=None):
  196. # Handle deprecations
  197. if size is not None:
  198. height = size
  199. msg = ("The `size` parameter has been renamed to `height`; "
  200. "please update your code.")
  201. warnings.warn(msg, UserWarning)
  202. # Determine the hue facet layer information
  203. hue_var = hue
  204. if hue is None:
  205. hue_names = None
  206. else:
  207. hue_names = utils.categorical_order(data[hue], hue_order)
  208. colors = self._get_palette(data, hue, hue_order, palette)
  209. # Set up the lists of names for the row and column facet variables
  210. if row is None:
  211. row_names = []
  212. else:
  213. row_names = utils.categorical_order(data[row], row_order)
  214. if col is None:
  215. col_names = []
  216. else:
  217. col_names = utils.categorical_order(data[col], col_order)
  218. # Additional dict of kwarg -> list of values for mapping the hue var
  219. hue_kws = hue_kws if hue_kws is not None else {}
  220. # Make a boolean mask that is True anywhere there is an NA
  221. # value in one of the faceting variables, but only if dropna is True
  222. none_na = np.zeros(len(data), np.bool)
  223. if dropna:
  224. row_na = none_na if row is None else data[row].isnull()
  225. col_na = none_na if col is None else data[col].isnull()
  226. hue_na = none_na if hue is None else data[hue].isnull()
  227. not_na = ~(row_na | col_na | hue_na)
  228. else:
  229. not_na = ~none_na
  230. # Compute the grid shape
  231. ncol = 1 if col is None else len(col_names)
  232. nrow = 1 if row is None else len(row_names)
  233. self._n_facets = ncol * nrow
  234. self._col_wrap = col_wrap
  235. if col_wrap is not None:
  236. if row is not None:
  237. err = "Cannot use `row` and `col_wrap` together."
  238. raise ValueError(err)
  239. ncol = col_wrap
  240. nrow = int(np.ceil(len(col_names) / col_wrap))
  241. self._ncol = ncol
  242. self._nrow = nrow
  243. # Calculate the base figure size
  244. # This can get stretched later by a legend
  245. # TODO this doesn't account for axis labels
  246. figsize = (ncol * height * aspect, nrow * height)
  247. # Validate some inputs
  248. if col_wrap is not None:
  249. margin_titles = False
  250. # Build the subplot keyword dictionary
  251. subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
  252. gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
  253. if xlim is not None:
  254. subplot_kws["xlim"] = xlim
  255. if ylim is not None:
  256. subplot_kws["ylim"] = ylim
  257. # Initialize the subplot grid
  258. if col_wrap is None:
  259. kwargs = dict(figsize=figsize, squeeze=False,
  260. sharex=sharex, sharey=sharey,
  261. subplot_kw=subplot_kws,
  262. gridspec_kw=gridspec_kws)
  263. fig, axes = plt.subplots(nrow, ncol, **kwargs)
  264. self.axes = axes
  265. else:
  266. # If wrapping the col variable we need to make the grid ourselves
  267. if gridspec_kws:
  268. warnings.warn("`gridspec_kws` ignored when using `col_wrap`")
  269. n_axes = len(col_names)
  270. fig = plt.figure(figsize=figsize)
  271. axes = np.empty(n_axes, object)
  272. axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
  273. if sharex:
  274. subplot_kws["sharex"] = axes[0]
  275. if sharey:
  276. subplot_kws["sharey"] = axes[0]
  277. for i in range(1, n_axes):
  278. axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)
  279. self.axes = axes
  280. # Now we turn off labels on the inner axes
  281. if sharex:
  282. for ax in self._not_bottom_axes:
  283. for label in ax.get_xticklabels():
  284. label.set_visible(False)
  285. ax.xaxis.offsetText.set_visible(False)
  286. if sharey:
  287. for ax in self._not_left_axes:
  288. for label in ax.get_yticklabels():
  289. label.set_visible(False)
  290. ax.yaxis.offsetText.set_visible(False)
  291. # Set up the class attributes
  292. # ---------------------------
  293. # First the public API
  294. self.data = data
  295. self.fig = fig
  296. self.axes = axes
  297. self.row_names = row_names
  298. self.col_names = col_names
  299. self.hue_names = hue_names
  300. self.hue_kws = hue_kws
  301. # Next the private variables
  302. self._nrow = nrow
  303. self._row_var = row
  304. self._ncol = ncol
  305. self._col_var = col
  306. self._margin_titles = margin_titles
  307. self._col_wrap = col_wrap
  308. self._hue_var = hue_var
  309. self._colors = colors
  310. self._legend_out = legend_out
  311. self._legend = None
  312. self._legend_data = {}
  313. self._x_var = None
  314. self._y_var = None
  315. self._dropna = dropna
  316. self._not_na = not_na
  317. # Make the axes look good
  318. fig.tight_layout()
  319. if despine:
  320. self.despine()
  321. __init__.__doc__ = dedent("""\
  322. Initialize the matplotlib figure and FacetGrid object.
  323. This class maps a dataset onto multiple axes arrayed in a grid of rows
  324. and columns that correspond to *levels* of variables in the dataset.
  325. The plots it produces are often called "lattice", "trellis", or
  326. "small-multiple" graphics.
  327. It can also represent levels of a third variable with the ``hue``
  328. parameter, which plots different subsets of data in different colors.
  329. This uses color to resolve elements on a third dimension, but only
  330. draws subsets on top of each other and will not tailor the ``hue``
  331. parameter for the specific visualization the way that axes-level
  332. functions that accept ``hue`` will.
  333. When using seaborn functions that infer semantic mappings from a
  334. dataset, care must be taken to synchronize those mappings across
  335. facets (e.g., by defing the ``hue`` mapping with a palette dict or
  336. setting the data type of the variables to ``category``). In most cases,
  337. it will be better to use a figure-level function (e.g. :func:`relplot`
  338. or :func:`catplot`) than to use :class:`FacetGrid` directly.
  339. The basic workflow is to initialize the :class:`FacetGrid` object with
  340. the dataset and the variables that are used to structure the grid. Then
  341. one or more plotting functions can be applied to each subset by calling
  342. :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the
  343. plot can be tweaked with other methods to do things like change the
  344. axis labels, use different ticks, or add a legend. See the detailed
  345. code examples below for more information.
  346. See the :ref:`tutorial <grid_tutorial>` for more information.
  347. Parameters
  348. ----------
  349. {data}
  350. row, col, hue : strings
  351. Variables that define subsets of the data, which will be drawn on
  352. separate facets in the grid. See the ``*_order`` parameters to
  353. control the order of levels of this variable.
  354. {col_wrap}
  355. {share_xy}
  356. {height}
  357. {aspect}
  358. {palette}
  359. {{row,col,hue}}_order : lists, optional
  360. Order for the levels of the faceting variables. By default, this
  361. will be the order that the levels appear in ``data`` or, if the
  362. variables are pandas categoricals, the category order.
  363. hue_kws : dictionary of param -> list of values mapping
  364. Other keyword arguments to insert into the plotting call to let
  365. other plot attributes vary across levels of the hue variable (e.g.
  366. the markers in a scatterplot).
  367. {legend_out}
  368. despine : boolean, optional
  369. Remove the top and right spines from the plots.
  370. {margin_titles}
  371. {{x, y}}lim: tuples, optional
  372. Limits for each of the axes on each facet (only relevant when
  373. share{{x, y}} is True).
  374. subplot_kws : dict, optional
  375. Dictionary of keyword arguments passed to matplotlib subplot(s)
  376. methods.
  377. gridspec_kws : dict, optional
  378. Dictionary of keyword arguments passed to matplotlib's ``gridspec``
  379. module (via ``plt.subplots``). Ignored if ``col_wrap`` is not
  380. ``None``.
  381. See Also
  382. --------
  383. PairGrid : Subplot grid for plotting pairwise relationships.
  384. relplot : Combine a relational plot and a :class:`FacetGrid`.
  385. catplot : Combine a categorical plot and a :class:`FacetGrid`.
  386. lmplot : Combine a regression plot and a :class:`FacetGrid`.
  387. Examples
  388. --------
  389. Initialize a 2x2 grid of facets using the tips dataset:
  390. .. plot::
  391. :context: close-figs
  392. >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
  393. >>> tips = sns.load_dataset("tips")
  394. >>> g = sns.FacetGrid(tips, col="time", row="smoker")
  395. Draw a univariate plot on each facet:
  396. .. plot::
  397. :context: close-figs
  398. >>> import matplotlib.pyplot as plt
  399. >>> g = sns.FacetGrid(tips, col="time", row="smoker")
  400. >>> g = g.map(plt.hist, "total_bill")
  401. (Note that it's not necessary to re-catch the returned variable; it's
  402. the same object, but doing so in the examples makes dealing with the
  403. doctests somewhat less annoying).
  404. Pass additional keyword arguments to the mapped function:
  405. .. plot::
  406. :context: close-figs
  407. >>> import numpy as np
  408. >>> bins = np.arange(0, 65, 5)
  409. >>> g = sns.FacetGrid(tips, col="time", row="smoker")
  410. >>> g = g.map(plt.hist, "total_bill", bins=bins, color="r")
  411. Plot a bivariate function on each facet:
  412. .. plot::
  413. :context: close-figs
  414. >>> g = sns.FacetGrid(tips, col="time", row="smoker")
  415. >>> g = g.map(plt.scatter, "total_bill", "tip", edgecolor="w")
  416. Assign one of the variables to the color of the plot elements:
  417. .. plot::
  418. :context: close-figs
  419. >>> g = sns.FacetGrid(tips, col="time", hue="smoker")
  420. >>> g = (g.map(plt.scatter, "total_bill", "tip", edgecolor="w")
  421. ... .add_legend())
  422. Change the height and aspect ratio of each facet:
  423. .. plot::
  424. :context: close-figs
  425. >>> g = sns.FacetGrid(tips, col="day", height=4, aspect=.5)
  426. >>> g = g.map(plt.hist, "total_bill", bins=bins)
  427. Specify the order for plot elements:
  428. .. plot::
  429. :context: close-figs
  430. >>> g = sns.FacetGrid(tips, col="smoker", col_order=["Yes", "No"])
  431. >>> g = g.map(plt.hist, "total_bill", bins=bins, color="m")
  432. Use a different color palette:
  433. .. plot::
  434. :context: close-figs
  435. >>> kws = dict(s=50, linewidth=.5, edgecolor="w")
  436. >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette="Set1",
  437. ... hue_order=["Dinner", "Lunch"])
  438. >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
  439. ... .add_legend())
  440. Use a dictionary mapping hue levels to colors:
  441. .. plot::
  442. :context: close-figs
  443. >>> pal = dict(Lunch="seagreen", Dinner="gray")
  444. >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal,
  445. ... hue_order=["Dinner", "Lunch"])
  446. >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
  447. ... .add_legend())
  448. Additionally use a different marker for the hue levels:
  449. .. plot::
  450. :context: close-figs
  451. >>> g = sns.FacetGrid(tips, col="sex", hue="time", palette=pal,
  452. ... hue_order=["Dinner", "Lunch"],
  453. ... hue_kws=dict(marker=["^", "v"]))
  454. >>> g = (g.map(plt.scatter, "total_bill", "tip", **kws)
  455. ... .add_legend())
  456. "Wrap" a column variable with many levels into the rows:
  457. .. plot::
  458. :context: close-figs
  459. >>> att = sns.load_dataset("attention")
  460. >>> g = sns.FacetGrid(att, col="subject", col_wrap=5, height=1.5)
  461. >>> g = g.map(plt.plot, "solutions", "score", marker=".")
  462. Define a custom bivariate function to map onto the grid:
  463. .. plot::
  464. :context: close-figs
  465. >>> from scipy import stats
  466. >>> def qqplot(x, y, **kwargs):
  467. ... _, xr = stats.probplot(x, fit=False)
  468. ... _, yr = stats.probplot(y, fit=False)
  469. ... sns.scatterplot(xr, yr, **kwargs)
  470. >>> g = sns.FacetGrid(tips, col="smoker", hue="sex")
  471. >>> g = (g.map(qqplot, "total_bill", "tip", **kws)
  472. ... .add_legend())
  473. Define a custom function that uses a ``DataFrame`` object and accepts
  474. column names as positional variables:
  475. .. plot::
  476. :context: close-figs
  477. >>> import pandas as pd
  478. >>> df = pd.DataFrame(
  479. ... data=np.random.randn(90, 4),
  480. ... columns=pd.Series(list("ABCD"), name="walk"),
  481. ... index=pd.date_range("2015-01-01", "2015-03-31",
  482. ... name="date"))
  483. >>> df = df.cumsum(axis=0).stack().reset_index(name="val")
  484. >>> def dateplot(x, y, **kwargs):
  485. ... ax = plt.gca()
  486. ... data = kwargs.pop("data")
  487. ... data.plot(x=x, y=y, ax=ax, grid=False, **kwargs)
  488. >>> g = sns.FacetGrid(df, col="walk", col_wrap=2, height=3.5)
  489. >>> g = g.map_dataframe(dateplot, "date", "val")
  490. Use different axes labels after plotting:
  491. .. plot::
  492. :context: close-figs
  493. >>> g = sns.FacetGrid(tips, col="smoker", row="sex")
  494. >>> g = (g.map(plt.scatter, "total_bill", "tip", color="g", **kws)
  495. ... .set_axis_labels("Total bill (US Dollars)", "Tip"))
  496. Set other attributes that are shared across the facetes:
  497. .. plot::
  498. :context: close-figs
  499. >>> g = sns.FacetGrid(tips, col="smoker", row="sex")
  500. >>> g = (g.map(plt.scatter, "total_bill", "tip", color="r", **kws)
  501. ... .set(xlim=(0, 60), ylim=(0, 12),
  502. ... xticks=[10, 30, 50], yticks=[2, 6, 10]))
  503. Use a different template for the facet titles:
  504. .. plot::
  505. :context: close-figs
  506. >>> g = sns.FacetGrid(tips, col="size", col_wrap=3)
  507. >>> g = (g.map(plt.hist, "tip", bins=np.arange(0, 13), color="c")
  508. ... .set_titles("{{col_name}} diners"))
  509. Tighten the facets:
  510. .. plot::
  511. :context: close-figs
  512. >>> g = sns.FacetGrid(tips, col="smoker", row="sex",
  513. ... margin_titles=True)
  514. >>> g = (g.map(plt.scatter, "total_bill", "tip", color="m", **kws)
  515. ... .set(xlim=(0, 60), ylim=(0, 12),
  516. ... xticks=[10, 30, 50], yticks=[2, 6, 10])
  517. ... .fig.subplots_adjust(wspace=.05, hspace=.05))
  518. """).format(**_facet_docs)
  519. def facet_data(self):
  520. """Generator for name indices and data subsets for each facet.
  521. Yields
  522. ------
  523. (i, j, k), data_ijk : tuple of ints, DataFrame
  524. The ints provide an index into the {row, col, hue}_names attribute,
  525. and the dataframe contains a subset of the full data corresponding
  526. to each facet. The generator yields subsets that correspond with
  527. the self.axes.flat iterator, or self.axes[i, j] when `col_wrap`
  528. is None.
  529. """
  530. data = self.data
  531. # Construct masks for the row variable
  532. if self.row_names:
  533. row_masks = [data[self._row_var] == n for n in self.row_names]
  534. else:
  535. row_masks = [np.repeat(True, len(self.data))]
  536. # Construct masks for the column variable
  537. if self.col_names:
  538. col_masks = [data[self._col_var] == n for n in self.col_names]
  539. else:
  540. col_masks = [np.repeat(True, len(self.data))]
  541. # Construct masks for the hue variable
  542. if self.hue_names:
  543. hue_masks = [data[self._hue_var] == n for n in self.hue_names]
  544. else:
  545. hue_masks = [np.repeat(True, len(self.data))]
  546. # Here is the main generator loop
  547. for (i, row), (j, col), (k, hue) in product(enumerate(row_masks),
  548. enumerate(col_masks),
  549. enumerate(hue_masks)):
  550. data_ijk = data[row & col & hue & self._not_na]
  551. yield (i, j, k), data_ijk
  552. def map(self, func, *args, **kwargs):
  553. """Apply a plotting function to each facet's subset of the data.
  554. Parameters
  555. ----------
  556. func : callable
  557. A plotting function that takes data and keyword arguments. It
  558. must plot to the currently active matplotlib Axes and take a
  559. `color` keyword argument. If faceting on the `hue` dimension,
  560. it must also take a `label` keyword argument.
  561. args : strings
  562. Column names in self.data that identify variables with data to
  563. plot. The data for each variable is passed to `func` in the
  564. order the variables are specified in the call.
  565. kwargs : keyword arguments
  566. All keyword arguments are passed to the plotting function.
  567. Returns
  568. -------
  569. self : object
  570. Returns self.
  571. """
  572. # If color was a keyword argument, grab it here
  573. kw_color = kwargs.pop("color", None)
  574. if hasattr(func, "__module__"):
  575. func_module = str(func.__module__)
  576. else:
  577. func_module = ""
  578. # Check for categorical plots without order information
  579. if func_module == "seaborn.categorical":
  580. if "order" not in kwargs:
  581. warning = ("Using the {} function without specifying "
  582. "`order` is likely to produce an incorrect "
  583. "plot.".format(func.__name__))
  584. warnings.warn(warning)
  585. if len(args) == 3 and "hue_order" not in kwargs:
  586. warning = ("Using the {} function without specifying "
  587. "`hue_order` is likely to produce an incorrect "
  588. "plot.".format(func.__name__))
  589. warnings.warn(warning)
  590. # Iterate over the data subsets
  591. for (row_i, col_j, hue_k), data_ijk in self.facet_data():
  592. # If this subset is null, move on
  593. if not data_ijk.values.size:
  594. continue
  595. # Get the current axis
  596. ax = self.facet_axis(row_i, col_j)
  597. # Decide what color to plot with
  598. kwargs["color"] = self._facet_color(hue_k, kw_color)
  599. # Insert the other hue aesthetics if appropriate
  600. for kw, val_list in self.hue_kws.items():
  601. kwargs[kw] = val_list[hue_k]
  602. # Insert a label in the keyword arguments for the legend
  603. if self._hue_var is not None:
  604. kwargs["label"] = utils.to_utf8(self.hue_names[hue_k])
  605. # Get the actual data we are going to plot with
  606. plot_data = data_ijk[list(args)]
  607. if self._dropna:
  608. plot_data = plot_data.dropna()
  609. plot_args = [v for k, v in plot_data.iteritems()]
  610. # Some matplotlib functions don't handle pandas objects correctly
  611. if func_module.startswith("matplotlib"):
  612. plot_args = [v.values for v in plot_args]
  613. # Draw the plot
  614. self._facet_plot(func, ax, plot_args, kwargs)
  615. # Finalize the annotations and layout
  616. self._finalize_grid(args[:2])
  617. return self
  618. def map_dataframe(self, func, *args, **kwargs):
  619. """Like ``.map`` but passes args as strings and inserts data in kwargs.
  620. This method is suitable for plotting with functions that accept a
  621. long-form DataFrame as a `data` keyword argument and access the
  622. data in that DataFrame using string variable names.
  623. Parameters
  624. ----------
  625. func : callable
  626. A plotting function that takes data and keyword arguments. Unlike
  627. the `map` method, a function used here must "understand" Pandas
  628. objects. It also must plot to the currently active matplotlib Axes
  629. and take a `color` keyword argument. If faceting on the `hue`
  630. dimension, it must also take a `label` keyword argument.
  631. args : strings
  632. Column names in self.data that identify variables with data to
  633. plot. The data for each variable is passed to `func` in the
  634. order the variables are specified in the call.
  635. kwargs : keyword arguments
  636. All keyword arguments are passed to the plotting function.
  637. Returns
  638. -------
  639. self : object
  640. Returns self.
  641. """
  642. # If color was a keyword argument, grab it here
  643. kw_color = kwargs.pop("color", None)
  644. # Iterate over the data subsets
  645. for (row_i, col_j, hue_k), data_ijk in self.facet_data():
  646. # If this subset is null, move on
  647. if not data_ijk.values.size:
  648. continue
  649. # Get the current axis
  650. ax = self.facet_axis(row_i, col_j)
  651. # Decide what color to plot with
  652. kwargs["color"] = self._facet_color(hue_k, kw_color)
  653. # Insert the other hue aesthetics if appropriate
  654. for kw, val_list in self.hue_kws.items():
  655. kwargs[kw] = val_list[hue_k]
  656. # Insert a label in the keyword arguments for the legend
  657. if self._hue_var is not None:
  658. kwargs["label"] = self.hue_names[hue_k]
  659. # Stick the facet dataframe into the kwargs
  660. if self._dropna:
  661. data_ijk = data_ijk.dropna()
  662. kwargs["data"] = data_ijk
  663. # Draw the plot
  664. self._facet_plot(func, ax, args, kwargs)
  665. # Finalize the annotations and layout
  666. self._finalize_grid(args[:2])
  667. return self
  668. def _facet_color(self, hue_index, kw_color):
  669. color = self._colors[hue_index]
  670. if kw_color is not None:
  671. return kw_color
  672. elif color is not None:
  673. return color
  674. def _facet_plot(self, func, ax, plot_args, plot_kwargs):
  675. # Draw the plot
  676. func(*plot_args, **plot_kwargs)
  677. # Sort out the supporting information
  678. self._update_legend_data(ax)
  679. self._clean_axis(ax)
  680. def _finalize_grid(self, axlabels):
  681. """Finalize the annotations and layout."""
  682. self.set_axis_labels(*axlabels)
  683. self.set_titles()
  684. self.fig.tight_layout()
  685. def facet_axis(self, row_i, col_j):
  686. """Make the axis identified by these indices active and return it."""
  687. # Calculate the actual indices of the axes to plot on
  688. if self._col_wrap is not None:
  689. ax = self.axes.flat[col_j]
  690. else:
  691. ax = self.axes[row_i, col_j]
  692. # Get a reference to the axes object we want, and make it active
  693. plt.sca(ax)
  694. return ax
  695. def despine(self, **kwargs):
  696. """Remove axis spines from the facets."""
  697. utils.despine(self.fig, **kwargs)
  698. return self
  699. def set_axis_labels(self, x_var=None, y_var=None):
  700. """Set axis labels on the left column and bottom row of the grid."""
  701. if x_var is not None:
  702. self._x_var = x_var
  703. self.set_xlabels(x_var)
  704. if y_var is not None:
  705. self._y_var = y_var
  706. self.set_ylabels(y_var)
  707. return self
  708. def set_xlabels(self, label=None, **kwargs):
  709. """Label the x axis on the bottom row of the grid."""
  710. if label is None:
  711. label = self._x_var
  712. for ax in self._bottom_axes:
  713. ax.set_xlabel(label, **kwargs)
  714. return self
  715. def set_ylabels(self, label=None, **kwargs):
  716. """Label the y axis on the left column of the grid."""
  717. if label is None:
  718. label = self._y_var
  719. for ax in self._left_axes:
  720. ax.set_ylabel(label, **kwargs)
  721. return self
  722. def set_xticklabels(self, labels=None, step=None, **kwargs):
  723. """Set x axis tick labels of the grid."""
  724. for ax in self.axes.flat:
  725. if labels is None:
  726. curr_labels = [l.get_text() for l in ax.get_xticklabels()]
  727. if step is not None:
  728. xticks = ax.get_xticks()[::step]
  729. curr_labels = curr_labels[::step]
  730. ax.set_xticks(xticks)
  731. ax.set_xticklabels(curr_labels, **kwargs)
  732. else:
  733. ax.set_xticklabels(labels, **kwargs)
  734. return self
  735. def set_yticklabels(self, labels=None, **kwargs):
  736. """Set y axis tick labels on the left column of the grid."""
  737. for ax in self.axes.flat:
  738. if labels is None:
  739. curr_labels = [l.get_text() for l in ax.get_yticklabels()]
  740. ax.set_yticklabels(curr_labels, **kwargs)
  741. else:
  742. ax.set_yticklabels(labels, **kwargs)
  743. return self
  744. def set_titles(self, template=None, row_template=None, col_template=None,
  745. **kwargs):
  746. """Draw titles either above each facet or on the grid margins.
  747. Parameters
  748. ----------
  749. template : string
  750. Template for all titles with the formatting keys {col_var} and
  751. {col_name} (if using a `col` faceting variable) and/or {row_var}
  752. and {row_name} (if using a `row` faceting variable).
  753. row_template:
  754. Template for the row variable when titles are drawn on the grid
  755. margins. Must have {row_var} and {row_name} formatting keys.
  756. col_template:
  757. Template for the row variable when titles are drawn on the grid
  758. margins. Must have {col_var} and {col_name} formatting keys.
  759. Returns
  760. -------
  761. self: object
  762. Returns self.
  763. """
  764. args = dict(row_var=self._row_var, col_var=self._col_var)
  765. kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])
  766. # Establish default templates
  767. if row_template is None:
  768. row_template = "{row_var} = {row_name}"
  769. if col_template is None:
  770. col_template = "{col_var} = {col_name}"
  771. if template is None:
  772. if self._row_var is None:
  773. template = col_template
  774. elif self._col_var is None:
  775. template = row_template
  776. else:
  777. template = " | ".join([row_template, col_template])
  778. row_template = utils.to_utf8(row_template)
  779. col_template = utils.to_utf8(col_template)
  780. template = utils.to_utf8(template)
  781. if self._margin_titles:
  782. if self.row_names is not None:
  783. # Draw the row titles on the right edge of the grid
  784. for i, row_name in enumerate(self.row_names):
  785. ax = self.axes[i, -1]
  786. args.update(dict(row_name=row_name))
  787. title = row_template.format(**args)
  788. bgcolor = self.fig.get_facecolor()
  789. ax.annotate(title, xy=(1.02, .5), xycoords="axes fraction",
  790. rotation=270, ha="left", va="center",
  791. backgroundcolor=bgcolor, **kwargs)
  792. if self.col_names is not None:
  793. # Draw the column titles as normal titles
  794. for j, col_name in enumerate(self.col_names):
  795. args.update(dict(col_name=col_name))
  796. title = col_template.format(**args)
  797. self.axes[0, j].set_title(title, **kwargs)
  798. return self
  799. # Otherwise title each facet with all the necessary information
  800. if (self._row_var is not None) and (self._col_var is not None):
  801. for i, row_name in enumerate(self.row_names):
  802. for j, col_name in enumerate(self.col_names):
  803. args.update(dict(row_name=row_name, col_name=col_name))
  804. title = template.format(**args)
  805. self.axes[i, j].set_title(title, **kwargs)
  806. elif self.row_names is not None and len(self.row_names):
  807. for i, row_name in enumerate(self.row_names):
  808. args.update(dict(row_name=row_name))
  809. title = template.format(**args)
  810. self.axes[i, 0].set_title(title, **kwargs)
  811. elif self.col_names is not None and len(self.col_names):
  812. for i, col_name in enumerate(self.col_names):
  813. args.update(dict(col_name=col_name))
  814. title = template.format(**args)
  815. # Index the flat array so col_wrap works
  816. self.axes.flat[i].set_title(title, **kwargs)
  817. return self
  818. @property
  819. def ax(self):
  820. """Easy access to single axes."""
  821. if self.axes.shape == (1, 1):
  822. return self.axes[0, 0]
  823. else:
  824. err = ("You must use the `.axes` attribute (an array) when "
  825. "there is more than one plot.")
  826. raise AttributeError(err)
  827. @property
  828. def _inner_axes(self):
  829. """Return a flat array of the inner axes."""
  830. if self._col_wrap is None:
  831. return self.axes[:-1, 1:].flat
  832. else:
  833. axes = []
  834. n_empty = self._nrow * self._ncol - self._n_facets
  835. for i, ax in enumerate(self.axes):
  836. append = (i % self._ncol and
  837. i < (self._ncol * (self._nrow - 1)) and
  838. i < (self._ncol * (self._nrow - 1) - n_empty))
  839. if append:
  840. axes.append(ax)
  841. return np.array(axes, object).flat
  842. @property
  843. def _left_axes(self):
  844. """Return a flat array of the left column of axes."""
  845. if self._col_wrap is None:
  846. return self.axes[:, 0].flat
  847. else:
  848. axes = []
  849. for i, ax in enumerate(self.axes):
  850. if not i % self._ncol:
  851. axes.append(ax)
  852. return np.array(axes, object).flat
  853. @property
  854. def _not_left_axes(self):
  855. """Return a flat array of axes that aren't on the left column."""
  856. if self._col_wrap is None:
  857. return self.axes[:, 1:].flat
  858. else:
  859. axes = []
  860. for i, ax in enumerate(self.axes):
  861. if i % self._ncol:
  862. axes.append(ax)
  863. return np.array(axes, object).flat
  864. @property
  865. def _bottom_axes(self):
  866. """Return a flat array of the bottom row of axes."""
  867. if self._col_wrap is None:
  868. return self.axes[-1, :].flat
  869. else:
  870. axes = []
  871. n_empty = self._nrow * self._ncol - self._n_facets
  872. for i, ax in enumerate(self.axes):
  873. append = (i >= (self._ncol * (self._nrow - 1)) or
  874. i >= (self._ncol * (self._nrow - 1) - n_empty))
  875. if append:
  876. axes.append(ax)
  877. return np.array(axes, object).flat
  878. @property
  879. def _not_bottom_axes(self):
  880. """Return a flat array of axes that aren't on the bottom row."""
  881. if self._col_wrap is None:
  882. return self.axes[:-1, :].flat
  883. else:
  884. axes = []
  885. n_empty = self._nrow * self._ncol - self._n_facets
  886. for i, ax in enumerate(self.axes):
  887. append = (i < (self._ncol * (self._nrow - 1)) and
  888. i < (self._ncol * (self._nrow - 1) - n_empty))
  889. if append:
  890. axes.append(ax)
  891. return np.array(axes, object).flat
  892. class PairGrid(Grid):
  893. """Subplot grid for plotting pairwise relationships in a dataset.
  894. This class maps each variable in a dataset onto a column and row in a
  895. grid of multiple axes. Different axes-level plotting functions can be
  896. used to draw bivariate plots in the upper and lower triangles, and the
  897. the marginal distribution of each variable can be shown on the diagonal.
  898. It can also represent an additional level of conditionalization with the
  899. ``hue`` parameter, which plots different subsets of data in different
  900. colors. This uses color to resolve elements on a third dimension, but
  901. only draws subsets on top of each other and will not tailor the ``hue``
  902. parameter for the specific visualization the way that axes-level functions
  903. that accept ``hue`` will.
  904. See the :ref:`tutorial <grid_tutorial>` for more information.
  905. """
  906. def __init__(self, data, hue=None, hue_order=None, palette=None,
  907. hue_kws=None, vars=None, x_vars=None, y_vars=None,
  908. corner=False, diag_sharey=True, height=2.5, aspect=1,
  909. layout_pad=0, despine=True, dropna=True, size=None):
  910. """Initialize the plot figure and PairGrid object.
  911. Parameters
  912. ----------
  913. data : DataFrame
  914. Tidy (long-form) dataframe where each column is a variable and
  915. each row is an observation.
  916. hue : string (variable name), optional
  917. Variable in ``data`` to map plot aspects to different colors. This
  918. variable will be excluded from the default x and y variables.
  919. hue_order : list of strings
  920. Order for the levels of the hue variable in the palette
  921. palette : dict or seaborn color palette
  922. Set of colors for mapping the ``hue`` variable. If a dict, keys
  923. should be values in the ``hue`` variable.
  924. hue_kws : dictionary of param -> list of values mapping
  925. Other keyword arguments to insert into the plotting call to let
  926. other plot attributes vary across levels of the hue variable (e.g.
  927. the markers in a scatterplot).
  928. vars : list of variable names, optional
  929. Variables within ``data`` to use, otherwise use every column with
  930. a numeric datatype.
  931. {x, y}_vars : lists of variable names, optional
  932. Variables within ``data`` to use separately for the rows and
  933. columns of the figure; i.e. to make a non-square plot.
  934. corner : bool, optional
  935. If True, don't add axes to the upper (off-diagonal) triangle of the
  936. grid, making this a "corner" plot.
  937. height : scalar, optional
  938. Height (in inches) of each facet.
  939. aspect : scalar, optional
  940. Aspect * height gives the width (in inches) of each facet.
  941. layout_pad : scalar, optional
  942. Padding between axes; passed to ``fig.tight_layout``.
  943. despine : boolean, optional
  944. Remove the top and right spines from the plots.
  945. dropna : boolean, optional
  946. Drop missing values from the data before plotting.
  947. See Also
  948. --------
  949. pairplot : Easily drawing common uses of :class:`PairGrid`.
  950. FacetGrid : Subplot grid for plotting conditional relationships.
  951. Examples
  952. --------
  953. Draw a scatterplot for each pairwise relationship:
  954. .. plot::
  955. :context: close-figs
  956. >>> import matplotlib.pyplot as plt
  957. >>> import seaborn as sns; sns.set()
  958. >>> iris = sns.load_dataset("iris")
  959. >>> g = sns.PairGrid(iris)
  960. >>> g = g.map(plt.scatter)
  961. Show a univariate distribution on the diagonal:
  962. .. plot::
  963. :context: close-figs
  964. >>> g = sns.PairGrid(iris)
  965. >>> g = g.map_diag(plt.hist)
  966. >>> g = g.map_offdiag(plt.scatter)
  967. (It's not actually necessary to catch the return value every time,
  968. as it is the same object, but it makes it easier to deal with the
  969. doctests).
  970. Color the points using a categorical variable:
  971. .. plot::
  972. :context: close-figs
  973. >>> g = sns.PairGrid(iris, hue="species")
  974. >>> g = g.map_diag(plt.hist)
  975. >>> g = g.map_offdiag(plt.scatter)
  976. >>> g = g.add_legend()
  977. Use a different style to show multiple histograms:
  978. .. plot::
  979. :context: close-figs
  980. >>> g = sns.PairGrid(iris, hue="species")
  981. >>> g = g.map_diag(plt.hist, histtype="step", linewidth=3)
  982. >>> g = g.map_offdiag(plt.scatter)
  983. >>> g = g.add_legend()
  984. Plot a subset of variables
  985. .. plot::
  986. :context: close-figs
  987. >>> g = sns.PairGrid(iris, vars=["sepal_length", "sepal_width"])
  988. >>> g = g.map(plt.scatter)
  989. Pass additional keyword arguments to the functions
  990. .. plot::
  991. :context: close-figs
  992. >>> g = sns.PairGrid(iris)
  993. >>> g = g.map_diag(plt.hist, edgecolor="w")
  994. >>> g = g.map_offdiag(plt.scatter, edgecolor="w", s=40)
  995. Use different variables for the rows and columns:
  996. .. plot::
  997. :context: close-figs
  998. >>> g = sns.PairGrid(iris,
  999. ... x_vars=["sepal_length", "sepal_width"],
  1000. ... y_vars=["petal_length", "petal_width"])
  1001. >>> g = g.map(plt.scatter)
  1002. Use different functions on the upper and lower triangles:
  1003. .. plot::
  1004. :context: close-figs
  1005. >>> g = sns.PairGrid(iris)
  1006. >>> g = g.map_upper(sns.scatterplot)
  1007. >>> g = g.map_lower(sns.kdeplot, colors="C0")
  1008. >>> g = g.map_diag(sns.kdeplot, lw=2)
  1009. Use different colors and markers for each categorical level:
  1010. .. plot::
  1011. :context: close-figs
  1012. >>> g = sns.PairGrid(iris, hue="species", palette="Set2",
  1013. ... hue_kws={"marker": ["o", "s", "D"]})
  1014. >>> g = g.map(sns.scatterplot, linewidths=1, edgecolor="w", s=40)
  1015. >>> g = g.add_legend()
  1016. """
  1017. # Handle deprecations
  1018. if size is not None:
  1019. height = size
  1020. msg = ("The `size` parameter has been renamed to `height`; "
  1021. "please update your code.")
  1022. warnings.warn(UserWarning(msg))
  1023. # Sort out the variables that define the grid
  1024. if vars is not None:
  1025. x_vars = list(vars)
  1026. y_vars = list(vars)
  1027. elif (x_vars is not None) or (y_vars is not None):
  1028. if (x_vars is None) or (y_vars is None):
  1029. raise ValueError("Must specify `x_vars` and `y_vars`")
  1030. else:
  1031. numeric_cols = self._find_numeric_cols(data)
  1032. if hue in numeric_cols:
  1033. numeric_cols.remove(hue)
  1034. x_vars = numeric_cols
  1035. y_vars = numeric_cols
  1036. if np.isscalar(x_vars):
  1037. x_vars = [x_vars]
  1038. if np.isscalar(y_vars):
  1039. y_vars = [y_vars]
  1040. self.x_vars = list(x_vars)
  1041. self.y_vars = list(y_vars)
  1042. self.square_grid = self.x_vars == self.y_vars
  1043. # Create the figure and the array of subplots
  1044. figsize = len(x_vars) * height * aspect, len(y_vars) * height
  1045. fig, axes = plt.subplots(len(y_vars), len(x_vars),
  1046. figsize=figsize,
  1047. sharex="col", sharey="row",
  1048. squeeze=False)
  1049. # Possibly remove upper axes to make a corner grid
  1050. # Note: setting up the axes is usually the most time-intensive part
  1051. # of using the PairGrid. We are foregoing the speed improvement that
  1052. # we would get by just not setting up the hidden axes so that we can
  1053. # avoid implementing plt.subplots ourselves. But worth thinking about.
  1054. self._corner = corner
  1055. if corner:
  1056. hide_indices = np.triu_indices_from(axes, 1)
  1057. for i, j in zip(*hide_indices):
  1058. axes[i, j].remove()
  1059. axes[i, j] = None
  1060. self.fig = fig
  1061. self.axes = axes
  1062. self.data = data
  1063. # Save what we are going to do with the diagonal
  1064. self.diag_sharey = diag_sharey
  1065. self.diag_vars = None
  1066. self.diag_axes = None
  1067. self._dropna = dropna
  1068. # Label the axes
  1069. self._add_axis_labels()
  1070. # Sort out the hue variable
  1071. self._hue_var = hue
  1072. if hue is None:
  1073. self.hue_names = ["_nolegend_"]
  1074. self.hue_vals = pd.Series(["_nolegend_"] * len(data),
  1075. index=data.index)
  1076. else:
  1077. hue_names = utils.categorical_order(data[hue], hue_order)
  1078. if dropna:
  1079. # Filter NA from the list of unique hue names
  1080. hue_names = list(filter(pd.notnull, hue_names))
  1081. self.hue_names = hue_names
  1082. self.hue_vals = data[hue]
  1083. # Additional dict of kwarg -> list of values for mapping the hue var
  1084. self.hue_kws = hue_kws if hue_kws is not None else {}
  1085. self.palette = self._get_palette(data, hue, hue_order, palette)
  1086. self._legend_data = {}
  1087. # Make the plot look nice
  1088. if despine:
  1089. self._despine = True
  1090. utils.despine(fig=fig)
  1091. fig.tight_layout(pad=layout_pad)
  1092. def map(self, func, **kwargs):
  1093. """Plot with the same function in every subplot.
  1094. Parameters
  1095. ----------
  1096. func : callable plotting function
  1097. Must take x, y arrays as positional arguments and draw onto the
  1098. "currently active" matplotlib Axes. Also needs to accept kwargs
  1099. called ``color`` and ``label``.
  1100. """
  1101. row_indices, col_indices = np.indices(self.axes.shape)
  1102. indices = zip(row_indices.flat, col_indices.flat)
  1103. self._map_bivariate(func, indices, **kwargs)
  1104. return self
  1105. def map_lower(self, func, **kwargs):
  1106. """Plot with a bivariate function on the lower diagonal subplots.
  1107. Parameters
  1108. ----------
  1109. func : callable plotting function
  1110. Must take x, y arrays as positional arguments and draw onto the
  1111. "currently active" matplotlib Axes. Also needs to accept kwargs
  1112. called ``color`` and ``label``.
  1113. """
  1114. indices = zip(*np.tril_indices_from(self.axes, -1))
  1115. self._map_bivariate(func, indices, **kwargs)
  1116. return self
  1117. def map_upper(self, func, **kwargs):
  1118. """Plot with a bivariate function on the upper diagonal subplots.
  1119. Parameters
  1120. ----------
  1121. func : callable plotting function
  1122. Must take x, y arrays as positional arguments and draw onto the
  1123. "currently active" matplotlib Axes. Also needs to accept kwargs
  1124. called ``color`` and ``label``.
  1125. """
  1126. indices = zip(*np.triu_indices_from(self.axes, 1))
  1127. self._map_bivariate(func, indices, **kwargs)
  1128. return self
  1129. def map_offdiag(self, func, **kwargs):
  1130. """Plot with a bivariate function on the off-diagonal subplots.
  1131. Parameters
  1132. ----------
  1133. func : callable plotting function
  1134. Must take x, y arrays as positional arguments and draw onto the
  1135. "currently active" matplotlib Axes. Also needs to accept kwargs
  1136. called ``color`` and ``label``.
  1137. """
  1138. self.map_lower(func, **kwargs)
  1139. if not self._corner:
  1140. self.map_upper(func, **kwargs)
  1141. return self
  1142. def map_diag(self, func, **kwargs):
  1143. """Plot with a univariate function on each diagonal subplot.
  1144. Parameters
  1145. ----------
  1146. func : callable plotting function
  1147. Must take an x array as a positional argument and draw onto the
  1148. "currently active" matplotlib Axes. Also needs to accept kwargs
  1149. called ``color`` and ``label``.
  1150. """
  1151. # Add special diagonal axes for the univariate plot
  1152. if self.diag_axes is None:
  1153. diag_vars = []
  1154. diag_axes = []
  1155. for i, y_var in enumerate(self.y_vars):
  1156. for j, x_var in enumerate(self.x_vars):
  1157. if x_var == y_var:
  1158. # Make the density axes
  1159. diag_vars.append(x_var)
  1160. ax = self.axes[i, j]
  1161. diag_ax = ax.twinx()
  1162. diag_ax.set_axis_off()
  1163. diag_axes.append(diag_ax)
  1164. # Work around matplotlib bug
  1165. # https://github.com/matplotlib/matplotlib/issues/15188
  1166. if not plt.rcParams.get("ytick.left", True):
  1167. for tick in ax.yaxis.majorTicks:
  1168. tick.tick1line.set_visible(False)
  1169. # Remove main y axis from density axes in a corner plot
  1170. if self._corner:
  1171. ax.yaxis.set_visible(False)
  1172. if self._despine:
  1173. utils.despine(ax=ax, left=True)
  1174. # TODO add optional density ticks (on the right)
  1175. # when drawing a corner plot?
  1176. if self.diag_sharey:
  1177. # This may change in future matplotlibs
  1178. # See https://github.com/matplotlib/matplotlib/pull/9923
  1179. group = diag_axes[0].get_shared_y_axes()
  1180. for ax in diag_axes[1:]:
  1181. group.join(ax, diag_axes[0])
  1182. self.diag_vars = np.array(diag_vars, np.object)
  1183. self.diag_axes = np.array(diag_axes, np.object)
  1184. # Plot on each of the diagonal axes
  1185. fixed_color = kwargs.pop("color", None)
  1186. for var, ax in zip(self.diag_vars, self.diag_axes):
  1187. hue_grouped = self.data[var].groupby(self.hue_vals)
  1188. plt.sca(ax)
  1189. for k, label_k in enumerate(self.hue_names):
  1190. # Attempt to get data for this level, allowing for empty
  1191. try:
  1192. # TODO newer matplotlib(?) doesn't need array for hist
  1193. data_k = np.asarray(hue_grouped.get_group(label_k))
  1194. except KeyError:
  1195. data_k = np.array([])
  1196. if fixed_color is None:
  1197. color = self.palette[k]
  1198. else:
  1199. color = fixed_color
  1200. if self._dropna:
  1201. data_k = utils.remove_na(data_k)
  1202. func(data_k, label=label_k, color=color, **kwargs)
  1203. self._clean_axis(ax)
  1204. self._add_axis_labels()
  1205. return self
  1206. def _map_bivariate(self, func, indices, **kwargs):
  1207. """Draw a bivariate plot on the indicated axes."""
  1208. kws = kwargs.copy() # Use copy as we insert other kwargs
  1209. kw_color = kws.pop("color", None)
  1210. for i, j in indices:
  1211. x_var = self.x_vars[j]
  1212. y_var = self.y_vars[i]
  1213. ax = self.axes[i, j]
  1214. self._plot_bivariate(x_var, y_var, ax, func, kw_color, **kws)
  1215. self._add_axis_labels()
  1216. def _plot_bivariate(self, x_var, y_var, ax, func, kw_color, **kwargs):
  1217. """Draw a bivariate plot on the specified axes."""
  1218. plt.sca(ax)
  1219. if x_var == y_var:
  1220. axes_vars = [x_var]
  1221. else:
  1222. axes_vars = [x_var, y_var]
  1223. hue_grouped = self.data.groupby(self.hue_vals)
  1224. for k, label_k in enumerate(self.hue_names):
  1225. # Attempt to get data for this level, allowing for empty
  1226. try:
  1227. data_k = hue_grouped.get_group(label_k)
  1228. except KeyError:
  1229. data_k = pd.DataFrame(columns=axes_vars,
  1230. dtype=np.float)
  1231. if self._dropna:
  1232. data_k = data_k[axes_vars].dropna()
  1233. x = data_k[x_var]
  1234. y = data_k[y_var]
  1235. for kw, val_list in self.hue_kws.items():
  1236. kwargs[kw] = val_list[k]
  1237. color = self.palette[k] if kw_color is None else kw_color
  1238. func(x, y, label=label_k, color=color, **kwargs)
  1239. self._clean_axis(ax)
  1240. self._update_legend_data(ax)
  1241. def _add_axis_labels(self):
  1242. """Add labels to the left and bottom Axes."""
  1243. for ax, label in zip(self.axes[-1, :], self.x_vars):
  1244. ax.set_xlabel(label)
  1245. for ax, label in zip(self.axes[:, 0], self.y_vars):
  1246. ax.set_ylabel(label)
  1247. if self._corner:
  1248. self.axes[0, 0].set_ylabel("")
  1249. def _find_numeric_cols(self, data):
  1250. """Find which variables in a DataFrame are numeric."""
  1251. # This can't be the best way to do this, but I do not
  1252. # know what the best way might be, so this seems ok
  1253. numeric_cols = []
  1254. for col in data:
  1255. try:
  1256. data[col].astype(np.float)
  1257. numeric_cols.append(col)
  1258. except (ValueError, TypeError):
  1259. pass
  1260. return numeric_cols
  1261. class JointGrid(object):
  1262. """Grid for drawing a bivariate plot with marginal univariate plots."""
  1263. def __init__(self, x, y, data=None, height=6, ratio=5, space=.2,
  1264. dropna=True, xlim=None, ylim=None, size=None):
  1265. """Set up the grid of subplots.
  1266. Parameters
  1267. ----------
  1268. x, y : strings or vectors
  1269. Data or names of variables in ``data``.
  1270. data : DataFrame, optional
  1271. DataFrame when ``x`` and ``y`` are variable names.
  1272. height : numeric
  1273. Size of each side of the figure in inches (it will be square).
  1274. ratio : numeric
  1275. Ratio of joint axes size to marginal axes height.
  1276. space : numeric, optional
  1277. Space between the joint and marginal axes
  1278. dropna : bool, optional
  1279. If True, remove observations that are missing from `x` and `y`.
  1280. {x, y}lim : two-tuples, optional
  1281. Axis limits to set before plotting.
  1282. See Also
  1283. --------
  1284. jointplot : High-level interface for drawing bivariate plots with
  1285. several different default plot kinds.
  1286. Examples
  1287. --------
  1288. Initialize the figure but don't draw any plots onto it:
  1289. .. plot::
  1290. :context: close-figs
  1291. >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
  1292. >>> tips = sns.load_dataset("tips")
  1293. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
  1294. Add plots using default parameters:
  1295. .. plot::
  1296. :context: close-figs
  1297. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
  1298. >>> g = g.plot(sns.regplot, sns.distplot)
  1299. Draw the join and marginal plots separately, which allows finer-level
  1300. control other parameters:
  1301. .. plot::
  1302. :context: close-figs
  1303. >>> import matplotlib.pyplot as plt
  1304. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
  1305. >>> g = g.plot_joint(sns.scatterplot, color=".5")
  1306. >>> g = g.plot_marginals(sns.distplot, kde=False, color=".5")
  1307. Draw the two marginal plots separately:
  1308. .. plot::
  1309. :context: close-figs
  1310. >>> import numpy as np
  1311. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips)
  1312. >>> g = g.plot_joint(sns.scatterplot, color="m")
  1313. >>> _ = g.ax_marg_x.hist(tips["total_bill"], color="b", alpha=.6,
  1314. ... bins=np.arange(0, 60, 5))
  1315. >>> _ = g.ax_marg_y.hist(tips["tip"], color="r", alpha=.6,
  1316. ... orientation="horizontal",
  1317. ... bins=np.arange(0, 12, 1))
  1318. Remove the space between the joint and marginal axes:
  1319. .. plot::
  1320. :context: close-figs
  1321. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips, space=0)
  1322. >>> g = g.plot_joint(sns.kdeplot, cmap="Blues_d")
  1323. >>> g = g.plot_marginals(sns.kdeplot, shade=True)
  1324. Draw a smaller plot with relatively larger marginal axes:
  1325. .. plot::
  1326. :context: close-figs
  1327. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips,
  1328. ... height=5, ratio=2)
  1329. >>> g = g.plot_joint(sns.kdeplot, cmap="Reds_d")
  1330. >>> g = g.plot_marginals(sns.kdeplot, color="r", shade=True)
  1331. Set limits on the axes:
  1332. .. plot::
  1333. :context: close-figs
  1334. >>> g = sns.JointGrid(x="total_bill", y="tip", data=tips,
  1335. ... xlim=(0, 50), ylim=(0, 8))
  1336. >>> g = g.plot_joint(sns.kdeplot, cmap="Purples_d")
  1337. >>> g = g.plot_marginals(sns.kdeplot, color="m", shade=True)
  1338. """
  1339. # Handle deprecations
  1340. if size is not None:
  1341. height = size
  1342. msg = ("The `size` parameter has been renamed to `height`; "
  1343. "please update your code.")
  1344. warnings.warn(msg, UserWarning)
  1345. # Set up the subplot grid
  1346. f = plt.figure(figsize=(height, height))
  1347. gs = plt.GridSpec(ratio + 1, ratio + 1)
  1348. ax_joint = f.add_subplot(gs[1:, :-1])
  1349. ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
  1350. ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
  1351. self.fig = f
  1352. self.ax_joint = ax_joint
  1353. self.ax_marg_x = ax_marg_x
  1354. self.ax_marg_y = ax_marg_y
  1355. # Turn off tick visibility for the measure axis on the marginal plots
  1356. plt.setp(ax_marg_x.get_xticklabels(), visible=False)
  1357. plt.setp(ax_marg_y.get_yticklabels(), visible=False)
  1358. # Turn off the ticks on the density axis for the marginal plots
  1359. plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
  1360. plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
  1361. plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
  1362. plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
  1363. plt.setp(ax_marg_x.get_yticklabels(), visible=False)
  1364. plt.setp(ax_marg_y.get_xticklabels(), visible=False)
  1365. ax_marg_x.yaxis.grid(False)
  1366. ax_marg_y.xaxis.grid(False)
  1367. # Possibly extract the variables from a DataFrame
  1368. if data is not None:
  1369. x = data.get(x, x)
  1370. y = data.get(y, y)
  1371. for var in [x, y]:
  1372. if isinstance(var, str):
  1373. err = "Could not interpret input '{}'".format(var)
  1374. raise ValueError(err)
  1375. # Find the names of the variables
  1376. if hasattr(x, "name"):
  1377. xlabel = x.name
  1378. ax_joint.set_xlabel(xlabel)
  1379. if hasattr(y, "name"):
  1380. ylabel = y.name
  1381. ax_joint.set_ylabel(ylabel)
  1382. # Convert the x and y data to arrays for indexing and plotting
  1383. x_array = np.asarray(x)
  1384. y_array = np.asarray(y)
  1385. # Possibly drop NA
  1386. if dropna:
  1387. not_na = pd.notnull(x_array) & pd.notnull(y_array)
  1388. x_array = x_array[not_na]
  1389. y_array = y_array[not_na]
  1390. self.x = x_array
  1391. self.y = y_array
  1392. if xlim is not None:
  1393. ax_joint.set_xlim(xlim)
  1394. if ylim is not None:
  1395. ax_joint.set_ylim(ylim)
  1396. # Make the grid look nice
  1397. utils.despine(f)
  1398. utils.despine(ax=ax_marg_x, left=True)
  1399. utils.despine(ax=ax_marg_y, bottom=True)
  1400. f.tight_layout()
  1401. f.subplots_adjust(hspace=space, wspace=space)
  1402. def plot(self, joint_func, marginal_func, annot_func=None):
  1403. """Shortcut to draw the full plot.
  1404. Use `plot_joint` and `plot_marginals` directly for more control.
  1405. Parameters
  1406. ----------
  1407. joint_func, marginal_func: callables
  1408. Functions to draw the bivariate and univariate plots.
  1409. Returns
  1410. -------
  1411. self : JointGrid instance
  1412. Returns `self`.
  1413. """
  1414. self.plot_marginals(marginal_func)
  1415. self.plot_joint(joint_func)
  1416. if annot_func is not None:
  1417. self.annotate(annot_func)
  1418. return self
  1419. def plot_joint(self, func, **kwargs):
  1420. """Draw a bivariate plot of `x` and `y`.
  1421. Parameters
  1422. ----------
  1423. func : plotting callable
  1424. This must take two 1d arrays of data as the first two
  1425. positional arguments, and it must plot on the "current" axes.
  1426. kwargs : key, value mappings
  1427. Keyword argument are passed to the plotting function.
  1428. Returns
  1429. -------
  1430. self : JointGrid instance
  1431. Returns `self`.
  1432. """
  1433. plt.sca(self.ax_joint)
  1434. func(self.x, self.y, **kwargs)
  1435. return self
  1436. def plot_marginals(self, func, **kwargs):
  1437. """Draw univariate plots for `x` and `y` separately.
  1438. Parameters
  1439. ----------
  1440. func : plotting callable
  1441. This must take a 1d array of data as the first positional
  1442. argument, it must plot on the "current" axes, and it must
  1443. accept a "vertical" keyword argument to orient the measure
  1444. dimension of the plot vertically.
  1445. kwargs : key, value mappings
  1446. Keyword argument are passed to the plotting function.
  1447. Returns
  1448. -------
  1449. self : JointGrid instance
  1450. Returns `self`.
  1451. """
  1452. kwargs["vertical"] = False
  1453. plt.sca(self.ax_marg_x)
  1454. func(self.x, **kwargs)
  1455. kwargs["vertical"] = True
  1456. plt.sca(self.ax_marg_y)
  1457. func(self.y, **kwargs)
  1458. return self
  1459. def annotate(self, func, template=None, stat=None, loc="best", **kwargs):
  1460. """Annotate the plot with a statistic about the relationship.
  1461. *Deprecated and will be removed in a future version*.
  1462. Parameters
  1463. ----------
  1464. func : callable
  1465. Statistical function that maps the x, y vectors either to (val, p)
  1466. or to val.
  1467. template : string format template, optional
  1468. The template must have the format keys "stat" and "val";
  1469. if `func` returns a p value, it should also have the key "p".
  1470. stat : string, optional
  1471. Name to use for the statistic in the annotation, by default it
  1472. uses the name of `func`.
  1473. loc : string or int, optional
  1474. Matplotlib legend location code; used to place the annotation.
  1475. kwargs : key, value mappings
  1476. Other keyword arguments are passed to `ax.legend`, which formats
  1477. the annotation.
  1478. Returns
  1479. -------
  1480. self : JointGrid instance.
  1481. Returns `self`.
  1482. """
  1483. msg = ("JointGrid annotation is deprecated and will be removed "
  1484. "in a future release.")
  1485. warnings.warn(UserWarning(msg))
  1486. default_template = "{stat} = {val:.2g}; p = {p:.2g}"
  1487. # Call the function and determine the form of the return value(s)
  1488. out = func(self.x, self.y)
  1489. try:
  1490. val, p = out
  1491. except TypeError:
  1492. val, p = out, None
  1493. default_template, _ = default_template.split(";")
  1494. # Set the default template
  1495. if template is None:
  1496. template = default_template
  1497. # Default to name of the function
  1498. if stat is None:
  1499. stat = func.__name__
  1500. # Format the annotation
  1501. if p is None:
  1502. annotation = template.format(stat=stat, val=val)
  1503. else:
  1504. annotation = template.format(stat=stat, val=val, p=p)
  1505. # Draw an invisible plot and use the legend to draw the annotation
  1506. # This is a bit of a hack, but `loc=best` works nicely and is not
  1507. # easily abstracted.
  1508. phantom, = self.ax_joint.plot(self.x, self.y, linestyle="", alpha=0)
  1509. self.ax_joint.legend([phantom], [annotation], loc=loc, **kwargs)
  1510. phantom.remove()
  1511. return self
  1512. def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
  1513. """Set the axis labels on the bivariate axes.
  1514. Parameters
  1515. ----------
  1516. xlabel, ylabel : strings
  1517. Label names for the x and y variables.
  1518. kwargs : key, value mappings
  1519. Other keyword arguments are passed to the set_xlabel or
  1520. set_ylabel.
  1521. Returns
  1522. -------
  1523. self : JointGrid instance
  1524. returns `self`
  1525. """
  1526. self.ax_joint.set_xlabel(xlabel, **kwargs)
  1527. self.ax_joint.set_ylabel(ylabel, **kwargs)
  1528. return self
  1529. def savefig(self, *args, **kwargs):
  1530. """Wrap figure.savefig defaulting to tight bounding box."""
  1531. kwargs.setdefault("bbox_inches", "tight")
  1532. self.fig.savefig(*args, **kwargs)
  1533. def pairplot(data, hue=None, hue_order=None, palette=None,
  1534. vars=None, x_vars=None, y_vars=None,
  1535. kind="scatter", diag_kind="auto", markers=None,
  1536. height=2.5, aspect=1, corner=False, dropna=True,
  1537. plot_kws=None, diag_kws=None, grid_kws=None, size=None):
  1538. """Plot pairwise relationships in a dataset.
  1539. By default, this function will create a grid of Axes such that each numeric
  1540. variable in ``data`` will by shared in the y-axis across a single row and
  1541. in the x-axis across a single column. The diagonal Axes are treated
  1542. differently, drawing a plot to show the univariate distribution of the data
  1543. for the variable in that column.
  1544. It is also possible to show a subset of variables or plot different
  1545. variables on the rows and columns.
  1546. This is a high-level interface for :class:`PairGrid` that is intended to
  1547. make it easy to draw a few common styles. You should use :class:`PairGrid`
  1548. directly if you need more flexibility.
  1549. Parameters
  1550. ----------
  1551. data : DataFrame
  1552. Tidy (long-form) dataframe where each column is a variable and
  1553. each row is an observation.
  1554. hue : string (variable name), optional
  1555. Variable in ``data`` to map plot aspects to different colors.
  1556. hue_order : list of strings
  1557. Order for the levels of the hue variable in the palette
  1558. palette : dict or seaborn color palette
  1559. Set of colors for mapping the ``hue`` variable. If a dict, keys
  1560. should be values in the ``hue`` variable.
  1561. vars : list of variable names, optional
  1562. Variables within ``data`` to use, otherwise use every column with
  1563. a numeric datatype.
  1564. {x, y}_vars : lists of variable names, optional
  1565. Variables within ``data`` to use separately for the rows and
  1566. columns of the figure; i.e. to make a non-square plot.
  1567. kind : {'scatter', 'reg'}, optional
  1568. Kind of plot for the non-identity relationships.
  1569. diag_kind : {'auto', 'hist', 'kde', None}, optional
  1570. Kind of plot for the diagonal subplots. The default depends on whether
  1571. ``"hue"`` is used or not.
  1572. markers : single matplotlib marker code or list, optional
  1573. Either the marker to use for all datapoints or a list of markers with
  1574. a length the same as the number of levels in the hue variable so that
  1575. differently colored points will also have different scatterplot
  1576. markers.
  1577. height : scalar, optional
  1578. Height (in inches) of each facet.
  1579. aspect : scalar, optional
  1580. Aspect * height gives the width (in inches) of each facet.
  1581. corner : bool, optional
  1582. If True, don't add axes to the upper (off-diagonal) triangle of the
  1583. grid, making this a "corner" plot.
  1584. dropna : boolean, optional
  1585. Drop missing values from the data before plotting.
  1586. {plot, diag, grid}_kws : dicts, optional
  1587. Dictionaries of keyword arguments. ``plot_kws`` are passed to the
  1588. bivariate plotting function, ``diag_kws`` are passed to the univariate
  1589. plotting function, and ``grid_kws`` are passed to the :class:`PairGrid`
  1590. constructor.
  1591. Returns
  1592. -------
  1593. grid : :class:`PairGrid`
  1594. Returns the underlying :class:`PairGrid` instance for further tweaking.
  1595. See Also
  1596. --------
  1597. PairGrid : Subplot grid for more flexible plotting of pairwise
  1598. relationships.
  1599. Examples
  1600. --------
  1601. Draw scatterplots for joint relationships and histograms for univariate
  1602. distributions:
  1603. .. plot::
  1604. :context: close-figs
  1605. >>> import seaborn as sns; sns.set(style="ticks", color_codes=True)
  1606. >>> iris = sns.load_dataset("iris")
  1607. >>> g = sns.pairplot(iris)
  1608. Show different levels of a categorical variable by the color of plot
  1609. elements:
  1610. .. plot::
  1611. :context: close-figs
  1612. >>> g = sns.pairplot(iris, hue="species")
  1613. Use a different color palette:
  1614. .. plot::
  1615. :context: close-figs
  1616. >>> g = sns.pairplot(iris, hue="species", palette="husl")
  1617. Use different markers for each level of the hue variable:
  1618. .. plot::
  1619. :context: close-figs
  1620. >>> g = sns.pairplot(iris, hue="species", markers=["o", "s", "D"])
  1621. Plot a subset of variables:
  1622. .. plot::
  1623. :context: close-figs
  1624. >>> g = sns.pairplot(iris, vars=["sepal_width", "sepal_length"])
  1625. Draw larger plots:
  1626. .. plot::
  1627. :context: close-figs
  1628. >>> g = sns.pairplot(iris, height=3,
  1629. ... vars=["sepal_width", "sepal_length"])
  1630. Plot different variables in the rows and columns:
  1631. .. plot::
  1632. :context: close-figs
  1633. >>> g = sns.pairplot(iris,
  1634. ... x_vars=["sepal_width", "sepal_length"],
  1635. ... y_vars=["petal_width", "petal_length"])
  1636. Plot only the lower triangle of bivariate axes:
  1637. .. plot::
  1638. :context: close-figs
  1639. >>> g = sns.pairplot(iris, corner=True)
  1640. Use kernel density estimates for univariate plots:
  1641. .. plot::
  1642. :context: close-figs
  1643. >>> g = sns.pairplot(iris, diag_kind="kde")
  1644. Fit linear regression models to the scatter plots:
  1645. .. plot::
  1646. :context: close-figs
  1647. >>> g = sns.pairplot(iris, kind="reg")
  1648. Pass keyword arguments down to the underlying functions (it may be easier
  1649. to use :class:`PairGrid` directly):
  1650. .. plot::
  1651. :context: close-figs
  1652. >>> g = sns.pairplot(iris, diag_kind="kde", markers="+",
  1653. ... plot_kws=dict(s=50, edgecolor="b", linewidth=1),
  1654. ... diag_kws=dict(shade=True))
  1655. """
  1656. # Handle deprecations
  1657. if size is not None:
  1658. height = size
  1659. msg = ("The `size` parameter has been renamed to `height`; "
  1660. "please update your code.")
  1661. warnings.warn(msg, UserWarning)
  1662. if not isinstance(data, pd.DataFrame):
  1663. raise TypeError(
  1664. "'data' must be pandas DataFrame object, not: {typefound}".format(
  1665. typefound=type(data)))
  1666. plot_kws = {} if plot_kws is None else plot_kws.copy()
  1667. diag_kws = {} if diag_kws is None else diag_kws.copy()
  1668. grid_kws = {} if grid_kws is None else grid_kws.copy()
  1669. # Set up the PairGrid
  1670. grid_kws.setdefault("diag_sharey", diag_kind == "hist")
  1671. grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue,
  1672. hue_order=hue_order, palette=palette, corner=corner,
  1673. height=height, aspect=aspect, dropna=dropna, **grid_kws)
  1674. # Add the markers here as PairGrid has figured out how many levels of the
  1675. # hue variable are needed and we don't want to duplicate that process
  1676. if markers is not None:
  1677. if grid.hue_names is None:
  1678. n_markers = 1
  1679. else:
  1680. n_markers = len(grid.hue_names)
  1681. if not isinstance(markers, list):
  1682. markers = [markers] * n_markers
  1683. if len(markers) != n_markers:
  1684. raise ValueError(("markers must be a singleton or a list of "
  1685. "markers for each level of the hue variable"))
  1686. grid.hue_kws = {"marker": markers}
  1687. # Maybe plot on the diagonal
  1688. if diag_kind == "auto":
  1689. diag_kind = "hist" if hue is None else "kde"
  1690. diag_kws = diag_kws.copy()
  1691. if grid.square_grid:
  1692. if diag_kind == "hist":
  1693. grid.map_diag(plt.hist, **diag_kws)
  1694. elif diag_kind == "kde":
  1695. diag_kws.setdefault("shade", True)
  1696. diag_kws["legend"] = False
  1697. grid.map_diag(kdeplot, **diag_kws)
  1698. # Maybe plot on the off-diagonals
  1699. if grid.square_grid and diag_kind is not None:
  1700. plotter = grid.map_offdiag
  1701. else:
  1702. plotter = grid.map
  1703. if kind == "scatter":
  1704. from .relational import scatterplot # Avoid circular import
  1705. plotter(scatterplot, **plot_kws)
  1706. elif kind == "reg":
  1707. from .regression import regplot # Avoid circular import
  1708. plotter(regplot, **plot_kws)
  1709. # Add a legend
  1710. if hue is not None:
  1711. grid.add_legend()
  1712. return grid
  1713. def jointplot(x, y, data=None, kind="scatter", stat_func=None,
  1714. color=None, height=6, ratio=5, space=.2,
  1715. dropna=True, xlim=None, ylim=None,
  1716. joint_kws=None, marginal_kws=None, annot_kws=None, **kwargs):
  1717. """Draw a plot of two variables with bivariate and univariate graphs.
  1718. This function provides a convenient interface to the :class:`JointGrid`
  1719. class, with several canned plot kinds. This is intended to be a fairly
  1720. lightweight wrapper; if you need more flexibility, you should use
  1721. :class:`JointGrid` directly.
  1722. Parameters
  1723. ----------
  1724. x, y : strings or vectors
  1725. Data or names of variables in ``data``.
  1726. data : DataFrame, optional
  1727. DataFrame when ``x`` and ``y`` are variable names.
  1728. kind : { "scatter" | "reg" | "resid" | "kde" | "hex" }, optional
  1729. Kind of plot to draw.
  1730. stat_func : callable or None, optional
  1731. *Deprecated*
  1732. color : matplotlib color, optional
  1733. Color used for the plot elements.
  1734. height : numeric, optional
  1735. Size of the figure (it will be square).
  1736. ratio : numeric, optional
  1737. Ratio of joint axes height to marginal axes height.
  1738. space : numeric, optional
  1739. Space between the joint and marginal axes
  1740. dropna : bool, optional
  1741. If True, remove observations that are missing from ``x`` and ``y``.
  1742. {x, y}lim : two-tuples, optional
  1743. Axis limits to set before plotting.
  1744. {joint, marginal, annot}_kws : dicts, optional
  1745. Additional keyword arguments for the plot components.
  1746. kwargs : key, value pairings
  1747. Additional keyword arguments are passed to the function used to
  1748. draw the plot on the joint Axes, superseding items in the
  1749. ``joint_kws`` dictionary.
  1750. Returns
  1751. -------
  1752. grid : :class:`JointGrid`
  1753. :class:`JointGrid` object with the plot on it.
  1754. See Also
  1755. --------
  1756. JointGrid : The Grid class used for drawing this plot. Use it directly if
  1757. you need more flexibility.
  1758. Examples
  1759. --------
  1760. Draw a scatterplot with marginal histograms:
  1761. .. plot::
  1762. :context: close-figs
  1763. >>> import numpy as np, pandas as pd; np.random.seed(0)
  1764. >>> import seaborn as sns; sns.set(style="white", color_codes=True)
  1765. >>> tips = sns.load_dataset("tips")
  1766. >>> g = sns.jointplot(x="total_bill", y="tip", data=tips)
  1767. Add regression and kernel density fits:
  1768. .. plot::
  1769. :context: close-figs
  1770. >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="reg")
  1771. Replace the scatterplot with a joint histogram using hexagonal bins:
  1772. .. plot::
  1773. :context: close-figs
  1774. >>> g = sns.jointplot("total_bill", "tip", data=tips, kind="hex")
  1775. Replace the scatterplots and histograms with density estimates and align
  1776. the marginal Axes tightly with the joint Axes:
  1777. .. plot::
  1778. :context: close-figs
  1779. >>> iris = sns.load_dataset("iris")
  1780. >>> g = sns.jointplot("sepal_width", "petal_length", data=iris,
  1781. ... kind="kde", space=0, color="g")
  1782. Draw a scatterplot, then add a joint density estimate:
  1783. .. plot::
  1784. :context: close-figs
  1785. >>> g = (sns.jointplot("sepal_length", "sepal_width",
  1786. ... data=iris, color="k")
  1787. ... .plot_joint(sns.kdeplot, zorder=0, n_levels=6))
  1788. Pass vectors in directly without using Pandas, then name the axes:
  1789. .. plot::
  1790. :context: close-figs
  1791. >>> x, y = np.random.randn(2, 300)
  1792. >>> g = (sns.jointplot(x, y, kind="hex")
  1793. ... .set_axis_labels("x", "y"))
  1794. Draw a smaller figure with more space devoted to the marginal plots:
  1795. .. plot::
  1796. :context: close-figs
  1797. >>> g = sns.jointplot("total_bill", "tip", data=tips,
  1798. ... height=5, ratio=3, color="g")
  1799. Pass keyword arguments down to the underlying plots:
  1800. .. plot::
  1801. :context: close-figs
  1802. >>> g = sns.jointplot("petal_length", "sepal_length", data=iris,
  1803. ... marginal_kws=dict(bins=15, rug=True),
  1804. ... annot_kws=dict(stat="r"),
  1805. ... s=40, edgecolor="w", linewidth=1)
  1806. """
  1807. # Handle deprecations
  1808. if "size" in kwargs:
  1809. height = kwargs.pop("size")
  1810. msg = ("The `size` parameter has been renamed to `height`; "
  1811. "please update your code.")
  1812. warnings.warn(msg, UserWarning)
  1813. # Set up empty default kwarg dicts
  1814. joint_kws = {} if joint_kws is None else joint_kws.copy()
  1815. joint_kws.update(kwargs)
  1816. marginal_kws = {} if marginal_kws is None else marginal_kws.copy()
  1817. annot_kws = {} if annot_kws is None else annot_kws.copy()
  1818. # Make a colormap based off the plot color
  1819. if color is None:
  1820. color = color_palette()[0]
  1821. color_rgb = mpl.colors.colorConverter.to_rgb(color)
  1822. colors = [utils.set_hls_values(color_rgb, l=l) # noqa
  1823. for l in np.linspace(1, 0, 12)]
  1824. cmap = blend_palette(colors, as_cmap=True)
  1825. # Initialize the JointGrid object
  1826. grid = JointGrid(x, y, data, dropna=dropna,
  1827. height=height, ratio=ratio, space=space,
  1828. xlim=xlim, ylim=ylim)
  1829. # Plot the data using the grid
  1830. if kind == "scatter":
  1831. joint_kws.setdefault("color", color)
  1832. grid.plot_joint(plt.scatter, **joint_kws)
  1833. marginal_kws.setdefault("kde", False)
  1834. marginal_kws.setdefault("color", color)
  1835. grid.plot_marginals(distplot, **marginal_kws)
  1836. elif kind.startswith("hex"):
  1837. x_bins = min(_freedman_diaconis_bins(grid.x), 50)
  1838. y_bins = min(_freedman_diaconis_bins(grid.y), 50)
  1839. gridsize = int(np.mean([x_bins, y_bins]))
  1840. joint_kws.setdefault("gridsize", gridsize)
  1841. joint_kws.setdefault("cmap", cmap)
  1842. grid.plot_joint(plt.hexbin, **joint_kws)
  1843. marginal_kws.setdefault("kde", False)
  1844. marginal_kws.setdefault("color", color)
  1845. grid.plot_marginals(distplot, **marginal_kws)
  1846. elif kind.startswith("kde"):
  1847. joint_kws.setdefault("shade", True)
  1848. joint_kws.setdefault("cmap", cmap)
  1849. grid.plot_joint(kdeplot, **joint_kws)
  1850. marginal_kws.setdefault("shade", True)
  1851. marginal_kws.setdefault("color", color)
  1852. grid.plot_marginals(kdeplot, **marginal_kws)
  1853. elif kind.startswith("reg"):
  1854. from .regression import regplot
  1855. marginal_kws.setdefault("color", color)
  1856. grid.plot_marginals(distplot, **marginal_kws)
  1857. joint_kws.setdefault("color", color)
  1858. grid.plot_joint(regplot, **joint_kws)
  1859. elif kind.startswith("resid"):
  1860. from .regression import residplot
  1861. joint_kws.setdefault("color", color)
  1862. grid.plot_joint(residplot, **joint_kws)
  1863. x, y = grid.ax_joint.collections[0].get_offsets().T
  1864. marginal_kws.setdefault("color", color)
  1865. marginal_kws.setdefault("kde", False)
  1866. distplot(x, ax=grid.ax_marg_x, **marginal_kws)
  1867. distplot(y, vertical=True, fit=stats.norm, ax=grid.ax_marg_y,
  1868. **marginal_kws)
  1869. stat_func = None
  1870. else:
  1871. msg = "kind must be either 'scatter', 'reg', 'resid', 'kde', or 'hex'"
  1872. raise ValueError(msg)
  1873. if stat_func is not None:
  1874. grid.annotate(stat_func, **annot_kws)
  1875. return grid