categorical.py 136 KB


  1. from textwrap import dedent
  2. import colorsys
  3. import numpy as np
  4. from scipy import stats
  5. import pandas as pd
  6. import matplotlib as mpl
  7. from matplotlib.collections import PatchCollection
  8. import matplotlib.patches as Patches
  9. import matplotlib.pyplot as plt
  10. import warnings
  11. from distutils.version import LooseVersion
  12. from . import utils
  13. from .utils import iqr, categorical_order, remove_na
  14. from .algorithms import bootstrap
  15. from .palettes import color_palette, husl_palette, light_palette, dark_palette
  16. from .axisgrid import FacetGrid, _facet_docs
  17. __all__ = [
  18. "catplot", "factorplot",
  19. "stripplot", "swarmplot",
  20. "boxplot", "violinplot", "boxenplot", "lvplot",
  21. "pointplot", "barplot", "countplot",
  22. ]
  23. class _CategoricalPlotter(object):
  24. width = .8
  25. default_palette = "light"
  26. def establish_variables(self, x=None, y=None, hue=None, data=None,
  27. orient=None, order=None, hue_order=None,
  28. units=None):
  29. """Convert input specification into a common representation."""
  30. # Option 1:
  31. # We are plotting a wide-form dataset
  32. # -----------------------------------
  33. if x is None and y is None:
  34. # Do a sanity check on the inputs
  35. if hue is not None:
  36. error = "Cannot use `hue` without `x` or `y`"
  37. raise ValueError(error)
  38. # No hue grouping with wide inputs
  39. plot_hues = None
  40. hue_title = None
  41. hue_names = None
  42. # No statistical units with wide inputs
  43. plot_units = None
  44. # We also won't get a axes labels here
  45. value_label = None
  46. group_label = None
  47. # Option 1a:
  48. # The input data is a Pandas DataFrame
  49. # ------------------------------------
  50. if isinstance(data, pd.DataFrame):
  51. # Order the data correctly
  52. if order is None:
  53. order = []
  54. # Reduce to just numeric columns
  55. for col in data:
  56. try:
  57. data[col].astype(np.float)
  58. order.append(col)
  59. except ValueError:
  60. pass
  61. plot_data = data[order]
  62. group_names = order
  63. group_label = data.columns.name
  64. # Convert to a list of arrays, the common representation
  65. iter_data = plot_data.iteritems()
  66. plot_data = [np.asarray(s, np.float) for k, s in iter_data]
  67. # Option 1b:
  68. # The input data is an array or list
  69. # ----------------------------------
  70. else:
  71. # We can't reorder the data
  72. if order is not None:
  73. error = "Input data must be a pandas object to reorder"
  74. raise ValueError(error)
  75. # The input data is an array
  76. if hasattr(data, "shape"):
  77. if len(data.shape) == 1:
  78. if np.isscalar(data[0]):
  79. plot_data = [data]
  80. else:
  81. plot_data = list(data)
  82. elif len(data.shape) == 2:
  83. nr, nc = data.shape
  84. if nr == 1 or nc == 1:
  85. plot_data = [data.ravel()]
  86. else:
  87. plot_data = [data[:, i] for i in range(nc)]
  88. else:
  89. error = ("Input `data` can have no "
  90. "more than 2 dimensions")
  91. raise ValueError(error)
  92. # Check if `data` is None to let us bail out here (for testing)
  93. elif data is None:
  94. plot_data = [[]]
  95. # The input data is a flat list
  96. elif np.isscalar(data[0]):
  97. plot_data = [data]
  98. # The input data is a nested list
  99. # This will catch some things that might fail later
  100. # but exhaustive checks are hard
  101. else:
  102. plot_data = data
  103. # Convert to a list of arrays, the common representation
  104. plot_data = [np.asarray(d, np.float) for d in plot_data]
  105. # The group names will just be numeric indices
  106. group_names = list(range((len(plot_data))))
  107. # Figure out the plotting orientation
  108. orient = "h" if str(orient).startswith("h") else "v"
  109. # Option 2:
  110. # We are plotting a long-form dataset
  111. # -----------------------------------
  112. else:
  113. # See if we need to get variables from `data`
  114. if data is not None:
  115. x = data.get(x, x)
  116. y = data.get(y, y)
  117. hue = data.get(hue, hue)
  118. units = data.get(units, units)
  119. # Validate the inputs
  120. for var in [x, y, hue, units]:
  121. if isinstance(var, str):
  122. err = "Could not interpret input '{}'".format(var)
  123. raise ValueError(err)
  124. # Figure out the plotting orientation
  125. orient = self.infer_orient(x, y, orient)
  126. # Option 2a:
  127. # We are plotting a single set of data
  128. # ------------------------------------
  129. if x is None or y is None:
  130. # Determine where the data are
  131. vals = y if x is None else x
  132. # Put them into the common representation
  133. plot_data = [np.asarray(vals)]
  134. # Get a label for the value axis
  135. if hasattr(vals, "name"):
  136. value_label = vals.name
  137. else:
  138. value_label = None
  139. # This plot will not have group labels or hue nesting
  140. groups = None
  141. group_label = None
  142. group_names = []
  143. plot_hues = None
  144. hue_names = None
  145. hue_title = None
  146. plot_units = None
  147. # Option 2b:
  148. # We are grouping the data values by another variable
  149. # ---------------------------------------------------
  150. else:
  151. # Determine which role each variable will play
  152. if orient == "v":
  153. vals, groups = y, x
  154. else:
  155. vals, groups = x, y
  156. # Get the categorical axis label
  157. group_label = None
  158. if hasattr(groups, "name"):
  159. group_label = groups.name
  160. # Get the order on the categorical axis
  161. group_names = categorical_order(groups, order)
  162. # Group the numeric data
  163. plot_data, value_label = self._group_longform(vals, groups,
  164. group_names)
  165. # Now handle the hue levels for nested ordering
  166. if hue is None:
  167. plot_hues = None
  168. hue_title = None
  169. hue_names = None
  170. else:
  171. # Get the order of the hue levels
  172. hue_names = categorical_order(hue, hue_order)
  173. # Group the hue data
  174. plot_hues, hue_title = self._group_longform(hue, groups,
  175. group_names)
  176. # Now handle the units for nested observations
  177. if units is None:
  178. plot_units = None
  179. else:
  180. plot_units, _ = self._group_longform(units, groups,
  181. group_names)
  182. # Assign object attributes
  183. # ------------------------
  184. self.orient = orient
  185. self.plot_data = plot_data
  186. self.group_label = group_label
  187. self.value_label = value_label
  188. self.group_names = group_names
  189. self.plot_hues = plot_hues
  190. self.hue_title = hue_title
  191. self.hue_names = hue_names
  192. self.plot_units = plot_units
  193. def _group_longform(self, vals, grouper, order):
  194. """Group a long-form variable by another with correct order."""
  195. # Ensure that the groupby will work
  196. if not isinstance(vals, pd.Series):
  197. if isinstance(grouper, pd.Series):
  198. index = grouper.index
  199. else:
  200. index = None
  201. vals = pd.Series(vals, index=index)
  202. # Group the val data
  203. grouped_vals = vals.groupby(grouper)
  204. out_data = []
  205. for g in order:
  206. try:
  207. g_vals = grouped_vals.get_group(g)
  208. except KeyError:
  209. g_vals = np.array([])
  210. out_data.append(g_vals)
  211. # Get the vals axis label
  212. label = vals.name
  213. return out_data, label
  214. def establish_colors(self, color, palette, saturation):
  215. """Get a list of colors for the main component of the plots."""
  216. if self.hue_names is None:
  217. n_colors = len(self.plot_data)
  218. else:
  219. n_colors = len(self.hue_names)
  220. # Determine the main colors
  221. if color is None and palette is None:
  222. # Determine whether the current palette will have enough values
  223. # If not, we'll default to the husl palette so each is distinct
  224. current_palette = utils.get_color_cycle()
  225. if n_colors <= len(current_palette):
  226. colors = color_palette(n_colors=n_colors)
  227. else:
  228. colors = husl_palette(n_colors, l=.7) # noqa
  229. elif palette is None:
  230. # When passing a specific color, the interpretation depends
  231. # on whether there is a hue variable or not.
  232. # If so, we will make a blend palette so that the different
  233. # levels have some amount of variation.
  234. if self.hue_names is None:
  235. colors = [color] * n_colors
  236. else:
  237. if self.default_palette == "light":
  238. colors = light_palette(color, n_colors)
  239. elif self.default_palette == "dark":
  240. colors = dark_palette(color, n_colors)
  241. else:
  242. raise RuntimeError("No default palette specified")
  243. else:
  244. # Let `palette` be a dict mapping level to color
  245. if isinstance(palette, dict):
  246. if self.hue_names is None:
  247. levels = self.group_names
  248. else:
  249. levels = self.hue_names
  250. palette = [palette[l] for l in levels]
  251. colors = color_palette(palette, n_colors)
  252. # Desaturate a bit because these are patches
  253. if saturation < 1:
  254. colors = color_palette(colors, desat=saturation)
  255. # Convert the colors to a common representations
  256. rgb_colors = color_palette(colors)
  257. # Determine the gray color to use for the lines framing the plot
  258. light_vals = [colorsys.rgb_to_hls(*c)[1] for c in rgb_colors]
  259. lum = min(light_vals) * .6
  260. gray = mpl.colors.rgb2hex((lum, lum, lum))
  261. # Assign object attributes
  262. self.colors = rgb_colors
  263. self.gray = gray
  264. def infer_orient(self, x, y, orient=None):
  265. """Determine how the plot should be oriented based on the data."""
  266. orient = str(orient)
  267. def is_categorical(s):
  268. return pd.api.types.is_categorical_dtype(s)
  269. def is_not_numeric(s):
  270. try:
  271. np.asarray(s, dtype=np.float)
  272. except ValueError:
  273. return True
  274. return False
  275. no_numeric = "Neither the `x` nor `y` variable appears to be numeric."
  276. if orient.startswith("v"):
  277. return "v"
  278. elif orient.startswith("h"):
  279. return "h"
  280. elif x is None:
  281. return "v"
  282. elif y is None:
  283. return "h"
  284. elif is_categorical(y):
  285. if is_categorical(x):
  286. raise ValueError(no_numeric)
  287. else:
  288. return "h"
  289. elif is_not_numeric(y):
  290. if is_not_numeric(x):
  291. raise ValueError(no_numeric)
  292. else:
  293. return "h"
  294. else:
  295. return "v"
  296. @property
  297. def hue_offsets(self):
  298. """A list of center positions for plots when hue nesting is used."""
  299. n_levels = len(self.hue_names)
  300. if self.dodge:
  301. each_width = self.width / n_levels
  302. offsets = np.linspace(0, self.width - each_width, n_levels)
  303. offsets -= offsets.mean()
  304. else:
  305. offsets = np.zeros(n_levels)
  306. return offsets
  307. @property
  308. def nested_width(self):
  309. """A float with the width of plot elements when hue nesting is used."""
  310. if self.dodge:
  311. width = self.width / len(self.hue_names) * .98
  312. else:
  313. width = self.width
  314. return width
  315. def annotate_axes(self, ax):
  316. """Add descriptive labels to an Axes object."""
  317. if self.orient == "v":
  318. xlabel, ylabel = self.group_label, self.value_label
  319. else:
  320. xlabel, ylabel = self.value_label, self.group_label
  321. if xlabel is not None:
  322. ax.set_xlabel(xlabel)
  323. if ylabel is not None:
  324. ax.set_ylabel(ylabel)
  325. if self.orient == "v":
  326. ax.set_xticks(np.arange(len(self.plot_data)))
  327. ax.set_xticklabels(self.group_names)
  328. else:
  329. ax.set_yticks(np.arange(len(self.plot_data)))
  330. ax.set_yticklabels(self.group_names)
  331. if self.orient == "v":
  332. ax.xaxis.grid(False)
  333. ax.set_xlim(-.5, len(self.plot_data) - .5, auto=None)
  334. else:
  335. ax.yaxis.grid(False)
  336. ax.set_ylim(-.5, len(self.plot_data) - .5, auto=None)
  337. if self.hue_names is not None:
  338. leg = ax.legend(loc="best", title=self.hue_title)
  339. if self.hue_title is not None:
  340. if LooseVersion(mpl.__version__) < "3.0":
  341. # Old Matplotlib has no legend title size rcparam
  342. try:
  343. title_size = mpl.rcParams["axes.labelsize"] * .85
  344. except TypeError: # labelsize is something like "large"
  345. title_size = mpl.rcParams["axes.labelsize"]
  346. prop = mpl.font_manager.FontProperties(size=title_size)
  347. leg.set_title(self.hue_title, prop=prop)
  348. def add_legend_data(self, ax, color, label):
  349. """Add a dummy patch object so we can get legend data."""
  350. rect = plt.Rectangle([0, 0], 0, 0,
  351. linewidth=self.linewidth / 2,
  352. edgecolor=self.gray,
  353. facecolor=color,
  354. label=label)
  355. ax.add_patch(rect)
  356. class _BoxPlotter(_CategoricalPlotter):
  357. def __init__(self, x, y, hue, data, order, hue_order,
  358. orient, color, palette, saturation,
  359. width, dodge, fliersize, linewidth):
  360. self.establish_variables(x, y, hue, data, orient, order, hue_order)
  361. self.establish_colors(color, palette, saturation)
  362. self.dodge = dodge
  363. self.width = width
  364. self.fliersize = fliersize
  365. if linewidth is None:
  366. linewidth = mpl.rcParams["lines.linewidth"]
  367. self.linewidth = linewidth
  368. def draw_boxplot(self, ax, kws):
  369. """Use matplotlib to draw a boxplot on an Axes."""
  370. vert = self.orient == "v"
  371. props = {}
  372. for obj in ["box", "whisker", "cap", "median", "flier"]:
  373. props[obj] = kws.pop(obj + "props", {})
  374. for i, group_data in enumerate(self.plot_data):
  375. if self.plot_hues is None:
  376. # Handle case where there is data at this level
  377. if group_data.size == 0:
  378. continue
  379. # Draw a single box or a set of boxes
  380. # with a single level of grouping
  381. box_data = np.asarray(remove_na(group_data))
  382. # Handle case where there is no non-null data
  383. if box_data.size == 0:
  384. continue
  385. artist_dict = ax.boxplot(box_data,
  386. vert=vert,
  387. patch_artist=True,
  388. positions=[i],
  389. widths=self.width,
  390. **kws)
  391. color = self.colors[i]
  392. self.restyle_boxplot(artist_dict, color, props)
  393. else:
  394. # Draw nested groups of boxes
  395. offsets = self.hue_offsets
  396. for j, hue_level in enumerate(self.hue_names):
  397. # Add a legend for this hue level
  398. if not i:
  399. self.add_legend_data(ax, self.colors[j], hue_level)
  400. # Handle case where there is data at this level
  401. if group_data.size == 0:
  402. continue
  403. hue_mask = self.plot_hues[i] == hue_level
  404. box_data = np.asarray(remove_na(group_data[hue_mask]))
  405. # Handle case where there is no non-null data
  406. if box_data.size == 0:
  407. continue
  408. center = i + offsets[j]
  409. artist_dict = ax.boxplot(box_data,
  410. vert=vert,
  411. patch_artist=True,
  412. positions=[center],
  413. widths=self.nested_width,
  414. **kws)
  415. self.restyle_boxplot(artist_dict, self.colors[j], props)
  416. # Add legend data, but just for one set of boxes
  417. def restyle_boxplot(self, artist_dict, color, props):
  418. """Take a drawn matplotlib boxplot and make it look nice."""
  419. for box in artist_dict["boxes"]:
  420. box.update(dict(facecolor=color,
  421. zorder=.9,
  422. edgecolor=self.gray,
  423. linewidth=self.linewidth))
  424. box.update(props["box"])
  425. for whisk in artist_dict["whiskers"]:
  426. whisk.update(dict(color=self.gray,
  427. linewidth=self.linewidth,
  428. linestyle="-"))
  429. whisk.update(props["whisker"])
  430. for cap in artist_dict["caps"]:
  431. cap.update(dict(color=self.gray,
  432. linewidth=self.linewidth))
  433. cap.update(props["cap"])
  434. for med in artist_dict["medians"]:
  435. med.update(dict(color=self.gray,
  436. linewidth=self.linewidth))
  437. med.update(props["median"])
  438. for fly in artist_dict["fliers"]:
  439. fly.update(dict(markerfacecolor=self.gray,
  440. marker="d",
  441. markeredgecolor=self.gray,
  442. markersize=self.fliersize))
  443. fly.update(props["flier"])
  444. def plot(self, ax, boxplot_kws):
  445. """Make the plot."""
  446. self.draw_boxplot(ax, boxplot_kws)
  447. self.annotate_axes(ax)
  448. if self.orient == "h":
  449. ax.invert_yaxis()
  450. class _ViolinPlotter(_CategoricalPlotter):
  451. def __init__(self, x, y, hue, data, order, hue_order,
  452. bw, cut, scale, scale_hue, gridsize,
  453. width, inner, split, dodge, orient, linewidth,
  454. color, palette, saturation):
  455. self.establish_variables(x, y, hue, data, orient, order, hue_order)
  456. self.establish_colors(color, palette, saturation)
  457. self.estimate_densities(bw, cut, scale, scale_hue, gridsize)
  458. self.gridsize = gridsize
  459. self.width = width
  460. self.dodge = dodge
  461. if inner is not None:
  462. if not any([inner.startswith("quart"),
  463. inner.startswith("box"),
  464. inner.startswith("stick"),
  465. inner.startswith("point")]):
  466. err = "Inner style '{}' not recognized".format(inner)
  467. raise ValueError(err)
  468. self.inner = inner
  469. if split and self.hue_names is not None and len(self.hue_names) != 2:
  470. msg = "There must be exactly two hue levels to use `split`.'"
  471. raise ValueError(msg)
  472. self.split = split
  473. if linewidth is None:
  474. linewidth = mpl.rcParams["lines.linewidth"]
  475. self.linewidth = linewidth
  476. def estimate_densities(self, bw, cut, scale, scale_hue, gridsize):
  477. """Find the support and density for all of the data."""
  478. # Initialize data structures to keep track of plotting data
  479. if self.hue_names is None:
  480. support = []
  481. density = []
  482. counts = np.zeros(len(self.plot_data))
  483. max_density = np.zeros(len(self.plot_data))
  484. else:
  485. support = [[] for _ in self.plot_data]
  486. density = [[] for _ in self.plot_data]
  487. size = len(self.group_names), len(self.hue_names)
  488. counts = np.zeros(size)
  489. max_density = np.zeros(size)
  490. for i, group_data in enumerate(self.plot_data):
  491. # Option 1: we have a single level of grouping
  492. # --------------------------------------------
  493. if self.plot_hues is None:
  494. # Strip missing datapoints
  495. kde_data = remove_na(group_data)
  496. # Handle special case of no data at this level
  497. if kde_data.size == 0:
  498. support.append(np.array([]))
  499. density.append(np.array([1.]))
  500. counts[i] = 0
  501. max_density[i] = 0
  502. continue
  503. # Handle special case of a single unique datapoint
  504. elif np.unique(kde_data).size == 1:
  505. support.append(np.unique(kde_data))
  506. density.append(np.array([1.]))
  507. counts[i] = 1
  508. max_density[i] = 0
  509. continue
  510. # Fit the KDE and get the used bandwidth size
  511. kde, bw_used = self.fit_kde(kde_data, bw)
  512. # Determine the support grid and get the density over it
  513. support_i = self.kde_support(kde_data, bw_used, cut, gridsize)
  514. density_i = kde.evaluate(support_i)
  515. # Update the data structures with these results
  516. support.append(support_i)
  517. density.append(density_i)
  518. counts[i] = kde_data.size
  519. max_density[i] = density_i.max()
  520. # Option 2: we have nested grouping by a hue variable
  521. # ---------------------------------------------------
  522. else:
  523. for j, hue_level in enumerate(self.hue_names):
  524. # Handle special case of no data at this category level
  525. if not group_data.size:
  526. support[i].append(np.array([]))
  527. density[i].append(np.array([1.]))
  528. counts[i, j] = 0
  529. max_density[i, j] = 0
  530. continue
  531. # Select out the observations for this hue level
  532. hue_mask = self.plot_hues[i] == hue_level
  533. # Strip missing datapoints
  534. kde_data = remove_na(group_data[hue_mask])
  535. # Handle special case of no data at this level
  536. if kde_data.size == 0:
  537. support[i].append(np.array([]))
  538. density[i].append(np.array([1.]))
  539. counts[i, j] = 0
  540. max_density[i, j] = 0
  541. continue
  542. # Handle special case of a single unique datapoint
  543. elif np.unique(kde_data).size == 1:
  544. support[i].append(np.unique(kde_data))
  545. density[i].append(np.array([1.]))
  546. counts[i, j] = 1
  547. max_density[i, j] = 0
  548. continue
  549. # Fit the KDE and get the used bandwidth size
  550. kde, bw_used = self.fit_kde(kde_data, bw)
  551. # Determine the support grid and get the density over it
  552. support_ij = self.kde_support(kde_data, bw_used,
  553. cut, gridsize)
  554. density_ij = kde.evaluate(support_ij)
  555. # Update the data structures with these results
  556. support[i].append(support_ij)
  557. density[i].append(density_ij)
  558. counts[i, j] = kde_data.size
  559. max_density[i, j] = density_ij.max()
  560. # Scale the height of the density curve.
  561. # For a violinplot the density is non-quantitative.
  562. # The objective here is to scale the curves relative to 1 so that
  563. # they can be multiplied by the width parameter during plotting.
  564. if scale == "area":
  565. self.scale_area(density, max_density, scale_hue)
  566. elif scale == "width":
  567. self.scale_width(density)
  568. elif scale == "count":
  569. self.scale_count(density, counts, scale_hue)
  570. else:
  571. raise ValueError("scale method '{}' not recognized".format(scale))
  572. # Set object attributes that will be used while plotting
  573. self.support = support
  574. self.density = density
  575. def fit_kde(self, x, bw):
  576. """Estimate a KDE for a vector of data with flexible bandwidth."""
  577. kde = stats.gaussian_kde(x, bw)
  578. # Extract the numeric bandwidth from the KDE object
  579. bw_used = kde.factor
  580. # At this point, bw will be a numeric scale factor.
  581. # To get the actual bandwidth of the kernel, we multiple by the
  582. # unbiased standard deviation of the data, which we will use
  583. # elsewhere to compute the range of the support.
  584. bw_used = bw_used * x.std(ddof=1)
  585. return kde, bw_used
  586. def kde_support(self, x, bw, cut, gridsize):
  587. """Define a grid of support for the violin."""
  588. support_min = x.min() - bw * cut
  589. support_max = x.max() + bw * cut
  590. return np.linspace(support_min, support_max, gridsize)
  591. def scale_area(self, density, max_density, scale_hue):
  592. """Scale the relative area under the KDE curve.
  593. This essentially preserves the "standard" KDE scaling, but the
  594. resulting maximum density will be 1 so that the curve can be
  595. properly multiplied by the violin width.
  596. """
  597. if self.hue_names is None:
  598. for d in density:
  599. if d.size > 1:
  600. d /= max_density.max()
  601. else:
  602. for i, group in enumerate(density):
  603. for d in group:
  604. if scale_hue:
  605. max = max_density[i].max()
  606. else:
  607. max = max_density.max()
  608. if d.size > 1:
  609. d /= max
  610. def scale_width(self, density):
  611. """Scale each density curve to the same height."""
  612. if self.hue_names is None:
  613. for d in density:
  614. d /= d.max()
  615. else:
  616. for group in density:
  617. for d in group:
  618. d /= d.max()
  619. def scale_count(self, density, counts, scale_hue):
  620. """Scale each density curve by the number of observations."""
  621. if self.hue_names is None:
  622. if counts.max() == 0:
  623. d = 0
  624. else:
  625. for count, d in zip(counts, density):
  626. d /= d.max()
  627. d *= count / counts.max()
  628. else:
  629. for i, group in enumerate(density):
  630. for j, d in enumerate(group):
  631. if counts[i].max() == 0:
  632. d = 0
  633. else:
  634. count = counts[i, j]
  635. if scale_hue:
  636. scaler = count / counts[i].max()
  637. else:
  638. scaler = count / counts.max()
  639. d /= d.max()
  640. d *= scaler
  641. @property
  642. def dwidth(self):
  643. if self.hue_names is None or not self.dodge:
  644. return self.width / 2
  645. elif self.split:
  646. return self.width / 2
  647. else:
  648. return self.width / (2 * len(self.hue_names))
  649. def draw_violins(self, ax):
  650. """Draw the violins onto `ax`."""
  651. fill_func = ax.fill_betweenx if self.orient == "v" else ax.fill_between
  652. for i, group_data in enumerate(self.plot_data):
  653. kws = dict(edgecolor=self.gray, linewidth=self.linewidth)
  654. # Option 1: we have a single level of grouping
  655. # --------------------------------------------
  656. if self.plot_hues is None:
  657. support, density = self.support[i], self.density[i]
  658. # Handle special case of no observations in this bin
  659. if support.size == 0:
  660. continue
  661. # Handle special case of a single observation
  662. elif support.size == 1:
  663. val = support.item()
  664. d = density.item()
  665. self.draw_single_observation(ax, i, val, d)
  666. continue
  667. # Draw the violin for this group
  668. grid = np.ones(self.gridsize) * i
  669. fill_func(support,
  670. grid - density * self.dwidth,
  671. grid + density * self.dwidth,
  672. facecolor=self.colors[i],
  673. **kws)
  674. # Draw the interior representation of the data
  675. if self.inner is None:
  676. continue
  677. # Get a nan-free vector of datapoints
  678. violin_data = remove_na(group_data)
  679. # Draw box and whisker information
  680. if self.inner.startswith("box"):
  681. self.draw_box_lines(ax, violin_data, support, density, i)
  682. # Draw quartile lines
  683. elif self.inner.startswith("quart"):
  684. self.draw_quartiles(ax, violin_data, support, density, i)
  685. # Draw stick observations
  686. elif self.inner.startswith("stick"):
  687. self.draw_stick_lines(ax, violin_data, support, density, i)
  688. # Draw point observations
  689. elif self.inner.startswith("point"):
  690. self.draw_points(ax, violin_data, i)
  691. # Option 2: we have nested grouping by a hue variable
  692. # ---------------------------------------------------
  693. else:
  694. offsets = self.hue_offsets
  695. for j, hue_level in enumerate(self.hue_names):
  696. support, density = self.support[i][j], self.density[i][j]
  697. kws["facecolor"] = self.colors[j]
  698. # Add legend data, but just for one set of violins
  699. if not i:
  700. self.add_legend_data(ax, self.colors[j], hue_level)
  701. # Handle the special case where we have no observations
  702. if support.size == 0:
  703. continue
  704. # Handle the special case where we have one observation
  705. elif support.size == 1:
  706. val = support.item()
  707. d = density.item()
  708. if self.split:
  709. d = d / 2
  710. at_group = i + offsets[j]
  711. self.draw_single_observation(ax, at_group, val, d)
  712. continue
  713. # Option 2a: we are drawing a single split violin
  714. # -----------------------------------------------
  715. if self.split:
  716. grid = np.ones(self.gridsize) * i
  717. if j:
  718. fill_func(support,
  719. grid,
  720. grid + density * self.dwidth,
  721. **kws)
  722. else:
  723. fill_func(support,
  724. grid - density * self.dwidth,
  725. grid,
  726. **kws)
  727. # Draw the interior representation of the data
  728. if self.inner is None:
  729. continue
  730. # Get a nan-free vector of datapoints
  731. hue_mask = self.plot_hues[i] == hue_level
  732. violin_data = remove_na(group_data[hue_mask])
  733. # Draw quartile lines
  734. if self.inner.startswith("quart"):
  735. self.draw_quartiles(ax, violin_data,
  736. support, density, i,
  737. ["left", "right"][j])
  738. # Draw stick observations
  739. elif self.inner.startswith("stick"):
  740. self.draw_stick_lines(ax, violin_data,
  741. support, density, i,
  742. ["left", "right"][j])
  743. # The box and point interior plots are drawn for
  744. # all data at the group level, so we just do that once
  745. if not j:
  746. continue
  747. # Get the whole vector for this group level
  748. violin_data = remove_na(group_data)
  749. # Draw box and whisker information
  750. if self.inner.startswith("box"):
  751. self.draw_box_lines(ax, violin_data,
  752. support, density, i)
  753. # Draw point observations
  754. elif self.inner.startswith("point"):
  755. self.draw_points(ax, violin_data, i)
  756. # Option 2b: we are drawing full nested violins
  757. # -----------------------------------------------
  758. else:
  759. grid = np.ones(self.gridsize) * (i + offsets[j])
  760. fill_func(support,
  761. grid - density * self.dwidth,
  762. grid + density * self.dwidth,
  763. **kws)
  764. # Draw the interior representation
  765. if self.inner is None:
  766. continue
  767. # Get a nan-free vector of datapoints
  768. hue_mask = self.plot_hues[i] == hue_level
  769. violin_data = remove_na(group_data[hue_mask])
  770. # Draw box and whisker information
  771. if self.inner.startswith("box"):
  772. self.draw_box_lines(ax, violin_data,
  773. support, density,
  774. i + offsets[j])
  775. # Draw quartile lines
  776. elif self.inner.startswith("quart"):
  777. self.draw_quartiles(ax, violin_data,
  778. support, density,
  779. i + offsets[j])
  780. # Draw stick observations
  781. elif self.inner.startswith("stick"):
  782. self.draw_stick_lines(ax, violin_data,
  783. support, density,
  784. i + offsets[j])
  785. # Draw point observations
  786. elif self.inner.startswith("point"):
  787. self.draw_points(ax, violin_data, i + offsets[j])
  788. def draw_single_observation(self, ax, at_group, at_quant, density):
  789. """Draw a line to mark a single observation."""
  790. d_width = density * self.dwidth
  791. if self.orient == "v":
  792. ax.plot([at_group - d_width, at_group + d_width],
  793. [at_quant, at_quant],
  794. color=self.gray,
  795. linewidth=self.linewidth)
  796. else:
  797. ax.plot([at_quant, at_quant],
  798. [at_group - d_width, at_group + d_width],
  799. color=self.gray,
  800. linewidth=self.linewidth)
  801. def draw_box_lines(self, ax, data, support, density, center):
  802. """Draw boxplot information at center of the density."""
  803. # Compute the boxplot statistics
  804. q25, q50, q75 = np.percentile(data, [25, 50, 75])
  805. whisker_lim = 1.5 * iqr(data)
  806. h1 = np.min(data[data >= (q25 - whisker_lim)])
  807. h2 = np.max(data[data <= (q75 + whisker_lim)])
  808. # Draw a boxplot using lines and a point
  809. if self.orient == "v":
  810. ax.plot([center, center], [h1, h2],
  811. linewidth=self.linewidth,
  812. color=self.gray)
  813. ax.plot([center, center], [q25, q75],
  814. linewidth=self.linewidth * 3,
  815. color=self.gray)
  816. ax.scatter(center, q50,
  817. zorder=3,
  818. color="white",
  819. edgecolor=self.gray,
  820. s=np.square(self.linewidth * 2))
  821. else:
  822. ax.plot([h1, h2], [center, center],
  823. linewidth=self.linewidth,
  824. color=self.gray)
  825. ax.plot([q25, q75], [center, center],
  826. linewidth=self.linewidth * 3,
  827. color=self.gray)
  828. ax.scatter(q50, center,
  829. zorder=3,
  830. color="white",
  831. edgecolor=self.gray,
  832. s=np.square(self.linewidth * 2))
  833. def draw_quartiles(self, ax, data, support, density, center, split=False):
  834. """Draw the quartiles as lines at width of density."""
  835. q25, q50, q75 = np.percentile(data, [25, 50, 75])
  836. self.draw_to_density(ax, center, q25, support, density, split,
  837. linewidth=self.linewidth,
  838. dashes=[self.linewidth * 1.5] * 2)
  839. self.draw_to_density(ax, center, q50, support, density, split,
  840. linewidth=self.linewidth,
  841. dashes=[self.linewidth * 3] * 2)
  842. self.draw_to_density(ax, center, q75, support, density, split,
  843. linewidth=self.linewidth,
  844. dashes=[self.linewidth * 1.5] * 2)
  845. def draw_points(self, ax, data, center):
  846. """Draw individual observations as points at middle of the violin."""
  847. kws = dict(s=np.square(self.linewidth * 2),
  848. color=self.gray,
  849. edgecolor=self.gray)
  850. grid = np.ones(len(data)) * center
  851. if self.orient == "v":
  852. ax.scatter(grid, data, **kws)
  853. else:
  854. ax.scatter(data, grid, **kws)
  855. def draw_stick_lines(self, ax, data, support, density,
  856. center, split=False):
  857. """Draw individual observations as sticks at width of density."""
  858. for val in data:
  859. self.draw_to_density(ax, center, val, support, density, split,
  860. linewidth=self.linewidth * .5)
  861. def draw_to_density(self, ax, center, val, support, density, split, **kws):
  862. """Draw a line orthogonal to the value axis at width of density."""
  863. idx = np.argmin(np.abs(support - val))
  864. width = self.dwidth * density[idx] * .99
  865. kws["color"] = self.gray
  866. if self.orient == "v":
  867. if split == "left":
  868. ax.plot([center - width, center], [val, val], **kws)
  869. elif split == "right":
  870. ax.plot([center, center + width], [val, val], **kws)
  871. else:
  872. ax.plot([center - width, center + width], [val, val], **kws)
  873. else:
  874. if split == "left":
  875. ax.plot([val, val], [center - width, center], **kws)
  876. elif split == "right":
  877. ax.plot([val, val], [center, center + width], **kws)
  878. else:
  879. ax.plot([val, val], [center - width, center + width], **kws)
  880. def plot(self, ax):
  881. """Make the violin plot."""
  882. self.draw_violins(ax)
  883. self.annotate_axes(ax)
  884. if self.orient == "h":
  885. ax.invert_yaxis()
  886. class _CategoricalScatterPlotter(_CategoricalPlotter):
  887. default_palette = "dark"
  888. @property
  889. def point_colors(self):
  890. """Return an index into the palette for each scatter point."""
  891. point_colors = []
  892. for i, group_data in enumerate(self.plot_data):
  893. # Initialize the array for this group level
  894. group_colors = np.empty(group_data.size, np.int)
  895. if isinstance(group_data, pd.Series):
  896. group_colors = pd.Series(group_colors, group_data.index)
  897. if self.plot_hues is None:
  898. # Use the same color for all points at this level
  899. # group_color = self.colors[i]
  900. group_colors[:] = i
  901. else:
  902. # Color the points based on the hue level
  903. for j, level in enumerate(self.hue_names):
  904. # hue_color = self.colors[j]
  905. if group_data.size:
  906. group_colors[self.plot_hues[i] == level] = j
  907. point_colors.append(group_colors)
  908. return point_colors
  909. def add_legend_data(self, ax):
  910. """Add empty scatterplot artists with labels for the legend."""
  911. if self.hue_names is not None:
  912. for rgb, label in zip(self.colors, self.hue_names):
  913. ax.scatter([], [],
  914. color=mpl.colors.rgb2hex(rgb),
  915. label=label,
  916. s=60)
  917. class _StripPlotter(_CategoricalScatterPlotter):
  918. """1-d scatterplot with categorical organization."""
  919. def __init__(self, x, y, hue, data, order, hue_order,
  920. jitter, dodge, orient, color, palette):
  921. """Initialize the plotter."""
  922. self.establish_variables(x, y, hue, data, orient, order, hue_order)
  923. self.establish_colors(color, palette, 1)
  924. # Set object attributes
  925. self.dodge = dodge
  926. self.width = .8
  927. if jitter == 1: # Use a good default for `jitter = True`
  928. jlim = 0.1
  929. else:
  930. jlim = float(jitter)
  931. if self.hue_names is not None and dodge:
  932. jlim /= len(self.hue_names)
  933. self.jitterer = stats.uniform(-jlim, jlim * 2).rvs
  934. def draw_stripplot(self, ax, kws):
  935. """Draw the points onto `ax`."""
  936. palette = np.asarray(self.colors)
  937. for i, group_data in enumerate(self.plot_data):
  938. if self.plot_hues is None or not self.dodge:
  939. if self.hue_names is None:
  940. hue_mask = np.ones(group_data.size, np.bool)
  941. else:
  942. hue_mask = np.array([h in self.hue_names
  943. for h in self.plot_hues[i]], np.bool)
  944. # Broken on older numpys
  945. # hue_mask = np.in1d(self.plot_hues[i], self.hue_names)
  946. strip_data = group_data[hue_mask]
  947. point_colors = np.asarray(self.point_colors[i][hue_mask])
  948. # Plot the points in centered positions
  949. cat_pos = np.ones(strip_data.size) * i
  950. cat_pos += self.jitterer(len(strip_data))
  951. kws.update(c=palette[point_colors])
  952. if self.orient == "v":
  953. ax.scatter(cat_pos, strip_data, **kws)
  954. else:
  955. ax.scatter(strip_data, cat_pos, **kws)
  956. else:
  957. offsets = self.hue_offsets
  958. for j, hue_level in enumerate(self.hue_names):
  959. hue_mask = self.plot_hues[i] == hue_level
  960. strip_data = group_data[hue_mask]
  961. point_colors = np.asarray(self.point_colors[i][hue_mask])
  962. # Plot the points in centered positions
  963. center = i + offsets[j]
  964. cat_pos = np.ones(strip_data.size) * center
  965. cat_pos += self.jitterer(len(strip_data))
  966. kws.update(c=palette[point_colors])
  967. if self.orient == "v":
  968. ax.scatter(cat_pos, strip_data, **kws)
  969. else:
  970. ax.scatter(strip_data, cat_pos, **kws)
  971. def plot(self, ax, kws):
  972. """Make the plot."""
  973. self.draw_stripplot(ax, kws)
  974. self.add_legend_data(ax)
  975. self.annotate_axes(ax)
  976. if self.orient == "h":
  977. ax.invert_yaxis()
  978. class _SwarmPlotter(_CategoricalScatterPlotter):
  979. def __init__(self, x, y, hue, data, order, hue_order,
  980. dodge, orient, color, palette):
  981. """Initialize the plotter."""
  982. self.establish_variables(x, y, hue, data, orient, order, hue_order)
  983. self.establish_colors(color, palette, 1)
  984. # Set object attributes
  985. self.dodge = dodge
  986. self.width = .8
  987. def could_overlap(self, xy_i, swarm, d):
  988. """Return a list of all swarm points that could overlap with target.
  989. Assumes that swarm is a sorted list of all points below xy_i.
  990. """
  991. _, y_i = xy_i
  992. neighbors = []
  993. for xy_j in reversed(swarm):
  994. _, y_j = xy_j
  995. if (y_i - y_j) < d:
  996. neighbors.append(xy_j)
  997. else:
  998. break
  999. return np.array(list(reversed(neighbors)))
  1000. def position_candidates(self, xy_i, neighbors, d):
  1001. """Return a list of (x, y) coordinates that might be valid."""
  1002. candidates = [xy_i]
  1003. x_i, y_i = xy_i
  1004. left_first = True
  1005. for x_j, y_j in neighbors:
  1006. dy = y_i - y_j
  1007. dx = np.sqrt(max(d ** 2 - dy ** 2, 0)) * 1.05
  1008. cl, cr = (x_j - dx, y_i), (x_j + dx, y_i)
  1009. if left_first:
  1010. new_candidates = [cl, cr]
  1011. else:
  1012. new_candidates = [cr, cl]
  1013. candidates.extend(new_candidates)
  1014. left_first = not left_first
  1015. return np.array(candidates)
  1016. def first_non_overlapping_candidate(self, candidates, neighbors, d):
  1017. """Remove candidates from the list if they overlap with the swarm."""
  1018. # IF we have no neighbours, all candidates are good.
  1019. if len(neighbors) == 0:
  1020. return candidates[0]
  1021. neighbors_x = neighbors[:, 0]
  1022. neighbors_y = neighbors[:, 1]
  1023. d_square = d ** 2
  1024. for xy_i in candidates:
  1025. x_i, y_i = xy_i
  1026. dx = neighbors_x - x_i
  1027. dy = neighbors_y - y_i
  1028. sq_distances = np.power(dx, 2.0) + np.power(dy, 2.0)
  1029. # good candidate does not overlap any of neighbors
  1030. # which means that squared distance between candidate
  1031. # and any of the neighbours has to be at least
  1032. # square of the diameter
  1033. good_candidate = np.all(sq_distances >= d_square)
  1034. if good_candidate:
  1035. return xy_i
  1036. # If `position_candidates` works well
  1037. # this should never happen
  1038. raise Exception('No non-overlapping candidates found. '
  1039. 'This should not happen.')
  1040. def beeswarm(self, orig_xy, d):
  1041. """Adjust x position of points to avoid overlaps."""
  1042. # In this method, ``x`` is always the categorical axis
  1043. # Center of the swarm, in point coordinates
  1044. midline = orig_xy[0, 0]
  1045. # Start the swarm with the first point
  1046. swarm = [orig_xy[0]]
  1047. # Loop over the remaining points
  1048. for xy_i in orig_xy[1:]:
  1049. # Find the points in the swarm that could possibly
  1050. # overlap with the point we are currently placing
  1051. neighbors = self.could_overlap(xy_i, swarm, d)
  1052. # Find positions that would be valid individually
  1053. # with respect to each of the swarm neighbors
  1054. candidates = self.position_candidates(xy_i, neighbors, d)
  1055. # Sort candidates by their centrality
  1056. offsets = np.abs(candidates[:, 0] - midline)
  1057. candidates = candidates[np.argsort(offsets)]
  1058. # Find the first candidate that does not overlap any neighbours
  1059. new_xy_i = self.first_non_overlapping_candidate(candidates,
  1060. neighbors, d)
  1061. # Place it into the swarm
  1062. swarm.append(new_xy_i)
  1063. return np.array(swarm)
  1064. def add_gutters(self, points, center, width):
  1065. """Stop points from extending beyond their territory."""
  1066. half_width = width / 2
  1067. low_gutter = center - half_width
  1068. off_low = points < low_gutter
  1069. if off_low.any():
  1070. points[off_low] = low_gutter
  1071. high_gutter = center + half_width
  1072. off_high = points > high_gutter
  1073. if off_high.any():
  1074. points[off_high] = high_gutter
  1075. return points
  1076. def swarm_points(self, ax, points, center, width, s, **kws):
  1077. """Find new positions on the categorical axis for each point."""
  1078. # Convert from point size (area) to diameter
  1079. default_lw = mpl.rcParams["patch.linewidth"]
  1080. lw = kws.get("linewidth", kws.get("lw", default_lw))
  1081. dpi = ax.figure.dpi
  1082. d = (np.sqrt(s) + lw) * (dpi / 72)
  1083. # Transform the data coordinates to point coordinates.
  1084. # We'll figure out the swarm positions in the latter
  1085. # and then convert back to data coordinates and replot
  1086. orig_xy = ax.transData.transform(points.get_offsets())
  1087. # Order the variables so that x is the categorical axis
  1088. if self.orient == "h":
  1089. orig_xy = orig_xy[:, [1, 0]]
  1090. # Do the beeswarm in point coordinates
  1091. new_xy = self.beeswarm(orig_xy, d)
  1092. # Transform the point coordinates back to data coordinates
  1093. if self.orient == "h":
  1094. new_xy = new_xy[:, [1, 0]]
  1095. new_x, new_y = ax.transData.inverted().transform(new_xy).T
  1096. # Add gutters
  1097. if self.orient == "v":
  1098. self.add_gutters(new_x, center, width)
  1099. else:
  1100. self.add_gutters(new_y, center, width)
  1101. # Reposition the points so they do not overlap
  1102. points.set_offsets(np.c_[new_x, new_y])
  1103. def draw_swarmplot(self, ax, kws):
  1104. """Plot the data."""
  1105. s = kws.pop("s")
  1106. centers = []
  1107. swarms = []
  1108. palette = np.asarray(self.colors)
  1109. # Set the categorical axes limits here for the swarm math
  1110. if self.orient == "v":
  1111. ax.set_xlim(-.5, len(self.plot_data) - .5)
  1112. else:
  1113. ax.set_ylim(-.5, len(self.plot_data) - .5)
  1114. # Plot each swarm
  1115. for i, group_data in enumerate(self.plot_data):
  1116. if self.plot_hues is None or not self.dodge:
  1117. width = self.width
  1118. if self.hue_names is None:
  1119. hue_mask = np.ones(group_data.size, np.bool)
  1120. else:
  1121. hue_mask = np.array([h in self.hue_names
  1122. for h in self.plot_hues[i]], np.bool)
  1123. # Broken on older numpys
  1124. # hue_mask = np.in1d(self.plot_hues[i], self.hue_names)
  1125. swarm_data = np.asarray(group_data[hue_mask])
  1126. point_colors = np.asarray(self.point_colors[i][hue_mask])
  1127. # Sort the points for the beeswarm algorithm
  1128. sorter = np.argsort(swarm_data)
  1129. swarm_data = swarm_data[sorter]
  1130. point_colors = point_colors[sorter]
  1131. # Plot the points in centered positions
  1132. cat_pos = np.ones(swarm_data.size) * i
  1133. kws.update(c=palette[point_colors])
  1134. if self.orient == "v":
  1135. points = ax.scatter(cat_pos, swarm_data, s=s, **kws)
  1136. else:
  1137. points = ax.scatter(swarm_data, cat_pos, s=s, **kws)
  1138. centers.append(i)
  1139. swarms.append(points)
  1140. else:
  1141. offsets = self.hue_offsets
  1142. width = self.nested_width
  1143. for j, hue_level in enumerate(self.hue_names):
  1144. hue_mask = self.plot_hues[i] == hue_level
  1145. swarm_data = np.asarray(group_data[hue_mask])
  1146. point_colors = np.asarray(self.point_colors[i][hue_mask])
  1147. # Sort the points for the beeswarm algorithm
  1148. sorter = np.argsort(swarm_data)
  1149. swarm_data = swarm_data[sorter]
  1150. point_colors = point_colors[sorter]
  1151. # Plot the points in centered positions
  1152. center = i + offsets[j]
  1153. cat_pos = np.ones(swarm_data.size) * center
  1154. kws.update(c=palette[point_colors])
  1155. if self.orient == "v":
  1156. points = ax.scatter(cat_pos, swarm_data, s=s, **kws)
  1157. else:
  1158. points = ax.scatter(swarm_data, cat_pos, s=s, **kws)
  1159. centers.append(center)
  1160. swarms.append(points)
  1161. # Autoscale the valus axis to set the data/axes transforms properly
  1162. ax.autoscale_view(scalex=self.orient == "h", scaley=self.orient == "v")
  1163. # Update the position of each point on the categorical axis
  1164. # Do this after plotting so that the numerical axis limits are correct
  1165. for center, swarm in zip(centers, swarms):
  1166. if swarm.get_offsets().size:
  1167. self.swarm_points(ax, swarm, center, width, s, **kws)
  1168. def plot(self, ax, kws):
  1169. """Make the full plot."""
  1170. self.draw_swarmplot(ax, kws)
  1171. self.add_legend_data(ax)
  1172. self.annotate_axes(ax)
  1173. if self.orient == "h":
  1174. ax.invert_yaxis()
  1175. class _CategoricalStatPlotter(_CategoricalPlotter):
  1176. @property
  1177. def nested_width(self):
  1178. """A float with the width of plot elements when hue nesting is used."""
  1179. if self.dodge:
  1180. width = self.width / len(self.hue_names)
  1181. else:
  1182. width = self.width
  1183. return width
  1184. def estimate_statistic(self, estimator, ci, n_boot, seed):
  1185. if self.hue_names is None:
  1186. statistic = []
  1187. confint = []
  1188. else:
  1189. statistic = [[] for _ in self.plot_data]
  1190. confint = [[] for _ in self.plot_data]
  1191. for i, group_data in enumerate(self.plot_data):
  1192. # Option 1: we have a single layer of grouping
  1193. # --------------------------------------------
  1194. if self.plot_hues is None:
  1195. if self.plot_units is None:
  1196. stat_data = remove_na(group_data)
  1197. unit_data = None
  1198. else:
  1199. unit_data = self.plot_units[i]
  1200. have = pd.notnull(np.c_[group_data, unit_data]).all(axis=1)
  1201. stat_data = group_data[have]
  1202. unit_data = unit_data[have]
  1203. # Estimate a statistic from the vector of data
  1204. if not stat_data.size:
  1205. statistic.append(np.nan)
  1206. else:
  1207. statistic.append(estimator(stat_data))
  1208. # Get a confidence interval for this estimate
  1209. if ci is not None:
  1210. if stat_data.size < 2:
  1211. confint.append([np.nan, np.nan])
  1212. continue
  1213. if ci == "sd":
  1214. estimate = estimator(stat_data)
  1215. sd = np.std(stat_data)
  1216. confint.append((estimate - sd, estimate + sd))
  1217. else:
  1218. boots = bootstrap(stat_data, func=estimator,
  1219. n_boot=n_boot,
  1220. units=unit_data,
  1221. seed=seed)
  1222. confint.append(utils.ci(boots, ci))
  1223. # Option 2: we are grouping by a hue layer
  1224. # ----------------------------------------
  1225. else:
  1226. for j, hue_level in enumerate(self.hue_names):
  1227. if not self.plot_hues[i].size:
  1228. statistic[i].append(np.nan)
  1229. if ci is not None:
  1230. confint[i].append((np.nan, np.nan))
  1231. continue
  1232. hue_mask = self.plot_hues[i] == hue_level
  1233. if self.plot_units is None:
  1234. stat_data = remove_na(group_data[hue_mask])
  1235. unit_data = None
  1236. else:
  1237. group_units = self.plot_units[i]
  1238. have = pd.notnull(
  1239. np.c_[group_data, group_units]
  1240. ).all(axis=1)
  1241. stat_data = group_data[hue_mask & have]
  1242. unit_data = group_units[hue_mask & have]
  1243. # Estimate a statistic from the vector of data
  1244. if not stat_data.size:
  1245. statistic[i].append(np.nan)
  1246. else:
  1247. statistic[i].append(estimator(stat_data))
  1248. # Get a confidence interval for this estimate
  1249. if ci is not None:
  1250. if stat_data.size < 2:
  1251. confint[i].append([np.nan, np.nan])
  1252. continue
  1253. if ci == "sd":
  1254. estimate = estimator(stat_data)
  1255. sd = np.std(stat_data)
  1256. confint[i].append((estimate - sd, estimate + sd))
  1257. else:
  1258. boots = bootstrap(stat_data, func=estimator,
  1259. n_boot=n_boot,
  1260. units=unit_data,
  1261. seed=seed)
  1262. confint[i].append(utils.ci(boots, ci))
  1263. # Save the resulting values for plotting
  1264. self.statistic = np.array(statistic)
  1265. self.confint = np.array(confint)
  1266. def draw_confints(self, ax, at_group, confint, colors,
  1267. errwidth=None, capsize=None, **kws):
  1268. if errwidth is not None:
  1269. kws.setdefault("lw", errwidth)
  1270. else:
  1271. kws.setdefault("lw", mpl.rcParams["lines.linewidth"] * 1.8)
  1272. for at, (ci_low, ci_high), color in zip(at_group,
  1273. confint,
  1274. colors):
  1275. if self.orient == "v":
  1276. ax.plot([at, at], [ci_low, ci_high], color=color, **kws)
  1277. if capsize is not None:
  1278. ax.plot([at - capsize / 2, at + capsize / 2],
  1279. [ci_low, ci_low], color=color, **kws)
  1280. ax.plot([at - capsize / 2, at + capsize / 2],
  1281. [ci_high, ci_high], color=color, **kws)
  1282. else:
  1283. ax.plot([ci_low, ci_high], [at, at], color=color, **kws)
  1284. if capsize is not None:
  1285. ax.plot([ci_low, ci_low],
  1286. [at - capsize / 2, at + capsize / 2],
  1287. color=color, **kws)
  1288. ax.plot([ci_high, ci_high],
  1289. [at - capsize / 2, at + capsize / 2],
  1290. color=color, **kws)
  1291. class _BarPlotter(_CategoricalStatPlotter):
  1292. """Show point estimates and confidence intervals with bars."""
  1293. def __init__(self, x, y, hue, data, order, hue_order,
  1294. estimator, ci, n_boot, units, seed,
  1295. orient, color, palette, saturation, errcolor,
  1296. errwidth, capsize, dodge):
  1297. """Initialize the plotter."""
  1298. self.establish_variables(x, y, hue, data, orient,
  1299. order, hue_order, units)
  1300. self.establish_colors(color, palette, saturation)
  1301. self.estimate_statistic(estimator, ci, n_boot, seed)
  1302. self.dodge = dodge
  1303. self.errcolor = errcolor
  1304. self.errwidth = errwidth
  1305. self.capsize = capsize
  1306. def draw_bars(self, ax, kws):
  1307. """Draw the bars onto `ax`."""
  1308. # Get the right matplotlib function depending on the orientation
  1309. barfunc = ax.bar if self.orient == "v" else ax.barh
  1310. barpos = np.arange(len(self.statistic))
  1311. if self.plot_hues is None:
  1312. # Draw the bars
  1313. barfunc(barpos, self.statistic, self.width,
  1314. color=self.colors, align="center", **kws)
  1315. # Draw the confidence intervals
  1316. errcolors = [self.errcolor] * len(barpos)
  1317. self.draw_confints(ax,
  1318. barpos,
  1319. self.confint,
  1320. errcolors,
  1321. self.errwidth,
  1322. self.capsize)
  1323. else:
  1324. for j, hue_level in enumerate(self.hue_names):
  1325. # Draw the bars
  1326. offpos = barpos + self.hue_offsets[j]
  1327. barfunc(offpos, self.statistic[:, j], self.nested_width,
  1328. color=self.colors[j], align="center",
  1329. label=hue_level, **kws)
  1330. # Draw the confidence intervals
  1331. if self.confint.size:
  1332. confint = self.confint[:, j]
  1333. errcolors = [self.errcolor] * len(offpos)
  1334. self.draw_confints(ax,
  1335. offpos,
  1336. confint,
  1337. errcolors,
  1338. self.errwidth,
  1339. self.capsize)
  1340. def plot(self, ax, bar_kws):
  1341. """Make the plot."""
  1342. self.draw_bars(ax, bar_kws)
  1343. self.annotate_axes(ax)
  1344. if self.orient == "h":
  1345. ax.invert_yaxis()
  1346. class _PointPlotter(_CategoricalStatPlotter):
  1347. default_palette = "dark"
  1348. """Show point estimates and confidence intervals with (joined) points."""
  1349. def __init__(self, x, y, hue, data, order, hue_order,
  1350. estimator, ci, n_boot, units, seed,
  1351. markers, linestyles, dodge, join, scale,
  1352. orient, color, palette, errwidth=None, capsize=None):
  1353. """Initialize the plotter."""
  1354. self.establish_variables(x, y, hue, data, orient,
  1355. order, hue_order, units)
  1356. self.establish_colors(color, palette, 1)
  1357. self.estimate_statistic(estimator, ci, n_boot, seed)
  1358. # Override the default palette for single-color plots
  1359. if hue is None and color is None and palette is None:
  1360. self.colors = [color_palette()[0]] * len(self.colors)
  1361. # Don't join single-layer plots with different colors
  1362. if hue is None and palette is not None:
  1363. join = False
  1364. # Use a good default for `dodge=True`
  1365. if dodge is True and self.hue_names is not None:
  1366. dodge = .025 * len(self.hue_names)
  1367. # Make sure we have a marker for each hue level
  1368. if isinstance(markers, str):
  1369. markers = [markers] * len(self.colors)
  1370. self.markers = markers
  1371. # Make sure we have a line style for each hue level
  1372. if isinstance(linestyles, str):
  1373. linestyles = [linestyles] * len(self.colors)
  1374. self.linestyles = linestyles
  1375. # Set the other plot components
  1376. self.dodge = dodge
  1377. self.join = join
  1378. self.scale = scale
  1379. self.errwidth = errwidth
  1380. self.capsize = capsize
  1381. @property
  1382. def hue_offsets(self):
  1383. """Offsets relative to the center position for each hue level."""
  1384. if self.dodge:
  1385. offset = np.linspace(0, self.dodge, len(self.hue_names))
  1386. offset -= offset.mean()
  1387. else:
  1388. offset = np.zeros(len(self.hue_names))
  1389. return offset
  1390. def draw_points(self, ax):
  1391. """Draw the main data components of the plot."""
  1392. # Get the center positions on the categorical axis
  1393. pointpos = np.arange(len(self.statistic))
  1394. # Get the size of the plot elements
  1395. lw = mpl.rcParams["lines.linewidth"] * 1.8 * self.scale
  1396. mew = lw * .75
  1397. markersize = np.pi * np.square(lw) * 2
  1398. if self.plot_hues is None:
  1399. # Draw lines joining each estimate point
  1400. if self.join:
  1401. color = self.colors[0]
  1402. ls = self.linestyles[0]
  1403. if self.orient == "h":
  1404. ax.plot(self.statistic, pointpos,
  1405. color=color, ls=ls, lw=lw)
  1406. else:
  1407. ax.plot(pointpos, self.statistic,
  1408. color=color, ls=ls, lw=lw)
  1409. # Draw the confidence intervals
  1410. self.draw_confints(ax, pointpos, self.confint, self.colors,
  1411. self.errwidth, self.capsize)
  1412. # Draw the estimate points
  1413. marker = self.markers[0]
  1414. colors = [mpl.colors.colorConverter.to_rgb(c) for c in self.colors]
  1415. if self.orient == "h":
  1416. x, y = self.statistic, pointpos
  1417. else:
  1418. x, y = pointpos, self.statistic
  1419. ax.scatter(x, y,
  1420. linewidth=mew, marker=marker, s=markersize,
  1421. facecolor=colors, edgecolor=colors)
  1422. else:
  1423. offsets = self.hue_offsets
  1424. for j, hue_level in enumerate(self.hue_names):
  1425. # Determine the values to plot for this level
  1426. statistic = self.statistic[:, j]
  1427. # Determine the position on the categorical and z axes
  1428. offpos = pointpos + offsets[j]
  1429. z = j + 1
  1430. # Draw lines joining each estimate point
  1431. if self.join:
  1432. color = self.colors[j]
  1433. ls = self.linestyles[j]
  1434. if self.orient == "h":
  1435. ax.plot(statistic, offpos, color=color,
  1436. zorder=z, ls=ls, lw=lw)
  1437. else:
  1438. ax.plot(offpos, statistic, color=color,
  1439. zorder=z, ls=ls, lw=lw)
  1440. # Draw the confidence intervals
  1441. if self.confint.size:
  1442. confint = self.confint[:, j]
  1443. errcolors = [self.colors[j]] * len(offpos)
  1444. self.draw_confints(ax, offpos, confint, errcolors,
  1445. self.errwidth, self.capsize,
  1446. zorder=z)
  1447. # Draw the estimate points
  1448. n_points = len(remove_na(offpos))
  1449. marker = self.markers[j]
  1450. color = mpl.colors.colorConverter.to_rgb(self.colors[j])
  1451. if self.orient == "h":
  1452. x, y = statistic, offpos
  1453. else:
  1454. x, y = offpos, statistic
  1455. if not len(remove_na(statistic)):
  1456. x = y = [np.nan] * n_points
  1457. ax.scatter(x, y, label=hue_level,
  1458. facecolor=color, edgecolor=color,
  1459. linewidth=mew, marker=marker, s=markersize,
  1460. zorder=z)
  1461. def plot(self, ax):
  1462. """Make the plot."""
  1463. self.draw_points(ax)
  1464. self.annotate_axes(ax)
  1465. if self.orient == "h":
  1466. ax.invert_yaxis()
  1467. class _LVPlotter(_CategoricalPlotter):
  1468. def __init__(self, x, y, hue, data, order, hue_order,
  1469. orient, color, palette, saturation,
  1470. width, dodge, k_depth, linewidth, scale, outlier_prop,
  1471. showfliers=True):
  1472. # TODO assigning variables for None is unneccesary
  1473. if width is None:
  1474. width = .8
  1475. self.width = width
  1476. self.dodge = dodge
  1477. if saturation is None:
  1478. saturation = .75
  1479. self.saturation = saturation
  1480. if k_depth is None:
  1481. k_depth = 'proportion'
  1482. self.k_depth = k_depth
  1483. if linewidth is None:
  1484. linewidth = mpl.rcParams["lines.linewidth"]
  1485. self.linewidth = linewidth
  1486. if scale is None:
  1487. scale = 'exponential'
  1488. self.scale = scale
  1489. self.outlier_prop = outlier_prop
  1490. self.showfliers = showfliers
  1491. self.establish_variables(x, y, hue, data, orient, order, hue_order)
  1492. self.establish_colors(color, palette, saturation)
  1493. def _lv_box_ends(self, vals, k_depth='proportion', outlier_prop=None):
  1494. """Get the number of data points and calculate `depth` of
  1495. letter-value plot."""
  1496. vals = np.asarray(vals)
  1497. vals = vals[np.isfinite(vals)]
  1498. n = len(vals)
  1499. # If p is not set, calculate it so that 8 points are outliers
  1500. if not outlier_prop:
  1501. # Conventional boxplots assume this proportion of the data are
  1502. # outliers.
  1503. p = 0.007
  1504. else:
  1505. if ((outlier_prop > 1.) or (outlier_prop < 0.)):
  1506. raise ValueError('outlier_prop not in range [0, 1]!')
  1507. p = outlier_prop
  1508. # Select the depth, i.e. number of boxes to draw, based on the method
  1509. k_dict = {'proportion': (np.log2(n)) - int(np.log2(n*p)) + 1,
  1510. 'tukey': (np.log2(n)) - 3,
  1511. 'trustworthy': (np.log2(n) -
  1512. np.log2(2*stats.norm.ppf((1-p))**2)) + 1}
  1513. k = k_dict[k_depth]
  1514. try:
  1515. k = int(k)
  1516. except ValueError:
  1517. k = 1
  1518. # If the number happens to be less than 0, set k to 0
  1519. if k < 1.:
  1520. k = 1
  1521. # Calculate the upper box ends
  1522. upper = [100*(1 - 0.5**(i+2)) for i in range(k, -1, -1)]
  1523. # Calculate the lower box ends
  1524. lower = [100*(0.5**(i+2)) for i in range(k, -1, -1)]
  1525. # Stitch the box ends together
  1526. percentile_ends = [(i, j) for i, j in zip(lower, upper)]
  1527. box_ends = [np.percentile(vals, q) for q in percentile_ends]
  1528. return box_ends, k
  1529. def _lv_outliers(self, vals, k):
  1530. """Find the outliers based on the letter value depth."""
  1531. perc_ends = (100*(0.5**(k+2)), 100*(1 - 0.5**(k+2)))
  1532. edges = np.percentile(vals, perc_ends)
  1533. lower_out = vals[np.where(vals < edges[0])[0]]
  1534. upper_out = vals[np.where(vals > edges[1])[0]]
  1535. return np.concatenate((lower_out, upper_out))
  1536. def _width_functions(self, width_func):
  1537. # Dictionary of functions for computing the width of the boxes
  1538. width_functions = {'linear': lambda h, i, k: (i + 1.) / k,
  1539. 'exponential': lambda h, i, k: 2**(-k+i-1),
  1540. 'area': lambda h, i, k: (1 - 2**(-k+i-2)) / h}
  1541. return width_functions[width_func]
  1542. def _lvplot(self, box_data, positions,
  1543. color=[255. / 256., 185. / 256., 0.],
  1544. vert=True, widths=1, k_depth='proportion',
  1545. ax=None, outlier_prop=None, scale='exponential',
  1546. showfliers=True, **kws):
  1547. x = positions[0]
  1548. box_data = np.asarray(box_data)
  1549. # If we only have one data point, plot a line
  1550. if len(box_data) == 1:
  1551. kws.update({'color': self.gray, 'linestyle': '-'})
  1552. ys = [box_data[0], box_data[0]]
  1553. xs = [x - widths / 2, x + widths / 2]
  1554. if vert:
  1555. xx, yy = xs, ys
  1556. else:
  1557. xx, yy = ys, xs
  1558. ax.plot(xx, yy, **kws)
  1559. else:
  1560. # Get the number of data points and calculate "depth" of
  1561. # letter-value plot
  1562. box_ends, k = self._lv_box_ends(box_data, k_depth=k_depth,
  1563. outlier_prop=outlier_prop)
  1564. # Anonymous functions for calculating the width and height
  1565. # of the letter value boxes
  1566. width = self._width_functions(scale)
  1567. # Function to find height of boxes
  1568. def height(b):
  1569. return b[1] - b[0]
  1570. # Functions to construct the letter value boxes
  1571. def vert_perc_box(x, b, i, k, w):
  1572. rect = Patches.Rectangle((x - widths*w / 2, b[0]),
  1573. widths*w,
  1574. height(b), fill=True)
  1575. return rect
  1576. def horz_perc_box(x, b, i, k, w):
  1577. rect = Patches.Rectangle((b[0], x - widths*w / 2),
  1578. height(b), widths*w,
  1579. fill=True)
  1580. return rect
  1581. # Scale the width of the boxes so the biggest starts at 1
  1582. w_area = np.array([width(height(b), i, k)
  1583. for i, b in enumerate(box_ends)])
  1584. w_area = w_area / np.max(w_area)
  1585. # Calculate the medians
  1586. y = np.median(box_data)
  1587. # Calculate the outliers and plot (only if showfliers == True)
  1588. outliers = []
  1589. if self.showfliers:
  1590. outliers = self._lv_outliers(box_data, k)
  1591. hex_color = mpl.colors.rgb2hex(color)
  1592. if vert:
  1593. boxes = [vert_perc_box(x, b[0], i, k, b[1])
  1594. for i, b in enumerate(zip(box_ends, w_area))]
  1595. # Plot the medians
  1596. ax.plot([x - widths / 2, x + widths / 2], [y, y],
  1597. c='.15', alpha=.45, **kws)
  1598. ax.scatter(np.repeat(x, len(outliers)), outliers,
  1599. marker='d', c=hex_color, **kws)
  1600. else:
  1601. boxes = [horz_perc_box(x, b[0], i, k, b[1])
  1602. for i, b in enumerate(zip(box_ends, w_area))]
  1603. # Plot the medians
  1604. ax.plot([y, y], [x - widths / 2, x + widths / 2],
  1605. c='.15', alpha=.45, **kws)
  1606. ax.scatter(outliers, np.repeat(x, len(outliers)),
  1607. marker='d', c=hex_color, **kws)
  1608. # Construct a color map from the input color
  1609. rgb = [[1, 1, 1], hex_color]
  1610. cmap = mpl.colors.LinearSegmentedColormap.from_list('new_map', rgb)
  1611. collection = PatchCollection(boxes, cmap=cmap)
  1612. # Set the color gradation
  1613. collection.set_array(np.array(np.linspace(0, 1, len(boxes))))
  1614. # Plot the boxes
  1615. ax.add_collection(collection)
  1616. def draw_letter_value_plot(self, ax, kws):
  1617. """Use matplotlib to draw a letter value plot on an Axes."""
  1618. vert = self.orient == "v"
  1619. for i, group_data in enumerate(self.plot_data):
  1620. if self.plot_hues is None:
  1621. # Handle case where there is data at this level
  1622. if group_data.size == 0:
  1623. continue
  1624. # Draw a single box or a set of boxes
  1625. # with a single level of grouping
  1626. box_data = remove_na(group_data)
  1627. # Handle case where there is no non-null data
  1628. if box_data.size == 0:
  1629. continue
  1630. color = self.colors[i]
  1631. self._lvplot(box_data,
  1632. positions=[i],
  1633. color=color,
  1634. vert=vert,
  1635. widths=self.width,
  1636. k_depth=self.k_depth,
  1637. ax=ax,
  1638. scale=self.scale,
  1639. outlier_prop=self.outlier_prop,
  1640. showfliers=self.showfliers,
  1641. **kws)
  1642. else:
  1643. # Draw nested groups of boxes
  1644. offsets = self.hue_offsets
  1645. for j, hue_level in enumerate(self.hue_names):
  1646. # Add a legend for this hue level
  1647. if not i:
  1648. self.add_legend_data(ax, self.colors[j], hue_level)
  1649. # Handle case where there is data at this level
  1650. if group_data.size == 0:
  1651. continue
  1652. hue_mask = self.plot_hues[i] == hue_level
  1653. box_data = remove_na(group_data[hue_mask])
  1654. # Handle case where there is no non-null data
  1655. if box_data.size == 0:
  1656. continue
  1657. color = self.colors[j]
  1658. center = i + offsets[j]
  1659. self._lvplot(box_data,
  1660. positions=[center],
  1661. color=color,
  1662. vert=vert,
  1663. widths=self.nested_width,
  1664. k_depth=self.k_depth,
  1665. ax=ax,
  1666. scale=self.scale,
  1667. outlier_prop=self.outlier_prop,
  1668. **kws)
  1669. def plot(self, ax, boxplot_kws):
  1670. """Make the plot."""
  1671. self.draw_letter_value_plot(ax, boxplot_kws)
  1672. self.annotate_axes(ax)
  1673. if self.orient == "h":
  1674. ax.invert_yaxis()
  1675. _categorical_docs = dict(
  1676. # Shared narrative docs
  1677. categorical_narrative=dedent("""\
  1678. This function always treats one of the variables as categorical and
  1679. draws data at ordinal positions (0, 1, ... n) on the relevant axis, even
  1680. when the data has a numeric or date type.
  1681. See the :ref:`tutorial <categorical_tutorial>` for more information.\
  1682. """),
  1683. main_api_narrative=dedent("""\
  1684. Input data can be passed in a variety of formats, including:
  1685. - Vectors of data represented as lists, numpy arrays, or pandas Series
  1686. objects passed directly to the ``x``, ``y``, and/or ``hue`` parameters.
  1687. - A "long-form" DataFrame, in which case the ``x``, ``y``, and ``hue``
  1688. variables will determine how the data are plotted.
  1689. - A "wide-form" DataFrame, such that each numeric column will be plotted.
  1690. - An array or list of vectors.
  1691. In most cases, it is possible to use numpy or Python objects, but pandas
  1692. objects are preferable because the associated names will be used to
  1693. annotate the axes. Additionally, you can use Categorical types for the
  1694. grouping variables to control the order of plot elements.\
  1695. """),
  1696. # Shared function parameters
  1697. input_params=dedent("""\
  1698. x, y, hue : names of variables in ``data`` or vector data, optional
  1699. Inputs for plotting long-form data. See examples for interpretation.\
  1700. """),
  1701. string_input_params=dedent("""\
  1702. x, y, hue : names of variables in ``data``
  1703. Inputs for plotting long-form data. See examples for interpretation.\
  1704. """),
  1705. categorical_data=dedent("""\
  1706. data : DataFrame, array, or list of arrays, optional
  1707. Dataset for plotting. If ``x`` and ``y`` are absent, this is
  1708. interpreted as wide-form. Otherwise it is expected to be long-form.\
  1709. """),
  1710. long_form_data=dedent("""\
  1711. data : DataFrame
  1712. Long-form (tidy) dataset for plotting. Each column should correspond
  1713. to a variable, and each row should correspond to an observation.\
  1714. """),
  1715. order_vars=dedent("""\
  1716. order, hue_order : lists of strings, optional
  1717. Order to plot the categorical levels in, otherwise the levels are
  1718. inferred from the data objects.\
  1719. """),
  1720. stat_api_params=dedent("""\
  1721. estimator : callable that maps vector -> scalar, optional
  1722. Statistical function to estimate within each categorical bin.
  1723. ci : float or "sd" or None, optional
  1724. Size of confidence intervals to draw around estimated values. If
  1725. "sd", skip bootstrapping and draw the standard deviation of the
  1726. observations. If ``None``, no bootstrapping will be performed, and
  1727. error bars will not be drawn.
  1728. n_boot : int, optional
  1729. Number of bootstrap iterations to use when computing confidence
  1730. intervals.
  1731. units : name of variable in ``data`` or vector data, optional
  1732. Identifier of sampling units, which will be used to perform a
  1733. multilevel bootstrap and account for repeated measures design.
  1734. seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
  1735. Seed or random number generator for reproducible bootstrapping.\
  1736. """),
  1737. orient=dedent("""\
  1738. orient : "v" | "h", optional
  1739. Orientation of the plot (vertical or horizontal). This is usually
  1740. inferred from the dtype of the input variables, but can be used to
  1741. specify when the "categorical" variable is a numeric or when plotting
  1742. wide-form data.\
  1743. """),
  1744. color=dedent("""\
  1745. color : matplotlib color, optional
  1746. Color for all of the elements, or seed for a gradient palette.\
  1747. """),
  1748. palette=dedent("""\
  1749. palette : palette name, list, or dict, optional
  1750. Color palette that maps either the grouping variable or the hue
  1751. variable. If the palette is a dictionary, keys should be names of
  1752. levels and values should be matplotlib colors.\
  1753. """),
  1754. saturation=dedent("""\
  1755. saturation : float, optional
  1756. Proportion of the original saturation to draw colors at. Large patches
  1757. often look better with slightly desaturated colors, but set this to
  1758. ``1`` if you want the plot colors to perfectly match the input color
  1759. spec.\
  1760. """),
  1761. capsize=dedent("""\
  1762. capsize : float, optional
  1763. Width of the "caps" on error bars.
  1764. """),
  1765. errwidth=dedent("""\
  1766. errwidth : float, optional
  1767. Thickness of error bar lines (and caps).\
  1768. """),
  1769. width=dedent("""\
  1770. width : float, optional
  1771. Width of a full element when not using hue nesting, or width of all the
  1772. elements for one level of the major grouping variable.\
  1773. """),
  1774. dodge=dedent("""\
  1775. dodge : bool, optional
  1776. When hue nesting is used, whether elements should be shifted along the
  1777. categorical axis.\
  1778. """),
  1779. linewidth=dedent("""\
  1780. linewidth : float, optional
  1781. Width of the gray lines that frame the plot elements.\
  1782. """),
  1783. ax_in=dedent("""\
  1784. ax : matplotlib Axes, optional
  1785. Axes object to draw the plot onto, otherwise uses the current Axes.\
  1786. """),
  1787. ax_out=dedent("""\
  1788. ax : matplotlib Axes
  1789. Returns the Axes object with the plot drawn onto it.\
  1790. """),
  1791. # Shared see also
  1792. boxplot=dedent("""\
  1793. boxplot : A traditional box-and-whisker plot with a similar API.\
  1794. """),
  1795. violinplot=dedent("""\
  1796. violinplot : A combination of boxplot and kernel density estimation.\
  1797. """),
  1798. stripplot=dedent("""\
  1799. stripplot : A scatterplot where one variable is categorical. Can be used
  1800. in conjunction with other plots to show each observation.\
  1801. """),
  1802. swarmplot=dedent("""\
  1803. swarmplot : A categorical scatterplot where the points do not overlap. Can
  1804. be used with other plots to show each observation.\
  1805. """),
  1806. barplot=dedent("""\
  1807. barplot : Show point estimates and confidence intervals using bars.\
  1808. """),
  1809. countplot=dedent("""\
  1810. countplot : Show the counts of observations in each categorical bin.\
  1811. """),
  1812. pointplot=dedent("""\
  1813. pointplot : Show point estimates and confidence intervals using scatterplot
  1814. glyphs.\
  1815. """),
  1816. catplot=dedent("""\
  1817. catplot : Combine a categorical plot with a :class:`FacetGrid`.\
  1818. """),
  1819. boxenplot=dedent("""\
  1820. boxenplot : An enhanced boxplot for larger datasets.\
  1821. """),
  1822. )
  1823. _categorical_docs.update(_facet_docs)
  1824. def boxplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  1825. orient=None, color=None, palette=None, saturation=.75,
  1826. width=.8, dodge=True, fliersize=5, linewidth=None,
  1827. whis=1.5, ax=None, **kwargs):
  1828. plotter = _BoxPlotter(x, y, hue, data, order, hue_order,
  1829. orient, color, palette, saturation,
  1830. width, dodge, fliersize, linewidth)
  1831. if ax is None:
  1832. ax = plt.gca()
  1833. kwargs.update(dict(whis=whis))
  1834. plotter.plot(ax, kwargs)
  1835. return ax
  1836. boxplot.__doc__ = dedent("""\
  1837. Draw a box plot to show distributions with respect to categories.
  1838. A box plot (or box-and-whisker plot) shows the distribution of quantitative
  1839. data in a way that facilitates comparisons between variables or across
  1840. levels of a categorical variable. The box shows the quartiles of the
  1841. dataset while the whiskers extend to show the rest of the distribution,
  1842. except for points that are determined to be "outliers" using a method
  1843. that is a function of the inter-quartile range.
  1844. {main_api_narrative}
  1845. {categorical_narrative}
  1846. Parameters
  1847. ----------
  1848. {input_params}
  1849. {categorical_data}
  1850. {order_vars}
  1851. {orient}
  1852. {color}
  1853. {palette}
  1854. {saturation}
  1855. {width}
  1856. {dodge}
  1857. fliersize : float, optional
  1858. Size of the markers used to indicate outlier observations.
  1859. {linewidth}
  1860. whis : float, optional
  1861. Proportion of the IQR past the low and high quartiles to extend the
  1862. plot whiskers. Points outside this range will be identified as
  1863. outliers.
  1864. {ax_in}
  1865. kwargs : key, value mappings
  1866. Other keyword arguments are passed through to
  1867. :meth:`matplotlib.axes.Axes.boxplot`.
  1868. Returns
  1869. -------
  1870. {ax_out}
  1871. See Also
  1872. --------
  1873. {violinplot}
  1874. {stripplot}
  1875. {swarmplot}
  1876. {catplot}
  1877. Examples
  1878. --------
  1879. Draw a single horizontal boxplot:
  1880. .. plot::
  1881. :context: close-figs
  1882. >>> import seaborn as sns
  1883. >>> sns.set(style="whitegrid")
  1884. >>> tips = sns.load_dataset("tips")
  1885. >>> ax = sns.boxplot(x=tips["total_bill"])
  1886. Draw a vertical boxplot grouped by a categorical variable:
  1887. .. plot::
  1888. :context: close-figs
  1889. >>> ax = sns.boxplot(x="day", y="total_bill", data=tips)
  1890. Draw a boxplot with nested grouping by two categorical variables:
  1891. .. plot::
  1892. :context: close-figs
  1893. >>> ax = sns.boxplot(x="day", y="total_bill", hue="smoker",
  1894. ... data=tips, palette="Set3")
  1895. Draw a boxplot with nested grouping when some bins are empty:
  1896. .. plot::
  1897. :context: close-figs
  1898. >>> ax = sns.boxplot(x="day", y="total_bill", hue="time",
  1899. ... data=tips, linewidth=2.5)
  1900. Control box order by passing an explicit order:
  1901. .. plot::
  1902. :context: close-figs
  1903. >>> ax = sns.boxplot(x="time", y="tip", data=tips,
  1904. ... order=["Dinner", "Lunch"])
  1905. Draw a boxplot for each numeric variable in a DataFrame:
  1906. .. plot::
  1907. :context: close-figs
  1908. >>> iris = sns.load_dataset("iris")
  1909. >>> ax = sns.boxplot(data=iris, orient="h", palette="Set2")
  1910. Use ``hue`` without changing box position or width:
  1911. .. plot::
  1912. :context: close-figs
  1913. >>> tips["weekend"] = tips["day"].isin(["Sat", "Sun"])
  1914. >>> ax = sns.boxplot(x="day", y="total_bill", hue="weekend",
  1915. ... data=tips, dodge=False)
  1916. Use :func:`swarmplot` to show the datapoints on top of the boxes:
  1917. .. plot::
  1918. :context: close-figs
  1919. >>> ax = sns.boxplot(x="day", y="total_bill", data=tips)
  1920. >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips, color=".25")
  1921. Use :func:`catplot` to combine a :func:`boxplot` and a
  1922. :class:`FacetGrid`. This allows grouping within additional categorical
  1923. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  1924. directly, as it ensures synchronization of variable order across facets:
  1925. .. plot::
  1926. :context: close-figs
  1927. >>> g = sns.catplot(x="sex", y="total_bill",
  1928. ... hue="smoker", col="time",
  1929. ... data=tips, kind="box",
  1930. ... height=4, aspect=.7);
  1931. """).format(**_categorical_docs)
  1932. def violinplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  1933. bw="scott", cut=2, scale="area", scale_hue=True, gridsize=100,
  1934. width=.8, inner="box", split=False, dodge=True, orient=None,
  1935. linewidth=None, color=None, palette=None, saturation=.75,
  1936. ax=None, **kwargs):
  1937. plotter = _ViolinPlotter(x, y, hue, data, order, hue_order,
  1938. bw, cut, scale, scale_hue, gridsize,
  1939. width, inner, split, dodge, orient, linewidth,
  1940. color, palette, saturation)
  1941. if ax is None:
  1942. ax = plt.gca()
  1943. plotter.plot(ax)
  1944. return ax
  1945. violinplot.__doc__ = dedent("""\
  1946. Draw a combination of boxplot and kernel density estimate.
  1947. A violin plot plays a similar role as a box and whisker plot. It shows the
  1948. distribution of quantitative data across several levels of one (or more)
  1949. categorical variables such that those distributions can be compared. Unlike
  1950. a box plot, in which all of the plot components correspond to actual
  1951. datapoints, the violin plot features a kernel density estimation of the
  1952. underlying distribution.
  1953. This can be an effective and attractive way to show multiple distributions
  1954. of data at once, but keep in mind that the estimation procedure is
  1955. influenced by the sample size, and violins for relatively small samples
  1956. might look misleadingly smooth.
  1957. {main_api_narrative}
  1958. {categorical_narrative}
  1959. Parameters
  1960. ----------
  1961. {input_params}
  1962. {categorical_data}
  1963. {order_vars}
  1964. bw : {{'scott', 'silverman', float}}, optional
  1965. Either the name of a reference rule or the scale factor to use when
  1966. computing the kernel bandwidth. The actual kernel size will be
  1967. determined by multiplying the scale factor by the standard deviation of
  1968. the data within each bin.
  1969. cut : float, optional
  1970. Distance, in units of bandwidth size, to extend the density past the
  1971. extreme datapoints. Set to 0 to limit the violin range within the range
  1972. of the observed data (i.e., to have the same effect as ``trim=True`` in
  1973. ``ggplot``.
  1974. scale : {{"area", "count", "width"}}, optional
  1975. The method used to scale the width of each violin. If ``area``, each
  1976. violin will have the same area. If ``count``, the width of the violins
  1977. will be scaled by the number of observations in that bin. If ``width``,
  1978. each violin will have the same width.
  1979. scale_hue : bool, optional
  1980. When nesting violins using a ``hue`` variable, this parameter
  1981. determines whether the scaling is computed within each level of the
  1982. major grouping variable (``scale_hue=True``) or across all the violins
  1983. on the plot (``scale_hue=False``).
  1984. gridsize : int, optional
  1985. Number of points in the discrete grid used to compute the kernel
  1986. density estimate.
  1987. {width}
  1988. inner : {{"box", "quartile", "point", "stick", None}}, optional
  1989. Representation of the datapoints in the violin interior. If ``box``,
  1990. draw a miniature boxplot. If ``quartiles``, draw the quartiles of the
  1991. distribution. If ``point`` or ``stick``, show each underlying
  1992. datapoint. Using ``None`` will draw unadorned violins.
  1993. split : bool, optional
  1994. When using hue nesting with a variable that takes two levels, setting
  1995. ``split`` to True will draw half of a violin for each level. This can
  1996. make it easier to directly compare the distributions.
  1997. {dodge}
  1998. {orient}
  1999. {linewidth}
  2000. {color}
  2001. {palette}
  2002. {saturation}
  2003. {ax_in}
  2004. Returns
  2005. -------
  2006. {ax_out}
  2007. See Also
  2008. --------
  2009. {boxplot}
  2010. {stripplot}
  2011. {swarmplot}
  2012. {catplot}
  2013. Examples
  2014. --------
  2015. Draw a single horizontal violinplot:
  2016. .. plot::
  2017. :context: close-figs
  2018. >>> import seaborn as sns
  2019. >>> sns.set(style="whitegrid")
  2020. >>> tips = sns.load_dataset("tips")
  2021. >>> ax = sns.violinplot(x=tips["total_bill"])
  2022. Draw a vertical violinplot grouped by a categorical variable:
  2023. .. plot::
  2024. :context: close-figs
  2025. >>> ax = sns.violinplot(x="day", y="total_bill", data=tips)
  2026. Draw a violinplot with nested grouping by two categorical variables:
  2027. .. plot::
  2028. :context: close-figs
  2029. >>> ax = sns.violinplot(x="day", y="total_bill", hue="smoker",
  2030. ... data=tips, palette="muted")
  2031. Draw split violins to compare the across the hue variable:
  2032. .. plot::
  2033. :context: close-figs
  2034. >>> ax = sns.violinplot(x="day", y="total_bill", hue="smoker",
  2035. ... data=tips, palette="muted", split=True)
  2036. Control violin order by passing an explicit order:
  2037. .. plot::
  2038. :context: close-figs
  2039. >>> ax = sns.violinplot(x="time", y="tip", data=tips,
  2040. ... order=["Dinner", "Lunch"])
  2041. Scale the violin width by the number of observations in each bin:
  2042. .. plot::
  2043. :context: close-figs
  2044. >>> ax = sns.violinplot(x="day", y="total_bill", hue="sex",
  2045. ... data=tips, palette="Set2", split=True,
  2046. ... scale="count")
  2047. Draw the quartiles as horizontal lines instead of a mini-box:
  2048. .. plot::
  2049. :context: close-figs
  2050. >>> ax = sns.violinplot(x="day", y="total_bill", hue="sex",
  2051. ... data=tips, palette="Set2", split=True,
  2052. ... scale="count", inner="quartile")
  2053. Show each observation with a stick inside the violin:
  2054. .. plot::
  2055. :context: close-figs
  2056. >>> ax = sns.violinplot(x="day", y="total_bill", hue="sex",
  2057. ... data=tips, palette="Set2", split=True,
  2058. ... scale="count", inner="stick")
  2059. Scale the density relative to the counts across all bins:
  2060. .. plot::
  2061. :context: close-figs
  2062. >>> ax = sns.violinplot(x="day", y="total_bill", hue="sex",
  2063. ... data=tips, palette="Set2", split=True,
  2064. ... scale="count", inner="stick", scale_hue=False)
  2065. Use a narrow bandwidth to reduce the amount of smoothing:
  2066. .. plot::
  2067. :context: close-figs
  2068. >>> ax = sns.violinplot(x="day", y="total_bill", hue="sex",
  2069. ... data=tips, palette="Set2", split=True,
  2070. ... scale="count", inner="stick",
  2071. ... scale_hue=False, bw=.2)
  2072. Draw horizontal violins:
  2073. .. plot::
  2074. :context: close-figs
  2075. >>> planets = sns.load_dataset("planets")
  2076. >>> ax = sns.violinplot(x="orbital_period", y="method",
  2077. ... data=planets[planets.orbital_period < 1000],
  2078. ... scale="width", palette="Set3")
  2079. Don't let density extend past extreme values in the data:
  2080. .. plot::
  2081. :context: close-figs
  2082. >>> ax = sns.violinplot(x="orbital_period", y="method",
  2083. ... data=planets[planets.orbital_period < 1000],
  2084. ... cut=0, scale="width", palette="Set3")
  2085. Use ``hue`` without changing violin position or width:
  2086. .. plot::
  2087. :context: close-figs
  2088. >>> tips["weekend"] = tips["day"].isin(["Sat", "Sun"])
  2089. >>> ax = sns.violinplot(x="day", y="total_bill", hue="weekend",
  2090. ... data=tips, dodge=False)
  2091. Use :func:`catplot` to combine a :func:`violinplot` and a
  2092. :class:`FacetGrid`. This allows grouping within additional categorical
  2093. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  2094. directly, as it ensures synchronization of variable order across facets:
  2095. .. plot::
  2096. :context: close-figs
  2097. >>> g = sns.catplot(x="sex", y="total_bill",
  2098. ... hue="smoker", col="time",
  2099. ... data=tips, kind="violin", split=True,
  2100. ... height=4, aspect=.7);
  2101. """).format(**_categorical_docs)
  2102. def lvplot(*args, **kwargs):
  2103. """Deprecated; please use `boxenplot`."""
  2104. msg = (
  2105. "The `lvplot` function has been renamed to `boxenplot`. The original "
  2106. "name will be removed in a future release. Please update your code. "
  2107. )
  2108. warnings.warn(msg)
  2109. return boxenplot(*args, **kwargs)
  2110. def boxenplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2111. orient=None, color=None, palette=None, saturation=.75,
  2112. width=.8, dodge=True, k_depth='proportion', linewidth=None,
  2113. scale='exponential', outlier_prop=None, showfliers=True, ax=None,
  2114. **kwargs):
  2115. plotter = _LVPlotter(x, y, hue, data, order, hue_order,
  2116. orient, color, palette, saturation,
  2117. width, dodge, k_depth, linewidth, scale,
  2118. outlier_prop, showfliers)
  2119. if ax is None:
  2120. ax = plt.gca()
  2121. plotter.plot(ax, kwargs)
  2122. return ax
  2123. boxenplot.__doc__ = dedent("""\
  2124. Draw an enhanced box plot for larger datasets.
  2125. This style of plot was originally named a "letter value" plot because it
  2126. shows a large number of quantiles that are defined as "letter values". It
  2127. is similar to a box plot in plotting a nonparametric representation of a
  2128. distribution in which all features correspond to actual observations. By
  2129. plotting more quantiles, it provides more information about the shape of
  2130. the distribution, particularly in the tails. For a more extensive
  2131. explanation, you can read the paper that introduced the plot:
  2132. https://vita.had.co.nz/papers/letter-value-plot.html
  2133. {main_api_narrative}
  2134. {categorical_narrative}
  2135. Parameters
  2136. ----------
  2137. {input_params}
  2138. {categorical_data}
  2139. {order_vars}
  2140. {orient}
  2141. {color}
  2142. {palette}
  2143. {saturation}
  2144. {width}
  2145. {dodge}
  2146. k_depth : "proportion" | "tukey" | "trustworthy", optional
  2147. The number of boxes, and by extension number of percentiles, to draw.
  2148. All methods are detailed in Wickham's paper. Each makes different
  2149. assumptions about the number of outliers and leverages different
  2150. statistical properties.
  2151. {linewidth}
  2152. scale : "linear" | "exponential" | "area"
  2153. Method to use for the width of the letter value boxes. All give similar
  2154. results visually. "linear" reduces the width by a constant linear
  2155. factor, "exponential" uses the proportion of data not covered, "area"
  2156. is proportional to the percentage of data covered.
  2157. outlier_prop : float, optional
  2158. Proportion of data believed to be outliers. Used in conjunction with
  2159. k_depth to determine the number of percentiles to draw. Defaults to
  2160. 0.007 as a proportion of outliers. Should be in range [0, 1].
  2161. showfliers : bool, optional
  2162. If False, suppress the plotting of outliers.
  2163. {ax_in}
  2164. kwargs : key, value mappings
  2165. Other keyword arguments are passed through to
  2166. :meth:`matplotlib.axes.Axes.plot` and
  2167. :meth:`matplotlib.axes.Axes.scatter`.
  2168. Returns
  2169. -------
  2170. {ax_out}
  2171. See Also
  2172. --------
  2173. {violinplot}
  2174. {boxplot}
  2175. {catplot}
  2176. Examples
  2177. --------
  2178. Draw a single horizontal boxen plot:
  2179. .. plot::
  2180. :context: close-figs
  2181. >>> import seaborn as sns
  2182. >>> sns.set(style="whitegrid")
  2183. >>> tips = sns.load_dataset("tips")
  2184. >>> ax = sns.boxenplot(x=tips["total_bill"])
  2185. Draw a vertical boxen plot grouped by a categorical variable:
  2186. .. plot::
  2187. :context: close-figs
  2188. >>> ax = sns.boxenplot(x="day", y="total_bill", data=tips)
  2189. Draw a letter value plot with nested grouping by two categorical variables:
  2190. .. plot::
  2191. :context: close-figs
  2192. >>> ax = sns.boxenplot(x="day", y="total_bill", hue="smoker",
  2193. ... data=tips, palette="Set3")
  2194. Draw a boxen plot with nested grouping when some bins are empty:
  2195. .. plot::
  2196. :context: close-figs
  2197. >>> ax = sns.boxenplot(x="day", y="total_bill", hue="time",
  2198. ... data=tips, linewidth=2.5)
  2199. Control box order by passing an explicit order:
  2200. .. plot::
  2201. :context: close-figs
  2202. >>> ax = sns.boxenplot(x="time", y="tip", data=tips,
  2203. ... order=["Dinner", "Lunch"])
  2204. Draw a boxen plot for each numeric variable in a DataFrame:
  2205. .. plot::
  2206. :context: close-figs
  2207. >>> iris = sns.load_dataset("iris")
  2208. >>> ax = sns.boxenplot(data=iris, orient="h", palette="Set2")
  2209. Use :func:`stripplot` to show the datapoints on top of the boxes:
  2210. .. plot::
  2211. :context: close-figs
  2212. >>> ax = sns.boxenplot(x="day", y="total_bill", data=tips)
  2213. >>> ax = sns.stripplot(x="day", y="total_bill", data=tips,
  2214. ... size=4, color="gray")
  2215. Use :func:`catplot` to combine :func:`boxenplot` and a :class:`FacetGrid`.
  2216. This allows grouping within additional categorical variables. Using
  2217. :func:`catplot` is safer than using :class:`FacetGrid` directly, as it
  2218. ensures synchronization of variable order across facets:
  2219. .. plot::
  2220. :context: close-figs
  2221. >>> g = sns.catplot(x="sex", y="total_bill",
  2222. ... hue="smoker", col="time",
  2223. ... data=tips, kind="boxen",
  2224. ... height=4, aspect=.7);
  2225. """).format(**_categorical_docs)
  2226. def stripplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2227. jitter=True, dodge=False, orient=None, color=None, palette=None,
  2228. size=5, edgecolor="gray", linewidth=0, ax=None, **kwargs):
  2229. if "split" in kwargs:
  2230. dodge = kwargs.pop("split")
  2231. msg = "The `split` parameter has been renamed to `dodge`."
  2232. warnings.warn(msg, UserWarning)
  2233. plotter = _StripPlotter(x, y, hue, data, order, hue_order,
  2234. jitter, dodge, orient, color, palette)
  2235. if ax is None:
  2236. ax = plt.gca()
  2237. kwargs.setdefault("zorder", 3)
  2238. size = kwargs.get("s", size)
  2239. if linewidth is None:
  2240. linewidth = size / 10
  2241. if edgecolor == "gray":
  2242. edgecolor = plotter.gray
  2243. kwargs.update(dict(s=size ** 2,
  2244. edgecolor=edgecolor,
  2245. linewidth=linewidth))
  2246. plotter.plot(ax, kwargs)
  2247. return ax
  2248. stripplot.__doc__ = dedent("""\
  2249. Draw a scatterplot where one variable is categorical.
  2250. A strip plot can be drawn on its own, but it is also a good complement
  2251. to a box or violin plot in cases where you want to show all observations
  2252. along with some representation of the underlying distribution.
  2253. {main_api_narrative}
  2254. {categorical_narrative}
  2255. Parameters
  2256. ----------
  2257. {input_params}
  2258. {categorical_data}
  2259. {order_vars}
  2260. jitter : float, ``True``/``1`` is special-cased, optional
  2261. Amount of jitter (only along the categorical axis) to apply. This
  2262. can be useful when you have many points and they overlap, so that
  2263. it is easier to see the distribution. You can specify the amount
  2264. of jitter (half the width of the uniform random variable support),
  2265. or just use ``True`` for a good default.
  2266. dodge : bool, optional
  2267. When using ``hue`` nesting, setting this to ``True`` will separate
  2268. the strips for different hue levels along the categorical axis.
  2269. Otherwise, the points for each level will be plotted on top of
  2270. each other.
  2271. {orient}
  2272. {color}
  2273. {palette}
  2274. size : float, optional
  2275. Radius of the markers, in points.
  2276. edgecolor : matplotlib color, "gray" is special-cased, optional
  2277. Color of the lines around each point. If you pass ``"gray"``, the
  2278. brightness is determined by the color palette used for the body
  2279. of the points.
  2280. {linewidth}
  2281. {ax_in}
  2282. kwargs : key, value mappings
  2283. Other keyword arguments are passed through to
  2284. :meth:`matplotlib.axes.Axes.scatter`.
  2285. Returns
  2286. -------
  2287. {ax_out}
  2288. See Also
  2289. --------
  2290. {swarmplot}
  2291. {boxplot}
  2292. {violinplot}
  2293. {catplot}
  2294. Examples
  2295. --------
  2296. Draw a single horizontal strip plot:
  2297. .. plot::
  2298. :context: close-figs
  2299. >>> import seaborn as sns
  2300. >>> sns.set(style="whitegrid")
  2301. >>> tips = sns.load_dataset("tips")
  2302. >>> ax = sns.stripplot(x=tips["total_bill"])
  2303. Group the strips by a categorical variable:
  2304. .. plot::
  2305. :context: close-figs
  2306. >>> ax = sns.stripplot(x="day", y="total_bill", data=tips)
  2307. Use a smaller amount of jitter:
  2308. .. plot::
  2309. :context: close-figs
  2310. >>> ax = sns.stripplot(x="day", y="total_bill", data=tips, jitter=0.05)
  2311. Draw horizontal strips:
  2312. .. plot::
  2313. :context: close-figs
  2314. >>> ax = sns.stripplot(x="total_bill", y="day", data=tips)
  2315. Draw outlines around the points:
  2316. .. plot::
  2317. :context: close-figs
  2318. >>> ax = sns.stripplot(x="total_bill", y="day", data=tips,
  2319. ... linewidth=1)
  2320. Nest the strips within a second categorical variable:
  2321. .. plot::
  2322. :context: close-figs
  2323. >>> ax = sns.stripplot(x="sex", y="total_bill", hue="day", data=tips)
  2324. Draw each level of the ``hue`` variable at different locations on the
  2325. major categorical axis:
  2326. .. plot::
  2327. :context: close-figs
  2328. >>> ax = sns.stripplot(x="day", y="total_bill", hue="smoker",
  2329. ... data=tips, palette="Set2", dodge=True)
  2330. Control strip order by passing an explicit order:
  2331. .. plot::
  2332. :context: close-figs
  2333. >>> ax = sns.stripplot(x="time", y="tip", data=tips,
  2334. ... order=["Dinner", "Lunch"])
  2335. Draw strips with large points and different aesthetics:
  2336. .. plot::
  2337. :context: close-figs
  2338. >>> ax = sns.stripplot("day", "total_bill", "smoker", data=tips,
  2339. ... palette="Set2", size=20, marker="D",
  2340. ... edgecolor="gray", alpha=.25)
  2341. Draw strips of observations on top of a box plot:
  2342. .. plot::
  2343. :context: close-figs
  2344. >>> import numpy as np
  2345. >>> ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf)
  2346. >>> ax = sns.stripplot(x="tip", y="day", data=tips, color=".3")
  2347. Draw strips of observations on top of a violin plot:
  2348. .. plot::
  2349. :context: close-figs
  2350. >>> ax = sns.violinplot(x="day", y="total_bill", data=tips,
  2351. ... inner=None, color=".8")
  2352. >>> ax = sns.stripplot(x="day", y="total_bill", data=tips)
  2353. Use :func:`catplot` to combine a :func:`stripplot` and a
  2354. :class:`FacetGrid`. This allows grouping within additional categorical
  2355. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  2356. directly, as it ensures synchronization of variable order across facets:
  2357. .. plot::
  2358. :context: close-figs
  2359. >>> g = sns.catplot(x="sex", y="total_bill",
  2360. ... hue="smoker", col="time",
  2361. ... data=tips, kind="strip",
  2362. ... height=4, aspect=.7);
  2363. """).format(**_categorical_docs)
  2364. def swarmplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2365. dodge=False, orient=None, color=None, palette=None,
  2366. size=5, edgecolor="gray", linewidth=0, ax=None, **kwargs):
  2367. if "split" in kwargs:
  2368. dodge = kwargs.pop("split")
  2369. msg = "The `split` parameter has been renamed to `dodge`."
  2370. warnings.warn(msg, UserWarning)
  2371. plotter = _SwarmPlotter(x, y, hue, data, order, hue_order,
  2372. dodge, orient, color, palette)
  2373. if ax is None:
  2374. ax = plt.gca()
  2375. kwargs.setdefault("zorder", 3)
  2376. size = kwargs.get("s", size)
  2377. if linewidth is None:
  2378. linewidth = size / 10
  2379. if edgecolor == "gray":
  2380. edgecolor = plotter.gray
  2381. kwargs.update(dict(s=size ** 2,
  2382. edgecolor=edgecolor,
  2383. linewidth=linewidth))
  2384. plotter.plot(ax, kwargs)
  2385. return ax
  2386. swarmplot.__doc__ = dedent("""\
  2387. Draw a categorical scatterplot with non-overlapping points.
  2388. This function is similar to :func:`stripplot`, but the points are adjusted
  2389. (only along the categorical axis) so that they don't overlap. This gives a
  2390. better representation of the distribution of values, but it does not scale
  2391. well to large numbers of observations. This style of plot is sometimes
  2392. called a "beeswarm".
  2393. A swarm plot can be drawn on its own, but it is also a good complement
  2394. to a box or violin plot in cases where you want to show all observations
  2395. along with some representation of the underlying distribution.
  2396. Arranging the points properly requires an accurate transformation between
  2397. data and point coordinates. This means that non-default axis limits must
  2398. be set *before* drawing the plot.
  2399. {main_api_narrative}
  2400. {categorical_narrative}
  2401. Parameters
  2402. ----------
  2403. {input_params}
  2404. {categorical_data}
  2405. {order_vars}
  2406. dodge : bool, optional
  2407. When using ``hue`` nesting, setting this to ``True`` will separate
  2408. the strips for different hue levels along the categorical axis.
  2409. Otherwise, the points for each level will be plotted in one swarm.
  2410. {orient}
  2411. {color}
  2412. {palette}
  2413. size : float, optional
  2414. Radius of the markers, in points.
  2415. edgecolor : matplotlib color, "gray" is special-cased, optional
  2416. Color of the lines around each point. If you pass ``"gray"``, the
  2417. brightness is determined by the color palette used for the body
  2418. of the points.
  2419. {linewidth}
  2420. {ax_in}
  2421. kwargs : key, value mappings
  2422. Other keyword arguments are passed through to
  2423. :meth:`matplotlib.axes.Axes.scatter`.
  2424. Returns
  2425. -------
  2426. {ax_out}
  2427. See Also
  2428. --------
  2429. {boxplot}
  2430. {violinplot}
  2431. {stripplot}
  2432. {catplot}
  2433. Examples
  2434. --------
  2435. Draw a single horizontal swarm plot:
  2436. .. plot::
  2437. :context: close-figs
  2438. >>> import seaborn as sns
  2439. >>> sns.set(style="whitegrid")
  2440. >>> tips = sns.load_dataset("tips")
  2441. >>> ax = sns.swarmplot(x=tips["total_bill"])
  2442. Group the swarms by a categorical variable:
  2443. .. plot::
  2444. :context: close-figs
  2445. >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips)
  2446. Draw horizontal swarms:
  2447. .. plot::
  2448. :context: close-figs
  2449. >>> ax = sns.swarmplot(x="total_bill", y="day", data=tips)
  2450. Color the points using a second categorical variable:
  2451. .. plot::
  2452. :context: close-figs
  2453. >>> ax = sns.swarmplot(x="day", y="total_bill", hue="sex", data=tips)
  2454. Split each level of the ``hue`` variable along the categorical axis:
  2455. .. plot::
  2456. :context: close-figs
  2457. >>> ax = sns.swarmplot(x="day", y="total_bill", hue="smoker",
  2458. ... data=tips, palette="Set2", dodge=True)
  2459. Control swarm order by passing an explicit order:
  2460. .. plot::
  2461. :context: close-figs
  2462. >>> ax = sns.swarmplot(x="time", y="tip", data=tips,
  2463. ... order=["Dinner", "Lunch"])
  2464. Plot using larger points:
  2465. .. plot::
  2466. :context: close-figs
  2467. >>> ax = sns.swarmplot(x="time", y="tip", data=tips, size=6)
  2468. Draw swarms of observations on top of a box plot:
  2469. .. plot::
  2470. :context: close-figs
  2471. >>> ax = sns.boxplot(x="tip", y="day", data=tips, whis=np.inf)
  2472. >>> ax = sns.swarmplot(x="tip", y="day", data=tips, color=".2")
  2473. Draw swarms of observations on top of a violin plot:
  2474. .. plot::
  2475. :context: close-figs
  2476. >>> ax = sns.violinplot(x="day", y="total_bill", data=tips, inner=None)
  2477. >>> ax = sns.swarmplot(x="day", y="total_bill", data=tips,
  2478. ... color="white", edgecolor="gray")
  2479. Use :func:`catplot` to combine a :func:`swarmplot` and a
  2480. :class:`FacetGrid`. This allows grouping within additional categorical
  2481. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  2482. directly, as it ensures synchronization of variable order across facets:
  2483. .. plot::
  2484. :context: close-figs
  2485. >>> g = sns.catplot(x="sex", y="total_bill",
  2486. ... hue="smoker", col="time",
  2487. ... data=tips, kind="swarm",
  2488. ... height=4, aspect=.7);
  2489. """).format(**_categorical_docs)
  2490. def barplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2491. estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None,
  2492. orient=None, color=None, palette=None, saturation=.75,
  2493. errcolor=".26", errwidth=None, capsize=None, dodge=True,
  2494. ax=None, **kwargs):
  2495. plotter = _BarPlotter(x, y, hue, data, order, hue_order,
  2496. estimator, ci, n_boot, units, seed,
  2497. orient, color, palette, saturation,
  2498. errcolor, errwidth, capsize, dodge)
  2499. if ax is None:
  2500. ax = plt.gca()
  2501. plotter.plot(ax, kwargs)
  2502. return ax
  2503. barplot.__doc__ = dedent("""\
  2504. Show point estimates and confidence intervals as rectangular bars.
  2505. A bar plot represents an estimate of central tendency for a numeric
  2506. variable with the height of each rectangle and provides some indication of
  2507. the uncertainty around that estimate using error bars. Bar plots include 0
  2508. in the quantitative axis range, and they are a good choice when 0 is a
  2509. meaningful value for the quantitative variable, and you want to make
  2510. comparisons against it.
  2511. For datasets where 0 is not a meaningful value, a point plot will allow you
  2512. to focus on differences between levels of one or more categorical
  2513. variables.
  2514. It is also important to keep in mind that a bar plot shows only the mean
  2515. (or other estimator) value, but in many cases it may be more informative to
  2516. show the distribution of values at each level of the categorical variables.
  2517. In that case, other approaches such as a box or violin plot may be more
  2518. appropriate.
  2519. {main_api_narrative}
  2520. {categorical_narrative}
  2521. Parameters
  2522. ----------
  2523. {input_params}
  2524. {categorical_data}
  2525. {order_vars}
  2526. {stat_api_params}
  2527. {orient}
  2528. {color}
  2529. {palette}
  2530. {saturation}
  2531. errcolor : matplotlib color
  2532. Color for the lines that represent the confidence interval.
  2533. {errwidth}
  2534. {capsize}
  2535. {dodge}
  2536. {ax_in}
  2537. kwargs : key, value mappings
  2538. Other keyword arguments are passed through to
  2539. :meth:`matplotlib.axes.Axes.bar`.
  2540. Returns
  2541. -------
  2542. {ax_out}
  2543. See Also
  2544. --------
  2545. {countplot}
  2546. {pointplot}
  2547. {catplot}
  2548. Examples
  2549. --------
  2550. Draw a set of vertical bar plots grouped by a categorical variable:
  2551. .. plot::
  2552. :context: close-figs
  2553. >>> import seaborn as sns
  2554. >>> sns.set(style="whitegrid")
  2555. >>> tips = sns.load_dataset("tips")
  2556. >>> ax = sns.barplot(x="day", y="total_bill", data=tips)
  2557. Draw a set of vertical bars with nested grouping by a two variables:
  2558. .. plot::
  2559. :context: close-figs
  2560. >>> ax = sns.barplot(x="day", y="total_bill", hue="sex", data=tips)
  2561. Draw a set of horizontal bars:
  2562. .. plot::
  2563. :context: close-figs
  2564. >>> ax = sns.barplot(x="tip", y="day", data=tips)
  2565. Control bar order by passing an explicit order:
  2566. .. plot::
  2567. :context: close-figs
  2568. >>> ax = sns.barplot(x="time", y="tip", data=tips,
  2569. ... order=["Dinner", "Lunch"])
  2570. Use median as the estimate of central tendency:
  2571. .. plot::
  2572. :context: close-figs
  2573. >>> from numpy import median
  2574. >>> ax = sns.barplot(x="day", y="tip", data=tips, estimator=median)
  2575. Show the standard error of the mean with the error bars:
  2576. .. plot::
  2577. :context: close-figs
  2578. >>> ax = sns.barplot(x="day", y="tip", data=tips, ci=68)
  2579. Show standard deviation of observations instead of a confidence interval:
  2580. .. plot::
  2581. :context: close-figs
  2582. >>> ax = sns.barplot(x="day", y="tip", data=tips, ci="sd")
  2583. Add "caps" to the error bars:
  2584. .. plot::
  2585. :context: close-figs
  2586. >>> ax = sns.barplot(x="day", y="tip", data=tips, capsize=.2)
  2587. Use a different color palette for the bars:
  2588. .. plot::
  2589. :context: close-figs
  2590. >>> ax = sns.barplot("size", y="total_bill", data=tips,
  2591. ... palette="Blues_d")
  2592. Use ``hue`` without changing bar position or width:
  2593. .. plot::
  2594. :context: close-figs
  2595. >>> tips["weekend"] = tips["day"].isin(["Sat", "Sun"])
  2596. >>> ax = sns.barplot(x="day", y="total_bill", hue="weekend",
  2597. ... data=tips, dodge=False)
  2598. Plot all bars in a single color:
  2599. .. plot::
  2600. :context: close-figs
  2601. >>> ax = sns.barplot("size", y="total_bill", data=tips,
  2602. ... color="salmon", saturation=.5)
  2603. Use :meth:`matplotlib.axes.Axes.bar` parameters to control the style.
  2604. .. plot::
  2605. :context: close-figs
  2606. >>> ax = sns.barplot("day", "total_bill", data=tips,
  2607. ... linewidth=2.5, facecolor=(1, 1, 1, 0),
  2608. ... errcolor=".2", edgecolor=".2")
  2609. Use :func:`catplot` to combine a :func:`barplot` and a :class:`FacetGrid`.
  2610. This allows grouping within additional categorical variables. Using
  2611. :func:`catplot` is safer than using :class:`FacetGrid` directly, as it
  2612. ensures synchronization of variable order across facets:
  2613. .. plot::
  2614. :context: close-figs
  2615. >>> g = sns.catplot(x="sex", y="total_bill",
  2616. ... hue="smoker", col="time",
  2617. ... data=tips, kind="bar",
  2618. ... height=4, aspect=.7);
  2619. """).format(**_categorical_docs)
  2620. def pointplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2621. estimator=np.mean, ci=95, n_boot=1000, units=None, seed=None,
  2622. markers="o", linestyles="-", dodge=False, join=True, scale=1,
  2623. orient=None, color=None, palette=None, errwidth=None,
  2624. capsize=None, ax=None, **kwargs):
  2625. plotter = _PointPlotter(x, y, hue, data, order, hue_order,
  2626. estimator, ci, n_boot, units, seed,
  2627. markers, linestyles, dodge, join, scale,
  2628. orient, color, palette, errwidth, capsize)
  2629. if ax is None:
  2630. ax = plt.gca()
  2631. plotter.plot(ax)
  2632. return ax
  2633. pointplot.__doc__ = dedent("""\
  2634. Show point estimates and confidence intervals using scatter plot glyphs.
  2635. A point plot represents an estimate of central tendency for a numeric
  2636. variable by the position of scatter plot points and provides some
  2637. indication of the uncertainty around that estimate using error bars.
  2638. Point plots can be more useful than bar plots for focusing comparisons
  2639. between different levels of one or more categorical variables. They are
  2640. particularly adept at showing interactions: how the relationship between
  2641. levels of one categorical variable changes across levels of a second
  2642. categorical variable. The lines that join each point from the same ``hue``
  2643. level allow interactions to be judged by differences in slope, which is
  2644. easier for the eyes than comparing the heights of several groups of points
  2645. or bars.
  2646. It is important to keep in mind that a point plot shows only the mean (or
  2647. other estimator) value, but in many cases it may be more informative to
  2648. show the distribution of values at each level of the categorical variables.
  2649. In that case, other approaches such as a box or violin plot may be more
  2650. appropriate.
  2651. {main_api_narrative}
  2652. {categorical_narrative}
  2653. Parameters
  2654. ----------
  2655. {input_params}
  2656. {categorical_data}
  2657. {order_vars}
  2658. {stat_api_params}
  2659. markers : string or list of strings, optional
  2660. Markers to use for each of the ``hue`` levels.
  2661. linestyles : string or list of strings, optional
  2662. Line styles to use for each of the ``hue`` levels.
  2663. dodge : bool or float, optional
  2664. Amount to separate the points for each level of the ``hue`` variable
  2665. along the categorical axis.
  2666. join : bool, optional
  2667. If ``True``, lines will be drawn between point estimates at the same
  2668. ``hue`` level.
  2669. scale : float, optional
  2670. Scale factor for the plot elements.
  2671. {orient}
  2672. {color}
  2673. {palette}
  2674. {errwidth}
  2675. {capsize}
  2676. {ax_in}
  2677. Returns
  2678. -------
  2679. {ax_out}
  2680. See Also
  2681. --------
  2682. {barplot}
  2683. {catplot}
  2684. Examples
  2685. --------
  2686. Draw a set of vertical point plots grouped by a categorical variable:
  2687. .. plot::
  2688. :context: close-figs
  2689. >>> import seaborn as sns
  2690. >>> sns.set(style="darkgrid")
  2691. >>> tips = sns.load_dataset("tips")
  2692. >>> ax = sns.pointplot(x="time", y="total_bill", data=tips)
  2693. Draw a set of vertical points with nested grouping by a two variables:
  2694. .. plot::
  2695. :context: close-figs
  2696. >>> ax = sns.pointplot(x="time", y="total_bill", hue="smoker",
  2697. ... data=tips)
  2698. Separate the points for different hue levels along the categorical axis:
  2699. .. plot::
  2700. :context: close-figs
  2701. >>> ax = sns.pointplot(x="time", y="total_bill", hue="smoker",
  2702. ... data=tips, dodge=True)
  2703. Use a different marker and line style for the hue levels:
  2704. .. plot::
  2705. :context: close-figs
  2706. >>> ax = sns.pointplot(x="time", y="total_bill", hue="smoker",
  2707. ... data=tips,
  2708. ... markers=["o", "x"],
  2709. ... linestyles=["-", "--"])
  2710. Draw a set of horizontal points:
  2711. .. plot::
  2712. :context: close-figs
  2713. >>> ax = sns.pointplot(x="tip", y="day", data=tips)
  2714. Don't draw a line connecting each point:
  2715. .. plot::
  2716. :context: close-figs
  2717. >>> ax = sns.pointplot(x="tip", y="day", data=tips, join=False)
  2718. Use a different color for a single-layer plot:
  2719. .. plot::
  2720. :context: close-figs
  2721. >>> ax = sns.pointplot("time", y="total_bill", data=tips,
  2722. ... color="#bb3f3f")
  2723. Use a different color palette for the points:
  2724. .. plot::
  2725. :context: close-figs
  2726. >>> ax = sns.pointplot(x="time", y="total_bill", hue="smoker",
  2727. ... data=tips, palette="Set2")
  2728. Control point order by passing an explicit order:
  2729. .. plot::
  2730. :context: close-figs
  2731. >>> ax = sns.pointplot(x="time", y="tip", data=tips,
  2732. ... order=["Dinner", "Lunch"])
  2733. Use median as the estimate of central tendency:
  2734. .. plot::
  2735. :context: close-figs
  2736. >>> from numpy import median
  2737. >>> ax = sns.pointplot(x="day", y="tip", data=tips, estimator=median)
  2738. Show the standard error of the mean with the error bars:
  2739. .. plot::
  2740. :context: close-figs
  2741. >>> ax = sns.pointplot(x="day", y="tip", data=tips, ci=68)
  2742. Show standard deviation of observations instead of a confidence interval:
  2743. .. plot::
  2744. :context: close-figs
  2745. >>> ax = sns.pointplot(x="day", y="tip", data=tips, ci="sd")
  2746. Add "caps" to the error bars:
  2747. .. plot::
  2748. :context: close-figs
  2749. >>> ax = sns.pointplot(x="day", y="tip", data=tips, capsize=.2)
  2750. Use :func:`catplot` to combine a :func:`pointplot` and a
  2751. :class:`FacetGrid`. This allows grouping within additional categorical
  2752. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  2753. directly, as it ensures synchronization of variable order across facets:
  2754. .. plot::
  2755. :context: close-figs
  2756. >>> g = sns.catplot(x="sex", y="total_bill",
  2757. ... hue="smoker", col="time",
  2758. ... data=tips, kind="point",
  2759. ... dodge=True,
  2760. ... height=4, aspect=.7);
  2761. """).format(**_categorical_docs)
  2762. def countplot(x=None, y=None, hue=None, data=None, order=None, hue_order=None,
  2763. orient=None, color=None, palette=None, saturation=.75,
  2764. dodge=True, ax=None, **kwargs):
  2765. estimator = len
  2766. ci = None
  2767. n_boot = 0
  2768. units = None
  2769. seed = None
  2770. errcolor = None
  2771. errwidth = None
  2772. capsize = None
  2773. if x is None and y is not None:
  2774. orient = "h"
  2775. x = y
  2776. elif y is None and x is not None:
  2777. orient = "v"
  2778. y = x
  2779. elif x is not None and y is not None:
  2780. raise TypeError("Cannot pass values for both `x` and `y`")
  2781. else:
  2782. raise TypeError("Must pass values for either `x` or `y`")
  2783. plotter = _BarPlotter(x, y, hue, data, order, hue_order,
  2784. estimator, ci, n_boot, units, seed,
  2785. orient, color, palette, saturation,
  2786. errcolor, errwidth, capsize, dodge)
  2787. plotter.value_label = "count"
  2788. if ax is None:
  2789. ax = plt.gca()
  2790. plotter.plot(ax, kwargs)
  2791. return ax
  2792. countplot.__doc__ = dedent("""\
  2793. Show the counts of observations in each categorical bin using bars.
  2794. A count plot can be thought of as a histogram across a categorical, instead
  2795. of quantitative, variable. The basic API and options are identical to those
  2796. for :func:`barplot`, so you can compare counts across nested variables.
  2797. {main_api_narrative}
  2798. {categorical_narrative}
  2799. Parameters
  2800. ----------
  2801. {input_params}
  2802. {categorical_data}
  2803. {order_vars}
  2804. {orient}
  2805. {color}
  2806. {palette}
  2807. {saturation}
  2808. {dodge}
  2809. {ax_in}
  2810. kwargs : key, value mappings
  2811. Other keyword arguments are passed through to
  2812. :meth:`matplotlib.axes.Axes.bar`.
  2813. Returns
  2814. -------
  2815. {ax_out}
  2816. See Also
  2817. --------
  2818. {barplot}
  2819. {catplot}
  2820. Examples
  2821. --------
  2822. Show value counts for a single categorical variable:
  2823. .. plot::
  2824. :context: close-figs
  2825. >>> import seaborn as sns
  2826. >>> sns.set(style="darkgrid")
  2827. >>> titanic = sns.load_dataset("titanic")
  2828. >>> ax = sns.countplot(x="class", data=titanic)
  2829. Show value counts for two categorical variables:
  2830. .. plot::
  2831. :context: close-figs
  2832. >>> ax = sns.countplot(x="class", hue="who", data=titanic)
  2833. Plot the bars horizontally:
  2834. .. plot::
  2835. :context: close-figs
  2836. >>> ax = sns.countplot(y="class", hue="who", data=titanic)
  2837. Use a different color palette:
  2838. .. plot::
  2839. :context: close-figs
  2840. >>> ax = sns.countplot(x="who", data=titanic, palette="Set3")
  2841. Use :meth:`matplotlib.axes.Axes.bar` parameters to control the style.
  2842. .. plot::
  2843. :context: close-figs
  2844. >>> ax = sns.countplot(x="who", data=titanic,
  2845. ... facecolor=(0, 0, 0, 0),
  2846. ... linewidth=5,
  2847. ... edgecolor=sns.color_palette("dark", 3))
  2848. Use :func:`catplot` to combine a :func:`countplot` and a
  2849. :class:`FacetGrid`. This allows grouping within additional categorical
  2850. variables. Using :func:`catplot` is safer than using :class:`FacetGrid`
  2851. directly, as it ensures synchronization of variable order across facets:
  2852. .. plot::
  2853. :context: close-figs
  2854. >>> g = sns.catplot(x="class", hue="who", col="survived",
  2855. ... data=titanic, kind="count",
  2856. ... height=4, aspect=.7);
  2857. """).format(**_categorical_docs)
  2858. def factorplot(*args, **kwargs):
  2859. """Deprecated; please use `catplot` instead."""
  2860. msg = (
  2861. "The `factorplot` function has been renamed to `catplot`. The "
  2862. "original name will be removed in a future release. Please update "
  2863. "your code. Note that the default `kind` in `factorplot` (`'point'`) "
  2864. "has changed `'strip'` in `catplot`."
  2865. )
  2866. warnings.warn(msg)
  2867. if "size" in kwargs:
  2868. kwargs["height"] = kwargs.pop("size")
  2869. msg = ("The `size` parameter has been renamed to `height`; "
  2870. "please update your code.")
  2871. warnings.warn(msg, UserWarning)
  2872. kwargs.setdefault("kind", "point")
  2873. return catplot(*args, **kwargs)
  2874. def catplot(x=None, y=None, hue=None, data=None, row=None, col=None,
  2875. col_wrap=None, estimator=np.mean, ci=95, n_boot=1000,
  2876. units=None, seed=None, order=None, hue_order=None, row_order=None,
  2877. col_order=None, kind="strip", height=5, aspect=1,
  2878. orient=None, color=None, palette=None,
  2879. legend=True, legend_out=True, sharex=True, sharey=True,
  2880. margin_titles=False, facet_kws=None, **kwargs):
  2881. # Handle deprecations
  2882. if "size" in kwargs:
  2883. height = kwargs.pop("size")
  2884. msg = ("The `size` parameter has been renamed to `height`; "
  2885. "please update your code.")
  2886. warnings.warn(msg, UserWarning)
  2887. # Determine the plotting function
  2888. try:
  2889. plot_func = globals()[kind + "plot"]
  2890. except KeyError:
  2891. err = "Plot kind '{}' is not recognized".format(kind)
  2892. raise ValueError(err)
  2893. # Alias the input variables to determine categorical order and palette
  2894. # correctly in the case of a count plot
  2895. if kind == "count":
  2896. if x is None and y is not None:
  2897. x_, y_, orient = y, y, "h"
  2898. elif y is None and x is not None:
  2899. x_, y_, orient = x, x, "v"
  2900. else:
  2901. raise ValueError("Either `x` or `y` must be None for count plots")
  2902. else:
  2903. x_, y_ = x, y
  2904. # Check for attempt to plot onto specific axes and warn
  2905. if "ax" in kwargs:
  2906. msg = ("catplot is a figure-level function and does not accept "
  2907. "target axes. You may wish to try {}".format(kind + "plot"))
  2908. warnings.warn(msg, UserWarning)
  2909. kwargs.pop("ax")
  2910. # Determine the order for the whole dataset, which will be used in all
  2911. # facets to ensure representation of all data in the final plot
  2912. p = _CategoricalPlotter()
  2913. p.establish_variables(x_, y_, hue, data, orient, order, hue_order)
  2914. order = p.group_names
  2915. hue_order = p.hue_names
  2916. # Determine the palette to use
  2917. # (FacetGrid will pass a value for ``color`` to the plotting function
  2918. # so we need to define ``palette`` to get default behavior for the
  2919. # categorical functions
  2920. p.establish_colors(color, palette, 1)
  2921. if kind != "point" or hue is not None:
  2922. palette = p.colors
  2923. # Determine keyword arguments for the facets
  2924. facet_kws = {} if facet_kws is None else facet_kws
  2925. facet_kws.update(
  2926. data=data, row=row, col=col,
  2927. row_order=row_order, col_order=col_order,
  2928. col_wrap=col_wrap, height=height, aspect=aspect,
  2929. sharex=sharex, sharey=sharey,
  2930. legend_out=legend_out, margin_titles=margin_titles,
  2931. dropna=False,
  2932. )
  2933. # Determine keyword arguments for the plotting function
  2934. plot_kws = dict(
  2935. order=order, hue_order=hue_order,
  2936. orient=orient, color=color, palette=palette,
  2937. )
  2938. plot_kws.update(kwargs)
  2939. if kind in ["bar", "point"]:
  2940. plot_kws.update(
  2941. estimator=estimator, ci=ci, n_boot=n_boot, units=units, seed=seed,
  2942. )
  2943. # Initialize the facets
  2944. g = FacetGrid(**facet_kws)
  2945. # Draw the plot onto the facets
  2946. g.map_dataframe(plot_func, x, y, hue, **plot_kws)
  2947. # Special case axis labels for a count type plot
  2948. if kind == "count":
  2949. if x is None:
  2950. g.set_axis_labels(x_var="count")
  2951. if y is None:
  2952. g.set_axis_labels(y_var="count")
  2953. if legend and (hue is not None) and (hue not in [x, row, col]):
  2954. hue_order = list(map(utils.to_utf8, hue_order))
  2955. g.add_legend(title=hue, label_order=hue_order)
  2956. return g
  2957. catplot.__doc__ = dedent("""\
  2958. Figure-level interface for drawing categorical plots onto a
  2959. :class:`FacetGrid`.
  2960. This function provides access to several axes-level functions that
  2961. show the relationship between a numerical and one or more categorical
  2962. variables using one of several visual representations. The ``kind``
  2963. parameter selects the underlying axes-level function to use:
  2964. Categorical scatterplots:
  2965. - :func:`stripplot` (with ``kind="strip"``; the default)
  2966. - :func:`swarmplot` (with ``kind="swarm"``)
  2967. Categorical distribution plots:
  2968. - :func:`boxplot` (with ``kind="box"``)
  2969. - :func:`violinplot` (with ``kind="violin"``)
  2970. - :func:`boxenplot` (with ``kind="boxen"``)
  2971. Categorical estimate plots:
  2972. - :func:`pointplot` (with ``kind="point"``)
  2973. - :func:`barplot` (with ``kind="bar"``)
  2974. - :func:`countplot` (with ``kind="count"``)
  2975. Extra keyword arguments are passed to the underlying function, so you
  2976. should refer to the documentation for each to see kind-specific options.
  2977. Note that unlike when using the axes-level functions directly, data must be
  2978. passed in a long-form DataFrame with variables specified by passing strings
  2979. to ``x``, ``y``, ``hue``, etc.
  2980. As in the case with the underlying plot functions, if variables have a
  2981. ``categorical`` data type, the the levels of the categorical variables, and
  2982. their order will be inferred from the objects. Otherwise you may have to
  2983. use alter the dataframe sorting or use the function parameters (``orient``,
  2984. ``order``, ``hue_order``, etc.) to set up the plot correctly.
  2985. {categorical_narrative}
  2986. After plotting, the :class:`FacetGrid` with the plot is returned and can
  2987. be used directly to tweak supporting plot details or add other layers.
  2988. Parameters
  2989. ----------
  2990. {string_input_params}
  2991. {long_form_data}
  2992. row, col : names of variables in ``data``, optional
  2993. Categorical variables that will determine the faceting of the grid.
  2994. {col_wrap}
  2995. {stat_api_params}
  2996. {order_vars}
  2997. row_order, col_order : lists of strings, optional
  2998. Order to organize the rows and/or columns of the grid in, otherwise the
  2999. orders are inferred from the data objects.
  3000. kind : string, optional
  3001. The kind of plot to draw (corresponds to the name of a categorical
  3002. plotting function. Options are: "point", "bar", "strip", "swarm",
  3003. "box", "violin", or "boxen".
  3004. {height}
  3005. {aspect}
  3006. {orient}
  3007. {color}
  3008. {palette}
  3009. legend : bool, optional
  3010. If ``True`` and there is a ``hue`` variable, draw a legend on the plot.
  3011. {legend_out}
  3012. {share_xy}
  3013. {margin_titles}
  3014. facet_kws : dict, optional
  3015. Dictionary of other keyword arguments to pass to :class:`FacetGrid`.
  3016. kwargs : key, value pairings
  3017. Other keyword arguments are passed through to the underlying plotting
  3018. function.
  3019. Returns
  3020. -------
  3021. g : :class:`FacetGrid`
  3022. Returns the :class:`FacetGrid` object with the plot on it for further
  3023. tweaking.
  3024. Examples
  3025. --------
  3026. Draw a single facet to use the :class:`FacetGrid` legend placement:
  3027. .. plot::
  3028. :context: close-figs
  3029. >>> import seaborn as sns
  3030. >>> sns.set(style="ticks")
  3031. >>> exercise = sns.load_dataset("exercise")
  3032. >>> g = sns.catplot(x="time", y="pulse", hue="kind", data=exercise)
  3033. Use a different plot kind to visualize the same data:
  3034. .. plot::
  3035. :context: close-figs
  3036. >>> g = sns.catplot(x="time", y="pulse", hue="kind",
  3037. ... data=exercise, kind="violin")
  3038. Facet along the columns to show a third categorical variable:
  3039. .. plot::
  3040. :context: close-figs
  3041. >>> g = sns.catplot(x="time", y="pulse", hue="kind",
  3042. ... col="diet", data=exercise)
  3043. Use a different height and aspect ratio for the facets:
  3044. .. plot::
  3045. :context: close-figs
  3046. >>> g = sns.catplot(x="time", y="pulse", hue="kind",
  3047. ... col="diet", data=exercise,
  3048. ... height=5, aspect=.8)
  3049. Make many column facets and wrap them into the rows of the grid:
  3050. .. plot::
  3051. :context: close-figs
  3052. >>> titanic = sns.load_dataset("titanic")
  3053. >>> g = sns.catplot("alive", col="deck", col_wrap=4,
  3054. ... data=titanic[titanic.deck.notnull()],
  3055. ... kind="count", height=2.5, aspect=.8)
  3056. Plot horizontally and pass other keyword arguments to the plot function:
  3057. .. plot::
  3058. :context: close-figs
  3059. >>> g = sns.catplot(x="age", y="embark_town",
  3060. ... hue="sex", row="class",
  3061. ... data=titanic[titanic.embark_town.notnull()],
  3062. ... orient="h", height=2, aspect=3, palette="Set3",
  3063. ... kind="violin", dodge=True, cut=0, bw=.2)
  3064. Use methods on the returned :class:`FacetGrid` to tweak the presentation:
  3065. .. plot::
  3066. :context: close-figs
  3067. >>> g = sns.catplot(x="who", y="survived", col="class",
  3068. ... data=titanic, saturation=.5,
  3069. ... kind="bar", ci=None, aspect=.6)
  3070. >>> (g.set_axis_labels("", "Survival Rate")
  3071. ... .set_xticklabels(["Men", "Women", "Children"])
  3072. ... .set_titles("{{col_name}} {{col_var}}")
  3073. ... .set(ylim=(0, 1))
  3074. ... .despine(left=True)) #doctest: +ELLIPSIS
  3075. <seaborn.axisgrid.FacetGrid object at 0x...>
  3076. """).format(**_categorical_docs)