1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069 |
- """Plotting functions for linear models (broadly construed)."""
- import copy
- from textwrap import dedent
- import warnings
- import numpy as np
- import pandas as pd
- from scipy.spatial import distance
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- try:
- import statsmodels
- assert statsmodels
- _has_statsmodels = True
- except ImportError:
- _has_statsmodels = False
- from . import utils
- from . import algorithms as algo
- from .axisgrid import FacetGrid, _facet_docs
- __all__ = ["lmplot", "regplot", "residplot"]
- class _LinearPlotter(object):
- """Base class for plotting relational data in tidy format.
- To get anything useful done you'll have to inherit from this, but setup
- code that can be abstracted out should be put here.
- """
- def establish_variables(self, data, **kws):
- """Extract variables from data or use directly."""
- self.data = data
- # Validate the inputs
- any_strings = any([isinstance(v, str) for v in kws.values()])
- if any_strings and data is None:
- raise ValueError("Must pass `data` if using named variables.")
- # Set the variables
- for var, val in kws.items():
- if isinstance(val, str):
- vector = data[val]
- elif isinstance(val, list):
- vector = np.asarray(val)
- else:
- vector = val
- if vector is not None and vector.shape != (1,):
- vector = np.squeeze(vector)
- if np.ndim(vector) > 1:
- err = "regplot inputs must be 1d"
- raise ValueError(err)
- setattr(self, var, vector)
- def dropna(self, *vars):
- """Remove observations with missing data."""
- vals = [getattr(self, var) for var in vars]
- vals = [v for v in vals if v is not None]
- not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
- for var in vars:
- val = getattr(self, var)
- if val is not None:
- setattr(self, var, val[not_na])
- def plot(self, ax):
- raise NotImplementedError
- class _RegressionPlotter(_LinearPlotter):
- """Plotter for numeric independent variables with regression model.
- This does the computations and drawing for the `regplot` function, and
- is thus also used indirectly by `lmplot`.
- """
- def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
- x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
- units=None, seed=None, order=1, logistic=False, lowess=False,
- robust=False, logx=False, x_partial=None, y_partial=None,
- truncate=False, dropna=True, x_jitter=None, y_jitter=None,
- color=None, label=None):
- # Set member attributes
- self.x_estimator = x_estimator
- self.ci = ci
- self.x_ci = ci if x_ci == "ci" else x_ci
- self.n_boot = n_boot
- self.seed = seed
- self.scatter = scatter
- self.fit_reg = fit_reg
- self.order = order
- self.logistic = logistic
- self.lowess = lowess
- self.robust = robust
- self.logx = logx
- self.truncate = truncate
- self.x_jitter = x_jitter
- self.y_jitter = y_jitter
- self.color = color
- self.label = label
- # Validate the regression options:
- if sum((order > 1, logistic, robust, lowess, logx)) > 1:
- raise ValueError("Mutually exclusive regression options.")
- # Extract the data vals from the arguments or passed dataframe
- self.establish_variables(data, x=x, y=y, units=units,
- x_partial=x_partial, y_partial=y_partial)
- # Drop null observations
- if dropna:
- self.dropna("x", "y", "units", "x_partial", "y_partial")
- # Regress nuisance variables out of the data
- if self.x_partial is not None:
- self.x = self.regress_out(self.x, self.x_partial)
- if self.y_partial is not None:
- self.y = self.regress_out(self.y, self.y_partial)
- # Possibly bin the predictor variable, which implies a point estimate
- if x_bins is not None:
- self.x_estimator = np.mean if x_estimator is None else x_estimator
- x_discrete, x_bins = self.bin_predictor(x_bins)
- self.x_discrete = x_discrete
- else:
- self.x_discrete = self.x
- # Disable regression in case of singleton inputs
- if len(self.x) <= 1:
- self.fit_reg = False
- # Save the range of the x variable for the grid later
- if self.fit_reg:
- self.x_range = self.x.min(), self.x.max()
- @property
- def scatter_data(self):
- """Data where each observation is a point."""
- x_j = self.x_jitter
- if x_j is None:
- x = self.x
- else:
- x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
- y_j = self.y_jitter
- if y_j is None:
- y = self.y
- else:
- y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
- return x, y
- @property
- def estimate_data(self):
- """Data with a point estimate and CI for each discrete x value."""
- x, y = self.x_discrete, self.y
- vals = sorted(np.unique(x))
- points, cis = [], []
- for val in vals:
- # Get the point estimate of the y variable
- _y = y[x == val]
- est = self.x_estimator(_y)
- points.append(est)
- # Compute the confidence interval for this estimate
- if self.x_ci is None:
- cis.append(None)
- else:
- units = None
- if self.x_ci == "sd":
- sd = np.std(_y)
- _ci = est - sd, est + sd
- else:
- if self.units is not None:
- units = self.units[x == val]
- boots = algo.bootstrap(_y,
- func=self.x_estimator,
- n_boot=self.n_boot,
- units=units,
- seed=self.seed)
- _ci = utils.ci(boots, self.x_ci)
- cis.append(_ci)
- return vals, points, cis
- def fit_regression(self, ax=None, x_range=None, grid=None):
- """Fit the regression model."""
- # Create the grid for the regression
- if grid is None:
- if self.truncate:
- x_min, x_max = self.x_range
- else:
- if ax is None:
- x_min, x_max = x_range
- else:
- x_min, x_max = ax.get_xlim()
- grid = np.linspace(x_min, x_max, 100)
- ci = self.ci
- # Fit the regression
- if self.order > 1:
- yhat, yhat_boots = self.fit_poly(grid, self.order)
- elif self.logistic:
- from statsmodels.genmod.generalized_linear_model import GLM
- from statsmodels.genmod.families import Binomial
- yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
- family=Binomial())
- elif self.lowess:
- ci = None
- grid, yhat = self.fit_lowess()
- elif self.robust:
- from statsmodels.robust.robust_linear_model import RLM
- yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
- elif self.logx:
- yhat, yhat_boots = self.fit_logx(grid)
- else:
- yhat, yhat_boots = self.fit_fast(grid)
- # Compute the confidence interval at each grid point
- if ci is None:
- err_bands = None
- else:
- err_bands = utils.ci(yhat_boots, ci, axis=0)
- return grid, yhat, err_bands
- def fit_fast(self, grid):
- """Low-level regression and prediction using linear algebra."""
- def reg_func(_x, _y):
- return np.linalg.pinv(_x).dot(_y)
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), grid]
- yhat = grid.dot(reg_func(X, y))
- if self.ci is None:
- return yhat, None
- beta_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed).T
- yhat_boots = grid.dot(beta_boots).T
- return yhat, yhat_boots
- def fit_poly(self, grid, order):
- """Regression using numpy polyfit for higher-order trends."""
- def reg_func(_x, _y):
- return np.polyval(np.polyfit(_x, _y, order), grid)
- x, y = self.x, self.y
- yhat = reg_func(x, y)
- if self.ci is None:
- return yhat, None
- yhat_boots = algo.bootstrap(x, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed)
- return yhat, yhat_boots
- def fit_statsmodels(self, grid, model, **kwargs):
- """More general regression function using statsmodels objects."""
- import statsmodels.genmod.generalized_linear_model as glm
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), grid]
- def reg_func(_x, _y):
- try:
- yhat = model(_y, _x, **kwargs).fit().predict(grid)
- except glm.PerfectSeparationError:
- yhat = np.empty(len(grid))
- yhat.fill(np.nan)
- return yhat
- yhat = reg_func(X, y)
- if self.ci is None:
- return yhat, None
- yhat_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed)
- return yhat, yhat_boots
- def fit_lowess(self):
- """Fit a locally-weighted regression, which returns its own grid."""
- from statsmodels.nonparametric.smoothers_lowess import lowess
- grid, yhat = lowess(self.y, self.x).T
- return grid, yhat
- def fit_logx(self, grid):
- """Fit the model in log-space."""
- X, y = np.c_[np.ones(len(self.x)), self.x], self.y
- grid = np.c_[np.ones(len(grid)), np.log(grid)]
- def reg_func(_x, _y):
- _x = np.c_[_x[:, 0], np.log(_x[:, 1])]
- return np.linalg.pinv(_x).dot(_y)
- yhat = grid.dot(reg_func(X, y))
- if self.ci is None:
- return yhat, None
- beta_boots = algo.bootstrap(X, y,
- func=reg_func,
- n_boot=self.n_boot,
- units=self.units,
- seed=self.seed).T
- yhat_boots = grid.dot(beta_boots).T
- return yhat, yhat_boots
- def bin_predictor(self, bins):
- """Discretize a predictor by assigning value to closest bin."""
- x = self.x
- if np.isscalar(bins):
- percentiles = np.linspace(0, 100, bins + 2)[1:-1]
- bins = np.c_[np.percentile(x, percentiles)]
- else:
- bins = np.c_[np.ravel(bins)]
- dist = distance.cdist(np.c_[x], bins)
- x_binned = bins[np.argmin(dist, axis=1)].ravel()
- return x_binned, bins.ravel()
- def regress_out(self, a, b):
- """Regress b from a keeping a's original mean."""
- a_mean = a.mean()
- a = a - a_mean
- b = b - b.mean()
- b = np.c_[b]
- a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
- return np.asarray(a_prime + a_mean).reshape(a.shape)
- def plot(self, ax, scatter_kws, line_kws):
- """Draw the full plot."""
- # Insert the plot label into the correct set of keyword arguments
- if self.scatter:
- scatter_kws["label"] = self.label
- else:
- line_kws["label"] = self.label
- # Use the current color cycle state as a default
- if self.color is None:
- lines, = ax.plot([], [])
- color = lines.get_color()
- lines.remove()
- else:
- color = self.color
- # Ensure that color is hex to avoid matplotlib weirdness
- color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
- # Let color in keyword arguments override overall plot color
- scatter_kws.setdefault("color", color)
- line_kws.setdefault("color", color)
- # Draw the constituent plots
- if self.scatter:
- self.scatterplot(ax, scatter_kws)
- if self.fit_reg:
- self.lineplot(ax, line_kws)
- # Label the axes
- if hasattr(self.x, "name"):
- ax.set_xlabel(self.x.name)
- if hasattr(self.y, "name"):
- ax.set_ylabel(self.y.name)
- def scatterplot(self, ax, kws):
- """Draw the data."""
- # Treat the line-based markers specially, explicitly setting larger
- # linewidth than is provided by the seaborn style defaults.
- # This would ideally be handled better in matplotlib (i.e., distinguish
- # between edgewidth for solid glyphs and linewidth for line glyphs
- # but this should do for now.
- line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
- if self.x_estimator is None:
- if "marker" in kws and kws["marker"] in line_markers:
- lw = mpl.rcParams["lines.linewidth"]
- else:
- lw = mpl.rcParams["lines.markeredgewidth"]
- kws.setdefault("linewidths", lw)
- if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
- kws.setdefault("alpha", .8)
- x, y = self.scatter_data
- ax.scatter(x, y, **kws)
- else:
- # TODO abstraction
- ci_kws = {"color": kws["color"]}
- ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
- kws.setdefault("s", 50)
- xs, ys, cis = self.estimate_data
- if [ci for ci in cis if ci is not None]:
- for x, ci in zip(xs, cis):
- ax.plot([x, x], ci, **ci_kws)
- ax.scatter(xs, ys, **kws)
- def lineplot(self, ax, kws):
- """Draw the model."""
- # Fit the regression model
- grid, yhat, err_bands = self.fit_regression(ax)
- edges = grid[0], grid[-1]
- # Get set default aesthetics
- fill_color = kws["color"]
- lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
- kws.setdefault("linewidth", lw)
- # Draw the regression line and confidence interval
- line, = ax.plot(grid, yhat, **kws)
- line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
- if err_bands is not None:
- ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
- _regression_docs = dict(
- model_api=dedent("""\
- There are a number of mutually exclusive options for estimating the
- regression model. See the :ref:`tutorial <regression_tutorial>` for more
- information.\
- """),
- regplot_vs_lmplot=dedent("""\
- The :func:`regplot` and :func:`lmplot` functions are closely related, but
- the former is an axes-level function while the latter is a figure-level
- function that combines :func:`regplot` and :class:`FacetGrid`.\
- """),
- x_estimator=dedent("""\
- x_estimator : callable that maps vector -> scalar, optional
- Apply this function to each unique value of ``x`` and plot the
- resulting estimate. This is useful when ``x`` is a discrete variable.
- If ``x_ci`` is given, this estimate will be bootstrapped and a
- confidence interval will be drawn.\
- """),
- x_bins=dedent("""\
- x_bins : int or vector, optional
- Bin the ``x`` variable into discrete bins and then estimate the central
- tendency and a confidence interval. This binning only influences how
- the scatterplot is drawn; the regression is still fit to the original
- data. This parameter is interpreted either as the number of
- evenly-sized (not necessary spaced) bins or the positions of the bin
- centers. When this parameter is used, it implies that the default of
- ``x_estimator`` is ``numpy.mean``.\
- """),
- x_ci=dedent("""\
- x_ci : "ci", "sd", int in [0, 100] or None, optional
- Size of the confidence interval used when plotting a central tendency
- for discrete values of ``x``. If ``"ci"``, defer to the value of the
- ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
- standard deviation of the observations in each bin.\
- """),
- scatter=dedent("""\
- scatter : bool, optional
- If ``True``, draw a scatterplot with the underlying observations (or
- the ``x_estimator`` values).\
- """),
- fit_reg=dedent("""\
- fit_reg : bool, optional
- If ``True``, estimate and plot a regression model relating the ``x``
- and ``y`` variables.\
- """),
- ci=dedent("""\
- ci : int in [0, 100] or None, optional
- Size of the confidence interval for the regression estimate. This will
- be drawn using translucent bands around the regression line. The
- confidence interval is estimated using a bootstrap; for large
- datasets, it may be advisable to avoid that computation by setting
- this parameter to None.\
- """),
- n_boot=dedent("""\
- n_boot : int, optional
- Number of bootstrap resamples used to estimate the ``ci``. The default
- value attempts to balance time and stability; you may want to increase
- this value for "final" versions of plots.\
- """),
- units=dedent("""\
- units : variable name in ``data``, optional
- If the ``x`` and ``y`` observations are nested within sampling units,
- those can be specified here. This will be taken into account when
- computing the confidence intervals by performing a multilevel bootstrap
- that resamples both units and observations (within unit). This does not
- otherwise influence how the regression is estimated or drawn.\
- """),
- seed=dedent("""\
- seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
- Seed or random number generator for reproducible bootstrapping.\
- """),
- order=dedent("""\
- order : int, optional
- If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
- polynomial regression.\
- """),
- logistic=dedent("""\
- logistic : bool, optional
- If ``True``, assume that ``y`` is a binary variable and use
- ``statsmodels`` to estimate a logistic regression model. Note that this
- is substantially more computationally intensive than linear regression,
- so you may wish to decrease the number of bootstrap resamples
- (``n_boot``) or set ``ci`` to None.\
- """),
- lowess=dedent("""\
- lowess : bool, optional
- If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
- model (locally weighted linear regression). Note that confidence
- intervals cannot currently be drawn for this kind of model.\
- """),
- robust=dedent("""\
- robust : bool, optional
- If ``True``, use ``statsmodels`` to estimate a robust regression. This
- will de-weight outliers. Note that this is substantially more
- computationally intensive than standard linear regression, so you may
- wish to decrease the number of bootstrap resamples (``n_boot``) or set
- ``ci`` to None.\
- """),
- logx=dedent("""\
- logx : bool, optional
- If ``True``, estimate a linear regression of the form y ~ log(x), but
- plot the scatterplot and regression model in the input space. Note that
- ``x`` must be positive for this to work.\
- """),
- xy_partial=dedent("""\
- {x,y}_partial : strings in ``data`` or matrices
- Confounding variables to regress out of the ``x`` or ``y`` variables
- before plotting.\
- """),
- truncate=dedent("""\
- truncate : bool, optional
- If ``True``, the regression line is bounded by the data limits. If
- ``False``, it extends to the ``x`` axis limits.
- """),
- xy_jitter=dedent("""\
- {x,y}_jitter : floats, optional
- Add uniform random noise of this size to either the ``x`` or ``y``
- variables. The noise is added to a copy of the data after fitting the
- regression, and only influences the look of the scatterplot. This can
- be helpful when plotting variables that take discrete values.\
- """),
- scatter_line_kws=dedent("""\
- {scatter,line}_kws : dictionaries
- Additional keyword arguments to pass to ``plt.scatter`` and
- ``plt.plot``.\
- """),
- )
- _regression_docs.update(_facet_docs)
- def lmplot(x, y, data, hue=None, col=None, row=None, palette=None,
- col_wrap=None, height=5, aspect=1, markers="o", sharex=True,
- sharey=True, hue_order=None, col_order=None, row_order=None,
- legend=True, legend_out=True, x_estimator=None, x_bins=None,
- x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
- units=None, seed=None, order=1, logistic=False, lowess=False,
- robust=False, logx=False, x_partial=None, y_partial=None,
- truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
- line_kws=None, size=None):
- # Handle deprecations
- if size is not None:
- height = size
- msg = ("The `size` parameter has been renamed to `height`; "
- "please update your code.")
- warnings.warn(msg, UserWarning)
- # Reduce the dataframe to only needed columns
- need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
- cols = np.unique([a for a in need_cols if a is not None]).tolist()
- data = data[cols]
- # Initialize the grid
- facets = FacetGrid(data, row, col, hue, palette=palette,
- row_order=row_order, col_order=col_order,
- hue_order=hue_order, height=height, aspect=aspect,
- col_wrap=col_wrap, sharex=sharex, sharey=sharey,
- legend_out=legend_out)
- # Add the markers here as FacetGrid has figured out how many levels of the
- # hue variable are needed and we don't want to duplicate that process
- if facets.hue_names is None:
- n_markers = 1
- else:
- n_markers = len(facets.hue_names)
- if not isinstance(markers, list):
- markers = [markers] * n_markers
- if len(markers) != n_markers:
- raise ValueError(("markers must be a singeton or a list of markers "
- "for each level of the hue variable"))
- facets.hue_kws = {"marker": markers}
- # Hack to set the x limits properly, which needs to happen here
- # because the extent of the regression estimate is determined
- # by the limits of the plot
- if sharex:
- for ax in facets.axes.flat:
- ax.scatter(data[x], np.ones(len(data)) * data[y].mean()).remove()
- # Draw the regression plot on each facet
- regplot_kws = dict(
- x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
- scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
- seed=seed, order=order, logistic=logistic, lowess=lowess,
- robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
- truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
- scatter_kws=scatter_kws, line_kws=line_kws,
- )
- facets.map_dataframe(regplot, x, y, **regplot_kws)
- # Add a legend
- if legend and (hue is not None) and (hue not in [col, row]):
- facets.add_legend()
- return facets
- lmplot.__doc__ = dedent("""\
- Plot data and regression model fits across a FacetGrid.
- This function combines :func:`regplot` and :class:`FacetGrid`. It is
- intended as a convenient interface to fit regression models across
- conditional subsets of a dataset.
- When thinking about how to assign variables to different facets, a general
- rule is that it makes sense to use ``hue`` for the most important
- comparison, followed by ``col`` and ``row``. However, always think about
- your particular dataset and the goals of the visualization you are
- creating.
- {model_api}
- The parameters to this function span most of the options in
- :class:`FacetGrid`, although there may be occasional cases where you will
- want to use that class and :func:`regplot` directly.
- Parameters
- ----------
- x, y : strings, optional
- Input variables; these should be column names in ``data``.
- {data}
- hue, col, row : strings
- Variables that define subsets of the data, which will be drawn on
- separate facets in the grid. See the ``*_order`` parameters to control
- the order of levels of this variable.
- {palette}
- {col_wrap}
- {height}
- {aspect}
- markers : matplotlib marker code or list of marker codes, optional
- Markers for the scatterplot. If a list, each marker in the list will be
- used for each level of the ``hue`` variable.
- {share_xy}
- {{hue,col,row}}_order : lists, optional
- Order for the levels of the faceting variables. By default, this will
- be the order that the levels appear in ``data`` or, if the variables
- are pandas categoricals, the category order.
- legend : bool, optional
- If ``True`` and there is a ``hue`` variable, add a legend.
- {legend_out}
- {x_estimator}
- {x_bins}
- {x_ci}
- {scatter}
- {fit_reg}
- {ci}
- {n_boot}
- {units}
- {seed}
- {order}
- {logistic}
- {lowess}
- {robust}
- {logx}
- {xy_partial}
- {truncate}
- {xy_jitter}
- {scatter_line_kws}
- See Also
- --------
- regplot : Plot data and a conditional model fit.
- FacetGrid : Subplot grid for plotting conditional relationships.
- pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
- ``kind="reg"``).
- Notes
- -----
- {regplot_vs_lmplot}
- Examples
- --------
- These examples focus on basic regression model plots to exhibit the
- various faceting options; see the :func:`regplot` docs for demonstrations
- of the other options for plotting the data and models. There are also
- other examples for how to manipulate plot using the returned object on
- the :class:`FacetGrid` docs.
- Plot a simple linear relationship between two variables:
- .. plot::
- :context: close-figs
- >>> import seaborn as sns; sns.set(color_codes=True)
- >>> tips = sns.load_dataset("tips")
- >>> g = sns.lmplot(x="total_bill", y="tip", data=tips)
- Condition on a third variable and plot the levels in different colors:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips)
- Use different markers as well as colors so the plot will reproduce to
- black-and-white more easily:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
- ... markers=["o", "x"])
- Use a different color palette:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
- ... palette="Set1")
- Map ``hue`` levels to colors with a dictionary:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", hue="smoker", data=tips,
- ... palette=dict(Yes="g", No="m"))
- Plot the levels of the third variable across different columns:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", col="smoker", data=tips)
- Change the height and aspect ratio of the facets:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="size", y="total_bill", hue="day", col="day",
- ... data=tips, height=6, aspect=.4, x_jitter=.1)
- Wrap the levels of the column variable into multiple rows:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", col="day", hue="day",
- ... data=tips, col_wrap=2, height=3)
- Condition on two variables to make a full grid:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time",
- ... data=tips, height=3)
- Use methods on the returned :class:`FacetGrid` instance to further tweak
- the plot:
- .. plot::
- :context: close-figs
- >>> g = sns.lmplot(x="total_bill", y="tip", row="sex", col="time",
- ... data=tips, height=3)
- >>> g = (g.set_axis_labels("Total bill (US Dollars)", "Tip")
- ... .set(xlim=(0, 60), ylim=(0, 12),
- ... xticks=[10, 30, 50], yticks=[2, 6, 10])
- ... .fig.subplots_adjust(wspace=.02))
- """).format(**_regression_docs)
- def regplot(x, y, data=None, x_estimator=None, x_bins=None, x_ci="ci",
- scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
- seed=None, order=1, logistic=False, lowess=False, robust=False,
- logx=False, x_partial=None, y_partial=None,
- truncate=True, dropna=True, x_jitter=None, y_jitter=None,
- label=None, color=None, marker="o",
- scatter_kws=None, line_kws=None, ax=None):
- plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
- scatter, fit_reg, ci, n_boot, units, seed,
- order, logistic, lowess, robust, logx,
- x_partial, y_partial, truncate, dropna,
- x_jitter, y_jitter, color, label)
- if ax is None:
- ax = plt.gca()
- scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
- scatter_kws["marker"] = marker
- line_kws = {} if line_kws is None else copy.copy(line_kws)
- plotter.plot(ax, scatter_kws, line_kws)
- return ax
- regplot.__doc__ = dedent("""\
- Plot data and a linear regression model fit.
- {model_api}
- Parameters
- ----------
- x, y: string, series, or vector array
- Input variables. If strings, these should correspond with column names
- in ``data``. When pandas objects are used, axes will be labeled with
- the series name.
- {data}
- {x_estimator}
- {x_bins}
- {x_ci}
- {scatter}
- {fit_reg}
- {ci}
- {n_boot}
- {units}
- {seed}
- {order}
- {logistic}
- {lowess}
- {robust}
- {logx}
- {xy_partial}
- {truncate}
- {xy_jitter}
- label : string
- Label to apply to either the scatterplot or regression line (if
- ``scatter`` is ``False``) for use in a legend.
- color : matplotlib color
- Color to apply to all plot elements; will be superseded by colors
- passed in ``scatter_kws`` or ``line_kws``.
- marker : matplotlib marker code
- Marker to use for the scatterplot glyphs.
- {scatter_line_kws}
- ax : matplotlib Axes, optional
- Axes object to draw the plot onto, otherwise uses the current Axes.
- Returns
- -------
- ax : matplotlib Axes
- The Axes object containing the plot.
- See Also
- --------
- lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
- linear relationships in a dataset.
- jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
- ``kind="reg"``).
- pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
- ``kind="reg"``).
- residplot : Plot the residuals of a linear regression model.
- Notes
- -----
- {regplot_vs_lmplot}
- It's also easy to combine combine :func:`regplot` and :class:`JointGrid` or
- :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
- functions, although these do not directly accept all of :func:`regplot`'s
- parameters.
- Examples
- --------
- Plot the relationship between two variables in a DataFrame:
- .. plot::
- :context: close-figs
- >>> import seaborn as sns; sns.set(color_codes=True)
- >>> tips = sns.load_dataset("tips")
- >>> ax = sns.regplot(x="total_bill", y="tip", data=tips)
- Plot with two variables defined as numpy arrays; use a different color:
- .. plot::
- :context: close-figs
- >>> import numpy as np; np.random.seed(8)
- >>> mean, cov = [4, 6], [(1.5, .7), (.7, 1)]
- >>> x, y = np.random.multivariate_normal(mean, cov, 80).T
- >>> ax = sns.regplot(x=x, y=y, color="g")
- Plot with two variables defined as pandas Series; use a different marker:
- .. plot::
- :context: close-figs
- >>> import pandas as pd
- >>> x, y = pd.Series(x, name="x_var"), pd.Series(y, name="y_var")
- >>> ax = sns.regplot(x=x, y=y, marker="+")
- Use a 68% confidence interval, which corresponds with the standard error
- of the estimate, and extend the regression line to the axis limits:
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x=x, y=y, ci=68, truncate=False)
- Plot with a discrete ``x`` variable and add some jitter:
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x="size", y="total_bill", data=tips, x_jitter=.1)
- Plot with a discrete ``x`` variable showing means and confidence intervals
- for unique values:
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x="size", y="total_bill", data=tips,
- ... x_estimator=np.mean)
- Plot with a continuous variable divided into discrete bins:
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x=x, y=y, x_bins=4)
- Fit a higher-order polynomial regression:
- .. plot::
- :context: close-figs
- >>> ans = sns.load_dataset("anscombe")
- >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "II"],
- ... scatter_kws={{"s": 80}},
- ... order=2, ci=None)
- Fit a robust regression and don't plot a confidence interval:
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x="x", y="y", data=ans.loc[ans.dataset == "III"],
- ... scatter_kws={{"s": 80}},
- ... robust=True, ci=None)
- Fit a logistic regression; jitter the y variable and use fewer bootstrap
- iterations:
- .. plot::
- :context: close-figs
- >>> tips["big_tip"] = (tips.tip / tips.total_bill) > .175
- >>> ax = sns.regplot(x="total_bill", y="big_tip", data=tips,
- ... logistic=True, n_boot=500, y_jitter=.03)
- Fit the regression model using log(x):
- .. plot::
- :context: close-figs
- >>> ax = sns.regplot(x="size", y="total_bill", data=tips,
- ... x_estimator=np.mean, logx=True)
- """).format(**_regression_docs)
- def residplot(x, y, data=None, lowess=False, x_partial=None, y_partial=None,
- order=1, robust=False, dropna=True, label=None, color=None,
- scatter_kws=None, line_kws=None, ax=None):
- """Plot the residuals of a linear regression.
- This function will regress y on x (possibly as a robust or polynomial
- regression) and then draw a scatterplot of the residuals. You can
- optionally fit a lowess smoother to the residual plot, which can
- help in determining if there is structure to the residuals.
- Parameters
- ----------
- x : vector or string
- Data or column name in `data` for the predictor variable.
- y : vector or string
- Data or column name in `data` for the response variable.
- data : DataFrame, optional
- DataFrame to use if `x` and `y` are column names.
- lowess : boolean, optional
- Fit a lowess smoother to the residual scatterplot.
- {x, y}_partial : matrix or string(s) , optional
- Matrix with same first dimension as `x`, or column name(s) in `data`.
- These variables are treated as confounding and are removed from
- the `x` or `y` variables before plotting.
- order : int, optional
- Order of the polynomial to fit when calculating the residuals.
- robust : boolean, optional
- Fit a robust linear regression when calculating the residuals.
- dropna : boolean, optional
- If True, ignore observations with missing data when fitting and
- plotting.
- label : string, optional
- Label that will be used in any plot legends.
- color : matplotlib color, optional
- Color to use for all elements of the plot.
- {scatter, line}_kws : dictionaries, optional
- Additional keyword arguments passed to scatter() and plot() for drawing
- the components of the plot.
- ax : matplotlib axis, optional
- Plot into this axis, otherwise grab the current axis or make a new
- one if not existing.
- Returns
- -------
- ax: matplotlib axes
- Axes with the regression plot.
- See Also
- --------
- regplot : Plot a simple linear regression model.
- jointplot : Draw a :func:`residplot` with univariate marginal distributions
- (when used with ``kind="resid"``).
- """
- plotter = _RegressionPlotter(x, y, data, ci=None,
- order=order, robust=robust,
- x_partial=x_partial, y_partial=y_partial,
- dropna=dropna, color=color, label=label)
- if ax is None:
- ax = plt.gca()
- # Calculate the residual from a linear regression
- _, yhat, _ = plotter.fit_regression(grid=plotter.x)
- plotter.y = plotter.y - yhat
- # Set the regression option on the plotter
- if lowess:
- plotter.lowess = True
- else:
- plotter.fit_reg = False
- # Plot a horizontal line at 0
- ax.axhline(0, ls=":", c=".2")
- # Draw the scatterplot
- scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
- line_kws = {} if line_kws is None else line_kws.copy()
- plotter.plot(ax, scatter_kws, line_kws)
- return ax
|