1
0

relational.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845
  1. from itertools import product
  2. from textwrap import dedent
  3. import warnings
  4. import numpy as np
  5. import pandas as pd
  6. import matplotlib as mpl
  7. import matplotlib.pyplot as plt
  8. from . import utils
  9. from .utils import (categorical_order, get_color_cycle, ci_to_errsize,
  10. remove_na, locator_to_legend_entries)
  11. from .algorithms import bootstrap
  12. from .palettes import (color_palette, cubehelix_palette,
  13. _parse_cubehelix_args, QUAL_PALETTES)
  14. from .axisgrid import FacetGrid, _facet_docs
  15. __all__ = ["relplot", "scatterplot", "lineplot"]
  16. class _RelationalPlotter(object):
  17. default_markers = ["o", "X", "s", "P", "D", "^", "v", "p"]
  18. default_dashes = ["", (4, 1.5), (1, 1),
  19. (3, 1, 1.5, 1), (5, 1, 1, 1),
  20. (5, 1, 2, 1, 2, 1)]
  21. def establish_variables(self, x=None, y=None,
  22. hue=None, size=None, style=None,
  23. units=None, data=None):
  24. """Parse the inputs to define data for plotting."""
  25. # Initialize label variables
  26. x_label = y_label = hue_label = size_label = style_label = None
  27. # Option 1:
  28. # We have a wide-form datast
  29. # --------------------------
  30. if x is None and y is None:
  31. self.input_format = "wide"
  32. # Option 1a:
  33. # The input data is a Pandas DataFrame
  34. # ------------------------------------
  35. # We will assign the index to x, the values to y,
  36. # and the columns names to both hue and style
  37. # TODO accept a dict and try to coerce to a dataframe?
  38. if isinstance(data, pd.DataFrame):
  39. # Enforce numeric values
  40. try:
  41. data.astype(np.float)
  42. except ValueError:
  43. err = "A wide-form input must have only numeric values."
  44. raise ValueError(err)
  45. plot_data = data.copy()
  46. plot_data.loc[:, "x"] = data.index
  47. plot_data = pd.melt(plot_data, "x",
  48. var_name="hue", value_name="y")
  49. plot_data["style"] = plot_data["hue"]
  50. x_label = getattr(data.index, "name", None)
  51. hue_label = style_label = getattr(plot_data.columns,
  52. "name", None)
  53. # Option 1b:
  54. # The input data is an array or list
  55. # ----------------------------------
  56. else:
  57. if not len(data):
  58. plot_data = pd.DataFrame(columns=["x", "y"])
  59. elif np.isscalar(np.asarray(data)[0]):
  60. # The input data is a flat list(like):
  61. # We assign a numeric index for x and use the values for y
  62. x = getattr(data, "index", np.arange(len(data)))
  63. plot_data = pd.DataFrame(dict(x=x, y=data))
  64. elif hasattr(data, "shape"):
  65. # The input data is an array(like):
  66. # We either use the index or assign a numeric index to x,
  67. # the values to y, and id keys to both hue and style
  68. plot_data = pd.DataFrame(data)
  69. plot_data.loc[:, "x"] = plot_data.index
  70. plot_data = pd.melt(plot_data, "x",
  71. var_name="hue",
  72. value_name="y")
  73. plot_data["style"] = plot_data["hue"]
  74. else:
  75. # The input data is a nested list: We will either use the
  76. # index or assign a numeric index for x, use the values
  77. # for y, and use numeric hue/style identifiers.
  78. plot_data = []
  79. for i, data_i in enumerate(data):
  80. x = getattr(data_i, "index", np.arange(len(data_i)))
  81. n = getattr(data_i, "name", i)
  82. data_i = dict(x=x, y=data_i, hue=n, style=n, size=None)
  83. plot_data.append(pd.DataFrame(data_i))
  84. plot_data = pd.concat(plot_data)
  85. # Option 2:
  86. # We have long-form data
  87. # ----------------------
  88. elif x is not None and y is not None:
  89. self.input_format = "long"
  90. # Use variables as from the dataframe if specified
  91. if data is not None:
  92. x = data.get(x, x)
  93. y = data.get(y, y)
  94. hue = data.get(hue, hue)
  95. size = data.get(size, size)
  96. style = data.get(style, style)
  97. units = data.get(units, units)
  98. # Validate the inputs
  99. for var in [x, y, hue, size, style, units]:
  100. if isinstance(var, str):
  101. err = "Could not interpret input '{}'".format(var)
  102. raise ValueError(err)
  103. # Extract variable names
  104. x_label = getattr(x, "name", None)
  105. y_label = getattr(y, "name", None)
  106. hue_label = getattr(hue, "name", None)
  107. size_label = getattr(size, "name", None)
  108. style_label = getattr(style, "name", None)
  109. # Reassemble into a DataFrame
  110. plot_data = dict(
  111. x=x, y=y,
  112. hue=hue, style=style, size=size,
  113. units=units
  114. )
  115. plot_data = pd.DataFrame(plot_data)
  116. # Option 3:
  117. # Only one variable argument
  118. # --------------------------
  119. else:
  120. err = ("Either both or neither of `x` and `y` must be specified "
  121. "(but try passing to `data`, which is more flexible).")
  122. raise ValueError(err)
  123. # ---- Post-processing
  124. # Assign default values for missing attribute variables
  125. for attr in ["hue", "style", "size", "units"]:
  126. if attr not in plot_data:
  127. plot_data[attr] = None
  128. # Determine which semantics have (some) data
  129. plot_valid = plot_data.notnull().any()
  130. semantics = ["x", "y"] + [
  131. name for name in ["hue", "size", "style"]
  132. if plot_valid[name]
  133. ]
  134. self.x_label = x_label
  135. self.y_label = y_label
  136. self.hue_label = hue_label
  137. self.size_label = size_label
  138. self.style_label = style_label
  139. self.plot_data = plot_data
  140. self.semantics = semantics
  141. return plot_data
  142. def categorical_to_palette(self, data, order, palette):
  143. """Determine colors when the hue variable is qualitative."""
  144. # -- Identify the order and name of the levels
  145. if order is None:
  146. levels = categorical_order(data)
  147. else:
  148. levels = order
  149. n_colors = len(levels)
  150. # -- Identify the set of colors to use
  151. if isinstance(palette, dict):
  152. missing = set(levels) - set(palette)
  153. if any(missing):
  154. err = "The palette dictionary is missing keys: {}"
  155. raise ValueError(err.format(missing))
  156. else:
  157. if palette is None:
  158. if n_colors <= len(get_color_cycle()):
  159. colors = color_palette(None, n_colors)
  160. else:
  161. colors = color_palette("husl", n_colors)
  162. elif isinstance(palette, list):
  163. if len(palette) != n_colors:
  164. err = "The palette list has the wrong number of colors."
  165. raise ValueError(err)
  166. colors = palette
  167. else:
  168. colors = color_palette(palette, n_colors)
  169. palette = dict(zip(levels, colors))
  170. return levels, palette
  171. def numeric_to_palette(self, data, order, palette, norm):
  172. """Determine colors when the hue variable is quantitative."""
  173. levels = list(np.sort(remove_na(data.unique())))
  174. # TODO do we want to do something complicated to ensure contrast
  175. # at the extremes of the colormap against the background?
  176. # Identify the colormap to use
  177. palette = "ch:" if palette is None else palette
  178. if isinstance(palette, mpl.colors.Colormap):
  179. cmap = palette
  180. elif str(palette).startswith("ch:"):
  181. args, kwargs = _parse_cubehelix_args(palette)
  182. cmap = cubehelix_palette(0, *args, as_cmap=True, **kwargs)
  183. elif isinstance(palette, dict):
  184. colors = [palette[k] for k in sorted(palette)]
  185. cmap = mpl.colors.ListedColormap(colors)
  186. else:
  187. try:
  188. cmap = mpl.cm.get_cmap(palette)
  189. except (ValueError, TypeError):
  190. err = "Palette {} not understood"
  191. raise ValueError(err)
  192. if norm is None:
  193. norm = mpl.colors.Normalize()
  194. elif isinstance(norm, tuple):
  195. norm = mpl.colors.Normalize(*norm)
  196. elif not isinstance(norm, mpl.colors.Normalize):
  197. err = "``hue_norm`` must be None, tuple, or Normalize object."
  198. raise ValueError(err)
  199. if not norm.scaled():
  200. norm(np.asarray(data.dropna()))
  201. # TODO this should also use color_lookup, but that needs the
  202. # class attributes that get set after using this function...
  203. if not isinstance(palette, dict):
  204. palette = dict(zip(levels, cmap(norm(levels))))
  205. # palette = {l: cmap(norm([l, 1]))[0] for l in levels}
  206. return levels, palette, cmap, norm
  207. def color_lookup(self, key):
  208. """Return the color corresponding to the hue level."""
  209. if self.hue_type == "numeric":
  210. normed = self.hue_norm(key)
  211. if np.ma.is_masked(normed):
  212. normed = np.nan
  213. return self.cmap(normed)
  214. elif self.hue_type == "categorical":
  215. return self.palette[key]
  216. def size_lookup(self, key):
  217. """Return the size corresponding to the size level."""
  218. if self.size_type == "numeric":
  219. min_size, max_size = self.size_range
  220. val = self.size_norm(key)
  221. if np.ma.is_masked(val):
  222. return 0
  223. return min_size + val * (max_size - min_size)
  224. elif self.size_type == "categorical":
  225. return self.sizes[key]
  226. def style_to_attributes(self, levels, style, defaults, name):
  227. """Convert a style argument to a dict of matplotlib attributes."""
  228. if style is True:
  229. attrdict = dict(zip(levels, defaults))
  230. elif style and isinstance(style, dict):
  231. attrdict = style
  232. elif style:
  233. attrdict = dict(zip(levels, style))
  234. else:
  235. attrdict = {}
  236. if attrdict:
  237. missing_levels = set(levels) - set(attrdict)
  238. if any(missing_levels):
  239. err = "These `style` levels are missing {}: {}"
  240. raise ValueError(err.format(name, missing_levels))
  241. return attrdict
  242. def subset_data(self):
  243. """Return (x, y) data for each subset defined by semantics."""
  244. data = self.plot_data
  245. all_true = pd.Series(True, data.index)
  246. iter_levels = product(self.hue_levels,
  247. self.size_levels,
  248. self.style_levels)
  249. for hue, size, style in iter_levels:
  250. hue_rows = all_true if hue is None else data["hue"] == hue
  251. size_rows = all_true if size is None else data["size"] == size
  252. style_rows = all_true if style is None else data["style"] == style
  253. rows = hue_rows & size_rows & style_rows
  254. data["units"] = data.units.fillna("")
  255. subset_data = data.loc[rows, ["units", "x", "y"]].dropna()
  256. if not len(subset_data):
  257. continue
  258. if self.sort:
  259. subset_data = subset_data.sort_values(["units", "x", "y"])
  260. if self.units is None:
  261. subset_data = subset_data.drop("units", axis=1)
  262. yield (hue, size, style), subset_data
  263. def parse_hue(self, data, palette, order, norm):
  264. """Determine what colors to use given data characteristics."""
  265. if self._empty_data(data):
  266. # Set default values when not using a hue mapping
  267. levels = [None]
  268. limits = None
  269. norm = None
  270. palette = {}
  271. var_type = None
  272. cmap = None
  273. else:
  274. # Determine what kind of hue mapping we want
  275. var_type = self._semantic_type(data)
  276. # Override depending on the type of the palette argument
  277. if palette in QUAL_PALETTES:
  278. var_type = "categorical"
  279. elif norm is not None:
  280. var_type = "numeric"
  281. elif isinstance(palette, (dict, list)):
  282. var_type = "categorical"
  283. # -- Option 1: categorical color palette
  284. if var_type == "categorical":
  285. cmap = None
  286. limits = None
  287. levels, palette = self.categorical_to_palette(
  288. # List comprehension here is required to
  289. # overcome differences in the way pandas
  290. # externalizes numpy datetime64
  291. list(data), order, palette
  292. )
  293. # -- Option 2: sequential color palette
  294. elif var_type == "numeric":
  295. data = pd.to_numeric(data)
  296. levels, palette, cmap, norm = self.numeric_to_palette(
  297. data, order, palette, norm
  298. )
  299. limits = norm.vmin, norm.vmax
  300. self.hue_levels = levels
  301. self.hue_norm = norm
  302. self.hue_limits = limits
  303. self.hue_type = var_type
  304. self.palette = palette
  305. self.cmap = cmap
  306. # Update data as it may have changed dtype
  307. self.plot_data["hue"] = data
  308. def parse_size(self, data, sizes, order, norm):
  309. """Determine the linewidths given data characteristics."""
  310. # TODO could break out two options like parse_hue does for clarity
  311. if self._empty_data(data):
  312. levels = [None]
  313. limits = None
  314. norm = None
  315. sizes = {}
  316. var_type = None
  317. width_range = None
  318. else:
  319. var_type = self._semantic_type(data)
  320. # Override depending on the type of the sizes argument
  321. if norm is not None:
  322. var_type = "numeric"
  323. elif isinstance(sizes, (dict, list)):
  324. var_type = "categorical"
  325. if var_type == "categorical":
  326. levels = categorical_order(data, order)
  327. numbers = np.arange(1, 1 + len(levels))[::-1]
  328. elif var_type == "numeric":
  329. data = pd.to_numeric(data)
  330. levels = numbers = np.sort(remove_na(data.unique()))
  331. if isinstance(sizes, (dict, list)):
  332. # Use literal size values
  333. if isinstance(sizes, list):
  334. if len(sizes) != len(levels):
  335. err = "The `sizes` list has wrong number of levels"
  336. raise ValueError(err)
  337. sizes = dict(zip(levels, sizes))
  338. missing = set(levels) - set(sizes)
  339. if any(missing):
  340. err = "Missing sizes for the following levels: {}"
  341. raise ValueError(err.format(missing))
  342. width_range = min(sizes.values()), max(sizes.values())
  343. try:
  344. limits = min(sizes.keys()), max(sizes.keys())
  345. except TypeError:
  346. limits = None
  347. else:
  348. # Infer the range of sizes to use
  349. if sizes is None:
  350. min_width, max_width = self._default_size_range
  351. else:
  352. try:
  353. min_width, max_width = sizes
  354. except (TypeError, ValueError):
  355. err = "sizes argument {} not understood".format(sizes)
  356. raise ValueError(err)
  357. width_range = min_width, max_width
  358. if norm is None:
  359. norm = mpl.colors.Normalize()
  360. elif isinstance(norm, tuple):
  361. norm = mpl.colors.Normalize(*norm)
  362. elif not isinstance(norm, mpl.colors.Normalize):
  363. err = ("``size_norm`` must be None, tuple, "
  364. "or Normalize object.")
  365. raise ValueError(err)
  366. norm.clip = True
  367. if not norm.scaled():
  368. norm(np.asarray(numbers))
  369. limits = norm.vmin, norm.vmax
  370. scl = norm(numbers)
  371. widths = np.asarray(min_width + scl * (max_width - min_width))
  372. if scl.mask.any():
  373. widths[scl.mask] = 0
  374. sizes = dict(zip(levels, widths))
  375. # sizes = {l: min_width + norm(n) * (max_width - min_width)
  376. # for l, n in zip(levels, numbers)}
  377. if var_type == "categorical":
  378. # Don't keep a reference to the norm, which will avoid
  379. # downstream code from switching to numerical interpretation
  380. norm = None
  381. self.sizes = sizes
  382. self.size_type = var_type
  383. self.size_levels = levels
  384. self.size_norm = norm
  385. self.size_limits = limits
  386. self.size_range = width_range
  387. # Update data as it may have changed dtype
  388. self.plot_data["size"] = data
  389. def parse_style(self, data, markers, dashes, order):
  390. """Determine the markers and line dashes."""
  391. if self._empty_data(data):
  392. levels = [None]
  393. dashes = {}
  394. markers = {}
  395. else:
  396. if order is None:
  397. # List comprehension here is required to
  398. # overcome differences in the way pandas
  399. # coerces numpy datatypes
  400. levels = categorical_order(list(data))
  401. else:
  402. levels = order
  403. markers = self.style_to_attributes(
  404. levels, markers, self.default_markers, "markers"
  405. )
  406. dashes = self.style_to_attributes(
  407. levels, dashes, self.default_dashes, "dashes"
  408. )
  409. paths = {}
  410. filled_markers = []
  411. for k, m in markers.items():
  412. if not isinstance(m, mpl.markers.MarkerStyle):
  413. m = mpl.markers.MarkerStyle(m)
  414. paths[k] = m.get_path().transformed(m.get_transform())
  415. filled_markers.append(m.is_filled())
  416. # Mixture of filled and unfilled markers will show line art markers
  417. # in the edge color, which defaults to white. This can be handled,
  418. # but there would be additional complexity with specifying the
  419. # weight of the line art markers without overwhelming the filled
  420. # ones with the edges. So for now, we will disallow mixtures.
  421. if any(filled_markers) and not all(filled_markers):
  422. err = "Filled and line art markers cannot be mixed"
  423. raise ValueError(err)
  424. self.style_levels = levels
  425. self.dashes = dashes
  426. self.markers = markers
  427. self.paths = paths
  428. def _empty_data(self, data):
  429. """Test if a series is completely missing."""
  430. return data.isnull().all()
  431. def _semantic_type(self, data):
  432. """Determine if data should considered numeric or categorical."""
  433. if self.input_format == "wide":
  434. return "categorical"
  435. elif isinstance(data, pd.Series) and data.dtype.name == "category":
  436. return "categorical"
  437. else:
  438. try:
  439. float_data = data.astype(np.float)
  440. values = np.unique(float_data.dropna())
  441. # TODO replace with isin when pinned np version >= 1.13
  442. if np.all(np.in1d(values, np.array([0., 1.]))):
  443. return "categorical"
  444. return "numeric"
  445. except (ValueError, TypeError):
  446. return "categorical"
  447. def label_axes(self, ax):
  448. """Set x and y labels with visibility that matches the ticklabels."""
  449. if self.x_label is not None:
  450. x_visible = any(t.get_visible() for t in ax.get_xticklabels())
  451. ax.set_xlabel(self.x_label, visible=x_visible)
  452. if self.y_label is not None:
  453. y_visible = any(t.get_visible() for t in ax.get_yticklabels())
  454. ax.set_ylabel(self.y_label, visible=y_visible)
  455. def add_legend_data(self, ax):
  456. """Add labeled artists to represent the different plot semantics."""
  457. verbosity = self.legend
  458. if verbosity not in ["brief", "full"]:
  459. err = "`legend` must be 'brief', 'full', or False"
  460. raise ValueError(err)
  461. legend_kwargs = {}
  462. keys = []
  463. title_kws = dict(color="w", s=0, linewidth=0, marker="", dashes="")
  464. def update(var_name, val_name, **kws):
  465. key = var_name, val_name
  466. if key in legend_kwargs:
  467. legend_kwargs[key].update(**kws)
  468. else:
  469. keys.append(key)
  470. legend_kwargs[key] = dict(**kws)
  471. # -- Add a legend for hue semantics
  472. if verbosity == "brief" and self.hue_type == "numeric":
  473. if isinstance(self.hue_norm, mpl.colors.LogNorm):
  474. locator = mpl.ticker.LogLocator(numticks=3)
  475. else:
  476. locator = mpl.ticker.MaxNLocator(nbins=3)
  477. hue_levels, hue_formatted_levels = locator_to_legend_entries(
  478. locator, self.hue_limits, self.plot_data["hue"].dtype
  479. )
  480. else:
  481. hue_levels = hue_formatted_levels = self.hue_levels
  482. # Add the hue semantic subtitle
  483. if self.hue_label is not None:
  484. update((self.hue_label, "title"), self.hue_label, **title_kws)
  485. # Add the hue semantic labels
  486. for level, formatted_level in zip(hue_levels, hue_formatted_levels):
  487. if level is not None:
  488. color = self.color_lookup(level)
  489. update(self.hue_label, formatted_level, color=color)
  490. # -- Add a legend for size semantics
  491. if verbosity == "brief" and self.size_type == "numeric":
  492. if isinstance(self.size_norm, mpl.colors.LogNorm):
  493. locator = mpl.ticker.LogLocator(numticks=3)
  494. else:
  495. locator = mpl.ticker.MaxNLocator(nbins=3)
  496. size_levels, size_formatted_levels = locator_to_legend_entries(
  497. locator, self.size_limits, self.plot_data["size"].dtype)
  498. else:
  499. size_levels = size_formatted_levels = self.size_levels
  500. # Add the size semantic subtitle
  501. if self.size_label is not None:
  502. update((self.size_label, "title"), self.size_label, **title_kws)
  503. # Add the size semantic labels
  504. for level, formatted_level in zip(size_levels, size_formatted_levels):
  505. if level is not None:
  506. size = self.size_lookup(level)
  507. update(
  508. self.size_label, formatted_level, linewidth=size, s=size)
  509. # -- Add a legend for style semantics
  510. # Add the style semantic title
  511. if self.style_label is not None:
  512. update((self.style_label, "title"), self.style_label, **title_kws)
  513. # Add the style semantic labels
  514. for level in self.style_levels:
  515. if level is not None:
  516. update(self.style_label, level,
  517. marker=self.markers.get(level, ""),
  518. dashes=self.dashes.get(level, ""))
  519. func = getattr(ax, self._legend_func)
  520. legend_data = {}
  521. legend_order = []
  522. for key in keys:
  523. _, label = key
  524. kws = legend_kwargs[key]
  525. kws.setdefault("color", ".2")
  526. use_kws = {}
  527. for attr in self._legend_attributes + ["visible"]:
  528. if attr in kws:
  529. use_kws[attr] = kws[attr]
  530. artist = func([], [], label=label, **use_kws)
  531. if self._legend_func == "plot":
  532. artist = artist[0]
  533. legend_data[key] = artist
  534. legend_order.append(key)
  535. self.legend_data = legend_data
  536. self.legend_order = legend_order
  537. class _LinePlotter(_RelationalPlotter):
  538. _legend_attributes = ["color", "linewidth", "marker", "dashes"]
  539. _legend_func = "plot"
  540. def __init__(self,
  541. x=None, y=None, hue=None, size=None, style=None, data=None,
  542. palette=None, hue_order=None, hue_norm=None,
  543. sizes=None, size_order=None, size_norm=None,
  544. dashes=None, markers=None, style_order=None,
  545. units=None, estimator=None, ci=None, n_boot=None, seed=None,
  546. sort=True, err_style=None, err_kws=None, legend=None):
  547. plot_data = self.establish_variables(
  548. x, y, hue, size, style, units, data
  549. )
  550. self._default_size_range = (
  551. np.r_[.5, 2] * mpl.rcParams["lines.linewidth"]
  552. )
  553. self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm)
  554. self.parse_size(plot_data["size"], sizes, size_order, size_norm)
  555. self.parse_style(plot_data["style"], markers, dashes, style_order)
  556. self.units = units
  557. self.estimator = estimator
  558. self.ci = ci
  559. self.n_boot = n_boot
  560. self.seed = seed
  561. self.sort = sort
  562. self.err_style = err_style
  563. self.err_kws = {} if err_kws is None else err_kws
  564. self.legend = legend
  565. def aggregate(self, vals, grouper, units=None):
  566. """Compute an estimate and confidence interval using grouper."""
  567. func = self.estimator
  568. ci = self.ci
  569. n_boot = self.n_boot
  570. seed = self.seed
  571. # Define a "null" CI for when we only have one value
  572. null_ci = pd.Series(index=["low", "high"], dtype=np.float)
  573. # Function to bootstrap in the context of a pandas group by
  574. def bootstrapped_cis(vals):
  575. if len(vals) <= 1:
  576. return null_ci
  577. boots = bootstrap(vals, func=func, n_boot=n_boot, seed=seed)
  578. cis = utils.ci(boots, ci)
  579. return pd.Series(cis, ["low", "high"])
  580. # Group and get the aggregation estimate
  581. grouped = vals.groupby(grouper, sort=self.sort)
  582. est = grouped.agg(func)
  583. # Exit early if we don't want a confidence interval
  584. if ci is None:
  585. return est.index, est, None
  586. # Compute the error bar extents
  587. if ci == "sd":
  588. sd = grouped.std()
  589. cis = pd.DataFrame(np.c_[est - sd, est + sd],
  590. index=est.index,
  591. columns=["low", "high"]).stack()
  592. else:
  593. cis = grouped.apply(bootstrapped_cis)
  594. # Unpack the CIs into "wide" format for plotting
  595. if cis.notnull().any():
  596. cis = cis.unstack().reindex(est.index)
  597. else:
  598. cis = None
  599. return est.index, est, cis
  600. def plot(self, ax, kws):
  601. """Draw the plot onto an axes, passing matplotlib kwargs."""
  602. # Draw a test plot, using the passed in kwargs. The goal here is to
  603. # honor both (a) the current state of the plot cycler and (b) the
  604. # specified kwargs on all the lines we will draw, overriding when
  605. # relevant with the data semantics. Note that we won't cycle
  606. # internally; in other words, if ``hue`` is not used, all elements will
  607. # have the same color, but they will have the color that you would have
  608. # gotten from the corresponding matplotlib function, and calling the
  609. # function will advance the axes property cycle.
  610. scout, = ax.plot([], [], **kws)
  611. orig_color = kws.pop("color", scout.get_color())
  612. orig_marker = kws.pop("marker", scout.get_marker())
  613. orig_linewidth = kws.pop("linewidth",
  614. kws.pop("lw", scout.get_linewidth()))
  615. orig_dashes = kws.pop("dashes", "")
  616. kws.setdefault("markeredgewidth", kws.pop("mew", .75))
  617. kws.setdefault("markeredgecolor", kws.pop("mec", "w"))
  618. scout.remove()
  619. # Set default error kwargs
  620. err_kws = self.err_kws.copy()
  621. if self.err_style == "band":
  622. err_kws.setdefault("alpha", .2)
  623. elif self.err_style == "bars":
  624. pass
  625. elif self.err_style is not None:
  626. err = "`err_style` must be 'band' or 'bars', not {}"
  627. raise ValueError(err.format(self.err_style))
  628. # Loop over the semantic subsets and draw a line for each
  629. for semantics, data in self.subset_data():
  630. hue, size, style = semantics
  631. x, y, units = data["x"], data["y"], data.get("units", None)
  632. if self.estimator is not None:
  633. if self.units is not None:
  634. err = "estimator must be None when specifying units"
  635. raise ValueError(err)
  636. x, y, y_ci = self.aggregate(y, x, units)
  637. else:
  638. y_ci = None
  639. kws["color"] = self.palette.get(hue, orig_color)
  640. kws["dashes"] = self.dashes.get(style, orig_dashes)
  641. kws["marker"] = self.markers.get(style, orig_marker)
  642. kws["linewidth"] = self.sizes.get(size, orig_linewidth)
  643. line, = ax.plot([], [], **kws)
  644. line_color = line.get_color()
  645. line_alpha = line.get_alpha()
  646. line_capstyle = line.get_solid_capstyle()
  647. line.remove()
  648. # --- Draw the main line
  649. x, y = np.asarray(x), np.asarray(y)
  650. if self.units is None:
  651. line, = ax.plot(x, y, **kws)
  652. else:
  653. for u in units.unique():
  654. rows = np.asarray(units == u)
  655. ax.plot(x[rows], y[rows], **kws)
  656. # --- Draw the confidence intervals
  657. if y_ci is not None:
  658. low, high = np.asarray(y_ci["low"]), np.asarray(y_ci["high"])
  659. if self.err_style == "band":
  660. ax.fill_between(x, low, high, color=line_color, **err_kws)
  661. elif self.err_style == "bars":
  662. y_err = ci_to_errsize((low, high), y)
  663. ebars = ax.errorbar(x, y, y_err, linestyle="",
  664. color=line_color, alpha=line_alpha,
  665. **err_kws)
  666. # Set the capstyle properly on the error bars
  667. for obj in ebars.get_children():
  668. try:
  669. obj.set_capstyle(line_capstyle)
  670. except AttributeError:
  671. # Does not exist on mpl < 2.2
  672. pass
  673. # Finalize the axes details
  674. self.label_axes(ax)
  675. if self.legend:
  676. self.add_legend_data(ax)
  677. handles, _ = ax.get_legend_handles_labels()
  678. if handles:
  679. ax.legend()
  680. class _ScatterPlotter(_RelationalPlotter):
  681. _legend_attributes = ["color", "s", "marker"]
  682. _legend_func = "scatter"
  683. def __init__(self,
  684. x=None, y=None, hue=None, size=None, style=None, data=None,
  685. palette=None, hue_order=None, hue_norm=None,
  686. sizes=None, size_order=None, size_norm=None,
  687. dashes=None, markers=None, style_order=None,
  688. x_bins=None, y_bins=None,
  689. units=None, estimator=None, ci=None, n_boot=None,
  690. alpha=None, x_jitter=None, y_jitter=None,
  691. legend=None):
  692. plot_data = self.establish_variables(
  693. x, y, hue, size, style, units, data
  694. )
  695. self._default_size_range = (
  696. np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"])
  697. )
  698. self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm)
  699. self.parse_size(plot_data["size"], sizes, size_order, size_norm)
  700. self.parse_style(plot_data["style"], markers, None, style_order)
  701. self.units = units
  702. self.alpha = alpha
  703. self.legend = legend
  704. def plot(self, ax, kws):
  705. # Draw a test plot, using the passed in kwargs. The goal here is to
  706. # honor both (a) the current state of the plot cycler and (b) the
  707. # specified kwargs on all the lines we will draw, overriding when
  708. # relevant with the data semantics. Note that we won't cycle
  709. # internally; in other words, if ``hue`` is not used, all elements will
  710. # have the same color, but they will have the color that you would have
  711. # gotten from the corresponding matplotlib function, and calling the
  712. # function will advance the axes property cycle.
  713. scout = ax.scatter([], [], **kws)
  714. s = kws.pop("s", scout.get_sizes())
  715. c = kws.pop("c", scout.get_facecolors())
  716. scout.remove()
  717. kws.pop("color", None) # TODO is this optimal?
  718. kws.setdefault("linewidth", .75) # TODO scale with marker size?
  719. kws.setdefault("edgecolor", "w")
  720. if self.markers:
  721. # Use a representative marker so scatter sets the edgecolor
  722. # properly for line art markers. We currently enforce either
  723. # all or none line art so this works.
  724. example_marker = list(self.markers.values())[0]
  725. kws.setdefault("marker", example_marker)
  726. # TODO this makes it impossible to vary alpha with hue which might
  727. # otherwise be useful? Should we just pass None?
  728. kws["alpha"] = 1 if self.alpha == "auto" else self.alpha
  729. # Assign arguments for plt.scatter and draw the plot
  730. data = self.plot_data[self.semantics].dropna()
  731. if not data.size:
  732. return
  733. x = data["x"]
  734. y = data["y"]
  735. if self.palette:
  736. c = [self.palette.get(val) for val in data["hue"]]
  737. if self.sizes:
  738. s = [self.sizes.get(val) for val in data["size"]]
  739. args = np.asarray(x), np.asarray(y), np.asarray(s), np.asarray(c)
  740. points = ax.scatter(*args, **kws)
  741. # Update the paths to get different marker shapes. This has to be
  742. # done here because plt.scatter allows varying sizes and colors
  743. # but only a single marker shape per call.
  744. if self.paths:
  745. p = [self.paths.get(val) for val in data["style"]]
  746. points.set_paths(p)
  747. # Finalize the axes details
  748. self.label_axes(ax)
  749. if self.legend:
  750. self.add_legend_data(ax)
  751. handles, _ = ax.get_legend_handles_labels()
  752. if handles:
  753. ax.legend()
  754. _relational_docs = dict(
  755. # --- Introductory prose
  756. main_api_narrative=dedent("""\
  757. The relationship between ``x`` and ``y`` can be shown for different subsets
  758. of the data using the ``hue``, ``size``, and ``style`` parameters. These
  759. parameters control what visual semantics are used to identify the different
  760. subsets. It is possible to show up to three dimensions independently by
  761. using all three semantic types, but this style of plot can be hard to
  762. interpret and is often ineffective. Using redundant semantics (i.e. both
  763. ``hue`` and ``style`` for the same variable) can be helpful for making
  764. graphics more accessible.
  765. See the :ref:`tutorial <relational_tutorial>` for more information.\
  766. """),
  767. relational_semantic_narrative=dedent("""\
  768. The default treatment of the ``hue`` (and to a lesser extent, ``size``)
  769. semantic, if present, depends on whether the variable is inferred to
  770. represent "numeric" or "categorical" data. In particular, numeric variables
  771. are represented with a sequential colormap by default, and the legend
  772. entries show regular "ticks" with values that may or may not exist in the
  773. data. This behavior can be controlled through various parameters, as
  774. described and illustrated below.\
  775. """),
  776. # --- Shared function parameters
  777. data_vars=dedent("""\
  778. x, y : names of variables in ``data`` or vector data, optional
  779. Input data variables; must be numeric. Can pass data directly or
  780. reference columns in ``data``.\
  781. """),
  782. data=dedent("""\
  783. data : DataFrame, array, or list of arrays, optional
  784. Input data structure. If ``x`` and ``y`` are specified as names, this
  785. should be a "long-form" DataFrame containing those columns. Otherwise
  786. it is treated as "wide-form" data and grouping variables are ignored.
  787. See the examples for the various ways this parameter can be specified
  788. and the different effects of each.\
  789. """),
  790. palette=dedent("""\
  791. palette : string, list, dict, or matplotlib colormap
  792. An object that determines how colors are chosen when ``hue`` is used.
  793. It can be the name of a seaborn palette or matplotlib colormap, a list
  794. of colors (anything matplotlib understands), a dict mapping levels
  795. of the ``hue`` variable to colors, or a matplotlib colormap object.\
  796. """),
  797. hue_order=dedent("""\
  798. hue_order : list, optional
  799. Specified order for the appearance of the ``hue`` variable levels,
  800. otherwise they are determined from the data. Not relevant when the
  801. ``hue`` variable is numeric.\
  802. """),
  803. hue_norm=dedent("""\
  804. hue_norm : tuple or Normalize object, optional
  805. Normalization in data units for colormap applied to the ``hue``
  806. variable when it is numeric. Not relevant if it is categorical.\
  807. """),
  808. sizes=dedent("""\
  809. sizes : list, dict, or tuple, optional
  810. An object that determines how sizes are chosen when ``size`` is used.
  811. It can always be a list of size values or a dict mapping levels of the
  812. ``size`` variable to sizes. When ``size`` is numeric, it can also be
  813. a tuple specifying the minimum and maximum size to use such that other
  814. values are normalized within this range.\
  815. """),
  816. size_order=dedent("""\
  817. size_order : list, optional
  818. Specified order for appearance of the ``size`` variable levels,
  819. otherwise they are determined from the data. Not relevant when the
  820. ``size`` variable is numeric.\
  821. """),
  822. size_norm=dedent("""\
  823. size_norm : tuple or Normalize object, optional
  824. Normalization in data units for scaling plot objects when the
  825. ``size`` variable is numeric.\
  826. """),
  827. markers=dedent("""\
  828. markers : boolean, list, or dictionary, optional
  829. Object determining how to draw the markers for different levels of the
  830. ``style`` variable. Setting to ``True`` will use default markers, or
  831. you can pass a list of markers or a dictionary mapping levels of the
  832. ``style`` variable to markers. Setting to ``False`` will draw
  833. marker-less lines. Markers are specified as in matplotlib.\
  834. """),
  835. style_order=dedent("""\
  836. style_order : list, optional
  837. Specified order for appearance of the ``style`` variable levels
  838. otherwise they are determined from the data. Not relevant when the
  839. ``style`` variable is numeric.\
  840. """),
  841. units=dedent("""\
  842. units : {long_form_var}
  843. Grouping variable identifying sampling units. When used, a separate
  844. line will be drawn for each unit with appropriate semantics, but no
  845. legend entry will be added. Useful for showing distribution of
  846. experimental replicates when exact identities are not needed.
  847. """),
  848. estimator=dedent("""\
  849. estimator : name of pandas method or callable or None, optional
  850. Method for aggregating across multiple observations of the ``y``
  851. variable at the same ``x`` level. If ``None``, all observations will
  852. be drawn.\
  853. """),
  854. ci=dedent("""\
  855. ci : int or "sd" or None, optional
  856. Size of the confidence interval to draw when aggregating with an
  857. estimator. "sd" means to draw the standard deviation of the data.
  858. Setting to ``None`` will skip bootstrapping.\
  859. """),
  860. n_boot=dedent("""\
  861. n_boot : int, optional
  862. Number of bootstraps to use for computing the confidence interval.\
  863. """),
  864. seed=dedent("""\
  865. seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
  866. Seed or random number generator for reproducible bootstrapping.\
  867. """),
  868. legend=dedent("""\
  869. legend : "brief", "full", or False, optional
  870. How to draw the legend. If "brief", numeric ``hue`` and ``size``
  871. variables will be represented with a sample of evenly spaced values.
  872. If "full", every group will get an entry in the legend. If ``False``,
  873. no legend data is added and no legend is drawn.\
  874. """),
  875. ax_in=dedent("""\
  876. ax : matplotlib Axes, optional
  877. Axes object to draw the plot onto, otherwise uses the current Axes.\
  878. """),
  879. ax_out=dedent("""\
  880. ax : matplotlib Axes
  881. Returns the Axes object with the plot drawn onto it.\
  882. """),
  883. # --- Repeated phrases
  884. long_form_var="name of variables in ``data`` or vector data, optional",
  885. )
  886. _relational_docs.update(_facet_docs)
  887. def lineplot(x=None, y=None, hue=None, size=None, style=None, data=None,
  888. palette=None, hue_order=None, hue_norm=None,
  889. sizes=None, size_order=None, size_norm=None,
  890. dashes=True, markers=None, style_order=None,
  891. units=None, estimator="mean", ci=95, n_boot=1000, seed=None,
  892. sort=True, err_style="band", err_kws=None,
  893. legend="brief", ax=None, **kwargs):
  894. p = _LinePlotter(
  895. x=x, y=y, hue=hue, size=size, style=style, data=data,
  896. palette=palette, hue_order=hue_order, hue_norm=hue_norm,
  897. sizes=sizes, size_order=size_order, size_norm=size_norm,
  898. dashes=dashes, markers=markers, style_order=style_order,
  899. units=units, estimator=estimator, ci=ci, n_boot=n_boot, seed=seed,
  900. sort=sort, err_style=err_style, err_kws=err_kws, legend=legend,
  901. )
  902. if ax is None:
  903. ax = plt.gca()
  904. p.plot(ax, kwargs)
  905. return ax
  906. lineplot.__doc__ = dedent("""\
  907. Draw a line plot with possibility of several semantic groupings.
  908. {main_api_narrative}
  909. {relational_semantic_narrative}
  910. By default, the plot aggregates over multiple ``y`` values at each value of
  911. ``x`` and shows an estimate of the central tendency and a confidence
  912. interval for that estimate.
  913. Parameters
  914. ----------
  915. {data_vars}
  916. hue : {long_form_var}
  917. Grouping variable that will produce lines with different colors.
  918. Can be either categorical or numeric, although color mapping will
  919. behave differently in latter case.
  920. size : {long_form_var}
  921. Grouping variable that will produce lines with different widths.
  922. Can be either categorical or numeric, although size mapping will
  923. behave differently in latter case.
  924. style : {long_form_var}
  925. Grouping variable that will produce lines with different dashes
  926. and/or markers. Can have a numeric dtype but will always be treated
  927. as categorical.
  928. {data}
  929. {palette}
  930. {hue_order}
  931. {hue_norm}
  932. {sizes}
  933. {size_order}
  934. {size_norm}
  935. dashes : boolean, list, or dictionary, optional
  936. Object determining how to draw the lines for different levels of the
  937. ``style`` variable. Setting to ``True`` will use default dash codes, or
  938. you can pass a list of dash codes or a dictionary mapping levels of the
  939. ``style`` variable to dash codes. Setting to ``False`` will use solid
  940. lines for all subsets. Dashes are specified as in matplotlib: a tuple
  941. of ``(segment, gap)`` lengths, or an empty string to draw a solid line.
  942. {markers}
  943. {style_order}
  944. {units}
  945. {estimator}
  946. {ci}
  947. {n_boot}
  948. {seed}
  949. sort : boolean, optional
  950. If True, the data will be sorted by the x and y variables, otherwise
  951. lines will connect points in the order they appear in the dataset.
  952. err_style : "band" or "bars", optional
  953. Whether to draw the confidence intervals with translucent error bands
  954. or discrete error bars.
  955. err_kws : dict of keyword arguments
  956. Additional paramters to control the aesthetics of the error bars. The
  957. kwargs are passed either to :meth:`matplotlib.axes.Axes.fill_between`
  958. or :meth:`matplotlib.axes.Axes.errorbar`, depending on ``err_style``.
  959. {legend}
  960. {ax_in}
  961. kwargs : key, value mappings
  962. Other keyword arguments are passed down to
  963. :meth:`matplotlib.axes.Axes.plot`.
  964. Returns
  965. -------
  966. {ax_out}
  967. See Also
  968. --------
  969. scatterplot : Show the relationship between two variables without
  970. emphasizing continuity of the ``x`` variable.
  971. pointplot : Show the relationship between two variables when one is
  972. categorical.
  973. Examples
  974. --------
  975. Draw a single line plot with error bands showing a confidence interval:
  976. .. plot::
  977. :context: close-figs
  978. >>> import seaborn as sns; sns.set()
  979. >>> import matplotlib.pyplot as plt
  980. >>> fmri = sns.load_dataset("fmri")
  981. >>> ax = sns.lineplot(x="timepoint", y="signal", data=fmri)
  982. Group by another variable and show the groups with different colors:
  983. .. plot::
  984. :context: close-figs
  985. >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event",
  986. ... data=fmri)
  987. Show the grouping variable with both color and line dashing:
  988. .. plot::
  989. :context: close-figs
  990. >>> ax = sns.lineplot(x="timepoint", y="signal",
  991. ... hue="event", style="event", data=fmri)
  992. Use color and line dashing to represent two different grouping variables:
  993. .. plot::
  994. :context: close-figs
  995. >>> ax = sns.lineplot(x="timepoint", y="signal",
  996. ... hue="region", style="event", data=fmri)
  997. Use markers instead of the dashes to identify groups:
  998. .. plot::
  999. :context: close-figs
  1000. >>> ax = sns.lineplot(x="timepoint", y="signal",
  1001. ... hue="event", style="event",
  1002. ... markers=True, dashes=False, data=fmri)
  1003. Show error bars instead of error bands and plot the standard error:
  1004. .. plot::
  1005. :context: close-figs
  1006. >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event",
  1007. ... err_style="bars", ci=68, data=fmri)
  1008. Show experimental replicates instead of aggregating:
  1009. .. plot::
  1010. :context: close-figs
  1011. >>> ax = sns.lineplot(x="timepoint", y="signal", hue="event",
  1012. ... units="subject", estimator=None, lw=1,
  1013. ... data=fmri.query("region == 'frontal'"))
  1014. Use a quantitative color mapping:
  1015. .. plot::
  1016. :context: close-figs
  1017. >>> dots = sns.load_dataset("dots").query("align == 'dots'")
  1018. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1019. ... hue="coherence", style="choice",
  1020. ... data=dots)
  1021. Use a different normalization for the colormap:
  1022. .. plot::
  1023. :context: close-figs
  1024. >>> from matplotlib.colors import LogNorm
  1025. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1026. ... hue="coherence", style="choice",
  1027. ... hue_norm=LogNorm(), data=dots)
  1028. Use a different color palette:
  1029. .. plot::
  1030. :context: close-figs
  1031. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1032. ... hue="coherence", style="choice",
  1033. ... palette="ch:2.5,.25", data=dots)
  1034. Use specific color values, treating the hue variable as categorical:
  1035. .. plot::
  1036. :context: close-figs
  1037. >>> palette = sns.color_palette("mako_r", 6)
  1038. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1039. ... hue="coherence", style="choice",
  1040. ... palette=palette, data=dots)
  1041. Change the width of the lines with a quantitative variable:
  1042. .. plot::
  1043. :context: close-figs
  1044. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1045. ... size="coherence", hue="choice",
  1046. ... legend="full", data=dots)
  1047. Change the range of line widths used to normalize the size variable:
  1048. .. plot::
  1049. :context: close-figs
  1050. >>> ax = sns.lineplot(x="time", y="firing_rate",
  1051. ... size="coherence", hue="choice",
  1052. ... sizes=(.25, 2.5), data=dots)
  1053. Plot from a wide-form DataFrame:
  1054. .. plot::
  1055. :context: close-figs
  1056. >>> import numpy as np, pandas as pd; plt.close("all")
  1057. >>> index = pd.date_range("1 1 2000", periods=100,
  1058. ... freq="m", name="date")
  1059. >>> data = np.random.randn(100, 4).cumsum(axis=0)
  1060. >>> wide_df = pd.DataFrame(data, index, ["a", "b", "c", "d"])
  1061. >>> ax = sns.lineplot(data=wide_df)
  1062. Plot from a list of Series:
  1063. .. plot::
  1064. :context: close-figs
  1065. >>> list_data = [wide_df.loc[:"2005", "a"], wide_df.loc["2003":, "b"]]
  1066. >>> ax = sns.lineplot(data=list_data)
  1067. Plot a single Series, pass kwargs to :meth:`matplotlib.axes.Axes.plot`:
  1068. .. plot::
  1069. :context: close-figs
  1070. >>> ax = sns.lineplot(data=wide_df["a"], color="coral", label="line")
  1071. Draw lines at points as they appear in the dataset:
  1072. .. plot::
  1073. :context: close-figs
  1074. >>> x, y = np.random.randn(2, 5000).cumsum(axis=1)
  1075. >>> ax = sns.lineplot(x=x, y=y, sort=False, lw=1)
  1076. Use :func:`relplot` to combine :func:`lineplot` and :class:`FacetGrid`:
  1077. This allows grouping within additional categorical variables. Using
  1078. :func:`relplot` is safer than using :class:`FacetGrid` directly, as it
  1079. ensures synchronization of the semantic mappings across facets.
  1080. .. plot::
  1081. :context: close-figs
  1082. >>> g = sns.relplot(x="timepoint", y="signal",
  1083. ... col="region", hue="event", style="event",
  1084. ... kind="line", data=fmri)
  1085. """).format(**_relational_docs)
  1086. def scatterplot(x=None, y=None, hue=None, style=None, size=None, data=None,
  1087. palette=None, hue_order=None, hue_norm=None,
  1088. sizes=None, size_order=None, size_norm=None,
  1089. markers=True, style_order=None,
  1090. x_bins=None, y_bins=None,
  1091. units=None, estimator=None, ci=95, n_boot=1000,
  1092. alpha="auto", x_jitter=None, y_jitter=None,
  1093. legend="brief", ax=None, **kwargs):
  1094. p = _ScatterPlotter(
  1095. x=x, y=y, hue=hue, style=style, size=size, data=data,
  1096. palette=palette, hue_order=hue_order, hue_norm=hue_norm,
  1097. sizes=sizes, size_order=size_order, size_norm=size_norm,
  1098. markers=markers, style_order=style_order,
  1099. x_bins=x_bins, y_bins=y_bins,
  1100. estimator=estimator, ci=ci, n_boot=n_boot,
  1101. alpha=alpha, x_jitter=x_jitter, y_jitter=y_jitter, legend=legend,
  1102. )
  1103. if ax is None:
  1104. ax = plt.gca()
  1105. p.plot(ax, kwargs)
  1106. return ax
  1107. scatterplot.__doc__ = dedent("""\
  1108. Draw a scatter plot with possibility of several semantic groupings.
  1109. {main_api_narrative}
  1110. {relational_semantic_narrative}
  1111. Parameters
  1112. ----------
  1113. {data_vars}
  1114. hue : {long_form_var}
  1115. Grouping variable that will produce points with different colors.
  1116. Can be either categorical or numeric, although color mapping will
  1117. behave differently in latter case.
  1118. size : {long_form_var}
  1119. Grouping variable that will produce points with different sizes.
  1120. Can be either categorical or numeric, although size mapping will
  1121. behave differently in latter case.
  1122. style : {long_form_var}
  1123. Grouping variable that will produce points with different markers.
  1124. Can have a numeric dtype but will always be treated as categorical.
  1125. {data}
  1126. {palette}
  1127. {hue_order}
  1128. {hue_norm}
  1129. {sizes}
  1130. {size_order}
  1131. {size_norm}
  1132. {markers}
  1133. {style_order}
  1134. {{x,y}}_bins : lists or arrays or functions
  1135. *Currently non-functional.*
  1136. {units}
  1137. *Currently non-functional.*
  1138. {estimator}
  1139. *Currently non-functional.*
  1140. {ci}
  1141. *Currently non-functional.*
  1142. {n_boot}
  1143. *Currently non-functional.*
  1144. alpha : float
  1145. Proportional opacity of the points.
  1146. {{x,y}}_jitter : booleans or floats
  1147. *Currently non-functional.*
  1148. {legend}
  1149. {ax_in}
  1150. kwargs : key, value mappings
  1151. Other keyword arguments are passed down to
  1152. :meth:`matplotlib.axes.Axes.scatter`.
  1153. Returns
  1154. -------
  1155. {ax_out}
  1156. See Also
  1157. --------
  1158. lineplot : Show the relationship between two variables connected with
  1159. lines to emphasize continuity.
  1160. swarmplot : Draw a scatter plot with one categorical variable, arranging
  1161. the points to show the distribution of values.
  1162. Examples
  1163. --------
  1164. Draw a simple scatter plot between two variables:
  1165. .. plot::
  1166. :context: close-figs
  1167. >>> import seaborn as sns; sns.set()
  1168. >>> import matplotlib.pyplot as plt
  1169. >>> tips = sns.load_dataset("tips")
  1170. >>> ax = sns.scatterplot(x="total_bill", y="tip", data=tips)
  1171. Group by another variable and show the groups with different colors:
  1172. .. plot::
  1173. :context: close-figs
  1174. >>> ax = sns.scatterplot(x="total_bill", y="tip", hue="time",
  1175. ... data=tips)
  1176. Show the grouping variable by varying both color and marker:
  1177. .. plot::
  1178. :context: close-figs
  1179. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1180. ... hue="time", style="time", data=tips)
  1181. Vary colors and markers to show two different grouping variables:
  1182. .. plot::
  1183. :context: close-figs
  1184. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1185. ... hue="day", style="time", data=tips)
  1186. Show a quantitative variable by varying the size of the points:
  1187. .. plot::
  1188. :context: close-figs
  1189. >>> ax = sns.scatterplot(x="total_bill", y="tip", size="size",
  1190. ... data=tips)
  1191. Also show the quantitative variable by also using continuous colors:
  1192. .. plot::
  1193. :context: close-figs
  1194. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1195. ... hue="size", size="size",
  1196. ... data=tips)
  1197. Use a different continuous color map:
  1198. .. plot::
  1199. :context: close-figs
  1200. >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True)
  1201. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1202. ... hue="size", size="size",
  1203. ... palette=cmap,
  1204. ... data=tips)
  1205. Change the minimum and maximum point size and show all sizes in legend:
  1206. .. plot::
  1207. :context: close-figs
  1208. >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True)
  1209. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1210. ... hue="size", size="size",
  1211. ... sizes=(20, 200), palette=cmap,
  1212. ... legend="full", data=tips)
  1213. Use a narrower range of color map intensities:
  1214. .. plot::
  1215. :context: close-figs
  1216. >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True)
  1217. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1218. ... hue="size", size="size",
  1219. ... sizes=(20, 200), hue_norm=(0, 7),
  1220. ... legend="full", data=tips)
  1221. Vary the size with a categorical variable, and use a different palette:
  1222. .. plot::
  1223. :context: close-figs
  1224. >>> cmap = sns.cubehelix_palette(dark=.3, light=.8, as_cmap=True)
  1225. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1226. ... hue="day", size="smoker",
  1227. ... palette="Set2",
  1228. ... data=tips)
  1229. Use a specific set of markers:
  1230. .. plot::
  1231. :context: close-figs
  1232. >>> markers = {{"Lunch": "s", "Dinner": "X"}}
  1233. >>> ax = sns.scatterplot(x="total_bill", y="tip", style="time",
  1234. ... markers=markers,
  1235. ... data=tips)
  1236. Control plot attributes using matplotlib parameters:
  1237. .. plot::
  1238. :context: close-figs
  1239. >>> ax = sns.scatterplot(x="total_bill", y="tip",
  1240. ... s=100, color=".2", marker="+",
  1241. ... data=tips)
  1242. Pass data vectors instead of names in a data frame:
  1243. .. plot::
  1244. :context: close-figs
  1245. >>> iris = sns.load_dataset("iris")
  1246. >>> ax = sns.scatterplot(x=iris.sepal_length, y=iris.sepal_width,
  1247. ... hue=iris.species, style=iris.species)
  1248. Pass a wide-form dataset and plot against its index:
  1249. .. plot::
  1250. :context: close-figs
  1251. >>> import numpy as np, pandas as pd; plt.close("all")
  1252. >>> index = pd.date_range("1 1 2000", periods=100,
  1253. ... freq="m", name="date")
  1254. >>> data = np.random.randn(100, 4).cumsum(axis=0)
  1255. >>> wide_df = pd.DataFrame(data, index, ["a", "b", "c", "d"])
  1256. >>> ax = sns.scatterplot(data=wide_df)
  1257. Use :func:`relplot` to combine :func:`scatterplot` and :class:`FacetGrid`:
  1258. This allows grouping within additional categorical variables. Using
  1259. :func:`relplot` is safer than using :class:`FacetGrid` directly, as it
  1260. ensures synchronization of the semantic mappings across facets.
  1261. .. plot::
  1262. :context: close-figs
  1263. >>> g = sns.relplot(x="total_bill", y="tip",
  1264. ... col="time", hue="day", style="day",
  1265. ... kind="scatter", data=tips)
  1266. """).format(**_relational_docs)
  1267. def relplot(x=None, y=None, hue=None, size=None, style=None, data=None,
  1268. row=None, col=None, col_wrap=None, row_order=None, col_order=None,
  1269. palette=None, hue_order=None, hue_norm=None,
  1270. sizes=None, size_order=None, size_norm=None,
  1271. markers=None, dashes=None, style_order=None,
  1272. legend="brief", kind="scatter",
  1273. height=5, aspect=1, facet_kws=None, **kwargs):
  1274. if kind == "scatter":
  1275. plotter = _ScatterPlotter
  1276. func = scatterplot
  1277. markers = True if markers is None else markers
  1278. elif kind == "line":
  1279. plotter = _LinePlotter
  1280. func = lineplot
  1281. dashes = True if dashes is None else dashes
  1282. else:
  1283. err = "Plot kind {} not recognized".format(kind)
  1284. raise ValueError(err)
  1285. # Check for attempt to plot onto specific axes and warn
  1286. if "ax" in kwargs:
  1287. msg = ("relplot is a figure-level function and does not accept "
  1288. "target axes. You may wish to try {}".format(kind + "plot"))
  1289. warnings.warn(msg, UserWarning)
  1290. kwargs.pop("ax")
  1291. # Use the full dataset to establish how to draw the semantics
  1292. p = plotter(
  1293. x=x, y=y, hue=hue, size=size, style=style, data=data,
  1294. palette=palette, hue_order=hue_order, hue_norm=hue_norm,
  1295. sizes=sizes, size_order=size_order, size_norm=size_norm,
  1296. markers=markers, dashes=dashes, style_order=style_order,
  1297. legend=legend,
  1298. )
  1299. palette = p.palette if p.palette else None
  1300. hue_order = p.hue_levels if any(p.hue_levels) else None
  1301. hue_norm = p.hue_norm if p.hue_norm is not None else None
  1302. sizes = p.sizes if p.sizes else None
  1303. size_order = p.size_levels if any(p.size_levels) else None
  1304. size_norm = p.size_norm if p.size_norm is not None else None
  1305. markers = p.markers if p.markers else None
  1306. dashes = p.dashes if p.dashes else None
  1307. style_order = p.style_levels if any(p.style_levels) else None
  1308. plot_kws = dict(
  1309. palette=palette, hue_order=hue_order, hue_norm=p.hue_norm,
  1310. sizes=sizes, size_order=size_order, size_norm=p.size_norm,
  1311. markers=markers, dashes=dashes, style_order=style_order,
  1312. legend=False,
  1313. )
  1314. plot_kws.update(kwargs)
  1315. if kind == "scatter":
  1316. plot_kws.pop("dashes")
  1317. # Set up the FacetGrid object
  1318. facet_kws = {} if facet_kws is None else facet_kws
  1319. g = FacetGrid(
  1320. data=data, row=row, col=col, col_wrap=col_wrap,
  1321. row_order=row_order, col_order=col_order,
  1322. height=height, aspect=aspect, dropna=False,
  1323. **facet_kws
  1324. )
  1325. # Draw the plot
  1326. g.map_dataframe(func, x, y,
  1327. hue=hue, size=size, style=style,
  1328. **plot_kws)
  1329. # Show the legend
  1330. if legend:
  1331. p.add_legend_data(g.axes.flat[0])
  1332. if p.legend_data:
  1333. g.add_legend(legend_data=p.legend_data,
  1334. label_order=p.legend_order)
  1335. return g
  1336. relplot.__doc__ = dedent("""\
  1337. Figure-level interface for drawing relational plots onto a FacetGrid.
  1338. This function provides access to several different axes-level functions
  1339. that show the relationship between two variables with semantic mappings
  1340. of subsets. The ``kind`` parameter selects the underlying axes-level
  1341. function to use:
  1342. - :func:`scatterplot` (with ``kind="scatter"``; the default)
  1343. - :func:`lineplot` (with ``kind="line"``)
  1344. Extra keyword arguments are passed to the underlying function, so you
  1345. should refer to the documentation for each to see kind-specific options.
  1346. {main_api_narrative}
  1347. {relational_semantic_narrative}
  1348. After plotting, the :class:`FacetGrid` with the plot is returned and can
  1349. be used directly to tweak supporting plot details or add other layers.
  1350. Note that, unlike when using the underlying plotting functions directly,
  1351. data must be passed in a long-form DataFrame with variables specified by
  1352. passing strings to ``x``, ``y``, and other parameters.
  1353. Parameters
  1354. ----------
  1355. x, y : names of variables in ``data``
  1356. Input data variables; must be numeric.
  1357. hue : name in ``data``, optional
  1358. Grouping variable that will produce elements with different colors.
  1359. Can be either categorical or numeric, although color mapping will
  1360. behave differently in latter case.
  1361. size : name in ``data``, optional
  1362. Grouping variable that will produce elements with different sizes.
  1363. Can be either categorical or numeric, although size mapping will
  1364. behave differently in latter case.
  1365. style : name in ``data``, optional
  1366. Grouping variable that will produce elements with different styles.
  1367. Can have a numeric dtype but will always be treated as categorical.
  1368. {data}
  1369. row, col : names of variables in ``data``, optional
  1370. Categorical variables that will determine the faceting of the grid.
  1371. {col_wrap}
  1372. row_order, col_order : lists of strings, optional
  1373. Order to organize the rows and/or columns of the grid in, otherwise the
  1374. orders are inferred from the data objects.
  1375. {palette}
  1376. {hue_order}
  1377. {hue_norm}
  1378. {sizes}
  1379. {size_order}
  1380. {size_norm}
  1381. {legend}
  1382. kind : string, optional
  1383. Kind of plot to draw, corresponding to a seaborn relational plot.
  1384. Options are {{``scatter`` and ``line``}}.
  1385. {height}
  1386. {aspect}
  1387. facet_kws : dict, optional
  1388. Dictionary of other keyword arguments to pass to :class:`FacetGrid`.
  1389. kwargs : key, value pairings
  1390. Other keyword arguments are passed through to the underlying plotting
  1391. function.
  1392. Returns
  1393. -------
  1394. g : :class:`FacetGrid`
  1395. Returns the :class:`FacetGrid` object with the plot on it for further
  1396. tweaking.
  1397. Examples
  1398. --------
  1399. Draw a single facet to use the :class:`FacetGrid` legend placement:
  1400. .. plot::
  1401. :context: close-figs
  1402. >>> import seaborn as sns
  1403. >>> sns.set(style="ticks")
  1404. >>> tips = sns.load_dataset("tips")
  1405. >>> g = sns.relplot(x="total_bill", y="tip", hue="day", data=tips)
  1406. Facet on the columns with another variable:
  1407. .. plot::
  1408. :context: close-figs
  1409. >>> g = sns.relplot(x="total_bill", y="tip",
  1410. ... hue="day", col="time", data=tips)
  1411. Facet on the columns and rows:
  1412. .. plot::
  1413. :context: close-figs
  1414. >>> g = sns.relplot(x="total_bill", y="tip", hue="day",
  1415. ... col="time", row="sex", data=tips)
  1416. "Wrap" many column facets into multiple rows:
  1417. .. plot::
  1418. :context: close-figs
  1419. >>> g = sns.relplot(x="total_bill", y="tip", hue="time",
  1420. ... col="day", col_wrap=2, data=tips)
  1421. Use multiple semantic variables on each facet with specified attributes:
  1422. .. plot::
  1423. :context: close-figs
  1424. >>> g = sns.relplot(x="total_bill", y="tip", hue="time", size="size",
  1425. ... palette=["b", "r"], sizes=(10, 100),
  1426. ... col="time", data=tips)
  1427. Use a different kind of plot:
  1428. .. plot::
  1429. :context: close-figs
  1430. >>> fmri = sns.load_dataset("fmri")
  1431. >>> g = sns.relplot(x="timepoint", y="signal",
  1432. ... hue="event", style="event", col="region",
  1433. ... kind="line", data=fmri)
  1434. Change the size of each facet:
  1435. .. plot::
  1436. :context: close-figs
  1437. >>> g = sns.relplot(x="timepoint", y="signal",
  1438. ... hue="event", style="event", col="region",
  1439. ... height=5, aspect=.7, kind="line", data=fmri)
  1440. """).format(**_relational_docs)