__init__.py 69 KB


  1. """
  2. A collection of utility functions and classes. Originally, many
  3. (but not all) were from the Python Cookbook -- hence the name cbook.
  4. This module is safe to import from anywhere within matplotlib;
  5. it imports matplotlib only at runtime.
  6. """
  7. import collections
  8. import collections.abc
  9. import contextlib
  10. import functools
  11. import glob
  12. import gzip
  13. import itertools
  14. import locale
  15. import numbers
  16. import operator
  17. import os
  18. from pathlib import Path
  19. import re
  20. import shlex
  21. import subprocess
  22. import sys
  23. import time
  24. import traceback
  25. import types
  26. import warnings
  27. import weakref
  28. from weakref import WeakMethod
  29. import numpy as np
  30. import matplotlib
  31. from .deprecation import (
  32. deprecated, warn_deprecated,
  33. _rename_parameter, _delete_parameter, _make_keyword_only,
  34. _suppress_matplotlib_deprecation_warning,
  35. MatplotlibDeprecationWarning, mplDeprecation)
  36. def _exception_printer(exc):
  37. traceback.print_exc()
  38. class _StrongRef:
  39. """
  40. Wrapper similar to a weakref, but keeping a strong reference to the object.
  41. """
  42. def __init__(self, obj):
  43. self._obj = obj
  44. def __call__(self):
  45. return self._obj
  46. def __eq__(self, other):
  47. return isinstance(other, _StrongRef) and self._obj == other._obj
  48. def __hash__(self):
  49. return hash(self._obj)
  50. class CallbackRegistry:
  51. """Handle registering and disconnecting for a set of signals and callbacks:
  52. >>> def oneat(x):
  53. ... print('eat', x)
  54. >>> def ondrink(x):
  55. ... print('drink', x)
  56. >>> from matplotlib.cbook import CallbackRegistry
  57. >>> callbacks = CallbackRegistry()
  58. >>> id_eat = callbacks.connect('eat', oneat)
  59. >>> id_drink = callbacks.connect('drink', ondrink)
  60. >>> callbacks.process('drink', 123)
  61. drink 123
  62. >>> callbacks.process('eat', 456)
  63. eat 456
  64. >>> callbacks.process('be merry', 456) # nothing will be called
  65. >>> callbacks.disconnect(id_eat)
  66. >>> callbacks.process('eat', 456) # nothing will be called
  67. In practice, one should always disconnect all callbacks when they are
  68. no longer needed to avoid dangling references (and thus memory leaks).
  69. However, real code in Matplotlib rarely does so, and due to its design,
  70. it is rather difficult to place this kind of code. To get around this,
  71. and prevent this class of memory leaks, we instead store weak references
  72. to bound methods only, so when the destination object needs to die, the
  73. CallbackRegistry won't keep it alive.
  74. Parameters
  75. ----------
  76. exception_handler : callable, optional
  77. If provided must have signature ::
  78. def handler(exc: Exception) -> None:
  79. If not None this function will be called with any `Exception`
  80. subclass raised by the callbacks in `CallbackRegistry.process`.
  81. The handler may either consume the exception or re-raise.
  82. The callable must be pickle-able.
  83. The default handler is ::
  84. def h(exc):
  85. traceback.print_exc()
  86. """
  87. # We maintain two mappings:
  88. # callbacks: signal -> {cid -> callback}
  89. # _func_cid_map: signal -> {callback -> cid}
  90. # (actually, callbacks are weakrefs to the actual callbacks).
  91. def __init__(self, exception_handler=_exception_printer):
  92. self.exception_handler = exception_handler
  93. self.callbacks = {}
  94. self._cid_gen = itertools.count()
  95. self._func_cid_map = {}
  96. # In general, callbacks may not be pickled; thus, we simply recreate an
  97. # empty dictionary at unpickling. In order to ensure that `__setstate__`
  98. # (which just defers to `__init__`) is called, `__getstate__` must
  99. # return a truthy value (for pickle protocol>=3, i.e. Py3, the
  100. # *actual* behavior is that `__setstate__` will be called as long as
  101. # `__getstate__` does not return `None`, but this is undocumented -- see
  102. # http://bugs.python.org/issue12290).
  103. def __getstate__(self):
  104. return {'exception_handler': self.exception_handler}
  105. def __setstate__(self, state):
  106. self.__init__(**state)
  107. def connect(self, s, func):
  108. """Register *func* to be called when signal *s* is generated.
  109. """
  110. self._func_cid_map.setdefault(s, {})
  111. try:
  112. proxy = WeakMethod(func, self._remove_proxy)
  113. except TypeError:
  114. proxy = _StrongRef(func)
  115. if proxy in self._func_cid_map[s]:
  116. return self._func_cid_map[s][proxy]
  117. cid = next(self._cid_gen)
  118. self._func_cid_map[s][proxy] = cid
  119. self.callbacks.setdefault(s, {})
  120. self.callbacks[s][cid] = proxy
  121. return cid
  122. # Keep a reference to sys.is_finalizing, as sys may have been cleared out
  123. # at that point.
  124. def _remove_proxy(self, proxy, *, _is_finalizing=sys.is_finalizing):
  125. if _is_finalizing():
  126. # Weakrefs can't be properly torn down at that point anymore.
  127. return
  128. for signal, proxies in list(self._func_cid_map.items()):
  129. try:
  130. del self.callbacks[signal][proxies[proxy]]
  131. except KeyError:
  132. pass
  133. if len(self.callbacks[signal]) == 0:
  134. del self.callbacks[signal]
  135. del self._func_cid_map[signal]
  136. def disconnect(self, cid):
  137. """Disconnect the callback registered with callback id *cid*.
  138. """
  139. for eventname, callbackd in list(self.callbacks.items()):
  140. try:
  141. del callbackd[cid]
  142. except KeyError:
  143. continue
  144. else:
  145. for signal, functions in list(self._func_cid_map.items()):
  146. for function, value in list(functions.items()):
  147. if value == cid:
  148. del functions[function]
  149. return
  150. def process(self, s, *args, **kwargs):
  151. """
  152. Process signal *s*.
  153. All of the functions registered to receive callbacks on *s* will be
  154. called with ``*args`` and ``**kwargs``.
  155. """
  156. for cid, ref in list(self.callbacks.get(s, {}).items()):
  157. func = ref()
  158. if func is not None:
  159. try:
  160. func(*args, **kwargs)
  161. # this does not capture KeyboardInterrupt, SystemExit,
  162. # and GeneratorExit
  163. except Exception as exc:
  164. if self.exception_handler is not None:
  165. self.exception_handler(exc)
  166. else:
  167. raise
  168. class silent_list(list):
  169. """
  170. A list with a short ``repr()``.
  171. This is meant to be used for a homogeneous list of artists, so that they
  172. don't cause long, meaningless output.
  173. Instead of ::
  174. [<matplotlib.lines.Line2D object at 0x7f5749fed3c8>,
  175. <matplotlib.lines.Line2D object at 0x7f5749fed4e0>,
  176. <matplotlib.lines.Line2D object at 0x7f5758016550>]
  177. one will get ::
  178. <a list of 3 Line2D objects>
  179. """
  180. def __init__(self, type, seq=None):
  181. self.type = type
  182. if seq is not None:
  183. self.extend(seq)
  184. def __repr__(self):
  185. return '<a list of %d %s objects>' % (len(self), self.type)
  186. __str__ = __repr__
  187. def __getstate__(self):
  188. # store a dictionary of this SilentList's state
  189. return {'type': self.type, 'seq': self[:]}
  190. def __setstate__(self, state):
  191. self.type = state['type']
  192. self.extend(state['seq'])
  193. class IgnoredKeywordWarning(UserWarning):
  194. """
  195. A class for issuing warnings about keyword arguments that will be ignored
  196. by Matplotlib.
  197. """
  198. pass
  199. def local_over_kwdict(local_var, kwargs, *keys):
  200. """
  201. Enforces the priority of a local variable over potentially conflicting
  202. argument(s) from a kwargs dict. The following possible output values are
  203. considered in order of priority::
  204. local_var > kwargs[keys[0]] > ... > kwargs[keys[-1]]
  205. The first of these whose value is not None will be returned. If all are
  206. None then None will be returned. Each key in keys will be removed from the
  207. kwargs dict in place.
  208. Parameters
  209. ----------
  210. local_var : any object
  211. The local variable (highest priority).
  212. kwargs : dict
  213. Dictionary of keyword arguments; modified in place.
  214. keys : str(s)
  215. Name(s) of keyword arguments to process, in descending order of
  216. priority.
  217. Returns
  218. -------
  219. out : any object
  220. Either local_var or one of kwargs[key] for key in keys.
  221. Raises
  222. ------
  223. IgnoredKeywordWarning
  224. For each key in keys that is removed from kwargs but not used as
  225. the output value.
  226. """
  227. out = local_var
  228. for key in keys:
  229. kwarg_val = kwargs.pop(key, None)
  230. if kwarg_val is not None:
  231. if out is None:
  232. out = kwarg_val
  233. else:
  234. _warn_external('"%s" keyword argument will be ignored' % key,
  235. IgnoredKeywordWarning)
  236. return out
  237. def strip_math(s):
  238. """
  239. Remove latex formatting from mathtext.
  240. Only handles fully math and fully non-math strings.
  241. """
  242. if len(s) >= 2 and s[0] == s[-1] == "$":
  243. s = s[1:-1]
  244. for tex, plain in [
  245. (r"\times", "x"), # Specifically for Formatter support.
  246. (r"\mathdefault", ""),
  247. (r"\rm", ""),
  248. (r"\cal", ""),
  249. (r"\tt", ""),
  250. (r"\it", ""),
  251. ("\\", ""),
  252. ("{", ""),
  253. ("}", ""),
  254. ]:
  255. s = s.replace(tex, plain)
  256. return s
  257. @deprecated('3.1', alternative='np.iterable')
  258. def iterable(obj):
  259. """return true if *obj* is iterable"""
  260. try:
  261. iter(obj)
  262. except TypeError:
  263. return False
  264. return True
  265. @deprecated("3.1", alternative="isinstance(..., collections.abc.Hashable)")
  266. def is_hashable(obj):
  267. """Returns true if *obj* can be hashed"""
  268. try:
  269. hash(obj)
  270. except TypeError:
  271. return False
  272. return True
  273. def is_writable_file_like(obj):
  274. """Return whether *obj* looks like a file object with a *write* method."""
  275. return callable(getattr(obj, 'write', None))
  276. def file_requires_unicode(x):
  277. """
  278. Return whether the given writable file-like object requires Unicode to be
  279. written to it.
  280. """
  281. try:
  282. x.write(b'')
  283. except TypeError:
  284. return True
  285. else:
  286. return False
  287. def to_filehandle(fname, flag='r', return_opened=False, encoding=None):
  288. """
  289. Convert a path to an open file handle or pass-through a file-like object.
  290. Consider using `open_file_cm` instead, as it allows one to properly close
  291. newly created file objects more easily.
  292. Parameters
  293. ----------
  294. fname : str or path-like or file-like object
  295. If `str` or `os.PathLike`, the file is opened using the flags specified
  296. by *flag* and *encoding*. If a file-like object, it is passed through.
  297. flag : str, default 'r'
  298. Passed as the *mode* argument to `open` when *fname* is `str` or
  299. `os.PathLike`; ignored if *fname* is file-like.
  300. return_opened : bool, default False
  301. If True, return both the file object and a boolean indicating whether
  302. this was a new file (that the caller needs to close). If False, return
  303. only the new file.
  304. encoding : str or None, default None
  305. Passed as the *mode* argument to `open` when *fname* is `str` or
  306. `os.PathLike`; ignored if *fname* is file-like.
  307. Returns
  308. -------
  309. fh : file-like
  310. opened : bool
  311. *opened* is only returned if *return_opened* is True.
  312. """
  313. if isinstance(fname, os.PathLike):
  314. fname = os.fspath(fname)
  315. if isinstance(fname, str):
  316. if fname.endswith('.gz'):
  317. # get rid of 'U' in flag for gzipped files.
  318. flag = flag.replace('U', '')
  319. fh = gzip.open(fname, flag)
  320. elif fname.endswith('.bz2'):
  321. # python may not be complied with bz2 support,
  322. # bury import until we need it
  323. import bz2
  324. # get rid of 'U' in flag for bz2 files
  325. flag = flag.replace('U', '')
  326. fh = bz2.BZ2File(fname, flag)
  327. else:
  328. fh = open(fname, flag, encoding=encoding)
  329. opened = True
  330. elif hasattr(fname, 'seek'):
  331. fh = fname
  332. opened = False
  333. else:
  334. raise ValueError('fname must be a PathLike or file handle')
  335. if return_opened:
  336. return fh, opened
  337. return fh
  338. @contextlib.contextmanager
  339. def open_file_cm(path_or_file, mode="r", encoding=None):
  340. r"""Pass through file objects and context-manage `.PathLike`\s."""
  341. fh, opened = to_filehandle(path_or_file, mode, True, encoding)
  342. if opened:
  343. with fh:
  344. yield fh
  345. else:
  346. yield fh
  347. def is_scalar_or_string(val):
  348. """Return whether the given object is a scalar or string like."""
  349. return isinstance(val, str) or not np.iterable(val)
  350. def get_sample_data(fname, asfileobj=True):
  351. """
  352. Return a sample data file. *fname* is a path relative to the
  353. `mpl-data/sample_data` directory. If *asfileobj* is `True`
  354. return a file object, otherwise just a file path.
  355. Sample data files are stored in the 'mpl-data/sample_data' directory within
  356. the Matplotlib package.
  357. If the filename ends in .gz, the file is implicitly ungzipped.
  358. """
  359. path = Path(matplotlib.get_data_path(), 'sample_data', fname)
  360. if asfileobj:
  361. suffix = path.suffix.lower()
  362. if suffix == '.gz':
  363. return gzip.open(path)
  364. elif suffix in ['.csv', '.xrc', '.txt']:
  365. return path.open('r')
  366. else:
  367. return path.open('rb')
  368. else:
  369. return str(path)
  370. def _get_data_path(*args):
  371. """
  372. Return the `Path` to a resource file provided by Matplotlib.
  373. ``*args`` specify a path relative to the base data path.
  374. """
  375. return Path(matplotlib.get_data_path(), *args)
  376. def flatten(seq, scalarp=is_scalar_or_string):
  377. """
  378. Return a generator of flattened nested containers.
  379. For example:
  380. >>> from matplotlib.cbook import flatten
  381. >>> l = (('John', ['Hunter']), (1, 23), [[([42, (5, 23)], )]])
  382. >>> print(list(flatten(l)))
  383. ['John', 'Hunter', 1, 23, 42, 5, 23]
  384. By: Composite of Holger Krekel and Luther Blissett
  385. From: https://code.activestate.com/recipes/121294/
  386. and Recipe 1.12 in cookbook
  387. """
  388. for item in seq:
  389. if scalarp(item) or item is None:
  390. yield item
  391. else:
  392. yield from flatten(item, scalarp)
  393. @functools.lru_cache()
  394. def get_realpath_and_stat(path):
  395. realpath = os.path.realpath(path)
  396. stat = os.stat(realpath)
  397. stat_key = (stat.st_ino, stat.st_dev)
  398. return realpath, stat_key
  399. # A regular expression used to determine the amount of space to
  400. # remove. It looks for the first sequence of spaces immediately
  401. # following the first newline, or at the beginning of the string.
  402. _find_dedent_regex = re.compile(r"(?:(?:\n\r?)|^)( *)\S")
  403. # A cache to hold the regexs that actually remove the indent.
  404. _dedent_regex = {}
  405. @deprecated("3.1", alternative="inspect.cleandoc")
  406. def dedent(s):
  407. """
  408. Remove excess indentation from docstring *s*.
  409. Discards any leading blank lines, then removes up to n whitespace
  410. characters from each line, where n is the number of leading
  411. whitespace characters in the first line. It differs from
  412. textwrap.dedent in its deletion of leading blank lines and its use
  413. of the first non-blank line to determine the indentation.
  414. It is also faster in most cases.
  415. """
  416. # This implementation has a somewhat obtuse use of regular
  417. # expressions. However, this function accounted for almost 30% of
  418. # matplotlib startup time, so it is worthy of optimization at all
  419. # costs.
  420. if not s: # includes case of s is None
  421. return ''
  422. match = _find_dedent_regex.match(s)
  423. if match is None:
  424. return s
  425. # This is the number of spaces to remove from the left-hand side.
  426. nshift = match.end(1) - match.start(1)
  427. if nshift == 0:
  428. return s
  429. # Get a regex that will remove *up to* nshift spaces from the
  430. # beginning of each line. If it isn't in the cache, generate it.
  431. unindent = _dedent_regex.get(nshift, None)
  432. if unindent is None:
  433. unindent = re.compile("\n\r? {0,%d}" % nshift)
  434. _dedent_regex[nshift] = unindent
  435. result = unindent.sub("\n", s).strip()
  436. return result
  437. class maxdict(dict):
  438. """
  439. A dictionary with a maximum size.
  440. Notes
  441. -----
  442. This doesn't override all the relevant methods to constrain the size,
  443. just ``__setitem__``, so use with caution.
  444. """
  445. def __init__(self, maxsize):
  446. dict.__init__(self)
  447. self.maxsize = maxsize
  448. self._killkeys = []
  449. def __setitem__(self, k, v):
  450. if k not in self:
  451. if len(self) >= self.maxsize:
  452. del self[self._killkeys[0]]
  453. del self._killkeys[0]
  454. self._killkeys.append(k)
  455. dict.__setitem__(self, k, v)
  456. class Stack:
  457. """
  458. Stack of elements with a movable cursor.
  459. Mimics home/back/forward in a web browser.
  460. """
  461. def __init__(self, default=None):
  462. self.clear()
  463. self._default = default
  464. def __call__(self):
  465. """Return the current element, or None."""
  466. if not len(self._elements):
  467. return self._default
  468. else:
  469. return self._elements[self._pos]
  470. def __len__(self):
  471. return len(self._elements)
  472. def __getitem__(self, ind):
  473. return self._elements[ind]
  474. def forward(self):
  475. """Move the position forward and return the current element."""
  476. self._pos = min(self._pos + 1, len(self._elements) - 1)
  477. return self()
  478. def back(self):
  479. """Move the position back and return the current element."""
  480. if self._pos > 0:
  481. self._pos -= 1
  482. return self()
  483. def push(self, o):
  484. """
  485. Push *o* to the stack at current position. Discard all later elements.
  486. *o* is returned.
  487. """
  488. self._elements = self._elements[:self._pos + 1] + [o]
  489. self._pos = len(self._elements) - 1
  490. return self()
  491. def home(self):
  492. """
  493. Push the first element onto the top of the stack.
  494. The first element is returned.
  495. """
  496. if not len(self._elements):
  497. return
  498. self.push(self._elements[0])
  499. return self()
  500. def empty(self):
  501. """Return whether the stack is empty."""
  502. return len(self._elements) == 0
  503. def clear(self):
  504. """Empty the stack."""
  505. self._pos = -1
  506. self._elements = []
  507. def bubble(self, o):
  508. """
  509. Raise *o* to the top of the stack. *o* must be present in the stack.
  510. *o* is returned.
  511. """
  512. if o not in self._elements:
  513. raise ValueError('Unknown element o')
  514. old = self._elements[:]
  515. self.clear()
  516. bubbles = []
  517. for thiso in old:
  518. if thiso == o:
  519. bubbles.append(thiso)
  520. else:
  521. self.push(thiso)
  522. for _ in bubbles:
  523. self.push(o)
  524. return o
  525. def remove(self, o):
  526. """Remove *o* from the stack."""
  527. if o not in self._elements:
  528. raise ValueError('Unknown element o')
  529. old = self._elements[:]
  530. self.clear()
  531. for thiso in old:
  532. if thiso != o:
  533. self.push(thiso)
  534. def report_memory(i=0): # argument may go away
  535. """Return the memory consumed by the process."""
  536. def call(command, os_name):
  537. try:
  538. return subprocess.check_output(command)
  539. except subprocess.CalledProcessError:
  540. raise NotImplementedError(
  541. "report_memory works on %s only if "
  542. "the '%s' program is found" % (os_name, command[0])
  543. )
  544. pid = os.getpid()
  545. if sys.platform == 'sunos5':
  546. lines = call(['ps', '-p', '%d' % pid, '-o', 'osz'], 'Sun OS')
  547. mem = int(lines[-1].strip())
  548. elif sys.platform == 'linux':
  549. lines = call(['ps', '-p', '%d' % pid, '-o', 'rss,sz'], 'Linux')
  550. mem = int(lines[1].split()[1])
  551. elif sys.platform == 'darwin':
  552. lines = call(['ps', '-p', '%d' % pid, '-o', 'rss,vsz'], 'Mac OS')
  553. mem = int(lines[1].split()[0])
  554. elif sys.platform == 'win32':
  555. lines = call(["tasklist", "/nh", "/fi", "pid eq %d" % pid], 'Windows')
  556. mem = int(lines.strip().split()[-2].replace(',', ''))
  557. else:
  558. raise NotImplementedError(
  559. "We don't have a memory monitor for %s" % sys.platform)
  560. return mem
  561. _safezip_msg = 'In safezip, len(args[0])=%d but len(args[%d])=%d'
  562. @deprecated("3.1")
  563. def safezip(*args):
  564. """make sure *args* are equal len before zipping"""
  565. Nx = len(args[0])
  566. for i, arg in enumerate(args[1:]):
  567. if len(arg) != Nx:
  568. raise ValueError(_safezip_msg % (Nx, i + 1, len(arg)))
  569. return list(zip(*args))
  570. def safe_masked_invalid(x, copy=False):
  571. x = np.array(x, subok=True, copy=copy)
  572. if not x.dtype.isnative:
  573. # Note that the argument to `byteswap` is 'inplace',
  574. # thus if we have already made a copy, do the byteswap in
  575. # place, else make a copy with the byte order swapped.
  576. # Be explicit that we are swapping the byte order of the dtype
  577. x = x.byteswap(copy).newbyteorder('S')
  578. try:
  579. xm = np.ma.masked_invalid(x, copy=False)
  580. xm.shrink_mask()
  581. except TypeError:
  582. return x
  583. return xm
  584. def print_cycles(objects, outstream=sys.stdout, show_progress=False):
  585. """
  586. Print loops of cyclic references in the given *objects*.
  587. It is often useful to pass in ``gc.garbage`` to find the cycles that are
  588. preventing some objects from being garbage collected.
  589. Parameters
  590. ----------
  591. objects
  592. A list of objects to find cycles in.
  593. outstream
  594. The stream for output.
  595. show_progress : bool
  596. If True, print the number of objects reached as they are found.
  597. """
  598. import gc
  599. def print_path(path):
  600. for i, step in enumerate(path):
  601. # next "wraps around"
  602. next = path[(i + 1) % len(path)]
  603. outstream.write(" %s -- " % type(step))
  604. if isinstance(step, dict):
  605. for key, val in step.items():
  606. if val is next:
  607. outstream.write("[{!r}]".format(key))
  608. break
  609. if key is next:
  610. outstream.write("[key] = {!r}".format(val))
  611. break
  612. elif isinstance(step, list):
  613. outstream.write("[%d]" % step.index(next))
  614. elif isinstance(step, tuple):
  615. outstream.write("( tuple )")
  616. else:
  617. outstream.write(repr(step))
  618. outstream.write(" ->\n")
  619. outstream.write("\n")
  620. def recurse(obj, start, all, current_path):
  621. if show_progress:
  622. outstream.write("%d\r" % len(all))
  623. all[id(obj)] = None
  624. referents = gc.get_referents(obj)
  625. for referent in referents:
  626. # If we've found our way back to the start, this is
  627. # a cycle, so print it out
  628. if referent is start:
  629. print_path(current_path)
  630. # Don't go back through the original list of objects, or
  631. # through temporary references to the object, since those
  632. # are just an artifact of the cycle detector itself.
  633. elif referent is objects or isinstance(referent, types.FrameType):
  634. continue
  635. # We haven't seen this object before, so recurse
  636. elif id(referent) not in all:
  637. recurse(referent, start, all, current_path + [obj])
  638. for obj in objects:
  639. outstream.write(f"Examining: {obj!r}\n")
  640. recurse(obj, obj, {}, [])
  641. class Grouper:
  642. """
  643. This class provides a lightweight way to group arbitrary objects
  644. together into disjoint sets when a full-blown graph data structure
  645. would be overkill.
  646. Objects can be joined using :meth:`join`, tested for connectedness
  647. using :meth:`joined`, and all disjoint sets can be retrieved by
  648. using the object as an iterator.
  649. The objects being joined must be hashable and weak-referenceable.
  650. For example:
  651. >>> from matplotlib.cbook import Grouper
  652. >>> class Foo:
  653. ... def __init__(self, s):
  654. ... self.s = s
  655. ... def __repr__(self):
  656. ... return self.s
  657. ...
  658. >>> a, b, c, d, e, f = [Foo(x) for x in 'abcdef']
  659. >>> grp = Grouper()
  660. >>> grp.join(a, b)
  661. >>> grp.join(b, c)
  662. >>> grp.join(d, e)
  663. >>> sorted(map(tuple, grp))
  664. [(a, b, c), (d, e)]
  665. >>> grp.joined(a, b)
  666. True
  667. >>> grp.joined(a, c)
  668. True
  669. >>> grp.joined(a, d)
  670. False
  671. """
  672. def __init__(self, init=()):
  673. self._mapping = {weakref.ref(x): [weakref.ref(x)] for x in init}
  674. def __contains__(self, item):
  675. return weakref.ref(item) in self._mapping
  676. def clean(self):
  677. """Clean dead weak references from the dictionary."""
  678. mapping = self._mapping
  679. to_drop = [key for key in mapping if key() is None]
  680. for key in to_drop:
  681. val = mapping.pop(key)
  682. val.remove(key)
  683. def join(self, a, *args):
  684. """
  685. Join given arguments into the same set. Accepts one or more arguments.
  686. """
  687. mapping = self._mapping
  688. set_a = mapping.setdefault(weakref.ref(a), [weakref.ref(a)])
  689. for arg in args:
  690. set_b = mapping.get(weakref.ref(arg), [weakref.ref(arg)])
  691. if set_b is not set_a:
  692. if len(set_b) > len(set_a):
  693. set_a, set_b = set_b, set_a
  694. set_a.extend(set_b)
  695. for elem in set_b:
  696. mapping[elem] = set_a
  697. self.clean()
  698. def joined(self, a, b):
  699. """Return whether *a* and *b* are members of the same set."""
  700. self.clean()
  701. return (self._mapping.get(weakref.ref(a), object())
  702. is self._mapping.get(weakref.ref(b)))
  703. def remove(self, a):
  704. self.clean()
  705. set_a = self._mapping.pop(weakref.ref(a), None)
  706. if set_a:
  707. set_a.remove(weakref.ref(a))
  708. def __iter__(self):
  709. """
  710. Iterate over each of the disjoint sets as a list.
  711. The iterator is invalid if interleaved with calls to join().
  712. """
  713. self.clean()
  714. unique_groups = {id(group): group for group in self._mapping.values()}
  715. for group in unique_groups.values():
  716. yield [x() for x in group]
  717. def get_siblings(self, a):
  718. """Return all of the items joined with *a*, including itself."""
  719. self.clean()
  720. siblings = self._mapping.get(weakref.ref(a), [weakref.ref(a)])
  721. return [x() for x in siblings]
  722. def simple_linear_interpolation(a, steps):
  723. """
  724. Resample an array with ``steps - 1`` points between original point pairs.
  725. Along each column of *a*, ``(steps - 1)`` points are introduced between
  726. each original values; the values are linearly interpolated.
  727. Parameters
  728. ----------
  729. a : array, shape (n, ...)
  730. steps : int
  731. Returns
  732. -------
  733. array
  734. shape ``((n - 1) * steps + 1, ...)``
  735. """
  736. fps = a.reshape((len(a), -1))
  737. xp = np.arange(len(a)) * steps
  738. x = np.arange((len(a) - 1) * steps + 1)
  739. return (np.column_stack([np.interp(x, xp, fp) for fp in fps.T])
  740. .reshape((len(x),) + a.shape[1:]))
  741. def delete_masked_points(*args):
  742. """
  743. Find all masked and/or non-finite points in a set of arguments,
  744. and return the arguments with only the unmasked points remaining.
  745. Arguments can be in any of 5 categories:
  746. 1) 1-D masked arrays
  747. 2) 1-D ndarrays
  748. 3) ndarrays with more than one dimension
  749. 4) other non-string iterables
  750. 5) anything else
  751. The first argument must be in one of the first four categories;
  752. any argument with a length differing from that of the first
  753. argument (and hence anything in category 5) then will be
  754. passed through unchanged.
  755. Masks are obtained from all arguments of the correct length
  756. in categories 1, 2, and 4; a point is bad if masked in a masked
  757. array or if it is a nan or inf. No attempt is made to
  758. extract a mask from categories 2, 3, and 4 if :meth:`np.isfinite`
  759. does not yield a Boolean array.
  760. All input arguments that are not passed unchanged are returned
  761. as ndarrays after removing the points or rows corresponding to
  762. masks in any of the arguments.
  763. A vastly simpler version of this function was originally
  764. written as a helper for Axes.scatter().
  765. """
  766. if not len(args):
  767. return ()
  768. if is_scalar_or_string(args[0]):
  769. raise ValueError("First argument must be a sequence")
  770. nrecs = len(args[0])
  771. margs = []
  772. seqlist = [False] * len(args)
  773. for i, x in enumerate(args):
  774. if not isinstance(x, str) and np.iterable(x) and len(x) == nrecs:
  775. seqlist[i] = True
  776. if isinstance(x, np.ma.MaskedArray):
  777. if x.ndim > 1:
  778. raise ValueError("Masked arrays must be 1-D")
  779. else:
  780. x = np.asarray(x)
  781. margs.append(x)
  782. masks = [] # list of masks that are True where good
  783. for i, x in enumerate(margs):
  784. if seqlist[i]:
  785. if x.ndim > 1:
  786. continue # Don't try to get nan locations unless 1-D.
  787. if isinstance(x, np.ma.MaskedArray):
  788. masks.append(~np.ma.getmaskarray(x)) # invert the mask
  789. xd = x.data
  790. else:
  791. xd = x
  792. try:
  793. mask = np.isfinite(xd)
  794. if isinstance(mask, np.ndarray):
  795. masks.append(mask)
  796. except Exception: # Fixme: put in tuple of possible exceptions?
  797. pass
  798. if len(masks):
  799. mask = np.logical_and.reduce(masks)
  800. igood = mask.nonzero()[0]
  801. if len(igood) < nrecs:
  802. for i, x in enumerate(margs):
  803. if seqlist[i]:
  804. margs[i] = x[igood]
  805. for i, x in enumerate(margs):
  806. if seqlist[i] and isinstance(x, np.ma.MaskedArray):
  807. margs[i] = x.filled()
  808. return margs
  809. def _combine_masks(*args):
  810. """
  811. Find all masked and/or non-finite points in a set of arguments,
  812. and return the arguments as masked arrays with a common mask.
  813. Arguments can be in any of 5 categories:
  814. 1) 1-D masked arrays
  815. 2) 1-D ndarrays
  816. 3) ndarrays with more than one dimension
  817. 4) other non-string iterables
  818. 5) anything else
  819. The first argument must be in one of the first four categories;
  820. any argument with a length differing from that of the first
  821. argument (and hence anything in category 5) then will be
  822. passed through unchanged.
  823. Masks are obtained from all arguments of the correct length
  824. in categories 1, 2, and 4; a point is bad if masked in a masked
  825. array or if it is a nan or inf. No attempt is made to
  826. extract a mask from categories 2 and 4 if :meth:`np.isfinite`
  827. does not yield a Boolean array. Category 3 is included to
  828. support RGB or RGBA ndarrays, which are assumed to have only
  829. valid values and which are passed through unchanged.
  830. All input arguments that are not passed unchanged are returned
  831. as masked arrays if any masked points are found, otherwise as
  832. ndarrays.
  833. """
  834. if not len(args):
  835. return ()
  836. if is_scalar_or_string(args[0]):
  837. raise ValueError("First argument must be a sequence")
  838. nrecs = len(args[0])
  839. margs = [] # Output args; some may be modified.
  840. seqlist = [False] * len(args) # Flags: True if output will be masked.
  841. masks = [] # List of masks.
  842. for i, x in enumerate(args):
  843. if is_scalar_or_string(x) or len(x) != nrecs:
  844. margs.append(x) # Leave it unmodified.
  845. else:
  846. if isinstance(x, np.ma.MaskedArray) and x.ndim > 1:
  847. raise ValueError("Masked arrays must be 1-D")
  848. x = np.asanyarray(x)
  849. if x.ndim == 1:
  850. x = safe_masked_invalid(x)
  851. seqlist[i] = True
  852. if np.ma.is_masked(x):
  853. masks.append(np.ma.getmaskarray(x))
  854. margs.append(x) # Possibly modified.
  855. if len(masks):
  856. mask = np.logical_or.reduce(masks)
  857. for i, x in enumerate(margs):
  858. if seqlist[i]:
  859. margs[i] = np.ma.array(x, mask=mask)
  860. return margs
  861. def boxplot_stats(X, whis=1.5, bootstrap=None, labels=None,
  862. autorange=False):
  863. r"""
  864. Returns list of dictionaries of statistics used to draw a series
  865. of box and whisker plots. The `Returns` section enumerates the
  866. required keys of the dictionary. Users can skip this function and
  867. pass a user-defined set of dictionaries to the new `axes.bxp` method
  868. instead of relying on Matplotlib to do the calculations.
  869. Parameters
  870. ----------
  871. X : array-like
  872. Data that will be represented in the boxplots. Should have 2 or
  873. fewer dimensions.
  874. whis : float or (float, float) (default = 1.5)
  875. The position of the whiskers.
  876. If a float, the lower whisker is at the lowest datum above
  877. ``Q1 - whis*(Q3-Q1)``, and the upper whisker at the highest datum below
  878. ``Q3 + whis*(Q3-Q1)``, where Q1 and Q3 are the first and third
  879. quartiles. The default value of ``whis = 1.5`` corresponds to Tukey's
  880. original definition of boxplots.
  881. If a pair of floats, they indicate the percentiles at which to draw the
  882. whiskers (e.g., (5, 95)). In particular, setting this to (0, 100)
  883. results in whiskers covering the whole range of the data. "range" is
  884. a deprecated synonym for (0, 100).
  885. In the edge case where ``Q1 == Q3``, *whis* is automatically set to
  886. (0, 100) (cover the whole range of the data) if *autorange* is True.
  887. Beyond the whiskers, data are considered outliers and are plotted as
  888. individual points.
  889. bootstrap : int, optional
  890. Number of times the confidence intervals around the median
  891. should be bootstrapped (percentile method).
  892. labels : array-like, optional
  893. Labels for each dataset. Length must be compatible with
  894. dimensions of *X*.
  895. autorange : bool, optional (False)
  896. When `True` and the data are distributed such that the 25th and 75th
  897. percentiles are equal, ``whis`` is set to (0, 100) such that the
  898. whisker ends are at the minimum and maximum of the data.
  899. Returns
  900. -------
  901. bxpstats : list of dict
  902. A list of dictionaries containing the results for each column
  903. of data. Keys of each dictionary are the following:
  904. ======== ===================================
  905. Key Value Description
  906. ======== ===================================
  907. label tick label for the boxplot
  908. mean arithmetic mean value
  909. med 50th percentile
  910. q1 first quartile (25th percentile)
  911. q3 third quartile (75th percentile)
  912. cilo lower notch around the median
  913. cihi upper notch around the median
  914. whislo end of the lower whisker
  915. whishi end of the upper whisker
  916. fliers outliers
  917. ======== ===================================
  918. Notes
  919. -----
  920. Non-bootstrapping approach to confidence interval uses Gaussian-
  921. based asymptotic approximation:
  922. .. math::
  923. \mathrm{med} \pm 1.57 \times \frac{\mathrm{iqr}}{\sqrt{N}}
  924. General approach from:
  925. McGill, R., Tukey, J.W., and Larsen, W.A. (1978) "Variations of
  926. Boxplots", The American Statistician, 32:12-16.
  927. """
  928. def _bootstrap_median(data, N=5000):
  929. # determine 95% confidence intervals of the median
  930. M = len(data)
  931. percentiles = [2.5, 97.5]
  932. bs_index = np.random.randint(M, size=(N, M))
  933. bsData = data[bs_index]
  934. estimate = np.median(bsData, axis=1, overwrite_input=True)
  935. CI = np.percentile(estimate, percentiles)
  936. return CI
  937. def _compute_conf_interval(data, med, iqr, bootstrap):
  938. if bootstrap is not None:
  939. # Do a bootstrap estimate of notch locations.
  940. # get conf. intervals around median
  941. CI = _bootstrap_median(data, N=bootstrap)
  942. notch_min = CI[0]
  943. notch_max = CI[1]
  944. else:
  945. N = len(data)
  946. notch_min = med - 1.57 * iqr / np.sqrt(N)
  947. notch_max = med + 1.57 * iqr / np.sqrt(N)
  948. return notch_min, notch_max
  949. # output is a list of dicts
  950. bxpstats = []
  951. # convert X to a list of lists
  952. X = _reshape_2D(X, "X")
  953. ncols = len(X)
  954. if labels is None:
  955. labels = itertools.repeat(None)
  956. elif len(labels) != ncols:
  957. raise ValueError("Dimensions of labels and X must be compatible")
  958. input_whis = whis
  959. for ii, (x, label) in enumerate(zip(X, labels)):
  960. # empty dict
  961. stats = {}
  962. if label is not None:
  963. stats['label'] = label
  964. # restore whis to the input values in case it got changed in the loop
  965. whis = input_whis
  966. # note tricksiness, append up here and then mutate below
  967. bxpstats.append(stats)
  968. # if empty, bail
  969. if len(x) == 0:
  970. stats['fliers'] = np.array([])
  971. stats['mean'] = np.nan
  972. stats['med'] = np.nan
  973. stats['q1'] = np.nan
  974. stats['q3'] = np.nan
  975. stats['cilo'] = np.nan
  976. stats['cihi'] = np.nan
  977. stats['whislo'] = np.nan
  978. stats['whishi'] = np.nan
  979. stats['med'] = np.nan
  980. continue
  981. # up-convert to an array, just to be safe
  982. x = np.asarray(x)
  983. # arithmetic mean
  984. stats['mean'] = np.mean(x)
  985. # medians and quartiles
  986. q1, med, q3 = np.percentile(x, [25, 50, 75])
  987. # interquartile range
  988. stats['iqr'] = q3 - q1
  989. if stats['iqr'] == 0 and autorange:
  990. whis = (0, 100)
  991. # conf. interval around median
  992. stats['cilo'], stats['cihi'] = _compute_conf_interval(
  993. x, med, stats['iqr'], bootstrap
  994. )
  995. # lowest/highest non-outliers
  996. if np.isscalar(whis):
  997. if np.isreal(whis):
  998. loval = q1 - whis * stats['iqr']
  999. hival = q3 + whis * stats['iqr']
  1000. elif whis in ['range', 'limit', 'limits', 'min/max']:
  1001. warn_deprecated(
  1002. "3.2", message=f"Setting whis to {whis!r} is deprecated "
  1003. "since %(since)s and support for it will be removed "
  1004. "%(removal)s; set it to [0, 100] to achieve the same "
  1005. "effect.")
  1006. loval = np.min(x)
  1007. hival = np.max(x)
  1008. else:
  1009. raise ValueError('whis must be a float or list of percentiles')
  1010. else:
  1011. loval, hival = np.percentile(x, whis)
  1012. # get high extreme
  1013. wiskhi = x[x <= hival]
  1014. if len(wiskhi) == 0 or np.max(wiskhi) < q3:
  1015. stats['whishi'] = q3
  1016. else:
  1017. stats['whishi'] = np.max(wiskhi)
  1018. # get low extreme
  1019. wisklo = x[x >= loval]
  1020. if len(wisklo) == 0 or np.min(wisklo) > q1:
  1021. stats['whislo'] = q1
  1022. else:
  1023. stats['whislo'] = np.min(wisklo)
  1024. # compute a single array of outliers
  1025. stats['fliers'] = np.hstack([
  1026. x[x < stats['whislo']],
  1027. x[x > stats['whishi']],
  1028. ])
  1029. # add in the remaining stats
  1030. stats['q1'], stats['med'], stats['q3'] = q1, med, q3
  1031. return bxpstats
  1032. # The ls_mapper maps short codes for line style to their full name used by
  1033. # backends; the reverse mapper is for mapping full names to short ones.
  1034. ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'}
  1035. ls_mapper_r = {v: k for k, v in ls_mapper.items()}
  1036. def contiguous_regions(mask):
  1037. """
  1038. Return a list of (ind0, ind1) such that ``mask[ind0:ind1].all()`` is
  1039. True and we cover all such regions.
  1040. """
  1041. mask = np.asarray(mask, dtype=bool)
  1042. if not mask.size:
  1043. return []
  1044. # Find the indices of region changes, and correct offset
  1045. idx, = np.nonzero(mask[:-1] != mask[1:])
  1046. idx += 1
  1047. # List operations are faster for moderately sized arrays
  1048. idx = idx.tolist()
  1049. # Add first and/or last index if needed
  1050. if mask[0]:
  1051. idx = [0] + idx
  1052. if mask[-1]:
  1053. idx.append(len(mask))
  1054. return list(zip(idx[::2], idx[1::2]))
  1055. def is_math_text(s):
  1056. """
  1057. Returns whether the string *s* contains math expressions.
  1058. This is done by checking whether *s* contains an even number of
  1059. non-escaped dollar signs.
  1060. """
  1061. s = str(s)
  1062. dollar_count = s.count(r'$') - s.count(r'\$')
  1063. even_dollars = (dollar_count > 0 and dollar_count % 2 == 0)
  1064. return even_dollars
  1065. def _to_unmasked_float_array(x):
  1066. """
  1067. Convert a sequence to a float array; if input was a masked array, masked
  1068. values are converted to nans.
  1069. """
  1070. if hasattr(x, 'mask'):
  1071. return np.ma.asarray(x, float).filled(np.nan)
  1072. else:
  1073. return np.asarray(x, float)
  1074. def _check_1d(x):
  1075. '''
  1076. Converts a sequence of less than 1 dimension, to an array of 1
  1077. dimension; leaves everything else untouched.
  1078. '''
  1079. if not hasattr(x, 'shape') or len(x.shape) < 1:
  1080. return np.atleast_1d(x)
  1081. else:
  1082. try:
  1083. # work around
  1084. # https://github.com/pandas-dev/pandas/issues/27775 which
  1085. # means the shape of multi-dimensional slicing is not as
  1086. # expected. That this ever worked was an unintentional
  1087. # quirk of pandas and will raise an exception in the
  1088. # future. This slicing warns in pandas >= 1.0rc0 via
  1089. # https://github.com/pandas-dev/pandas/pull/30588
  1090. #
  1091. # < 1.0rc0 : x[:, None].ndim == 1, no warning, custom type
  1092. # >= 1.0rc1 : x[:, None].ndim == 2, warns, numpy array
  1093. # future : x[:, None] -> raises
  1094. #
  1095. # This code should correctly identify and coerce to a
  1096. # numpy array all pandas versions.
  1097. with warnings.catch_warnings(record=True) as w:
  1098. warnings.filterwarnings(
  1099. "always",
  1100. category=DeprecationWarning,
  1101. message='Support for multi-dimensional indexing')
  1102. ndim = x[:, None].ndim
  1103. # we have definitely hit a pandas index or series object
  1104. # cast to a numpy array.
  1105. if len(w) > 0:
  1106. return np.asanyarray(x)
  1107. # We have likely hit a pandas object, or at least
  1108. # something where 2D slicing does not result in a 2D
  1109. # object.
  1110. if ndim < 2:
  1111. return np.atleast_1d(x)
  1112. return x
  1113. except (IndexError, TypeError):
  1114. return np.atleast_1d(x)
  1115. def _reshape_2D(X, name):
  1116. """
  1117. Use Fortran ordering to convert ndarrays and lists of iterables to lists of
  1118. 1D arrays.
  1119. Lists of iterables are converted by applying `np.asarray` to each of their
  1120. elements. 1D ndarrays are returned in a singleton list containing them.
  1121. 2D ndarrays are converted to the list of their *columns*.
  1122. *name* is used to generate the error message for invalid inputs.
  1123. """
  1124. # Iterate over columns for ndarrays, over rows otherwise.
  1125. X = np.atleast_1d(X.T if isinstance(X, np.ndarray) else np.asarray(X))
  1126. if len(X) == 0:
  1127. return [[]]
  1128. elif X.ndim == 1 and np.ndim(X[0]) == 0:
  1129. # 1D array of scalars: directly return it.
  1130. return [X]
  1131. elif X.ndim in [1, 2]:
  1132. # 2D array, or 1D array of iterables: flatten them first.
  1133. return [np.reshape(x, -1) for x in X]
  1134. else:
  1135. raise ValueError("{} must have 2 or fewer dimensions".format(name))
  1136. def violin_stats(X, method, points=100, quantiles=None):
  1137. """
  1138. Returns a list of dictionaries of data which can be used to draw a series
  1139. of violin plots.
  1140. See the Returns section below to view the required keys of the dictionary.
  1141. Users can skip this function and pass a user-defined set of dictionaries
  1142. with the same keys to `~.axes.Axes.violinplot` instead of using Matplotlib
  1143. to do the calculations. See the *Returns* section below for the keys
  1144. that must be present in the dictionaries.
  1145. Parameters
  1146. ----------
  1147. X : array-like
  1148. Sample data that will be used to produce the gaussian kernel density
  1149. estimates. Must have 2 or fewer dimensions.
  1150. method : callable
  1151. The method used to calculate the kernel density estimate for each
  1152. column of data. When called via `method(v, coords)`, it should
  1153. return a vector of the values of the KDE evaluated at the values
  1154. specified in coords.
  1155. points : int, default = 100
  1156. Defines the number of points to evaluate each of the gaussian kernel
  1157. density estimates at.
  1158. quantiles : array-like, default = None
  1159. Defines (if not None) a list of floats in interval [0, 1] for each
  1160. column of data, which represents the quantiles that will be rendered
  1161. for that column of data. Must have 2 or fewer dimensions. 1D array will
  1162. be treated as a singleton list containing them.
  1163. Returns
  1164. -------
  1165. vpstats : list of dict
  1166. A list of dictionaries containing the results for each column of data.
  1167. The dictionaries contain at least the following:
  1168. - coords: A list of scalars containing the coordinates this particular
  1169. kernel density estimate was evaluated at.
  1170. - vals: A list of scalars containing the values of the kernel density
  1171. estimate at each of the coordinates given in `coords`.
  1172. - mean: The mean value for this column of data.
  1173. - median: The median value for this column of data.
  1174. - min: The minimum value for this column of data.
  1175. - max: The maximum value for this column of data.
  1176. - quantiles: The quantile values for this column of data.
  1177. """
  1178. # List of dictionaries describing each of the violins.
  1179. vpstats = []
  1180. # Want X to be a list of data sequences
  1181. X = _reshape_2D(X, "X")
  1182. # Want quantiles to be as the same shape as data sequences
  1183. if quantiles is not None and len(quantiles) != 0:
  1184. quantiles = _reshape_2D(quantiles, "quantiles")
  1185. # Else, mock quantiles if is none or empty
  1186. else:
  1187. quantiles = [[]] * np.shape(X)[0]
  1188. # quantiles should has the same size as dataset
  1189. if np.shape(X)[:1] != np.shape(quantiles)[:1]:
  1190. raise ValueError("List of violinplot statistics and quantiles values"
  1191. " must have the same length")
  1192. # Zip x and quantiles
  1193. for (x, q) in zip(X, quantiles):
  1194. # Dictionary of results for this distribution
  1195. stats = {}
  1196. # Calculate basic stats for the distribution
  1197. min_val = np.min(x)
  1198. max_val = np.max(x)
  1199. quantile_val = np.percentile(x, 100 * q)
  1200. # Evaluate the kernel density estimate
  1201. coords = np.linspace(min_val, max_val, points)
  1202. stats['vals'] = method(x, coords)
  1203. stats['coords'] = coords
  1204. # Store additional statistics for this distribution
  1205. stats['mean'] = np.mean(x)
  1206. stats['median'] = np.median(x)
  1207. stats['min'] = min_val
  1208. stats['max'] = max_val
  1209. stats['quantiles'] = np.atleast_1d(quantile_val)
  1210. # Append to output
  1211. vpstats.append(stats)
  1212. return vpstats
  1213. def pts_to_prestep(x, *args):
  1214. """
  1215. Convert continuous line to pre-steps.
  1216. Given a set of ``N`` points, convert to ``2N - 1`` points, which when
  1217. connected linearly give a step function which changes values at the
  1218. beginning of the intervals.
  1219. Parameters
  1220. ----------
  1221. x : array
  1222. The x location of the steps. May be empty.
  1223. y1, ..., yp : array
  1224. y arrays to be turned into steps; all must be the same length as ``x``.
  1225. Returns
  1226. -------
  1227. out : array
  1228. The x and y values converted to steps in the same order as the input;
  1229. can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
  1230. length ``N``, each of these arrays will be length ``2N + 1``. For
  1231. ``N=0``, the length will be 0.
  1232. Examples
  1233. --------
  1234. >>> x_s, y1_s, y2_s = pts_to_prestep(x, y1, y2)
  1235. """
  1236. steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0)))
  1237. # In all `pts_to_*step` functions, only assign once using *x* and *args*,
  1238. # as converting to an array may be expensive.
  1239. steps[0, 0::2] = x
  1240. steps[0, 1::2] = steps[0, 0:-2:2]
  1241. steps[1:, 0::2] = args
  1242. steps[1:, 1::2] = steps[1:, 2::2]
  1243. return steps
  1244. def pts_to_poststep(x, *args):
  1245. """
  1246. Convert continuous line to post-steps.
  1247. Given a set of ``N`` points convert to ``2N + 1`` points, which when
  1248. connected linearly give a step function which changes values at the end of
  1249. the intervals.
  1250. Parameters
  1251. ----------
  1252. x : array
  1253. The x location of the steps. May be empty.
  1254. y1, ..., yp : array
  1255. y arrays to be turned into steps; all must be the same length as ``x``.
  1256. Returns
  1257. -------
  1258. out : array
  1259. The x and y values converted to steps in the same order as the input;
  1260. can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
  1261. length ``N``, each of these arrays will be length ``2N + 1``. For
  1262. ``N=0``, the length will be 0.
  1263. Examples
  1264. --------
  1265. >>> x_s, y1_s, y2_s = pts_to_poststep(x, y1, y2)
  1266. """
  1267. steps = np.zeros((1 + len(args), max(2 * len(x) - 1, 0)))
  1268. steps[0, 0::2] = x
  1269. steps[0, 1::2] = steps[0, 2::2]
  1270. steps[1:, 0::2] = args
  1271. steps[1:, 1::2] = steps[1:, 0:-2:2]
  1272. return steps
  1273. def pts_to_midstep(x, *args):
  1274. """
  1275. Convert continuous line to mid-steps.
  1276. Given a set of ``N`` points convert to ``2N`` points which when connected
  1277. linearly give a step function which changes values at the middle of the
  1278. intervals.
  1279. Parameters
  1280. ----------
  1281. x : array
  1282. The x location of the steps. May be empty.
  1283. y1, ..., yp : array
  1284. y arrays to be turned into steps; all must be the same length as
  1285. ``x``.
  1286. Returns
  1287. -------
  1288. out : array
  1289. The x and y values converted to steps in the same order as the input;
  1290. can be unpacked as ``x_out, y1_out, ..., yp_out``. If the input is
  1291. length ``N``, each of these arrays will be length ``2N``.
  1292. Examples
  1293. --------
  1294. >>> x_s, y1_s, y2_s = pts_to_midstep(x, y1, y2)
  1295. """
  1296. steps = np.zeros((1 + len(args), 2 * len(x)))
  1297. x = np.asanyarray(x)
  1298. steps[0, 1:-1:2] = steps[0, 2::2] = (x[:-1] + x[1:]) / 2
  1299. steps[0, :1] = x[:1] # Also works for zero-sized input.
  1300. steps[0, -1:] = x[-1:]
  1301. steps[1:, 0::2] = args
  1302. steps[1:, 1::2] = steps[1:, 0::2]
  1303. return steps
  1304. STEP_LOOKUP_MAP = {'default': lambda x, y: (x, y),
  1305. 'steps': pts_to_prestep,
  1306. 'steps-pre': pts_to_prestep,
  1307. 'steps-post': pts_to_poststep,
  1308. 'steps-mid': pts_to_midstep}
  1309. def index_of(y):
  1310. """
  1311. A helper function to create reasonable x values for the given *y*.
  1312. This is used for plotting (x, y) if x values are not explicitly given.
  1313. First try ``y.index`` (assuming *y* is a `pandas.Series`), if that
  1314. fails, use ``range(len(y))``.
  1315. This will be extended in the future to deal with more types of
  1316. labeled data.
  1317. Parameters
  1318. ----------
  1319. y : scalar or array-like
  1320. Returns
  1321. -------
  1322. x, y : ndarray
  1323. The x and y values to plot.
  1324. """
  1325. try:
  1326. return y.index.values, y.values
  1327. except AttributeError:
  1328. y = _check_1d(y)
  1329. return np.arange(y.shape[0], dtype=float), y
  1330. def safe_first_element(obj):
  1331. """
  1332. Return the first element in *obj*.
  1333. This is an type-independent way of obtaining the first element, supporting
  1334. both index access and the iterator protocol.
  1335. """
  1336. if isinstance(obj, collections.abc.Iterator):
  1337. # needed to accept `array.flat` as input.
  1338. # np.flatiter reports as an instance of collections.Iterator
  1339. # but can still be indexed via [].
  1340. # This has the side effect of re-setting the iterator, but
  1341. # that is acceptable.
  1342. try:
  1343. return obj[0]
  1344. except TypeError:
  1345. pass
  1346. raise RuntimeError("matplotlib does not support generators "
  1347. "as input")
  1348. return next(iter(obj))
  1349. def sanitize_sequence(data):
  1350. """
  1351. Convert dictview objects to list. Other inputs are returned unchanged.
  1352. """
  1353. return (list(data) if isinstance(data, collections.abc.MappingView)
  1354. else data)
  1355. def normalize_kwargs(kw, alias_mapping=None, required=(), forbidden=(),
  1356. allowed=None):
  1357. """
  1358. Helper function to normalize kwarg inputs.
  1359. The order they are resolved are:
  1360. 1. aliasing
  1361. 2. required
  1362. 3. forbidden
  1363. 4. allowed
  1364. This order means that only the canonical names need appear in
  1365. *allowed*, *forbidden*, *required*.
  1366. Parameters
  1367. ----------
  1368. kw : dict
  1369. A dict of keyword arguments.
  1370. alias_mapping : dict or Artist subclass or Artist instance, optional
  1371. A mapping between a canonical name to a list of
  1372. aliases, in order of precedence from lowest to highest.
  1373. If the canonical value is not in the list it is assumed to have
  1374. the highest priority.
  1375. If an Artist subclass or instance is passed, use its properties alias
  1376. mapping.
  1377. required : list of str, optional
  1378. A list of keys that must be in *kws*.
  1379. forbidden : list of str, optional
  1380. A list of keys which may not be in *kw*.
  1381. allowed : list of str, optional
  1382. A list of allowed fields. If this not None, then raise if
  1383. *kw* contains any keys not in the union of *required*
  1384. and *allowed*. To allow only the required fields pass in
  1385. an empty tuple ``allowed=()``.
  1386. Raises
  1387. ------
  1388. TypeError
  1389. To match what python raises if invalid args/kwargs are passed to
  1390. a callable.
  1391. """
  1392. from matplotlib.artist import Artist
  1393. # deal with default value of alias_mapping
  1394. if alias_mapping is None:
  1395. alias_mapping = dict()
  1396. elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist)
  1397. or isinstance(alias_mapping, Artist)):
  1398. alias_mapping = getattr(alias_mapping, "_alias_map", {})
  1399. # make a local so we can pop
  1400. kw = dict(kw)
  1401. # output dictionary
  1402. ret = dict()
  1403. # hit all alias mappings
  1404. for canonical, alias_list in alias_mapping.items():
  1405. # the alias lists are ordered from lowest to highest priority
  1406. # so we know to use the last value in this list
  1407. tmp = []
  1408. seen = []
  1409. for a in alias_list:
  1410. try:
  1411. tmp.append(kw.pop(a))
  1412. seen.append(a)
  1413. except KeyError:
  1414. pass
  1415. # if canonical is not in the alias_list assume highest priority
  1416. if canonical not in alias_list:
  1417. try:
  1418. tmp.append(kw.pop(canonical))
  1419. seen.append(canonical)
  1420. except KeyError:
  1421. pass
  1422. # if we found anything in this set of aliases put it in the return
  1423. # dict
  1424. if tmp:
  1425. ret[canonical] = tmp[-1]
  1426. if len(tmp) > 1:
  1427. warn_deprecated(
  1428. "3.1", message=f"Saw kwargs {seen!r} which are all "
  1429. f"aliases for {canonical!r}. Kept value from "
  1430. f"{seen[-1]!r}. Passing multiple aliases for the same "
  1431. f"property will raise a TypeError %(removal)s.")
  1432. # at this point we know that all keys which are aliased are removed, update
  1433. # the return dictionary from the cleaned local copy of the input
  1434. ret.update(kw)
  1435. fail_keys = [k for k in required if k not in ret]
  1436. if fail_keys:
  1437. raise TypeError("The required keys {keys!r} "
  1438. "are not in kwargs".format(keys=fail_keys))
  1439. fail_keys = [k for k in forbidden if k in ret]
  1440. if fail_keys:
  1441. raise TypeError("The forbidden keys {keys!r} "
  1442. "are in kwargs".format(keys=fail_keys))
  1443. if allowed is not None:
  1444. allowed_set = {*required, *allowed}
  1445. fail_keys = [k for k in ret if k not in allowed_set]
  1446. if fail_keys:
  1447. raise TypeError(
  1448. "kwargs contains {keys!r} which are not in the required "
  1449. "{req!r} or allowed {allow!r} keys".format(
  1450. keys=fail_keys, req=required, allow=allowed))
  1451. return ret
  1452. @deprecated("3.1")
  1453. def get_label(y, default_name):
  1454. try:
  1455. return y.name
  1456. except AttributeError:
  1457. return default_name
  1458. _lockstr = """\
  1459. LOCKERROR: matplotlib is trying to acquire the lock
  1460. {!r}
  1461. and has failed. This maybe due to any other process holding this
  1462. lock. If you are sure no other matplotlib process is running try
  1463. removing these folders and trying again.
  1464. """
  1465. @contextlib.contextmanager
  1466. def _lock_path(path):
  1467. """
  1468. Context manager for locking a path.
  1469. Usage::
  1470. with _lock_path(path):
  1471. ...
  1472. Another thread or process that attempts to lock the same path will wait
  1473. until this context manager is exited.
  1474. The lock is implemented by creating a temporary file in the parent
  1475. directory, so that directory must exist and be writable.
  1476. """
  1477. path = Path(path)
  1478. lock_path = path.with_name(path.name + ".matplotlib-lock")
  1479. retries = 50
  1480. sleeptime = 0.1
  1481. for _ in range(retries):
  1482. try:
  1483. with lock_path.open("xb"):
  1484. break
  1485. except FileExistsError:
  1486. time.sleep(sleeptime)
  1487. else:
  1488. raise TimeoutError("""\
  1489. Lock error: Matplotlib failed to acquire the following lock file:
  1490. {}
  1491. This maybe due to another process holding this lock file. If you are sure no
  1492. other Matplotlib process is running, remove this file and try again.""".format(
  1493. lock_path))
  1494. try:
  1495. yield
  1496. finally:
  1497. lock_path.unlink()
  1498. def _topmost_artist(
  1499. artists,
  1500. _cached_max=functools.partial(max, key=operator.attrgetter("zorder"))):
  1501. """Get the topmost artist of a list.
  1502. In case of a tie, return the *last* of the tied artists, as it will be
  1503. drawn on top of the others. `max` returns the first maximum in case of
  1504. ties, so we need to iterate over the list in reverse order.
  1505. """
  1506. return _cached_max(reversed(artists))
  1507. def _str_equal(obj, s):
  1508. """Return whether *obj* is a string equal to string *s*.
  1509. This helper solely exists to handle the case where *obj* is a numpy array,
  1510. because in such cases, a naive ``obj == s`` would yield an array, which
  1511. cannot be used in a boolean context.
  1512. """
  1513. return isinstance(obj, str) and obj == s
  1514. def _str_lower_equal(obj, s):
  1515. """Return whether *obj* is a string equal, when lowercased, to string *s*.
  1516. This helper solely exists to handle the case where *obj* is a numpy array,
  1517. because in such cases, a naive ``obj == s`` would yield an array, which
  1518. cannot be used in a boolean context.
  1519. """
  1520. return isinstance(obj, str) and obj.lower() == s
  1521. def _define_aliases(alias_d, cls=None):
  1522. """Class decorator for defining property aliases.
  1523. Use as ::
  1524. @cbook._define_aliases({"property": ["alias", ...], ...})
  1525. class C: ...
  1526. For each property, if the corresponding ``get_property`` is defined in the
  1527. class so far, an alias named ``get_alias`` will be defined; the same will
  1528. be done for setters. If neither the getter nor the setter exists, an
  1529. exception will be raised.
  1530. The alias map is stored as the ``_alias_map`` attribute on the class and
  1531. can be used by `~.normalize_kwargs` (which assumes that higher priority
  1532. aliases come last).
  1533. """
  1534. if cls is None: # Return the actual class decorator.
  1535. return functools.partial(_define_aliases, alias_d)
  1536. def make_alias(name): # Enforce a closure over *name*.
  1537. @functools.wraps(getattr(cls, name))
  1538. def method(self, *args, **kwargs):
  1539. return getattr(self, name)(*args, **kwargs)
  1540. return method
  1541. for prop, aliases in alias_d.items():
  1542. exists = False
  1543. for prefix in ["get_", "set_"]:
  1544. if prefix + prop in vars(cls):
  1545. exists = True
  1546. for alias in aliases:
  1547. method = make_alias(prefix + prop)
  1548. method.__name__ = prefix + alias
  1549. method.__doc__ = "Alias for `{}`.".format(prefix + prop)
  1550. setattr(cls, prefix + alias, method)
  1551. if not exists:
  1552. raise ValueError(
  1553. "Neither getter nor setter exists for {!r}".format(prop))
  1554. if hasattr(cls, "_alias_map"):
  1555. # Need to decide on conflict resolution policy.
  1556. raise NotImplementedError("Parent class already defines aliases")
  1557. cls._alias_map = alias_d
  1558. return cls
  1559. def _array_perimeter(arr):
  1560. """
  1561. Get the elements on the perimeter of ``arr``,
  1562. Parameters
  1563. ----------
  1564. arr : ndarray, shape (M, N)
  1565. The input array
  1566. Returns
  1567. -------
  1568. perimeter : ndarray, shape (2*(M - 1) + 2*(N - 1),)
  1569. The elements on the perimeter of the array::
  1570. [arr[0, 0], ..., arr[0, -1], ..., arr[-1, -1], ..., arr[-1, 0], ...]
  1571. Examples
  1572. --------
  1573. >>> i, j = np.ogrid[:3,:4]
  1574. >>> a = i*10 + j
  1575. >>> a
  1576. array([[ 0, 1, 2, 3],
  1577. [10, 11, 12, 13],
  1578. [20, 21, 22, 23]])
  1579. >>> _array_perimeter(a)
  1580. array([ 0, 1, 2, 3, 13, 23, 22, 21, 20, 10])
  1581. """
  1582. # note we use Python's half-open ranges to avoid repeating
  1583. # the corners
  1584. forward = np.s_[0:-1] # [0 ... -1)
  1585. backward = np.s_[-1:0:-1] # [-1 ... 0)
  1586. return np.concatenate((
  1587. arr[0, forward],
  1588. arr[forward, -1],
  1589. arr[-1, backward],
  1590. arr[backward, 0],
  1591. ))
  1592. @contextlib.contextmanager
  1593. def _setattr_cm(obj, **kwargs):
  1594. """Temporarily set some attributes; restore original state at context exit.
  1595. """
  1596. sentinel = object()
  1597. origs = [(attr, getattr(obj, attr, sentinel)) for attr in kwargs]
  1598. try:
  1599. for attr, val in kwargs.items():
  1600. setattr(obj, attr, val)
  1601. yield
  1602. finally:
  1603. for attr, orig in origs:
  1604. if orig is sentinel:
  1605. delattr(obj, attr)
  1606. else:
  1607. setattr(obj, attr, orig)
  1608. def _warn_external(message, category=None):
  1609. """
  1610. `warnings.warn` wrapper that sets *stacklevel* to "outside Matplotlib".
  1611. The original emitter of the warning can be obtained by patching this
  1612. function back to `warnings.warn`, i.e. ``cbook._warn_external =
  1613. warnings.warn`` (or ``functools.partial(warnings.warn, stacklevel=2)``,
  1614. etc.).
  1615. """
  1616. frame = sys._getframe()
  1617. for stacklevel in itertools.count(1): # lgtm[py/unused-loop-variable]
  1618. if frame is None:
  1619. # when called in embedded context may hit frame is None
  1620. break
  1621. if not re.match(r"\A(matplotlib|mpl_toolkits)(\Z|\.(?!tests\.))",
  1622. # Work around sphinx-gallery not setting __name__.
  1623. frame.f_globals.get("__name__", "")):
  1624. break
  1625. frame = frame.f_back
  1626. warnings.warn(message, category, stacklevel)
  1627. class _OrderedSet(collections.abc.MutableSet):
  1628. def __init__(self):
  1629. self._od = collections.OrderedDict()
  1630. def __contains__(self, key):
  1631. return key in self._od
  1632. def __iter__(self):
  1633. return iter(self._od)
  1634. def __len__(self):
  1635. return len(self._od)
  1636. def add(self, key):
  1637. self._od.pop(key, None)
  1638. self._od[key] = None
  1639. def discard(self, key):
  1640. self._od.pop(key, None)
  1641. # Agg's buffers are unmultiplied RGBA8888, which neither PyQt4 nor cairo
  1642. # support; however, both do support premultiplied ARGB32.
  1643. def _premultiplied_argb32_to_unmultiplied_rgba8888(buf):
  1644. """
  1645. Convert a premultiplied ARGB32 buffer to an unmultiplied RGBA8888 buffer.
  1646. """
  1647. rgba = np.take( # .take() ensures C-contiguity of the result.
  1648. buf,
  1649. [2, 1, 0, 3] if sys.byteorder == "little" else [1, 2, 3, 0], axis=2)
  1650. rgb = rgba[..., :-1]
  1651. alpha = rgba[..., -1]
  1652. # Un-premultiply alpha. The formula is the same as in cairo-png.c.
  1653. mask = alpha != 0
  1654. for channel in np.rollaxis(rgb, -1):
  1655. channel[mask] = (
  1656. (channel[mask].astype(int) * 255 + alpha[mask] // 2)
  1657. // alpha[mask])
  1658. return rgba
  1659. def _unmultiplied_rgba8888_to_premultiplied_argb32(rgba8888):
  1660. """
  1661. Convert an unmultiplied RGBA8888 buffer to a premultiplied ARGB32 buffer.
  1662. """
  1663. if sys.byteorder == "little":
  1664. argb32 = np.take(rgba8888, [2, 1, 0, 3], axis=2)
  1665. rgb24 = argb32[..., :-1]
  1666. alpha8 = argb32[..., -1:]
  1667. else:
  1668. argb32 = np.take(rgba8888, [3, 0, 1, 2], axis=2)
  1669. alpha8 = argb32[..., :1]
  1670. rgb24 = argb32[..., 1:]
  1671. # Only bother premultiplying when the alpha channel is not fully opaque,
  1672. # as the cost is not negligible. The unsafe cast is needed to do the
  1673. # multiplication in-place in an integer buffer.
  1674. if alpha8.min() != 0xff:
  1675. np.multiply(rgb24, alpha8 / 0xff, out=rgb24, casting="unsafe")
  1676. return argb32
  1677. def _pformat_subprocess(command):
  1678. """Pretty-format a subprocess command for printing/logging purposes."""
  1679. return (command if isinstance(command, str)
  1680. else " ".join(shlex.quote(os.fspath(arg)) for arg in command))
  1681. def _check_and_log_subprocess(command, logger, **kwargs):
  1682. """
  1683. Run *command*, returning its stdout output if it succeeds.
  1684. If it fails (exits with nonzero return code), raise an exception whose text
  1685. includes the failed command and captured stdout and stderr output.
  1686. Regardless of the return code, the command is logged at DEBUG level on
  1687. *logger*. In case of success, the output is likewise logged.
  1688. """
  1689. logger.debug('%s', _pformat_subprocess(command))
  1690. proc = subprocess.run(
  1691. command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwargs)
  1692. if proc.returncode:
  1693. raise RuntimeError(
  1694. f"The command\n"
  1695. f" {_pformat_subprocess(command)}\n"
  1696. f"failed and generated the following output:\n"
  1697. f"{proc.stdout.decode('utf-8')}\n"
  1698. f"and the following error:\n"
  1699. f"{proc.stderr.decode('utf-8')}")
  1700. logger.debug("stdout:\n%s", proc.stdout)
  1701. logger.debug("stderr:\n%s", proc.stderr)
  1702. return proc.stdout
  1703. # In the following _check_foo functions, the first parameter starts with an
  1704. # underscore because it is intended to be positional-only (e.g., so that
  1705. # `_check_isinstance([...], types=foo)` doesn't fail.
  1706. def _check_isinstance(_types, **kwargs):
  1707. """
  1708. For each *key, value* pair in *kwargs*, check that *value* is an instance
  1709. of one of *_types*; if not, raise an appropriate TypeError.
  1710. As a special case, a ``None`` entry in *_types* is treated as NoneType.
  1711. Examples
  1712. --------
  1713. >>> cbook._check_isinstance((SomeClass, None), arg=arg)
  1714. """
  1715. types = _types
  1716. if isinstance(types, type) or types is None:
  1717. types = (types,)
  1718. none_allowed = None in types
  1719. types = tuple(tp for tp in types if tp is not None)
  1720. def type_name(tp):
  1721. return (tp.__qualname__ if tp.__module__ == "builtins"
  1722. else f"{tp.__module__}.{tp.__qualname__}")
  1723. names = [*map(type_name, types)]
  1724. if none_allowed:
  1725. types = (*types, type(None))
  1726. names.append("None")
  1727. for k, v in kwargs.items():
  1728. if not isinstance(v, types):
  1729. raise TypeError(
  1730. "{!r} must be an instance of {}, not a {}".format(
  1731. k,
  1732. ", ".join(names[:-1]) + " or " + names[-1]
  1733. if len(names) > 1 else names[0],
  1734. type_name(type(v))))
  1735. def _check_in_list(_values, **kwargs):
  1736. """
  1737. For each *key, value* pair in *kwargs*, check that *value* is in *_values*;
  1738. if not, raise an appropriate ValueError.
  1739. Examples
  1740. --------
  1741. >>> cbook._check_in_list(["foo", "bar"], arg=arg, other_arg=other_arg)
  1742. """
  1743. values = _values
  1744. for k, v in kwargs.items():
  1745. if v not in values:
  1746. raise ValueError(
  1747. "{!r} is not a valid value for {}; supported values are {}"
  1748. .format(v, k, ', '.join(map(repr, values))))
  1749. def _check_getitem(_mapping, **kwargs):
  1750. """
  1751. *kwargs* must consist of a single *key, value* pair. If *key* is in
  1752. *_mapping*, return ``_mapping[value]``; else, raise an appropriate
  1753. ValueError.
  1754. Examples
  1755. --------
  1756. >>> cbook._check_getitem({"foo": "bar"}, arg=arg)
  1757. """
  1758. mapping = _mapping
  1759. if len(kwargs) != 1:
  1760. raise ValueError("_check_getitem takes a single keyword argument")
  1761. (k, v), = kwargs.items()
  1762. try:
  1763. return mapping[v]
  1764. except KeyError:
  1765. raise ValueError(
  1766. "{!r} is not a valid value for {}; supported values are {}"
  1767. .format(v, k, ', '.join(map(repr, mapping)))) from None
  1768. class _classproperty:
  1769. """
  1770. Like `property`, but also triggers on access via the class, and it is the
  1771. *class* that's passed as argument.
  1772. Examples
  1773. --------
  1774. ::
  1775. class C:
  1776. @classproperty
  1777. def foo(cls):
  1778. return cls.__name__
  1779. assert C.foo == "C"
  1780. """
  1781. def __init__(self, fget):
  1782. self._fget = fget
  1783. def __get__(self, instance, owner):
  1784. return self._fget(owner)