common.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563
  1. #!/usr/bin/env python
  2. # coding: utf-8
  3. import os
  4. import warnings
  5. import numpy as np
  6. from numpy import random
  7. from pandas.util._decorators import cache_readonly
  8. import pandas.util._test_decorators as td
  9. from pandas.core.dtypes.api import is_list_like
  10. import pandas as pd
  11. from pandas import DataFrame, Series
  12. import pandas._testing as tm
  13. """
  14. This is a common base class used for various plotting tests
  15. """
  16. @td.skip_if_no_mpl
  17. class TestPlotBase:
  18. def setup_method(self, method):
  19. import matplotlib as mpl
  20. from pandas.plotting._matplotlib import compat
  21. mpl.rcdefaults()
  22. self.mpl_ge_2_2_3 = compat._mpl_ge_2_2_3()
  23. self.mpl_ge_3_0_0 = compat._mpl_ge_3_0_0()
  24. self.mpl_ge_3_1_0 = compat._mpl_ge_3_1_0()
  25. self.mpl_ge_3_2_0 = compat._mpl_ge_3_2_0()
  26. self.bp_n_objects = 7
  27. self.polycollection_factor = 2
  28. self.default_figsize = (6.4, 4.8)
  29. self.default_tick_position = "left"
  30. n = 100
  31. with tm.RNGContext(42):
  32. gender = np.random.choice(["Male", "Female"], size=n)
  33. classroom = np.random.choice(["A", "B", "C"], size=n)
  34. self.hist_df = DataFrame(
  35. {
  36. "gender": gender,
  37. "classroom": classroom,
  38. "height": random.normal(66, 4, size=n),
  39. "weight": random.normal(161, 32, size=n),
  40. "category": random.randint(4, size=n),
  41. }
  42. )
  43. self.tdf = tm.makeTimeDataFrame()
  44. self.hexbin_df = DataFrame(
  45. {
  46. "A": np.random.uniform(size=20),
  47. "B": np.random.uniform(size=20),
  48. "C": np.arange(20) + np.random.uniform(size=20),
  49. }
  50. )
  51. def teardown_method(self, method):
  52. tm.close()
  53. @cache_readonly
  54. def plt(self):
  55. import matplotlib.pyplot as plt
  56. return plt
  57. @cache_readonly
  58. def colorconverter(self):
  59. import matplotlib.colors as colors
  60. return colors.colorConverter
  61. def _check_legend_labels(self, axes, labels=None, visible=True):
  62. """
  63. Check each axes has expected legend labels
  64. Parameters
  65. ----------
  66. axes : matplotlib Axes object, or its list-like
  67. labels : list-like
  68. expected legend labels
  69. visible : bool
  70. expected legend visibility. labels are checked only when visible is
  71. True
  72. """
  73. if visible and (labels is None):
  74. raise ValueError("labels must be specified when visible is True")
  75. axes = self._flatten_visible(axes)
  76. for ax in axes:
  77. if visible:
  78. assert ax.get_legend() is not None
  79. self._check_text_labels(ax.get_legend().get_texts(), labels)
  80. else:
  81. assert ax.get_legend() is None
  82. def _check_legend_marker(self, ax, expected_markers=None, visible=True):
  83. """
  84. Check ax has expected legend markers
  85. Parameters
  86. ----------
  87. ax : matplotlib Axes object
  88. expected_markers : list-like
  89. expected legend markers
  90. visible : bool
  91. expected legend visibility. labels are checked only when visible is
  92. True
  93. """
  94. if visible and (expected_markers is None):
  95. raise ValueError("Markers must be specified when visible is True")
  96. if visible:
  97. handles, _ = ax.get_legend_handles_labels()
  98. markers = [handle.get_marker() for handle in handles]
  99. assert markers == expected_markers
  100. else:
  101. assert ax.get_legend() is None
  102. def _check_data(self, xp, rs):
  103. """
  104. Check each axes has identical lines
  105. Parameters
  106. ----------
  107. xp : matplotlib Axes object
  108. rs : matplotlib Axes object
  109. """
  110. xp_lines = xp.get_lines()
  111. rs_lines = rs.get_lines()
  112. def check_line(xpl, rsl):
  113. xpdata = xpl.get_xydata()
  114. rsdata = rsl.get_xydata()
  115. tm.assert_almost_equal(xpdata, rsdata)
  116. assert len(xp_lines) == len(rs_lines)
  117. [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
  118. tm.close()
  119. def _check_visible(self, collections, visible=True):
  120. """
  121. Check each artist is visible or not
  122. Parameters
  123. ----------
  124. collections : matplotlib Artist or its list-like
  125. target Artist or its list or collection
  126. visible : bool
  127. expected visibility
  128. """
  129. from matplotlib.collections import Collection
  130. if not isinstance(collections, Collection) and not is_list_like(collections):
  131. collections = [collections]
  132. for patch in collections:
  133. assert patch.get_visible() == visible
  134. def _get_colors_mapped(self, series, colors):
  135. unique = series.unique()
  136. # unique and colors length can be differed
  137. # depending on slice value
  138. mapped = dict(zip(unique, colors))
  139. return [mapped[v] for v in series.values]
  140. def _check_colors(
  141. self, collections, linecolors=None, facecolors=None, mapping=None
  142. ):
  143. """
  144. Check each artist has expected line colors and face colors
  145. Parameters
  146. ----------
  147. collections : list-like
  148. list or collection of target artist
  149. linecolors : list-like which has the same length as collections
  150. list of expected line colors
  151. facecolors : list-like which has the same length as collections
  152. list of expected face colors
  153. mapping : Series
  154. Series used for color grouping key
  155. used for andrew_curves, parallel_coordinates, radviz test
  156. """
  157. from matplotlib.lines import Line2D
  158. from matplotlib.collections import Collection, PolyCollection, LineCollection
  159. conv = self.colorconverter
  160. if linecolors is not None:
  161. if mapping is not None:
  162. linecolors = self._get_colors_mapped(mapping, linecolors)
  163. linecolors = linecolors[: len(collections)]
  164. assert len(collections) == len(linecolors)
  165. for patch, color in zip(collections, linecolors):
  166. if isinstance(patch, Line2D):
  167. result = patch.get_color()
  168. # Line2D may contains string color expression
  169. result = conv.to_rgba(result)
  170. elif isinstance(patch, (PolyCollection, LineCollection)):
  171. result = tuple(patch.get_edgecolor()[0])
  172. else:
  173. result = patch.get_edgecolor()
  174. expected = conv.to_rgba(color)
  175. assert result == expected
  176. if facecolors is not None:
  177. if mapping is not None:
  178. facecolors = self._get_colors_mapped(mapping, facecolors)
  179. facecolors = facecolors[: len(collections)]
  180. assert len(collections) == len(facecolors)
  181. for patch, color in zip(collections, facecolors):
  182. if isinstance(patch, Collection):
  183. # returned as list of np.array
  184. result = patch.get_facecolor()[0]
  185. else:
  186. result = patch.get_facecolor()
  187. if isinstance(result, np.ndarray):
  188. result = tuple(result)
  189. expected = conv.to_rgba(color)
  190. assert result == expected
  191. def _check_text_labels(self, texts, expected):
  192. """
  193. Check each text has expected labels
  194. Parameters
  195. ----------
  196. texts : matplotlib Text object, or its list-like
  197. target text, or its list
  198. expected : str or list-like which has the same length as texts
  199. expected text label, or its list
  200. """
  201. if not is_list_like(texts):
  202. assert texts.get_text() == expected
  203. else:
  204. labels = [t.get_text() for t in texts]
  205. assert len(labels) == len(expected)
  206. for label, e in zip(labels, expected):
  207. assert label == e
  208. def _check_ticks_props(
  209. self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
  210. ):
  211. """
  212. Check each axes has expected tick properties
  213. Parameters
  214. ----------
  215. axes : matplotlib Axes object, or its list-like
  216. xlabelsize : number
  217. expected xticks font size
  218. xrot : number
  219. expected xticks rotation
  220. ylabelsize : number
  221. expected yticks font size
  222. yrot : number
  223. expected yticks rotation
  224. """
  225. from matplotlib.ticker import NullFormatter
  226. axes = self._flatten_visible(axes)
  227. for ax in axes:
  228. if xlabelsize or xrot:
  229. if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
  230. # If minor ticks has NullFormatter, rot / fontsize are not
  231. # retained
  232. labels = ax.get_xticklabels()
  233. else:
  234. labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
  235. for label in labels:
  236. if xlabelsize is not None:
  237. tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
  238. if xrot is not None:
  239. tm.assert_almost_equal(label.get_rotation(), xrot)
  240. if ylabelsize or yrot:
  241. if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
  242. labels = ax.get_yticklabels()
  243. else:
  244. labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
  245. for label in labels:
  246. if ylabelsize is not None:
  247. tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
  248. if yrot is not None:
  249. tm.assert_almost_equal(label.get_rotation(), yrot)
  250. def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
  251. """
  252. Check each axes has expected scales
  253. Parameters
  254. ----------
  255. axes : matplotlib Axes object, or its list-like
  256. xaxis : {'linear', 'log'}
  257. expected xaxis scale
  258. yaxis : {'linear', 'log'}
  259. expected yaxis scale
  260. """
  261. axes = self._flatten_visible(axes)
  262. for ax in axes:
  263. assert ax.xaxis.get_scale() == xaxis
  264. assert ax.yaxis.get_scale() == yaxis
  265. def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
  266. """
  267. Check expected number of axes is drawn in expected layout
  268. Parameters
  269. ----------
  270. axes : matplotlib Axes object, or its list-like
  271. axes_num : number
  272. expected number of axes. Unnecessary axes should be set to
  273. invisible.
  274. layout : tuple
  275. expected layout, (expected number of rows , columns)
  276. figsize : tuple
  277. expected figsize. default is matplotlib default
  278. """
  279. from pandas.plotting._matplotlib.tools import _flatten
  280. if figsize is None:
  281. figsize = self.default_figsize
  282. visible_axes = self._flatten_visible(axes)
  283. if axes_num is not None:
  284. assert len(visible_axes) == axes_num
  285. for ax in visible_axes:
  286. # check something drawn on visible axes
  287. assert len(ax.get_children()) > 0
  288. if layout is not None:
  289. result = self._get_axes_layout(_flatten(axes))
  290. assert result == layout
  291. tm.assert_numpy_array_equal(
  292. visible_axes[0].figure.get_size_inches(),
  293. np.array(figsize, dtype=np.float64),
  294. )
  295. def _get_axes_layout(self, axes):
  296. x_set = set()
  297. y_set = set()
  298. for ax in axes:
  299. # check axes coordinates to estimate layout
  300. points = ax.get_position().get_points()
  301. x_set.add(points[0][0])
  302. y_set.add(points[0][1])
  303. return (len(y_set), len(x_set))
  304. def _flatten_visible(self, axes):
  305. """
  306. Flatten axes, and filter only visible
  307. Parameters
  308. ----------
  309. axes : matplotlib Axes object, or its list-like
  310. """
  311. from pandas.plotting._matplotlib.tools import _flatten
  312. axes = _flatten(axes)
  313. axes = [ax for ax in axes if ax.get_visible()]
  314. return axes
  315. def _check_has_errorbars(self, axes, xerr=0, yerr=0):
  316. """
  317. Check axes has expected number of errorbars
  318. Parameters
  319. ----------
  320. axes : matplotlib Axes object, or its list-like
  321. xerr : number
  322. expected number of x errorbar
  323. yerr : number
  324. expected number of y errorbar
  325. """
  326. axes = self._flatten_visible(axes)
  327. for ax in axes:
  328. containers = ax.containers
  329. xerr_count = 0
  330. yerr_count = 0
  331. for c in containers:
  332. has_xerr = getattr(c, "has_xerr", False)
  333. has_yerr = getattr(c, "has_yerr", False)
  334. if has_xerr:
  335. xerr_count += 1
  336. if has_yerr:
  337. yerr_count += 1
  338. assert xerr == xerr_count
  339. assert yerr == yerr_count
  340. def _check_box_return_type(
  341. self, returned, return_type, expected_keys=None, check_ax_title=True
  342. ):
  343. """
  344. Check box returned type is correct
  345. Parameters
  346. ----------
  347. returned : object to be tested, returned from boxplot
  348. return_type : str
  349. return_type passed to boxplot
  350. expected_keys : list-like, optional
  351. group labels in subplot case. If not passed,
  352. the function checks assuming boxplot uses single ax
  353. check_ax_title : bool
  354. Whether to check the ax.title is the same as expected_key
  355. Intended to be checked by calling from ``boxplot``.
  356. Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
  357. """
  358. from matplotlib.axes import Axes
  359. types = {"dict": dict, "axes": Axes, "both": tuple}
  360. if expected_keys is None:
  361. # should be fixed when the returning default is changed
  362. if return_type is None:
  363. return_type = "dict"
  364. assert isinstance(returned, types[return_type])
  365. if return_type == "both":
  366. assert isinstance(returned.ax, Axes)
  367. assert isinstance(returned.lines, dict)
  368. else:
  369. # should be fixed when the returning default is changed
  370. if return_type is None:
  371. for r in self._flatten_visible(returned):
  372. assert isinstance(r, Axes)
  373. return
  374. assert isinstance(returned, Series)
  375. assert sorted(returned.keys()) == sorted(expected_keys)
  376. for key, value in returned.items():
  377. assert isinstance(value, types[return_type])
  378. # check returned dict has correct mapping
  379. if return_type == "axes":
  380. if check_ax_title:
  381. assert value.get_title() == key
  382. elif return_type == "both":
  383. if check_ax_title:
  384. assert value.ax.get_title() == key
  385. assert isinstance(value.ax, Axes)
  386. assert isinstance(value.lines, dict)
  387. elif return_type == "dict":
  388. line = value["medians"][0]
  389. axes = line.axes
  390. if check_ax_title:
  391. assert axes.get_title() == key
  392. else:
  393. raise AssertionError
  394. def _check_grid_settings(self, obj, kinds, kws={}):
  395. # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
  396. import matplotlib as mpl
  397. def is_grid_on():
  398. xticks = self.plt.gca().xaxis.get_major_ticks()
  399. yticks = self.plt.gca().yaxis.get_major_ticks()
  400. # for mpl 2.2.2, gridOn and gridline.get_visible disagree.
  401. # for new MPL, they are the same.
  402. if self.mpl_ge_3_1_0:
  403. xoff = all(not g.gridline.get_visible() for g in xticks)
  404. yoff = all(not g.gridline.get_visible() for g in yticks)
  405. else:
  406. xoff = all(not g.gridOn for g in xticks)
  407. yoff = all(not g.gridOn for g in yticks)
  408. return not (xoff and yoff)
  409. spndx = 1
  410. for kind in kinds:
  411. self.plt.subplot(1, 4 * len(kinds), spndx)
  412. spndx += 1
  413. mpl.rc("axes", grid=False)
  414. obj.plot(kind=kind, **kws)
  415. assert not is_grid_on()
  416. self.plt.subplot(1, 4 * len(kinds), spndx)
  417. spndx += 1
  418. mpl.rc("axes", grid=True)
  419. obj.plot(kind=kind, grid=False, **kws)
  420. assert not is_grid_on()
  421. if kind != "pie":
  422. self.plt.subplot(1, 4 * len(kinds), spndx)
  423. spndx += 1
  424. mpl.rc("axes", grid=True)
  425. obj.plot(kind=kind, **kws)
  426. assert is_grid_on()
  427. self.plt.subplot(1, 4 * len(kinds), spndx)
  428. spndx += 1
  429. mpl.rc("axes", grid=False)
  430. obj.plot(kind=kind, grid=True, **kws)
  431. assert is_grid_on()
  432. def _unpack_cycler(self, rcParams, field="color"):
  433. """
  434. Auxiliary function for correctly unpacking cycler after MPL >= 1.5
  435. """
  436. return [v[field] for v in rcParams["axes.prop_cycle"]]
  437. def _check_plot_works(f, filterwarnings="always", **kwargs):
  438. import matplotlib.pyplot as plt
  439. ret = None
  440. with warnings.catch_warnings():
  441. warnings.simplefilter(filterwarnings)
  442. try:
  443. try:
  444. fig = kwargs["figure"]
  445. except KeyError:
  446. fig = plt.gcf()
  447. plt.clf()
  448. kwargs.get("ax", fig.add_subplot(211))
  449. ret = f(**kwargs)
  450. tm.assert_is_valid_plot_return_object(ret)
  451. if f is pd.plotting.bootstrap_plot:
  452. assert "ax" not in kwargs
  453. else:
  454. kwargs["ax"] = fig.add_subplot(212)
  455. ret = f(**kwargs)
  456. tm.assert_is_valid_plot_return_object(ret)
  457. with tm.ensure_clean(return_filelike=True) as path:
  458. plt.savefig(path)
  459. finally:
  460. tm.close(fig)
  461. return ret
  462. def curpath():
  463. pth, _ = os.path.split(os.path.abspath(__file__))
  464. return pth