123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454 |
- """Timeseries plotting functions."""
- from __future__ import division
- import numpy as np
- import pandas as pd
- from scipy import stats, interpolate
- import matplotlib as mpl
- import matplotlib.pyplot as plt
- import warnings
- from .external.six import string_types
- from . import utils
- from . import algorithms as algo
- from .palettes import color_palette
- __all__ = ["tsplot"]
- def tsplot(data, time=None, unit=None, condition=None, value=None,
- err_style="ci_band", ci=68, interpolate=True, color=None,
- estimator=np.mean, n_boot=5000, err_palette=None, err_kws=None,
- legend=True, ax=None, **kwargs):
- """Plot one or more timeseries with flexible representation of uncertainty.
- This function is intended to be used with data where observations are
- nested within sampling units that were measured at multiple timepoints.
- It can take data specified either as a long-form (tidy) DataFrame or as an
- ndarray with dimensions (unit, time) The interpretation of some of the
- other parameters changes depending on the type of object passed as data.
- Parameters
- ----------
- data : DataFrame or ndarray
- Data for the plot. Should either be a "long form" dataframe or an
- array with dimensions (unit, time, condition). In both cases, the
- condition field/dimension is optional. The type of this argument
- determines the interpretation of the next few parameters. When
- using a DataFrame, the index has to be sequential.
- time : string or series-like
- Either the name of the field corresponding to time in the data
- DataFrame or x values for a plot when data is an array. If a Series,
- the name will be used to label the x axis.
- unit : string
- Field in the data DataFrame identifying the sampling unit (e.g.
- subject, neuron, etc.). The error representation will collapse over
- units at each time/condition observation. This has no role when data
- is an array.
- value : string
- Either the name of the field corresponding to the data values in
- the data DataFrame (i.e. the y coordinate) or a string that forms
- the y axis label when data is an array.
- condition : string or Series-like
- Either the name of the field identifying the condition an observation
- falls under in the data DataFrame, or a sequence of names with a length
- equal to the size of the third dimension of data. There will be a
- separate trace plotted for each condition. If condition is a Series
- with a name attribute, the name will form the title for the plot
- legend (unless legend is set to False).
- err_style : string or list of strings or None
- Names of ways to plot uncertainty across units from set of
- {ci_band, ci_bars, boot_traces, boot_kde, unit_traces, unit_points}.
- Can use one or more than one method.
- ci : float or list of floats in [0, 100] or "sd" or None
- Confidence interval size(s). If a list, it will stack the error plots
- for each confidence interval. If ``"sd"``, show standard deviation of
- the observations instead of boostrapped confidence intervals. Only
- relevant for error styles with "ci" in the name.
- interpolate : boolean
- Whether to do a linear interpolation between each timepoint when
- plotting. The value of this parameter also determines the marker
- used for the main plot traces, unless marker is specified as a keyword
- argument.
- color : seaborn palette or matplotlib color name or dictionary
- Palette or color for the main plots and error representation (unless
- plotting by unit, which can be separately controlled with err_palette).
- If a dictionary, should map condition name to color spec.
- estimator : callable
- Function to determine central tendency and to pass to bootstrap
- must take an ``axis`` argument.
- n_boot : int
- Number of bootstrap iterations.
- err_palette : seaborn palette
- Palette name or list of colors used when plotting data for each unit.
- err_kws : dict, optional
- Keyword argument dictionary passed through to matplotlib function
- generating the error plot,
- legend : bool, optional
- If ``True`` and there is a ``condition`` variable, add a legend to
- the plot.
- ax : axis object, optional
- Plot in given axis; if None creates a new figure
- kwargs :
- Other keyword arguments are passed to main plot() call
- Returns
- -------
- ax : matplotlib axis
- axis with plot data
- Examples
- --------
- Plot a trace with translucent confidence bands:
- .. plot::
- :context: close-figs
- >>> import numpy as np; np.random.seed(22)
- >>> import seaborn as sns; sns.set(color_codes=True)
- >>> x = np.linspace(0, 15, 31)
- >>> data = np.sin(x) + np.random.rand(10, 31) + np.random.randn(10, 1)
- >>> ax = sns.tsplot(data=data)
- Plot a long-form dataframe with several conditions:
- .. plot::
- :context: close-figs
- >>> gammas = sns.load_dataset("gammas")
- >>> ax = sns.tsplot(time="timepoint", value="BOLD signal",
- ... unit="subject", condition="ROI",
- ... data=gammas)
- Use error bars at the positions of the observations:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, err_style="ci_bars", color="g")
- Don't interpolate between the observations:
- .. plot::
- :context: close-figs
- >>> import matplotlib.pyplot as plt
- >>> ax = sns.tsplot(data=data, err_style="ci_bars", interpolate=False)
- Show multiple confidence bands:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, ci=[68, 95], color="m")
- Show the standard deviation of the observations:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, ci="sd")
- Use a different estimator:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, estimator=np.median)
- Show each bootstrap resample:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, err_style="boot_traces", n_boot=500)
- Show the trace from each sampling unit:
- .. plot::
- :context: close-figs
- >>> ax = sns.tsplot(data=data, err_style="unit_traces")
- """
- msg = (
- "The `tsplot` function is deprecated and will be removed in a future "
- "release. Please update your code to use the new `lineplot` function."
- )
- warnings.warn(msg, UserWarning)
- # Sort out default values for the parameters
- if ax is None:
- ax = plt.gca()
- if err_kws is None:
- err_kws = {}
- # Handle different types of input data
- if isinstance(data, pd.DataFrame):
- xlabel = time
- ylabel = value
- # Condition is optional
- if condition is None:
- condition = pd.Series(1, index=data.index)
- legend = False
- legend_name = None
- n_cond = 1
- else:
- legend = True and legend
- legend_name = condition
- n_cond = len(data[condition].unique())
- else:
- data = np.asarray(data)
- # Data can be a timecourse from a single unit or
- # several observations in one condition
- if data.ndim == 1:
- data = data[np.newaxis, :, np.newaxis]
- elif data.ndim == 2:
- data = data[:, :, np.newaxis]
- n_unit, n_time, n_cond = data.shape
- # Units are experimental observations. Maybe subjects, or neurons
- if unit is None:
- units = np.arange(n_unit)
- unit = "unit"
- units = np.repeat(units, n_time * n_cond)
- ylabel = None
- # Time forms the xaxis of the plot
- if time is None:
- times = np.arange(n_time)
- else:
- times = np.asarray(time)
- xlabel = None
- if hasattr(time, "name"):
- xlabel = time.name
- time = "time"
- times = np.tile(np.repeat(times, n_cond), n_unit)
- # Conditions split the timeseries plots
- if condition is None:
- conds = range(n_cond)
- legend = False
- if isinstance(color, dict):
- err = "Must have condition names if using color dict."
- raise ValueError(err)
- else:
- conds = np.asarray(condition)
- legend = True and legend
- if hasattr(condition, "name"):
- legend_name = condition.name
- else:
- legend_name = None
- condition = "cond"
- conds = np.tile(conds, n_unit * n_time)
- # Value forms the y value in the plot
- if value is None:
- ylabel = None
- else:
- ylabel = value
- value = "value"
- # Convert to long-form DataFrame
- data = pd.DataFrame(dict(value=data.ravel(),
- time=times,
- unit=units,
- cond=conds))
- # Set up the err_style and ci arguments for the loop below
- if isinstance(err_style, string_types):
- err_style = [err_style]
- elif err_style is None:
- err_style = []
- if not hasattr(ci, "__iter__"):
- ci = [ci]
- # Set up the color palette
- if color is None:
- current_palette = utils.get_color_cycle()
- if len(current_palette) < n_cond:
- colors = color_palette("husl", n_cond)
- else:
- colors = color_palette(n_colors=n_cond)
- elif isinstance(color, dict):
- colors = [color[c] for c in data[condition].unique()]
- else:
- try:
- colors = color_palette(color, n_cond)
- except ValueError:
- color = mpl.colors.colorConverter.to_rgb(color)
- colors = [color] * n_cond
- # Do a groupby with condition and plot each trace
- c = None
- for c, (cond, df_c) in enumerate(data.groupby(condition, sort=False)):
- df_c = df_c.pivot(unit, time, value)
- x = df_c.columns.values.astype(np.float)
- # Bootstrap the data for confidence intervals
- if "sd" in ci:
- est = estimator(df_c.values, axis=0)
- sd = np.std(df_c.values, axis=0)
- cis = [(est - sd, est + sd)]
- boot_data = df_c.values
- else:
- boot_data = algo.bootstrap(df_c.values, n_boot=n_boot,
- axis=0, func=estimator)
- cis = [utils.ci(boot_data, v, axis=0) for v in ci]
- central_data = estimator(df_c.values, axis=0)
- # Get the color for this condition
- color = colors[c]
- # Use subroutines to plot the uncertainty
- for style in err_style:
- # Allow for null style (only plot central tendency)
- if style is None:
- continue
- # Grab the function from the global environment
- try:
- plot_func = globals()["_plot_%s" % style]
- except KeyError:
- raise ValueError("%s is not a valid err_style" % style)
- # Possibly set up to plot each observation in a different color
- if err_palette is not None and "unit" in style:
- orig_color = color
- color = color_palette(err_palette, len(df_c.values))
- # Pass all parameters to the error plotter as keyword args
- plot_kwargs = dict(ax=ax, x=x, data=df_c.values,
- boot_data=boot_data,
- central_data=central_data,
- color=color, err_kws=err_kws)
- # Plot the error representation, possibly for multiple cis
- for ci_i in cis:
- plot_kwargs["ci"] = ci_i
- plot_func(**plot_kwargs)
- if err_palette is not None and "unit" in style:
- color = orig_color
- # Plot the central trace
- kwargs.setdefault("marker", "" if interpolate else "o")
- ls = kwargs.pop("ls", "-" if interpolate else "")
- kwargs.setdefault("linestyle", ls)
- label = cond if legend else "_nolegend_"
- ax.plot(x, central_data, color=color, label=label, **kwargs)
- if c is None:
- raise RuntimeError("Invalid input data for tsplot.")
- # Pad the sides of the plot only when not interpolating
- ax.set_xlim(x.min(), x.max())
- x_diff = x[1] - x[0]
- if not interpolate:
- ax.set_xlim(x.min() - x_diff, x.max() + x_diff)
- # Add the plot labels
- if xlabel is not None:
- ax.set_xlabel(xlabel)
- if ylabel is not None:
- ax.set_ylabel(ylabel)
- if legend:
- ax.legend(loc=0, title=legend_name)
- return ax
- # Subroutines for tsplot errorbar plotting
- # ----------------------------------------
- def _plot_ci_band(ax, x, ci, color, err_kws, **kwargs):
- """Plot translucent error bands around the central tendancy."""
- low, high = ci
- if "alpha" not in err_kws:
- err_kws["alpha"] = 0.2
- ax.fill_between(x, low, high, facecolor=color, **err_kws)
- def _plot_ci_bars(ax, x, central_data, ci, color, err_kws, **kwargs):
- """Plot error bars at each data point."""
- for x_i, y_i, (low, high) in zip(x, central_data, ci.T):
- ax.plot([x_i, x_i], [low, high], color=color,
- solid_capstyle="round", **err_kws)
- def _plot_boot_traces(ax, x, boot_data, color, err_kws, **kwargs):
- """Plot 250 traces from bootstrap."""
- err_kws.setdefault("alpha", 0.25)
- err_kws.setdefault("linewidth", 0.25)
- if "lw" in err_kws:
- err_kws["linewidth"] = err_kws.pop("lw")
- ax.plot(x, boot_data.T, color=color, label="_nolegend_", **err_kws)
- def _plot_unit_traces(ax, x, data, ci, color, err_kws, **kwargs):
- """Plot a trace for each observation in the original data."""
- if isinstance(color, list):
- if "alpha" not in err_kws:
- err_kws["alpha"] = .5
- for i, obs in enumerate(data):
- ax.plot(x, obs, color=color[i], label="_nolegend_", **err_kws)
- else:
- if "alpha" not in err_kws:
- err_kws["alpha"] = .2
- ax.plot(x, data.T, color=color, label="_nolegend_", **err_kws)
- def _plot_unit_points(ax, x, data, color, err_kws, **kwargs):
- """Plot each original data point discretely."""
- if isinstance(color, list):
- for i, obs in enumerate(data):
- ax.plot(x, obs, "o", color=color[i], alpha=0.8, markersize=4,
- label="_nolegend_", **err_kws)
- else:
- ax.plot(x, data.T, "o", color=color, alpha=0.5, markersize=4,
- label="_nolegend_", **err_kws)
- def _plot_boot_kde(ax, x, boot_data, color, **kwargs):
- """Plot the kernal density estimate of the bootstrap distribution."""
- kwargs.pop("data")
- _ts_kde(ax, x, boot_data, color, **kwargs)
- def _plot_unit_kde(ax, x, data, color, **kwargs):
- """Plot the kernal density estimate over the sample."""
- _ts_kde(ax, x, data, color, **kwargs)
- def _ts_kde(ax, x, data, color, **kwargs):
- """Upsample over time and plot a KDE of the bootstrap distribution."""
- kde_data = []
- y_min, y_max = data.min(), data.max()
- y_vals = np.linspace(y_min, y_max, 100)
- upsampler = interpolate.interp1d(x, data)
- data_upsample = upsampler(np.linspace(x.min(), x.max(), 100))
- for pt_data in data_upsample.T:
- pt_kde = stats.kde.gaussian_kde(pt_data)
- kde_data.append(pt_kde(y_vals))
- kde_data = np.transpose(kde_data)
- rgb = mpl.colors.ColorConverter().to_rgb(color)
- img = np.zeros((kde_data.shape[0], kde_data.shape[1], 4))
- img[:, :, :3] = rgb
- kde_data /= kde_data.max(axis=0)
- kde_data[kde_data > 1] = 1
- img[:, :, 3] = kde_data
- ax.imshow(img, interpolation="spline16", zorder=2,
- extent=(x.min(), x.max(), y_min, y_max),
- aspect="auto", origin="lower")
|