1
0

timeseries.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454
  1. """Timeseries plotting functions."""
  2. from __future__ import division
  3. import numpy as np
  4. import pandas as pd
  5. from scipy import stats, interpolate
  6. import matplotlib as mpl
  7. import matplotlib.pyplot as plt
  8. import warnings
  9. from .external.six import string_types
  10. from . import utils
  11. from . import algorithms as algo
  12. from .palettes import color_palette
  13. __all__ = ["tsplot"]
  14. def tsplot(data, time=None, unit=None, condition=None, value=None,
  15. err_style="ci_band", ci=68, interpolate=True, color=None,
  16. estimator=np.mean, n_boot=5000, err_palette=None, err_kws=None,
  17. legend=True, ax=None, **kwargs):
  18. """Plot one or more timeseries with flexible representation of uncertainty.
  19. This function is intended to be used with data where observations are
  20. nested within sampling units that were measured at multiple timepoints.
  21. It can take data specified either as a long-form (tidy) DataFrame or as an
  22. ndarray with dimensions (unit, time) The interpretation of some of the
  23. other parameters changes depending on the type of object passed as data.
  24. Parameters
  25. ----------
  26. data : DataFrame or ndarray
  27. Data for the plot. Should either be a "long form" dataframe or an
  28. array with dimensions (unit, time, condition). In both cases, the
  29. condition field/dimension is optional. The type of this argument
  30. determines the interpretation of the next few parameters. When
  31. using a DataFrame, the index has to be sequential.
  32. time : string or series-like
  33. Either the name of the field corresponding to time in the data
  34. DataFrame or x values for a plot when data is an array. If a Series,
  35. the name will be used to label the x axis.
  36. unit : string
  37. Field in the data DataFrame identifying the sampling unit (e.g.
  38. subject, neuron, etc.). The error representation will collapse over
  39. units at each time/condition observation. This has no role when data
  40. is an array.
  41. value : string
  42. Either the name of the field corresponding to the data values in
  43. the data DataFrame (i.e. the y coordinate) or a string that forms
  44. the y axis label when data is an array.
  45. condition : string or Series-like
  46. Either the name of the field identifying the condition an observation
  47. falls under in the data DataFrame, or a sequence of names with a length
  48. equal to the size of the third dimension of data. There will be a
  49. separate trace plotted for each condition. If condition is a Series
  50. with a name attribute, the name will form the title for the plot
  51. legend (unless legend is set to False).
  52. err_style : string or list of strings or None
  53. Names of ways to plot uncertainty across units from set of
  54. {ci_band, ci_bars, boot_traces, boot_kde, unit_traces, unit_points}.
  55. Can use one or more than one method.
  56. ci : float or list of floats in [0, 100] or "sd" or None
  57. Confidence interval size(s). If a list, it will stack the error plots
  58. for each confidence interval. If ``"sd"``, show standard deviation of
  59. the observations instead of boostrapped confidence intervals. Only
  60. relevant for error styles with "ci" in the name.
  61. interpolate : boolean
  62. Whether to do a linear interpolation between each timepoint when
  63. plotting. The value of this parameter also determines the marker
  64. used for the main plot traces, unless marker is specified as a keyword
  65. argument.
  66. color : seaborn palette or matplotlib color name or dictionary
  67. Palette or color for the main plots and error representation (unless
  68. plotting by unit, which can be separately controlled with err_palette).
  69. If a dictionary, should map condition name to color spec.
  70. estimator : callable
  71. Function to determine central tendency and to pass to bootstrap
  72. must take an ``axis`` argument.
  73. n_boot : int
  74. Number of bootstrap iterations.
  75. err_palette : seaborn palette
  76. Palette name or list of colors used when plotting data for each unit.
  77. err_kws : dict, optional
  78. Keyword argument dictionary passed through to matplotlib function
  79. generating the error plot,
  80. legend : bool, optional
  81. If ``True`` and there is a ``condition`` variable, add a legend to
  82. the plot.
  83. ax : axis object, optional
  84. Plot in given axis; if None creates a new figure
  85. kwargs :
  86. Other keyword arguments are passed to main plot() call
  87. Returns
  88. -------
  89. ax : matplotlib axis
  90. axis with plot data
  91. Examples
  92. --------
  93. Plot a trace with translucent confidence bands:
  94. .. plot::
  95. :context: close-figs
  96. >>> import numpy as np; np.random.seed(22)
  97. >>> import seaborn as sns; sns.set(color_codes=True)
  98. >>> x = np.linspace(0, 15, 31)
  99. >>> data = np.sin(x) + np.random.rand(10, 31) + np.random.randn(10, 1)
  100. >>> ax = sns.tsplot(data=data)
  101. Plot a long-form dataframe with several conditions:
  102. .. plot::
  103. :context: close-figs
  104. >>> gammas = sns.load_dataset("gammas")
  105. >>> ax = sns.tsplot(time="timepoint", value="BOLD signal",
  106. ... unit="subject", condition="ROI",
  107. ... data=gammas)
  108. Use error bars at the positions of the observations:
  109. .. plot::
  110. :context: close-figs
  111. >>> ax = sns.tsplot(data=data, err_style="ci_bars", color="g")
  112. Don't interpolate between the observations:
  113. .. plot::
  114. :context: close-figs
  115. >>> import matplotlib.pyplot as plt
  116. >>> ax = sns.tsplot(data=data, err_style="ci_bars", interpolate=False)
  117. Show multiple confidence bands:
  118. .. plot::
  119. :context: close-figs
  120. >>> ax = sns.tsplot(data=data, ci=[68, 95], color="m")
  121. Show the standard deviation of the observations:
  122. .. plot::
  123. :context: close-figs
  124. >>> ax = sns.tsplot(data=data, ci="sd")
  125. Use a different estimator:
  126. .. plot::
  127. :context: close-figs
  128. >>> ax = sns.tsplot(data=data, estimator=np.median)
  129. Show each bootstrap resample:
  130. .. plot::
  131. :context: close-figs
  132. >>> ax = sns.tsplot(data=data, err_style="boot_traces", n_boot=500)
  133. Show the trace from each sampling unit:
  134. .. plot::
  135. :context: close-figs
  136. >>> ax = sns.tsplot(data=data, err_style="unit_traces")
  137. """
  138. msg = (
  139. "The `tsplot` function is deprecated and will be removed in a future "
  140. "release. Please update your code to use the new `lineplot` function."
  141. )
  142. warnings.warn(msg, UserWarning)
  143. # Sort out default values for the parameters
  144. if ax is None:
  145. ax = plt.gca()
  146. if err_kws is None:
  147. err_kws = {}
  148. # Handle different types of input data
  149. if isinstance(data, pd.DataFrame):
  150. xlabel = time
  151. ylabel = value
  152. # Condition is optional
  153. if condition is None:
  154. condition = pd.Series(1, index=data.index)
  155. legend = False
  156. legend_name = None
  157. n_cond = 1
  158. else:
  159. legend = True and legend
  160. legend_name = condition
  161. n_cond = len(data[condition].unique())
  162. else:
  163. data = np.asarray(data)
  164. # Data can be a timecourse from a single unit or
  165. # several observations in one condition
  166. if data.ndim == 1:
  167. data = data[np.newaxis, :, np.newaxis]
  168. elif data.ndim == 2:
  169. data = data[:, :, np.newaxis]
  170. n_unit, n_time, n_cond = data.shape
  171. # Units are experimental observations. Maybe subjects, or neurons
  172. if unit is None:
  173. units = np.arange(n_unit)
  174. unit = "unit"
  175. units = np.repeat(units, n_time * n_cond)
  176. ylabel = None
  177. # Time forms the xaxis of the plot
  178. if time is None:
  179. times = np.arange(n_time)
  180. else:
  181. times = np.asarray(time)
  182. xlabel = None
  183. if hasattr(time, "name"):
  184. xlabel = time.name
  185. time = "time"
  186. times = np.tile(np.repeat(times, n_cond), n_unit)
  187. # Conditions split the timeseries plots
  188. if condition is None:
  189. conds = range(n_cond)
  190. legend = False
  191. if isinstance(color, dict):
  192. err = "Must have condition names if using color dict."
  193. raise ValueError(err)
  194. else:
  195. conds = np.asarray(condition)
  196. legend = True and legend
  197. if hasattr(condition, "name"):
  198. legend_name = condition.name
  199. else:
  200. legend_name = None
  201. condition = "cond"
  202. conds = np.tile(conds, n_unit * n_time)
  203. # Value forms the y value in the plot
  204. if value is None:
  205. ylabel = None
  206. else:
  207. ylabel = value
  208. value = "value"
  209. # Convert to long-form DataFrame
  210. data = pd.DataFrame(dict(value=data.ravel(),
  211. time=times,
  212. unit=units,
  213. cond=conds))
  214. # Set up the err_style and ci arguments for the loop below
  215. if isinstance(err_style, string_types):
  216. err_style = [err_style]
  217. elif err_style is None:
  218. err_style = []
  219. if not hasattr(ci, "__iter__"):
  220. ci = [ci]
  221. # Set up the color palette
  222. if color is None:
  223. current_palette = utils.get_color_cycle()
  224. if len(current_palette) < n_cond:
  225. colors = color_palette("husl", n_cond)
  226. else:
  227. colors = color_palette(n_colors=n_cond)
  228. elif isinstance(color, dict):
  229. colors = [color[c] for c in data[condition].unique()]
  230. else:
  231. try:
  232. colors = color_palette(color, n_cond)
  233. except ValueError:
  234. color = mpl.colors.colorConverter.to_rgb(color)
  235. colors = [color] * n_cond
  236. # Do a groupby with condition and plot each trace
  237. c = None
  238. for c, (cond, df_c) in enumerate(data.groupby(condition, sort=False)):
  239. df_c = df_c.pivot(unit, time, value)
  240. x = df_c.columns.values.astype(np.float)
  241. # Bootstrap the data for confidence intervals
  242. if "sd" in ci:
  243. est = estimator(df_c.values, axis=0)
  244. sd = np.std(df_c.values, axis=0)
  245. cis = [(est - sd, est + sd)]
  246. boot_data = df_c.values
  247. else:
  248. boot_data = algo.bootstrap(df_c.values, n_boot=n_boot,
  249. axis=0, func=estimator)
  250. cis = [utils.ci(boot_data, v, axis=0) for v in ci]
  251. central_data = estimator(df_c.values, axis=0)
  252. # Get the color for this condition
  253. color = colors[c]
  254. # Use subroutines to plot the uncertainty
  255. for style in err_style:
  256. # Allow for null style (only plot central tendency)
  257. if style is None:
  258. continue
  259. # Grab the function from the global environment
  260. try:
  261. plot_func = globals()["_plot_%s" % style]
  262. except KeyError:
  263. raise ValueError("%s is not a valid err_style" % style)
  264. # Possibly set up to plot each observation in a different color
  265. if err_palette is not None and "unit" in style:
  266. orig_color = color
  267. color = color_palette(err_palette, len(df_c.values))
  268. # Pass all parameters to the error plotter as keyword args
  269. plot_kwargs = dict(ax=ax, x=x, data=df_c.values,
  270. boot_data=boot_data,
  271. central_data=central_data,
  272. color=color, err_kws=err_kws)
  273. # Plot the error representation, possibly for multiple cis
  274. for ci_i in cis:
  275. plot_kwargs["ci"] = ci_i
  276. plot_func(**plot_kwargs)
  277. if err_palette is not None and "unit" in style:
  278. color = orig_color
  279. # Plot the central trace
  280. kwargs.setdefault("marker", "" if interpolate else "o")
  281. ls = kwargs.pop("ls", "-" if interpolate else "")
  282. kwargs.setdefault("linestyle", ls)
  283. label = cond if legend else "_nolegend_"
  284. ax.plot(x, central_data, color=color, label=label, **kwargs)
  285. if c is None:
  286. raise RuntimeError("Invalid input data for tsplot.")
  287. # Pad the sides of the plot only when not interpolating
  288. ax.set_xlim(x.min(), x.max())
  289. x_diff = x[1] - x[0]
  290. if not interpolate:
  291. ax.set_xlim(x.min() - x_diff, x.max() + x_diff)
  292. # Add the plot labels
  293. if xlabel is not None:
  294. ax.set_xlabel(xlabel)
  295. if ylabel is not None:
  296. ax.set_ylabel(ylabel)
  297. if legend:
  298. ax.legend(loc=0, title=legend_name)
  299. return ax
  300. # Subroutines for tsplot errorbar plotting
  301. # ----------------------------------------
  302. def _plot_ci_band(ax, x, ci, color, err_kws, **kwargs):
  303. """Plot translucent error bands around the central tendancy."""
  304. low, high = ci
  305. if "alpha" not in err_kws:
  306. err_kws["alpha"] = 0.2
  307. ax.fill_between(x, low, high, facecolor=color, **err_kws)
  308. def _plot_ci_bars(ax, x, central_data, ci, color, err_kws, **kwargs):
  309. """Plot error bars at each data point."""
  310. for x_i, y_i, (low, high) in zip(x, central_data, ci.T):
  311. ax.plot([x_i, x_i], [low, high], color=color,
  312. solid_capstyle="round", **err_kws)
  313. def _plot_boot_traces(ax, x, boot_data, color, err_kws, **kwargs):
  314. """Plot 250 traces from bootstrap."""
  315. err_kws.setdefault("alpha", 0.25)
  316. err_kws.setdefault("linewidth", 0.25)
  317. if "lw" in err_kws:
  318. err_kws["linewidth"] = err_kws.pop("lw")
  319. ax.plot(x, boot_data.T, color=color, label="_nolegend_", **err_kws)
  320. def _plot_unit_traces(ax, x, data, ci, color, err_kws, **kwargs):
  321. """Plot a trace for each observation in the original data."""
  322. if isinstance(color, list):
  323. if "alpha" not in err_kws:
  324. err_kws["alpha"] = .5
  325. for i, obs in enumerate(data):
  326. ax.plot(x, obs, color=color[i], label="_nolegend_", **err_kws)
  327. else:
  328. if "alpha" not in err_kws:
  329. err_kws["alpha"] = .2
  330. ax.plot(x, data.T, color=color, label="_nolegend_", **err_kws)
  331. def _plot_unit_points(ax, x, data, color, err_kws, **kwargs):
  332. """Plot each original data point discretely."""
  333. if isinstance(color, list):
  334. for i, obs in enumerate(data):
  335. ax.plot(x, obs, "o", color=color[i], alpha=0.8, markersize=4,
  336. label="_nolegend_", **err_kws)
  337. else:
  338. ax.plot(x, data.T, "o", color=color, alpha=0.5, markersize=4,
  339. label="_nolegend_", **err_kws)
  340. def _plot_boot_kde(ax, x, boot_data, color, **kwargs):
  341. """Plot the kernal density estimate of the bootstrap distribution."""
  342. kwargs.pop("data")
  343. _ts_kde(ax, x, boot_data, color, **kwargs)
  344. def _plot_unit_kde(ax, x, data, color, **kwargs):
  345. """Plot the kernal density estimate over the sample."""
  346. _ts_kde(ax, x, data, color, **kwargs)
  347. def _ts_kde(ax, x, data, color, **kwargs):
  348. """Upsample over time and plot a KDE of the bootstrap distribution."""
  349. kde_data = []
  350. y_min, y_max = data.min(), data.max()
  351. y_vals = np.linspace(y_min, y_max, 100)
  352. upsampler = interpolate.interp1d(x, data)
  353. data_upsample = upsampler(np.linspace(x.min(), x.max(), 100))
  354. for pt_data in data_upsample.T:
  355. pt_kde = stats.kde.gaussian_kde(pt_data)
  356. kde_data.append(pt_kde(y_vals))
  357. kde_data = np.transpose(kde_data)
  358. rgb = mpl.colors.ColorConverter().to_rgb(color)
  359. img = np.zeros((kde_data.shape[0], kde_data.shape[1], 4))
  360. img[:, :, :3] = rgb
  361. kde_data /= kde_data.max(axis=0)
  362. kde_data[kde_data > 1] = 1
  363. img[:, :, 3] = kde_data
  364. ax.imshow(img, interpolation="spline16", zorder=2,
  365. extent=(x.min(), x.max(), y_min, y_max),
  366. aspect="auto", origin="lower")