regression.py 38 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069
  1. """Plotting functions for linear models (broadly construed)."""
  2. import copy
  3. from textwrap import dedent
  4. import warnings
  5. import numpy as np
  6. import pandas as pd
  7. from scipy.spatial import distance
  8. import matplotlib as mpl
  9. import matplotlib.pyplot as plt
  10. try:
  11. import statsmodels
  12. assert statsmodels
  13. _has_statsmodels = True
  14. except ImportError:
  15. _has_statsmodels = False
  16. from . import utils
  17. from . import algorithms as algo
  18. from .axisgrid import FacetGrid, _facet_docs
  19. __all__ = ["lmplot", "regplot", "residplot"]
  20. class _LinearPlotter(object):
  21. """Base class for plotting relational data in tidy format.
  22. To get anything useful done you'll have to inherit from this, but setup
  23. code that can be abstracted out should be put here.
  24. """
  25. def establish_variables(self, data, **kws):
  26. """Extract variables from data or use directly."""
  27. self.data = data
  28. # Validate the inputs
  29. any_strings = any([isinstance(v, str) for v in kws.values()])
  30. if any_strings and data is None:
  31. raise ValueError("Must pass `data` if using named variables.")
  32. # Set the variables
  33. for var, val in kws.items():
  34. if isinstance(val, str):
  35. vector = data[val]
  36. elif isinstance(val, list):
  37. vector = np.asarray(val)
  38. else:
  39. vector = val
  40. if vector is not None and vector.shape != (1,):
  41. vector = np.squeeze(vector)
  42. if np.ndim(vector) > 1:
  43. err = "regplot inputs must be 1d"
  44. raise ValueError(err)
  45. setattr(self, var, vector)
  46. def dropna(self, *vars):
  47. """Remove observations with missing data."""
  48. vals = [getattr(self, var) for var in vars]
  49. vals = [v for v in vals if v is not None]
  50. not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
  51. for var in vars:
  52. val = getattr(self, var)
  53. if val is not None:
  54. setattr(self, var, val[not_na])
  55. def plot(self, ax):
  56. raise NotImplementedError
  57. class _RegressionPlotter(_LinearPlotter):
  58. """Plotter for numeric independent variables with regression model.
  59. This does the computations and drawing for the `regplot` function, and
  60. is thus also used indirectly by `lmplot`.
  61. """
  62. def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
  63. x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
  64. units=None, seed=None, order=1, logistic=False, lowess=False,
  65. robust=False, logx=False, x_partial=None, y_partial=None,
  66. truncate=False, dropna=True, x_jitter=None, y_jitter=None,
  67. color=None, label=None):
  68. # Set member attributes
  69. self.x_estimator = x_estimator
  70. self.ci = ci
  71. self.x_ci = ci if x_ci == "ci" else x_ci
  72. self.n_boot = n_boot
  73. self.seed = seed
  74. self.scatter = scatter
  75. self.fit_reg = fit_reg
  76. self.order = order
  77. self.logistic = logistic
  78. self.lowess = lowess
  79. self.robust = robust
  80. self.logx = logx
  81. self.truncate = truncate
  82. self.x_jitter = x_jitter
  83. self.y_jitter = y_jitter
  84. self.color = color
  85. self.label = label
  86. # Validate the regression options:
  87. if sum((order > 1, logistic, robust, lowess, logx)) > 1:
  88. raise ValueError("Mutually exclusive regression options.")
  89. # Extract the data vals from the arguments or passed dataframe
  90. self.establish_variables(data, x=x, y=y, units=units,
  91. x_partial=x_partial, y_partial=y_partial)
  92. # Drop null observations
  93. if dropna:
  94. self.dropna("x", "y", "units", "x_partial", "y_partial")
  95. # Regress nuisance variables out of the data
  96. if self.x_partial is not None:
  97. self.x = self.regress_out(self.x, self.x_partial)
  98. if self.y_partial is not None:
  99. self.y = self.regress_out(self.y, self.y_partial)
  100. # Possibly bin the predictor variable, which implies a point estimate
  101. if x_bins is not None:
  102. self.x_estimator = np.mean if x_estimator is None else x_estimator
  103. x_discrete, x_bins = self.bin_predictor(x_bins)
  104. self.x_discrete = x_discrete
  105. else:
  106. self.x_discrete = self.x
  107. # Disable regression in case of singleton inputs
  108. if len(self.x) <= 1:
  109. self.fit_reg = False
  110. # Save the range of the x variable for the grid later
  111. if self.fit_reg:
  112. self.x_range = self.x.min(), self.x.max()
  113. @property
  114. def scatter_data(self):
  115. """Data where each observation is a point."""
  116. x_j = self.x_jitter
  117. if x_j is None:
  118. x = self.x
  119. else:
  120. x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
  121. y_j = self.y_jitter
  122. if y_j is None:
  123. y = self.y
  124. else:
  125. y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
  126. return x, y
  127. @property
  128. def estimate_data(self):
  129. """Data with a point estimate and CI for each discrete x value."""
  130. x, y = self.x_discrete, self.y
  131. vals = sorted(np.unique(x))
  132. points, cis = [], []
  133. for val in vals:
  134. # Get the point estimate of the y variable
  135. _y = y[x == val]
  136. est = self.x_estimator(_y)
  137. points.append(est)
  138. # Compute the confidence interval for this estimate
  139. if self.x_ci is None:
  140. cis.append(None)
  141. else:
  142. units = None
  143. if self.x_ci == "sd":
  144. sd = np.std(_y)
  145. _ci = est - sd, est + sd
  146. else:
  147. if self.units is not None:
  148. units = self.units[x == val]
  149. boots = algo.bootstrap(_y,
  150. func=self.x_estimator,
  151. n_boot=self.n_boot,
  152. units=units,
  153. seed=self.seed)
  154. _ci = utils.ci(boots, self.x_ci)
  155. cis.append(_ci)
  156. return vals, points, cis
  157. def fit_regression(self, ax=None, x_range=None, grid=None):
  158. """Fit the regression model."""
  159. # Create the grid for the regression
  160. if grid is None:
  161. if self.truncate:
  162. x_min, x_max = self.x_range
  163. else:
  164. if ax is None:
  165. x_min, x_max = x_range
  166. else:
  167. x_min, x_max = ax.get_xlim()
  168. grid = np.linspace(x_min, x_max, 100)
  169. ci = self.ci
  170. # Fit the regression
  171. if self.order > 1:
  172. yhat, yhat_boots = self.fit_poly(grid, self.order)
  173. elif self.logistic:
  174. from statsmodels.genmod.generalized_linear_model import GLM
  175. from statsmodels.genmod.families import Binomial
  176. yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
  177. family=Binomial())
  178. elif self.lowess:
  179. ci = None
  180. grid, yhat = self.fit_lowess()
  181. elif self.robust:
  182. from statsmodels.robust.robust_linear_model import RLM
  183. yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
  184. elif self.logx:
  185. yhat, yhat_boots = self.fit_logx(grid)
  186. else:
  187. yhat, yhat_boots = self.fit_fast(grid)
  188. # Compute the confidence interval at each grid point
  189. if ci is None:
  190. err_bands = None
  191. else:
  192. err_bands = utils.ci(yhat_boots, ci, axis=0)
  193. return grid, yhat, err_bands
  194. def fit_fast(self, grid):
  195. """Low-level regression and prediction using linear algebra."""
  196. def reg_func(_x, _y):
  197. return np.linalg.pinv(_x).dot(_y)
  198. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  199. grid = np.c_[np.ones(len(grid)), grid]
  200. yhat = grid.dot(reg_func(X, y))
  201. if self.ci is None:
  202. return yhat, None
  203. beta_boots = algo.bootstrap(X, y,
  204. func=reg_func,
  205. n_boot=self.n_boot,
  206. units=self.units,
  207. seed=self.seed).T
  208. yhat_boots = grid.dot(beta_boots).T
  209. return yhat, yhat_boots
  210. def fit_poly(self, grid, order):
  211. """Regression using numpy polyfit for higher-order trends."""
  212. def reg_func(_x, _y):
  213. return np.polyval(np.polyfit(_x, _y, order), grid)
  214. x, y = self.x, self.y
  215. yhat = reg_func(x, y)
  216. if self.ci is None:
  217. return yhat, None
  218. yhat_boots = algo.bootstrap(x, y,
  219. func=reg_func,
  220. n_boot=self.n_boot,
  221. units=self.units,
  222. seed=self.seed)
  223. return yhat, yhat_boots
  224. def fit_statsmodels(self, grid, model, **kwargs):
  225. """More general regression function using statsmodels objects."""
  226. import statsmodels.genmod.generalized_linear_model as glm
  227. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  228. grid = np.c_[np.ones(len(grid)), grid]
  229. def reg_func(_x, _y):
  230. try:
  231. yhat = model(_y, _x, **kwargs).fit().predict(grid)
  232. except glm.PerfectSeparationError:
  233. yhat = np.empty(len(grid))
  234. yhat.fill(np.nan)
  235. return yhat
  236. yhat = reg_func(X, y)
  237. if self.ci is None:
  238. return yhat, None
  239. yhat_boots = algo.bootstrap(X, y,
  240. func=reg_func,
  241. n_boot=self.n_boot,
  242. units=self.units,
  243. seed=self.seed)
  244. return yhat, yhat_boots
  245. def fit_lowess(self):
  246. """Fit a locally-weighted regression, which returns its own grid."""
  247. from statsmodels.nonparametric.smoothers_lowess import lowess
  248. grid, yhat = lowess(self.y, self.x).T
  249. return grid, yhat
  250. def fit_logx(self, grid):
  251. """Fit the model in log-space."""
  252. X, y = np.c_[np.ones(len(self.x)), self.x], self.y
  253. grid = np.c_[np.ones(len(grid)), np.log(grid)]
  254. def reg_func(_x, _y):
  255. _x = np.c_[_x[:, 0], np.log(_x[:, 1])]
  256. return np.linalg.pinv(_x).dot(_y)
  257. yhat = grid.dot(reg_func(X, y))
  258. if self.ci is None:
  259. return yhat, None
  260. beta_boots = algo.bootstrap(X, y,
  261. func=reg_func,
  262. n_boot=self.n_boot,
  263. units=self.units,
  264. seed=self.seed).T
  265. yhat_boots = grid.dot(beta_boots).T
  266. return yhat, yhat_boots
  267. def bin_predictor(self, bins):
  268. """Discretize a predictor by assigning value to closest bin."""
  269. x = self.x
  270. if np.isscalar(bins):
  271. percentiles = np.linspace(0, 100, bins + 2)[1:-1]
  272. bins = np.c_[np.percentile(x, percentiles)]
  273. else:
  274. bins = np.c_[np.ravel(bins)]
  275. dist = distance.cdist(np.c_[x], bins)
  276. x_binned = bins[np.argmin(dist, axis=1)].ravel()
  277. return x_binned, bins.ravel()
  278. def regress_out(self, a, b):
  279. """Regress b from a keeping a's original mean."""
  280. a_mean = a.mean()
  281. a = a - a_mean
  282. b = b - b.mean()
  283. b = np.c_[b]
  284. a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
  285. return np.asarray(a_prime + a_mean).reshape(a.shape)
  286. def plot(self, ax, scatter_kws, line_kws):
  287. """Draw the full plot."""
  288. # Insert the plot label into the correct set of keyword arguments
  289. if self.scatter:
  290. scatter_kws["label"] = self.label
  291. else:
  292. line_kws["label"] = self.label
  293. # Use the current color cycle state as a default
  294. if self.color is None:
  295. lines, = ax.plot([], [])
  296. color = lines.get_color()
  297. lines.remove()
  298. else:
  299. color = self.color
  300. # Ensure that color is hex to avoid matplotlib weirdness
  301. color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
  302. # Let color in keyword arguments override overall plot color
  303. scatter_kws.setdefault("color", color)
  304. line_kws.setdefault("color", color)
  305. # Draw the constituent plots
  306. if self.scatter:
  307. self.scatterplot(ax, scatter_kws)
  308. if self.fit_reg:
  309. self.lineplot(ax, line_kws)
  310. # Label the axes
  311. if hasattr(self.x, "name"):
  312. ax.set_xlabel(self.x.name)
  313. if hasattr(self.y, "name"):
  314. ax.set_ylabel(self.y.name)
  315. def scatterplot(self, ax, kws):
  316. """Draw the data."""
  317. # Treat the line-based markers specially, explicitly setting larger
  318. # linewidth than is provided by the seaborn style defaults.
  319. # This would ideally be handled better in matplotlib (i.e., distinguish
  320. # between edgewidth for solid glyphs and linewidth for line glyphs
  321. # but this should do for now.
  322. line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
  323. if self.x_estimator is None:
  324. if "marker" in kws and kws["marker"] in line_markers:
  325. lw = mpl.rcParams["lines.linewidth"]
  326. else:
  327. lw = mpl.rcParams["lines.markeredgewidth"]
  328. kws.setdefault("linewidths", lw)
  329. if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
  330. kws.setdefault("alpha", .8)
  331. x, y = self.scatter_data
  332. ax.scatter(x, y, **kws)
  333. else:
  334. # TODO abstraction
  335. ci_kws = {"color": kws["color"]}
  336. ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
  337. kws.setdefault("s", 50)
  338. xs, ys, cis = self.estimate_data
  339. if [ci for ci in cis if ci is not None]:
  340. for x, ci in zip(xs, cis):
  341. ax.plot([x, x], ci, **ci_kws)
  342. ax.scatter(xs, ys, **kws)
  343. def lineplot(self, ax, kws):
  344. """Draw the model."""
  345. # Fit the regression model
  346. grid, yhat, err_bands = self.fit_regression(ax)
  347. edges = grid[0], grid[-1]
  348. # Get set default aesthetics
  349. fill_color = kws["color"]
  350. lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
  351. kws.setdefault("linewidth", lw)
  352. # Draw the regression line and confidence interval
  353. line, = ax.plot(grid, yhat, **kws)
  354. line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
  355. if err_bands is not None:
  356. ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
  357. _regression_docs = dict(
  358. model_api=dedent("""\
  359. There are a number of mutually exclusive options for estimating the
  360. regression model. See the :ref:`tutorial <regression_tutorial>` for more
  361. information.\
  362. """),
  363. regplot_vs_lmplot=dedent("""\
  364. The :func:`regplot` and :func:`lmplot` functions are closely related, but
  365. the former is an axes-level function while the latter is a figure-level
  366. function that combines :func:`regplot` and :class:`FacetGrid`.\
  367. """),
  368. x_estimator=dedent("""\
  369. x_estimator : callable that maps vector -> scalar, optional
  370. Apply this function to each unique value of ``x`` and plot the
  371. resulting estimate. This is useful when ``x`` is a discrete variable.
  372. If ``x_ci`` is given, this estimate will be bootstrapped and a
  373. confidence interval will be drawn.\
  374. """),
  375. x_bins=dedent("""\
  376. x_bins : int or vector, optional
  377. Bin the ``x`` variable into discrete bins and then estimate the central
  378. tendency and a confidence interval. This binning only influences how
  379. the scatterplot is drawn; the regression is still fit to the original
  380. data. This parameter is interpreted either as the number of
  381. evenly-sized (not necessary spaced) bins or the positions of the bin
  382. centers. When this parameter is used, it implies that the default of
  383. ``x_estimator`` is ``numpy.mean``.\
  384. """),
  385. x_ci=dedent("""\
  386. x_ci : "ci", "sd", int in [0, 100] or None, optional
  387. Size of the confidence interval used when plotting a central tendency
  388. for discrete values of ``x``. If ``"ci"``, defer to the value of the
  389. ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
  390. standard deviation of the observations in each bin.\
  391. """),
  392. scatter=dedent("""\
  393. scatter : bool, optional
  394. If ``True``, draw a scatterplot with the underlying observations (or
  395. the ``x_estimator`` values).\
  396. """),
  397. fit_reg=dedent("""\
  398. fit_reg : bool, optional
  399. If ``True``, estimate and plot a regression model relating the ``x``
  400. and ``y`` variables.\
  401. """),
  402. ci=dedent("""\
  403. ci : int in [0, 100] or None, optional
  404. Size of the confidence interval for the regression estimate. This will
  405. be drawn using translucent bands around the regression line. The
  406. confidence interval is estimated using a bootstrap; for large
  407. datasets, it may be advisable to avoid that computation by setting
  408. this parameter to None.\
  409. """),
  410. n_boot=dedent("""\
  411. n_boot : int, optional
  412. Number of bootstrap resamples used to estimate the ``ci``. The default
  413. value attempts to balance time and stability; you may want to increase
  414. this value for "final" versions of plots.\
  415. """),
  416. units=dedent("""\
  417. units : variable name in ``data``, optional
  418. If the ``x`` and ``y`` observations are nested within sampling units,
  419. those can be specified here. This will be taken into account when
  420. computing the confidence intervals by performing a multilevel bootstrap
  421. that resamples both units and observations (within unit). This does not
  422. otherwise influence how the regression is estimated or drawn.\
  423. """),
  424. seed=dedent("""\
  425. seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
  426. Seed or random number generator for reproducible bootstrapping.\
  427. """),
  428. order=dedent("""\
  429. order : int, optional
  430. If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
  431. polynomial regression.\
  432. """),
  433. logistic=dedent("""\
  434. logistic : bool, optional
  435. If ``True``, assume that ``y`` is a binary variable and use
  436. ``statsmodels`` to estimate a logistic regression model. Note that this
  437. is substantially more computationally intensive than linear regression,
  438. so you may wish to decrease the number of bootstrap resamples
  439. (``n_boot``) or set ``ci`` to None.\
  440. """),
  441. lowess=dedent("""\
  442. lowess : bool, optional
  443. If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
  444. model (locally weighted linear regression). Note that confidence
  445. intervals cannot currently be drawn for this kind of model.\
  446. """),
  447. robust=dedent("""\
  448. robust : bool, optional
  449. If ``True``, use ``statsmodels`` to estimate a robust regression. This
  450. will de-weight outliers. Note that this is substantially more
  451. computationally intensive than standard linear regression, so you may
  452. wish to decrease the number of bootstrap resamples (``n_boot``) or set
  453. ``ci`` to None.\
  454. """),
  455. logx=dedent("""\
  456. logx : bool, optional
  457. If ``True``, estimate a linear regression of the form y ~ log(x), but
  458. plot the scatterplot and regression model in the input space. Note that
  459. ``x`` must be positive for this to work.\
  460. """),
  461. xy_partial=dedent("""\
  462. {x,y}_partial : strings in ``data`` or matrices
  463. Confounding variables to regress out of the ``x`` or ``y`` variables
  464. before plotting.\
  465. """),
  466. truncate=dedent("""\
  467. truncate : bool, optional
  468. If ``True``, the regression line is bounded by the data limits. If
  469. ``False``, it extends to the ``x`` axis limits.
  470. """),
  471. xy_jitter=dedent("""\
  472. {x,y}_jitter : floats, optional
  473. Add uniform random noise of this size to either the ``x`` or ``y``
  474. variables. The noise is added to a copy of the data after fitting the
  475. regression, and only influences the look of the scatterplot. This can
  476. be helpful when plotting variables that take discrete values.\
  477. """),
  478. scatter_line_kws=dedent("""\
  479. {scatter,line}_kws : dictionaries
  480. Additional keyword arguments to pass to ``plt.scatter`` and
  481. ``plt.plot``.\
  482. """),
  483. )
  484. _regression_docs.update(_facet_docs)
  485. def lmplot(x, y, data, hue=None, col=None, row=None, palette=None,
  486. col_wrap=None, height=5, aspect=1, markers="o", sharex=True,
  487. sharey=True, hue_order=None, col_order=None, row_order=None,
  488. legend=True, legend_out=True, x_estimator=None, x_bins=None,
  489. x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
  490. units=None, seed=None, order=1, logistic=False, lowess=False,
  491. robust=False, logx=False, x_partial=None, y_partial=None,
  492. truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
  493. line_kws=None, size=None):
  494. # Handle deprecations
  495. if size is not None:
  496. height = size
  497. msg = ("The `size` parameter has been renamed to `height`; "
  498. "please update your code.")
  499. warnings.warn(msg, UserWarning)
  500. # Reduce the dataframe to only needed columns
  501. need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
  502. cols = np.unique([a for a in need_cols if a is not None]).tolist()
  503. data = data[cols]
  504. # Initialize the grid
  505. facets = FacetGrid(data, row, col, hue, palette=palette,
  506. row_order=row_order, col_order=col_order,
  507. hue_order=hue_order, height=height, aspect=aspect,
  508. col_wrap=col_wrap, sharex=sharex, sharey=sharey,
  509. legend_out=legend_out)
  510. # Add the markers here as FacetGrid has figured out how many levels of the
  511. # hue variable are needed and we don't want to duplicate that process
  512. if facets.hue_names is None:
  513. n_markers = 1
  514. else:
  515. n_markers = len(facets.hue_names)
  516. if not isinstance(markers, list):
  517. markers = [markers] * n_markers
  518. if len(markers) != n_markers:
  519. raise ValueError(("markers must be a singeton or a list of markers "
  520. "for each level of the hue variable"))
  521. facets.hue_kws = {"marker": markers}
  522. # Hack to set the x limits properly, which needs to happen here
  523. # because the extent of the regression estimate is determined
  524. # by the limits of the plot
  525. if sharex:
  526. for ax in facets.axes.flat:
  527. ax.scatter(data[x], np.ones(len(data)) * data[y].mean()).remove()
  528. # Draw the regression plot on each facet
  529. regplot_kws = dict(
  530. x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
  531. scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
  532. seed=seed, order=order, logistic=logistic, lowess=lowess,
  533. robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
  534. truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
  535. scatter_kws=scatter_kws, line_kws=line_kws,
  536. )
  537. facets.map_dataframe(regplot, x, y, **regplot_kws)
  538. # Add a legend
  539. if legend and (hue is not None) and (hue not in [col, row]):
  540. facets.add_legend()
  541. return facets
  542. lmplot.__doc__ = dedent("""\
  543. Plot data and regression model fits across a FacetGrid.
  544. This function combines :func:`regplot` and :class:`FacetGrid`. It is
  545. intended as a convenient interface to fit regression models across
  546. conditional subsets of a dataset.
  547. When thinking about how to assign variables to different facets, a general
  548. rule is that it makes sense to use ``hue`` for the most important
  549. comparison, followed by ``col`` and ``row``. However, always think about
  550. your particular dataset and the goals of the visualization you are
  551. creating.
  552. {model_api}
  553. The parameters to this function span most of the options in
  554. :class:`FacetGrid`, although there may be occasional cases where you will
  555. want to use that class and :func:`regplot` directly.
  556. Parameters
  557. ----------
  558. x, y : strings, optional
  559. Input variables; these should be column names in ``data``.
  560. {data}
  561. hue, col, row : strings
  562. Variables that define subsets of the data, which will be drawn on
  563. separate facets in the grid. See the ``*_order`` parameters to control
  564. the order of levels of this variable.
  565. {palette}
  566. {col_wrap}
  567. {height}
  568. {aspect}
  569. markers : matplotlib marker code or list of marker codes, optional
  570. Markers for the scatterplot. If a list, each marker in the list will be
  571. used for each level of the ``hue`` variable.
  572. {share_xy}
  573. {{hue,col,row}}_order : lists, optional
  574. Order for the levels of the faceting variables. By default, this will
  575. be the order that the levels appear in ``data`` or, if the variables
  576. are pandas categoricals, the category order.
  577. legend : bool, optional
  578. If ``True`` and there is a ``hue`` variable, add a legend.
  579. {legend_out}
  580. {x_estimator}
  581. {x_bins}
  582. {x_ci}
  583. {scatter}
  584. {fit_reg}
  585. {ci}
  586. {n_boot}
  587. {units}
  588. {seed}
  589. {order}
  590. {logistic}
  591. {lowess}
  592. {robust}
  593. {logx}
  594. {xy_partial}
  595. {truncate}
  596. {xy_jitter}
  597. {scatter_line_kws}
  598. See Also
  599. --------
  600. regplot : Plot data and a conditional model fit.
  601. FacetGrid : Subplot grid for plotting conditional relationships.
  602. pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
  603. ``kind="reg"``).
  604. Notes
  605. -----
  606. {regplot_vs_lmplot}
  607. Examples
  608. --------
  609. These examples focus on basic regression model plots to exhibit the
  610. various faceting options; see the :func:`regplot` docs for demonstrations
  611. of the other options for plotting the data and models. There are also
  612. other examples for how to manipulate plot using the returned object on
  613. the :class:`FacetGrid` docs.
  614. Plot a simple linear relationship between two variables:
  615. .. plot::
  616. :context: close-figs
  617. >>> import seaborn as sns; sns.set(color_codes=True)
  618. >>> tips = sns.load_dataset("tips")
  619. >>> g = sns.lmplot(x="total_bill", y="tip", data=tips)
  620. Condition on a third variable and plot the levels in different colors:
  621. .. plot::
  622. :context: close-figs
  623. >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips)
  624. Use different markers as well as colors so the plot will reproduce to
  625. black-and-white more easily:
  626. .. plot::
  627. :context: close-figs
  628. >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
  629. ... markers=["o", "x"])
  630. Use a different color palette:
  631. .. plot::
  632. :context: close-figs
  633. >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
  634. ... palette="Set1")
  635. Map ``hue`` levels to colors with a dictionary:
  636. .. plot::
  637. :context: close-figs
  638. >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
  639. ... palette=dict(Yes="g", No="m"))
  640. Plot the levels of the third variable across different columns:
  641. .. plot::
  642. :context: close-figs
  643. >>> g = sns.lmplot(x="total_bill", y="tip", col="smoker", data=tips)
  644. Change the height and aspect ratio of the facets:
  645. .. plot::
  646. :context: close-figs
  647. >>> g = sns.lmplot(x="size", y="total_bill", hue="day", col="day",
  648. ... data=tips, height=6, aspect=.4, x_jitter=.1)
  649. Wrap the levels of the column variable into multiple rows:
  650. .. plot::
  651. :context: close-figs
  652. >>> g = sns.lmplot(x="total_bill", y="tip", col="day", hue="day",
  653. ... data=tips, col_wrap=2, height=3)
  654. Condition on two variables to make a full grid:
  655. .. plot::
  656. :context: close-figs
  657. >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time",
  658. ... data=tips, height=3)
  659. Use methods on the returned :class:`FacetGrid` instance to further tweak
  660. the plot:
  661. .. plot::
  662. :context: close-figs
  663. >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time",
  664. ... data=tips, height=3)
  665. >>> g = (g.set_axis_labels("Total bill (US Dollars)", "Tip")
  666. ... .set(xlim=(0, 60), ylim=(0, 12),
  667. ... xticks=[10, 30, 50], yticks=[2, 6, 10])
  668. ... .fig.subplots_adjust(wspace=.02))
  669. """).format(**_regression_docs)
  670. def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci",
  671. scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
  672. seed=None, order=1, logistic=False, lowess=False, robust=False,
  673. logx=False, x_partial=None, y_partial=None,
  674. truncate=True, dropna=True, x_jitter=None, y_jitter=None,
  675. label=None, color=None, marker="o",
  676. scatter_kws=None, line_kws=None, ax=None):
  677. plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
  678. scatter, fit_reg, ci, n_boot, units, seed,
  679. order, logistic, lowess, robust, logx,
  680. x_partial, y_partial, truncate, dropna,
  681. x_jitter, y_jitter, color, label)
  682. if ax is None:
  683. ax = plt.gca()
  684. scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
  685. scatter_kws["marker"] = marker
  686. line_kws = {} if line_kws is None else copy.copy(line_kws)
  687. plotter.plot(ax, scatter_kws, line_kws)
  688. return ax
  689. regplot.__doc__ = dedent("""\
  690. Plot data and a linear regression model fit.
  691. {model_api}
  692. Parameters
  693. ----------
  694. x, y: string, series, or vector array
  695. Input variables. If strings, these should correspond with column names
  696. in ``data``. When pandas objects are used, axes will be labeled with
  697. the series name.
  698. {data}
  699. {x_estimator}
  700. {x_bins}
  701. {x_ci}
  702. {scatter}
  703. {fit_reg}
  704. {ci}
  705. {n_boot}
  706. {units}
  707. {seed}
  708. {order}
  709. {logistic}
  710. {lowess}
  711. {robust}
  712. {logx}
  713. {xy_partial}
  714. {truncate}
  715. {xy_jitter}
  716. label : string
  717. Label to apply to either the scatterplot or regression line (if
  718. ``scatter`` is ``False``) for use in a legend.
  719. color : matplotlib color
  720. Color to apply to all plot elements; will be superseded by colors
  721. passed in ``scatter_kws`` or ``line_kws``.
  722. marker : matplotlib marker code
  723. Marker to use for the scatterplot glyphs.
  724. {scatter_line_kws}
  725. ax : matplotlib Axes, optional
  726. Axes object to draw the plot onto, otherwise uses the current Axes.
  727. Returns
  728. -------
  729. ax : matplotlib Axes
  730. The Axes object containing the plot.
  731. See Also
  732. --------
  733. lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
  734. linear relationships in a dataset.
  735. jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
  736. ``kind="reg"``).
  737. pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
  738. ``kind="reg"``).
  739. residplot : Plot the residuals of a linear regression model.
  740. Notes
  741. -----
  742. {regplot_vs_lmplot}
  743. It's also easy to combine combine :func:`regplot` and :class:`JointGrid` or
  744. :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
  745. functions, although these do not directly accept all of :func:`regplot`'s
  746. parameters.
  747. Examples
  748. --------
  749. Plot the relationship between two variables in a DataFrame:
  750. .. plot::
  751. :context: close-figs
  752. >>> import seaborn as sns; sns.set(color_codes=True)
  753. >>> tips = sns.load_dataset("tips")
  754. >>> ax = sns.regplot(x="total_bill", y="tip", data=tips)
  755. Plot with two variables defined as numpy arrays; use a different color:
  756. .. plot::
  757. :context: close-figs
  758. >>> import numpy as np; np.random.seed(8)
  759. >>> mean, cov = [4, 6], [(1.5, .7), (.7, 1)]
  760. >>> x, y = np.random.multivariate_normal(mean, cov, 80).T
  761. >>> ax = sns.regplot(x=x, y=y, color="g")
  762. Plot with two variables defined as pandas Series; use a different marker:
  763. .. plot::
  764. :context: close-figs
  765. >>> import pandas as pd
  766. >>> x, y = pd.Series(x, name="x_var"), pd.Series(y, name="y_var")
  767. >>> ax = sns.regplot(x=x, y=y, marker="+")
  768. Use a 68% confidence interval, which corresponds with the standard error
  769. of the estimate, and extend the regression line to the axis limits:
  770. .. plot::
  771. :context: close-figs
  772. >>> ax = sns.regplot(x=x, y=y, ci=68, truncate=False)
  773. Plot with a discrete ``x`` variable and add some jitter:
  774. .. plot::
  775. :context: close-figs
  776. >>> ax = sns.regplot(x="size", y="total_bill", data=tips, x_jitter=.1)
  777. Plot with a discrete ``x`` variable showing means and confidence intervals
  778. for unique values:
  779. .. plot::
  780. :context: close-figs
  781. >>> ax = sns.regplot(x="size", y="total_bill", data=tips,
  782. ... x_estimator=np.mean)
  783. Plot with a continuous variable divided into discrete bins:
  784. .. plot::
  785. :context: close-figs
  786. >>> ax = sns.regplot(x=x, y=y, x_bins=4)
  787. Fit a higher-order polynomial regression:
  788. .. plot::
  789. :context: close-figs
  790. >>> ans = sns.load_dataset("anscombe")
  791. >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "II"],
  792. ... scatter_kws={{"s": 80}},
  793. ... order=2, ci=None)
  794. Fit a robust regression and don't plot a confidence interval:
  795. .. plot::
  796. :context: close-figs
  797. >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "III"],
  798. ... scatter_kws={{"s": 80}},
  799. ... robust=True, ci=None)
  800. Fit a logistic regression; jitter the y variable and use fewer bootstrap
  801. iterations:
  802. .. plot::
  803. :context: close-figs
  804. >>> tips["big_tip"] = (tips.tip / tips.total_bill) > .175
  805. >>> ax = sns.regplot(x="total_bill", y="big_tip", data=tips,
  806. ... logistic=True, n_boot=500, y_jitter=.03)
  807. Fit the regression model using log(x):
  808. .. plot::
  809. :context: close-figs
  810. >>> ax = sns.regplot(x="size", y="total_bill", data=tips,
  811. ... x_estimator=np.mean, logx=True)
  812. """).format(**_regression_docs)
  813. def residplot(x, y, data=None, lowess=False, x_partial=None, y_partial=None,
  814. order=1, robust=False, dropna=True, label=None, color=None,
  815. scatter_kws=None, line_kws=None, ax=None):
  816. """Plot the residuals of a linear regression.
  817. This function will regress y on x (possibly as a robust or polynomial
  818. regression) and then draw a scatterplot of the residuals. You can
  819. optionally fit a lowess smoother to the residual plot, which can
  820. help in determining if there is structure to the residuals.
  821. Parameters
  822. ----------
  823. x : vector or string
  824. Data or column name in `data` for the predictor variable.
  825. y : vector or string
  826. Data or column name in `data` for the response variable.
  827. data : DataFrame, optional
  828. DataFrame to use if `x` and `y` are column names.
  829. lowess : boolean, optional
  830. Fit a lowess smoother to the residual scatterplot.
  831. {x, y}_partial : matrix or string(s) , optional
  832. Matrix with same first dimension as `x`, or column name(s) in `data`.
  833. These variables are treated as confounding and are removed from
  834. the `x` or `y` variables before plotting.
  835. order : int, optional
  836. Order of the polynomial to fit when calculating the residuals.
  837. robust : boolean, optional
  838. Fit a robust linear regression when calculating the residuals.
  839. dropna : boolean, optional
  840. If True, ignore observations with missing data when fitting and
  841. plotting.
  842. label : string, optional
  843. Label that will be used in any plot legends.
  844. color : matplotlib color, optional
  845. Color to use for all elements of the plot.
  846. {scatter, line}_kws : dictionaries, optional
  847. Additional keyword arguments passed to scatter() and plot() for drawing
  848. the components of the plot.
  849. ax : matplotlib axis, optional
  850. Plot into this axis, otherwise grab the current axis or make a new
  851. one if not existing.
  852. Returns
  853. -------
  854. ax: matplotlib axes
  855. Axes with the regression plot.
  856. See Also
  857. --------
  858. regplot : Plot a simple linear regression model.
  859. jointplot : Draw a :func:`residplot` with univariate marginal distributions
  860. (when used with ``kind="resid"``).
  861. """
  862. plotter = _RegressionPlotter(x, y, data, ci=None,
  863. order=order, robust=robust,
  864. x_partial=x_partial, y_partial=y_partial,
  865. dropna=dropna, color=color, label=label)
  866. if ax is None:
  867. ax = plt.gca()
  868. # Calculate the residual from a linear regression
  869. _, yhat, _ = plotter.fit_regression(grid=plotter.x)
  870. plotter.y = plotter.y - yhat
  871. # Set the regression option on the plotter
  872. if lowess:
  873. plotter.lowess = True
  874. else:
  875. plotter.fit_reg = False
  876. # Plot a horizontal line at 0
  877. ax.axhline(0, ls=":", c=".2")
  878. # Draw the scatterplot
  879. scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
  880. line_kws = {} if line_kws is None else line_kws.copy()
  881. plotter.plot(ax, scatter_kws, line_kws)
  882. return ax