1
0

utils.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674
  1. """Small plotting-related utility functions."""
  2. import colorsys
  3. import os
  4. import numpy as np
  5. from scipy import stats
  6. import pandas as pd
  7. import matplotlib as mpl
  8. import matplotlib.colors as mplcol
  9. import matplotlib.pyplot as plt
  10. import warnings
  11. from urllib.request import urlopen, urlretrieve
  12. from http.client import HTTPException
  13. __all__ = ["desaturate", "saturate", "set_hls_values",
  14. "despine", "get_dataset_names", "get_data_home", "load_dataset"]
  15. def remove_na(arr):
  16. """Helper method for removing NA values from array-like.
  17. Parameters
  18. ----------
  19. arr : array-like
  20. The array-like from which to remove NA values.
  21. Returns
  22. -------
  23. clean_arr : array-like
  24. The original array with NA values removed.
  25. """
  26. return arr[pd.notnull(arr)]
  27. def sort_df(df, *args, **kwargs):
  28. """Wrapper to handle different pandas sorting API pre/post 0.17."""
  29. msg = "This function is deprecated and will be removed in a future version"
  30. warnings.warn(msg)
  31. try:
  32. return df.sort_values(*args, **kwargs)
  33. except AttributeError:
  34. return df.sort(*args, **kwargs)
  35. def ci_to_errsize(cis, heights):
  36. """Convert intervals to error arguments relative to plot heights.
  37. Parameters
  38. ----------
  39. cis: 2 x n sequence
  40. sequence of confidence interval limits
  41. heights : n sequence
  42. sequence of plot heights
  43. Returns
  44. -------
  45. errsize : 2 x n array
  46. sequence of error size relative to height values in correct
  47. format as argument for plt.bar
  48. """
  49. cis = np.atleast_2d(cis).reshape(2, -1)
  50. heights = np.atleast_1d(heights)
  51. errsize = []
  52. for i, (low, high) in enumerate(np.transpose(cis)):
  53. h = heights[i]
  54. elow = h - low
  55. ehigh = high - h
  56. errsize.append([elow, ehigh])
  57. errsize = np.asarray(errsize).T
  58. return errsize
  59. def pmf_hist(a, bins=10):
  60. """Return arguments to plt.bar for pmf-like histogram of an array.
  61. DEPRECATED: will be removed in a future version.
  62. Parameters
  63. ----------
  64. a: array-like
  65. array to make histogram of
  66. bins: int
  67. number of bins
  68. Returns
  69. -------
  70. x: array
  71. left x position of bars
  72. h: array
  73. height of bars
  74. w: float
  75. width of bars
  76. """
  77. msg = "This function is deprecated and will be removed in a future version"
  78. warnings.warn(msg)
  79. n, x = np.histogram(a, bins)
  80. h = n / n.sum()
  81. w = x[1] - x[0]
  82. return x[:-1], h, w
  83. def desaturate(color, prop):
  84. """Decrease the saturation channel of a color by some percent.
  85. Parameters
  86. ----------
  87. color : matplotlib color
  88. hex, rgb-tuple, or html color name
  89. prop : float
  90. saturation channel of color will be multiplied by this value
  91. Returns
  92. -------
  93. new_color : rgb tuple
  94. desaturated color code in RGB tuple representation
  95. """
  96. # Check inputs
  97. if not 0 <= prop <= 1:
  98. raise ValueError("prop must be between 0 and 1")
  99. # Get rgb tuple rep
  100. rgb = mplcol.colorConverter.to_rgb(color)
  101. # Convert to hls
  102. h, l, s = colorsys.rgb_to_hls(*rgb)
  103. # Desaturate the saturation channel
  104. s *= prop
  105. # Convert back to rgb
  106. new_color = colorsys.hls_to_rgb(h, l, s)
  107. return new_color
  108. def saturate(color):
  109. """Return a fully saturated color with the same hue.
  110. Parameters
  111. ----------
  112. color : matplotlib color
  113. hex, rgb-tuple, or html color name
  114. Returns
  115. -------
  116. new_color : rgb tuple
  117. saturated color code in RGB tuple representation
  118. """
  119. return set_hls_values(color, s=1)
  120. def set_hls_values(color, h=None, l=None, s=None): # noqa
  121. """Independently manipulate the h, l, or s channels of a color.
  122. Parameters
  123. ----------
  124. color : matplotlib color
  125. hex, rgb-tuple, or html color name
  126. h, l, s : floats between 0 and 1, or None
  127. new values for each channel in hls space
  128. Returns
  129. -------
  130. new_color : rgb tuple
  131. new color code in RGB tuple representation
  132. """
  133. # Get an RGB tuple representation
  134. rgb = mplcol.colorConverter.to_rgb(color)
  135. vals = list(colorsys.rgb_to_hls(*rgb))
  136. for i, val in enumerate([h, l, s]):
  137. if val is not None:
  138. vals[i] = val
  139. rgb = colorsys.hls_to_rgb(*vals)
  140. return rgb
  141. def axlabel(xlabel, ylabel, **kwargs):
  142. """Grab current axis and label it."""
  143. ax = plt.gca()
  144. ax.set_xlabel(xlabel, **kwargs)
  145. ax.set_ylabel(ylabel, **kwargs)
  146. def despine(fig=None, ax=None, top=True, right=True, left=False,
  147. bottom=False, offset=None, trim=False):
  148. """Remove the top and right spines from plot(s).
  149. fig : matplotlib figure, optional
  150. Figure to despine all axes of, default uses current figure.
  151. ax : matplotlib axes, optional
  152. Specific axes object to despine.
  153. top, right, left, bottom : boolean, optional
  154. If True, remove that spine.
  155. offset : int or dict, optional
  156. Absolute distance, in points, spines should be moved away
  157. from the axes (negative values move spines inward). A single value
  158. applies to all spines; a dict can be used to set offset values per
  159. side.
  160. trim : bool, optional
  161. If True, limit spines to the smallest and largest major tick
  162. on each non-despined axis.
  163. Returns
  164. -------
  165. None
  166. """
  167. # Get references to the axes we want
  168. if fig is None and ax is None:
  169. axes = plt.gcf().axes
  170. elif fig is not None:
  171. axes = fig.axes
  172. elif ax is not None:
  173. axes = [ax]
  174. for ax_i in axes:
  175. for side in ["top", "right", "left", "bottom"]:
  176. # Toggle the spine objects
  177. is_visible = not locals()[side]
  178. ax_i.spines[side].set_visible(is_visible)
  179. if offset is not None and is_visible:
  180. try:
  181. val = offset.get(side, 0)
  182. except AttributeError:
  183. val = offset
  184. ax_i.spines[side].set_position(('outward', val))
  185. # Potentially move the ticks
  186. if left and not right:
  187. maj_on = any(
  188. t.tick1line.get_visible()
  189. for t in ax_i.yaxis.majorTicks
  190. )
  191. min_on = any(
  192. t.tick1line.get_visible()
  193. for t in ax_i.yaxis.minorTicks
  194. )
  195. ax_i.yaxis.set_ticks_position("right")
  196. for t in ax_i.yaxis.majorTicks:
  197. t.tick2line.set_visible(maj_on)
  198. for t in ax_i.yaxis.minorTicks:
  199. t.tick2line.set_visible(min_on)
  200. if bottom and not top:
  201. maj_on = any(
  202. t.tick1line.get_visible()
  203. for t in ax_i.xaxis.majorTicks
  204. )
  205. min_on = any(
  206. t.tick1line.get_visible()
  207. for t in ax_i.xaxis.minorTicks
  208. )
  209. ax_i.xaxis.set_ticks_position("top")
  210. for t in ax_i.xaxis.majorTicks:
  211. t.tick2line.set_visible(maj_on)
  212. for t in ax_i.xaxis.minorTicks:
  213. t.tick2line.set_visible(min_on)
  214. if trim:
  215. # clip off the parts of the spines that extend past major ticks
  216. xticks = np.asarray(ax_i.get_xticks())
  217. if xticks.size:
  218. firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
  219. xticks)[0]
  220. lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
  221. xticks)[-1]
  222. ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
  223. ax_i.spines['top'].set_bounds(firsttick, lasttick)
  224. newticks = xticks.compress(xticks <= lasttick)
  225. newticks = newticks.compress(newticks >= firsttick)
  226. ax_i.set_xticks(newticks)
  227. yticks = np.asarray(ax_i.get_yticks())
  228. if yticks.size:
  229. firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
  230. yticks)[0]
  231. lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
  232. yticks)[-1]
  233. ax_i.spines['left'].set_bounds(firsttick, lasttick)
  234. ax_i.spines['right'].set_bounds(firsttick, lasttick)
  235. newticks = yticks.compress(yticks <= lasttick)
  236. newticks = newticks.compress(newticks >= firsttick)
  237. ax_i.set_yticks(newticks)
  238. def _kde_support(data, bw, gridsize, cut, clip):
  239. """Establish support for a kernel density estimate."""
  240. support_min = max(data.min() - bw * cut, clip[0])
  241. support_max = min(data.max() + bw * cut, clip[1])
  242. return np.linspace(support_min, support_max, gridsize)
  243. def percentiles(a, pcts, axis=None):
  244. """Like scoreatpercentile but can take and return array of percentiles.
  245. DEPRECATED: will be removed in a future version.
  246. Parameters
  247. ----------
  248. a : array
  249. data
  250. pcts : sequence of percentile values
  251. percentile or percentiles to find score at
  252. axis : int or None
  253. if not None, computes scores over this axis
  254. Returns
  255. -------
  256. scores: array
  257. array of scores at requested percentiles
  258. first dimension is length of object passed to ``pcts``
  259. """
  260. msg = "This function is deprecated and will be removed in a future version"
  261. warnings.warn(msg)
  262. scores = []
  263. try:
  264. n = len(pcts)
  265. except TypeError:
  266. pcts = [pcts]
  267. n = 0
  268. for i, p in enumerate(pcts):
  269. if axis is None:
  270. score = stats.scoreatpercentile(a.ravel(), p)
  271. else:
  272. score = np.apply_along_axis(stats.scoreatpercentile, axis, a, p)
  273. scores.append(score)
  274. scores = np.asarray(scores)
  275. if not n:
  276. scores = scores.squeeze()
  277. return scores
  278. def ci(a, which=95, axis=None):
  279. """Return a percentile range from an array of values."""
  280. p = 50 - which / 2, 50 + which / 2
  281. return np.percentile(a, p, axis)
  282. def sig_stars(p):
  283. """Return a R-style significance string corresponding to p values.
  284. DEPRECATED: will be removed in a future version.
  285. """
  286. msg = "This function is deprecated and will be removed in a future version"
  287. warnings.warn(msg)
  288. if p < 0.001:
  289. return "***"
  290. elif p < 0.01:
  291. return "**"
  292. elif p < 0.05:
  293. return "*"
  294. elif p < 0.1:
  295. return "."
  296. return ""
  297. def iqr(a):
  298. """Calculate the IQR for an array of numbers."""
  299. a = np.asarray(a)
  300. q1 = stats.scoreatpercentile(a, 25)
  301. q3 = stats.scoreatpercentile(a, 75)
  302. return q3 - q1
  303. def get_dataset_names():
  304. """Report available example datasets, useful for reporting issues."""
  305. # delayed import to not demand bs4 unless this function is actually used
  306. from bs4 import BeautifulSoup
  307. http = urlopen('https://github.com/mwaskom/seaborn-data/')
  308. gh_list = BeautifulSoup(http)
  309. return [l.text.replace('.csv', '')
  310. for l in gh_list.find_all("a", {"class": "js-navigation-open"})
  311. if l.text.endswith('.csv')]
  312. def get_data_home(data_home=None):
  313. """Return a path to the cache directory for example datasets.
  314. This directory is then used by :func:`load_dataset`.
  315. If the ``data_home`` argument is not specified, it tries to read from the
  316. ``SEABORN_DATA`` environment variable and defaults to ``~/seaborn-data``.
  317. """
  318. if data_home is None:
  319. data_home = os.environ.get('SEABORN_DATA',
  320. os.path.join('~', 'seaborn-data'))
  321. data_home = os.path.expanduser(data_home)
  322. if not os.path.exists(data_home):
  323. os.makedirs(data_home)
  324. return data_home
  325. def load_dataset(name, cache=True, data_home=None, **kws):
  326. """Load an example dataset from the online repository (requires internet).
  327. This function provides quick access to a small number of example datasets
  328. that are useful for documenting seaborn or generating reproducible examples
  329. for bug reports. It is not necessary for normal usage.
  330. Note that some of the datasets have a small amount of preprocessing applied
  331. to define a proper ordering for categorical variables.
  332. Use :func:`get_dataset_names` to see a list of available datasets.
  333. Parameters
  334. ----------
  335. name : str
  336. Name of the dataset (``{name}.csv`` on
  337. https://github.com/mwaskom/seaborn-data).
  338. cache : boolean, optional
  339. If True, try to load from the local cache first, and save to the cache
  340. if a download is required.
  341. data_home : string, optional
  342. The directory in which to cache data; see :func:`get_data_home`.
  343. kws : keys and values, optional
  344. Additional keyword arguments are passed to passed through to
  345. :func:`pandas.read_csv`.
  346. Returns
  347. -------
  348. df : :class:`pandas.DataFrame`
  349. Tabular data, possibly with some preprocessing applied.
  350. """
  351. path = ("https://raw.githubusercontent.com/"
  352. "mwaskom/seaborn-data/master/{}.csv")
  353. full_path = path.format(name)
  354. if cache:
  355. cache_path = os.path.join(get_data_home(data_home),
  356. os.path.basename(full_path))
  357. if not os.path.exists(cache_path):
  358. urlretrieve(full_path, cache_path)
  359. full_path = cache_path
  360. df = pd.read_csv(full_path, **kws)
  361. if df.iloc[-1].isnull().all():
  362. df = df.iloc[:-1]
  363. # Set some columns as a categorical type with ordered levels
  364. if name == "tips":
  365. df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
  366. df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
  367. df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
  368. df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
  369. if name == "flights":
  370. df["month"] = pd.Categorical(df["month"], df.month.unique())
  371. if name == "exercise":
  372. df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
  373. df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
  374. df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
  375. if name == "titanic":
  376. df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
  377. df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
  378. return df
  379. def axis_ticklabels_overlap(labels):
  380. """Return a boolean for whether the list of ticklabels have overlaps.
  381. Parameters
  382. ----------
  383. labels : list of matplotlib ticklabels
  384. Returns
  385. -------
  386. overlap : boolean
  387. True if any of the labels overlap.
  388. """
  389. if not labels:
  390. return False
  391. try:
  392. bboxes = [l.get_window_extent() for l in labels]
  393. overlaps = [b.count_overlaps(bboxes) for b in bboxes]
  394. return max(overlaps) > 1
  395. except RuntimeError:
  396. # Issue on macos backend raises an error in the above code
  397. return False
  398. def axes_ticklabels_overlap(ax):
  399. """Return booleans for whether the x and y ticklabels on an Axes overlap.
  400. Parameters
  401. ----------
  402. ax : matplotlib Axes
  403. Returns
  404. -------
  405. x_overlap, y_overlap : booleans
  406. True when the labels on that axis overlap.
  407. """
  408. return (axis_ticklabels_overlap(ax.get_xticklabels()),
  409. axis_ticklabels_overlap(ax.get_yticklabels()))
  410. def categorical_order(values, order=None):
  411. """Return a list of unique data values.
  412. Determine an ordered list of levels in ``values``.
  413. Parameters
  414. ----------
  415. values : list, array, Categorical, or Series
  416. Vector of "categorical" values
  417. order : list-like, optional
  418. Desired order of category levels to override the order determined
  419. from the ``values`` object.
  420. Returns
  421. -------
  422. order : list
  423. Ordered list of category levels not including null values.
  424. """
  425. if order is None:
  426. if hasattr(values, "categories"):
  427. order = values.categories
  428. else:
  429. try:
  430. order = values.cat.categories
  431. except (TypeError, AttributeError):
  432. try:
  433. order = values.unique()
  434. except AttributeError:
  435. order = pd.unique(values)
  436. try:
  437. np.asarray(values).astype(np.float)
  438. order = np.sort(order)
  439. except (ValueError, TypeError):
  440. order = order
  441. order = filter(pd.notnull, order)
  442. return list(order)
  443. def locator_to_legend_entries(locator, limits, dtype):
  444. """Return levels and formatted levels for brief numeric legends."""
  445. raw_levels = locator.tick_values(*limits).astype(dtype)
  446. class dummy_axis:
  447. def get_view_interval(self):
  448. return limits
  449. if isinstance(locator, mpl.ticker.LogLocator):
  450. formatter = mpl.ticker.LogFormatter()
  451. else:
  452. formatter = mpl.ticker.ScalarFormatter()
  453. formatter.axis = dummy_axis()
  454. # TODO: The following two lines should be replaced
  455. # once pinned matplotlib>=3.1.0 with:
  456. # formatted_levels = formatter.format_ticks(raw_levels)
  457. formatter.set_locs(raw_levels)
  458. formatted_levels = [formatter(x) for x in raw_levels]
  459. return raw_levels, formatted_levels
  460. def get_color_cycle():
  461. """Return the list of colors in the current matplotlib color cycle
  462. Parameters
  463. ----------
  464. None
  465. Returns
  466. -------
  467. colors : list
  468. List of matplotlib colors in the current cycle, or dark gray if
  469. the current color cycle is empty.
  470. """
  471. cycler = mpl.rcParams['axes.prop_cycle']
  472. return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
  473. def relative_luminance(color):
  474. """Calculate the relative luminance of a color according to W3C standards
  475. Parameters
  476. ----------
  477. color : matplotlib color or sequence of matplotlib colors
  478. Hex code, rgb-tuple, or html color name.
  479. Returns
  480. -------
  481. luminance : float(s) between 0 and 1
  482. """
  483. rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
  484. rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
  485. lum = rgb.dot([.2126, .7152, .0722])
  486. try:
  487. return lum.item()
  488. except ValueError:
  489. return lum
  490. def to_utf8(obj):
  491. """Return a string representing a Python object.
  492. Strings (i.e. type ``str``) are returned unchanged.
  493. Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
  494. For other objects, the method ``__str__()`` is called, and the result is
  495. returned as a string.
  496. Parameters
  497. ----------
  498. obj : object
  499. Any Python object
  500. Returns
  501. -------
  502. s : str
  503. UTF-8-decoded string representation of ``obj``
  504. """
  505. if isinstance(obj, str):
  506. return obj
  507. try:
  508. return obj.decode(encoding="utf-8")
  509. except AttributeError: # obj is not bytes-like
  510. return str(obj)
  511. def _network(t=None, url='https://google.com'):
  512. """
  513. Decorator that will skip a test if `url` is unreachable.
  514. Parameters
  515. ----------
  516. t : function, optional
  517. url : str, optional
  518. """
  519. import nose
  520. if t is None:
  521. return lambda x: _network(x, url=url)
  522. def wrapper(*args, **kwargs):
  523. # attempt to connect
  524. try:
  525. f = urlopen(url)
  526. except (IOError, HTTPException):
  527. raise nose.SkipTest()
  528. else:
  529. f.close()
  530. return t(*args, **kwargs)
  531. return wrapper