test_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500
  1. """Tests for plotting utilities."""
  2. import tempfile
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib as mpl
  6. import matplotlib.pyplot as plt
  7. from cycler import cycler
  8. import pytest
  9. import nose
  10. import nose.tools as nt
  11. from nose.tools import assert_equal, raises
  12. import numpy.testing as npt
  13. try:
  14. import pandas.testing as pdt
  15. except ImportError:
  16. import pandas.util.testing as pdt
  17. from distutils.version import LooseVersion
  18. try:
  19. from bs4 import BeautifulSoup
  20. except ImportError:
  21. BeautifulSoup = None
  22. from .. import utils, rcmod
  23. from ..utils import get_dataset_names, load_dataset, _network
  24. a_norm = np.random.randn(100)
  25. def test_pmf_hist_basics():
  26. """Test the function to return barplot args for pmf hist."""
  27. with pytest.warns(UserWarning):
  28. out = utils.pmf_hist(a_norm)
  29. assert_equal(len(out), 3)
  30. x, h, w = out
  31. assert_equal(len(x), len(h))
  32. # Test simple case
  33. a = np.arange(10)
  34. with pytest.warns(UserWarning):
  35. x, h, w = utils.pmf_hist(a, 10)
  36. nose.tools.assert_true(np.all(h == h[0]))
  37. # Test width
  38. with pytest.warns(UserWarning):
  39. x, h, w = utils.pmf_hist(a_norm)
  40. assert_equal(x[1] - x[0], w)
  41. # Test normalization
  42. with pytest.warns(UserWarning):
  43. x, h, w = utils.pmf_hist(a_norm)
  44. nose.tools.assert_almost_equal(sum(h), 1)
  45. nose.tools.assert_less_equal(h.max(), 1)
  46. # Test bins
  47. with pytest.warns(UserWarning):
  48. x, h, w = utils.pmf_hist(a_norm, 20)
  49. assert_equal(len(x), 20)
  50. def test_ci_to_errsize():
  51. """Test behavior of ci_to_errsize."""
  52. cis = [[.5, .5],
  53. [1.25, 1.5]]
  54. heights = [1, 1.5]
  55. actual_errsize = np.array([[.5, 1],
  56. [.25, 0]])
  57. test_errsize = utils.ci_to_errsize(cis, heights)
  58. npt.assert_array_equal(actual_errsize, test_errsize)
  59. def test_desaturate():
  60. """Test color desaturation."""
  61. out1 = utils.desaturate("red", .5)
  62. assert_equal(out1, (.75, .25, .25))
  63. out2 = utils.desaturate("#00FF00", .5)
  64. assert_equal(out2, (.25, .75, .25))
  65. out3 = utils.desaturate((0, 0, 1), .5)
  66. assert_equal(out3, (.25, .25, .75))
  67. out4 = utils.desaturate("red", .5)
  68. assert_equal(out4, (.75, .25, .25))
  69. @raises(ValueError)
  70. def test_desaturation_prop():
  71. """Test that pct outside of [0, 1] raises exception."""
  72. utils.desaturate("blue", 50)
  73. def test_saturate():
  74. """Test performance of saturation function."""
  75. out = utils.saturate((.75, .25, .25))
  76. assert_equal(out, (1, 0, 0))
  77. @pytest.mark.parametrize(
  78. "p,annot", [(.0001, "***"), (.001, "**"), (.01, "*"), (.09, "."), (1, "")]
  79. )
  80. def test_sig_stars(p, annot):
  81. """Test the sig stars function."""
  82. with pytest.warns(UserWarning):
  83. stars = utils.sig_stars(p)
  84. assert_equal(stars, annot)
  85. def test_iqr():
  86. """Test the IQR function."""
  87. a = np.arange(5)
  88. iqr = utils.iqr(a)
  89. assert_equal(iqr, 2)
  90. @pytest.mark.parametrize(
  91. "s,exp",
  92. [
  93. ("a", "a"),
  94. ("abc", "abc"),
  95. (b"a", "a"),
  96. (b"abc", "abc"),
  97. (bytearray("abc", "utf-8"), "abc"),
  98. (bytearray(), ""),
  99. (1, "1"),
  100. (0, "0"),
  101. ([], str([])),
  102. ],
  103. )
  104. def test_to_utf8(s, exp):
  105. """Test the to_utf8 function: object to string"""
  106. u = utils.to_utf8(s)
  107. assert_equal(type(u), str)
  108. assert_equal(u, exp)
  109. class TestSpineUtils(object):
  110. sides = ["left", "right", "bottom", "top"]
  111. outer_sides = ["top", "right"]
  112. inner_sides = ["left", "bottom"]
  113. offset = 10
  114. original_position = ("outward", 0)
  115. offset_position = ("outward", offset)
  116. def test_despine(self):
  117. f, ax = plt.subplots()
  118. for side in self.sides:
  119. nt.assert_true(ax.spines[side].get_visible())
  120. utils.despine()
  121. for side in self.outer_sides:
  122. nt.assert_true(~ax.spines[side].get_visible())
  123. for side in self.inner_sides:
  124. nt.assert_true(ax.spines[side].get_visible())
  125. utils.despine(**dict(zip(self.sides, [True] * 4)))
  126. for side in self.sides:
  127. nt.assert_true(~ax.spines[side].get_visible())
  128. def test_despine_specific_axes(self):
  129. f, (ax1, ax2) = plt.subplots(2, 1)
  130. utils.despine(ax=ax2)
  131. for side in self.sides:
  132. nt.assert_true(ax1.spines[side].get_visible())
  133. for side in self.outer_sides:
  134. nt.assert_true(~ax2.spines[side].get_visible())
  135. for side in self.inner_sides:
  136. nt.assert_true(ax2.spines[side].get_visible())
  137. def test_despine_with_offset(self):
  138. f, ax = plt.subplots()
  139. for side in self.sides:
  140. nt.assert_equal(ax.spines[side].get_position(),
  141. self.original_position)
  142. utils.despine(ax=ax, offset=self.offset)
  143. for side in self.sides:
  144. is_visible = ax.spines[side].get_visible()
  145. new_position = ax.spines[side].get_position()
  146. if is_visible:
  147. nt.assert_equal(new_position, self.offset_position)
  148. else:
  149. nt.assert_equal(new_position, self.original_position)
  150. def test_despine_side_specific_offset(self):
  151. f, ax = plt.subplots()
  152. utils.despine(ax=ax, offset=dict(left=self.offset))
  153. for side in self.sides:
  154. is_visible = ax.spines[side].get_visible()
  155. new_position = ax.spines[side].get_position()
  156. if is_visible and side == "left":
  157. nt.assert_equal(new_position, self.offset_position)
  158. else:
  159. nt.assert_equal(new_position, self.original_position)
  160. def test_despine_with_offset_specific_axes(self):
  161. f, (ax1, ax2) = plt.subplots(2, 1)
  162. utils.despine(offset=self.offset, ax=ax2)
  163. for side in self.sides:
  164. nt.assert_equal(ax1.spines[side].get_position(),
  165. self.original_position)
  166. if ax2.spines[side].get_visible():
  167. nt.assert_equal(ax2.spines[side].get_position(),
  168. self.offset_position)
  169. else:
  170. nt.assert_equal(ax2.spines[side].get_position(),
  171. self.original_position)
  172. def test_despine_trim_spines(self):
  173. f, ax = plt.subplots()
  174. ax.plot([1, 2, 3], [1, 2, 3])
  175. ax.set_xlim(.75, 3.25)
  176. utils.despine(trim=True)
  177. for side in self.inner_sides:
  178. bounds = ax.spines[side].get_bounds()
  179. nt.assert_equal(bounds, (1, 3))
  180. def test_despine_trim_inverted(self):
  181. f, ax = plt.subplots()
  182. ax.plot([1, 2, 3], [1, 2, 3])
  183. ax.set_ylim(.85, 3.15)
  184. ax.invert_yaxis()
  185. utils.despine(trim=True)
  186. for side in self.inner_sides:
  187. bounds = ax.spines[side].get_bounds()
  188. nt.assert_equal(bounds, (1, 3))
  189. def test_despine_trim_noticks(self):
  190. f, ax = plt.subplots()
  191. ax.plot([1, 2, 3], [1, 2, 3])
  192. ax.set_yticks([])
  193. utils.despine(trim=True)
  194. nt.assert_equal(ax.get_yticks().size, 0)
  195. def test_despine_trim_categorical(self):
  196. f, ax = plt.subplots()
  197. ax.plot(["a", "b", "c"], [1, 2, 3])
  198. utils.despine(trim=True)
  199. bounds = ax.spines["left"].get_bounds()
  200. nt.assert_equal(bounds, (1, 3))
  201. bounds = ax.spines["bottom"].get_bounds()
  202. nt.assert_equal(bounds, (0, 2))
  203. def test_despine_moved_ticks(self):
  204. f, ax = plt.subplots()
  205. for t in ax.yaxis.majorTicks:
  206. t.tick1line.set_visible(True)
  207. utils.despine(ax=ax, left=True, right=False)
  208. for y in ax.yaxis.majorTicks:
  209. assert t.tick2line.get_visible()
  210. plt.close(f)
  211. f, ax = plt.subplots()
  212. for t in ax.yaxis.majorTicks:
  213. t.tick1line.set_visible(False)
  214. utils.despine(ax=ax, left=True, right=False)
  215. for y in ax.yaxis.majorTicks:
  216. assert not t.tick2line.get_visible()
  217. plt.close(f)
  218. f, ax = plt.subplots()
  219. for t in ax.xaxis.majorTicks:
  220. t.tick1line.set_visible(True)
  221. utils.despine(ax=ax, bottom=True, top=False)
  222. for y in ax.xaxis.majorTicks:
  223. assert t.tick2line.get_visible()
  224. plt.close(f)
  225. f, ax = plt.subplots()
  226. for t in ax.xaxis.majorTicks:
  227. t.tick1line.set_visible(False)
  228. utils.despine(ax=ax, bottom=True, top=False)
  229. for y in ax.xaxis.majorTicks:
  230. assert not t.tick2line.get_visible()
  231. plt.close(f)
  232. def test_ticklabels_overlap():
  233. rcmod.set()
  234. f, ax = plt.subplots(figsize=(2, 2))
  235. f.tight_layout() # This gets the Agg renderer working
  236. assert not utils.axis_ticklabels_overlap(ax.get_xticklabels())
  237. big_strings = "abcdefgh", "ijklmnop"
  238. ax.set_xlim(-.5, 1.5)
  239. ax.set_xticks([0, 1])
  240. ax.set_xticklabels(big_strings)
  241. assert utils.axis_ticklabels_overlap(ax.get_xticklabels())
  242. x, y = utils.axes_ticklabels_overlap(ax)
  243. assert x
  244. assert not y
  245. def test_categorical_order():
  246. x = ["a", "c", "c", "b", "a", "d"]
  247. y = [3, 2, 5, 1, 4]
  248. order = ["a", "b", "c", "d"]
  249. out = utils.categorical_order(x)
  250. nt.assert_equal(out, ["a", "c", "b", "d"])
  251. out = utils.categorical_order(x, order)
  252. nt.assert_equal(out, order)
  253. out = utils.categorical_order(x, ["b", "a"])
  254. nt.assert_equal(out, ["b", "a"])
  255. out = utils.categorical_order(np.array(x))
  256. nt.assert_equal(out, ["a", "c", "b", "d"])
  257. out = utils.categorical_order(pd.Series(x))
  258. nt.assert_equal(out, ["a", "c", "b", "d"])
  259. out = utils.categorical_order(y)
  260. nt.assert_equal(out, [1, 2, 3, 4, 5])
  261. out = utils.categorical_order(np.array(y))
  262. nt.assert_equal(out, [1, 2, 3, 4, 5])
  263. out = utils.categorical_order(pd.Series(y))
  264. nt.assert_equal(out, [1, 2, 3, 4, 5])
  265. x = pd.Categorical(x, order)
  266. out = utils.categorical_order(x)
  267. nt.assert_equal(out, list(x.categories))
  268. x = pd.Series(x)
  269. out = utils.categorical_order(x)
  270. nt.assert_equal(out, list(x.cat.categories))
  271. out = utils.categorical_order(x, ["b", "a"])
  272. nt.assert_equal(out, ["b", "a"])
  273. x = ["a", np.nan, "c", "c", "b", "a", "d"]
  274. out = utils.categorical_order(x)
  275. nt.assert_equal(out, ["a", "c", "b", "d"])
  276. def test_locator_to_legend_entries():
  277. locator = mpl.ticker.MaxNLocator(nbins=3)
  278. limits = (0.09, 0.4)
  279. levels, str_levels = utils.locator_to_legend_entries(
  280. locator, limits, float
  281. )
  282. assert str_levels == ["0.00", "0.15", "0.30", "0.45"]
  283. limits = (0.8, 0.9)
  284. levels, str_levels = utils.locator_to_legend_entries(
  285. locator, limits, float
  286. )
  287. assert str_levels == ["0.80", "0.84", "0.88", "0.92"]
  288. limits = (1, 6)
  289. levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
  290. assert str_levels == ["0", "2", "4", "6"]
  291. locator = mpl.ticker.LogLocator(numticks=3)
  292. limits = (5, 1425)
  293. levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
  294. if LooseVersion(mpl.__version__) >= "3.1":
  295. assert str_levels == ['0', '1', '100', '10000', '1e+06']
  296. limits = (0.00003, 0.02)
  297. levels, str_levels = utils.locator_to_legend_entries(
  298. locator, limits, float
  299. )
  300. if LooseVersion(mpl.__version__) >= "3.1":
  301. assert str_levels == ['1e-07', '1e-05', '1e-03', '1e-01', '10']
  302. @pytest.mark.parametrize(
  303. "cycler,result",
  304. [
  305. (cycler(color=["y"]), ["y"]),
  306. (cycler(color=["k"]), ["k"]),
  307. (cycler(color=["k", "y"]), ["k", "y"]),
  308. (cycler(color=["y", "k"]), ["y", "k"]),
  309. (cycler(color=["b", "r"]), ["b", "r"]),
  310. (cycler(color=["r", "b"]), ["r", "b"]),
  311. (cycler(lw=[1, 2]), [".15"]), # no color in cycle
  312. ],
  313. )
  314. def test_get_color_cycle(cycler, result):
  315. with mpl.rc_context(rc={"axes.prop_cycle": cycler}):
  316. assert utils.get_color_cycle() == result
  317. def check_load_dataset(name):
  318. ds = load_dataset(name, cache=False)
  319. assert(isinstance(ds, pd.DataFrame))
  320. def check_load_cached_dataset(name):
  321. # Test the cacheing using a temporary file.
  322. with tempfile.TemporaryDirectory() as tmpdir:
  323. # download and cache
  324. ds = load_dataset(name, cache=True, data_home=tmpdir)
  325. # use cached version
  326. ds2 = load_dataset(name, cache=True, data_home=tmpdir)
  327. pdt.assert_frame_equal(ds, ds2)
  328. @_network(url="https://github.com/mwaskom/seaborn-data")
  329. def test_get_dataset_names():
  330. if not BeautifulSoup:
  331. raise nose.SkipTest("No BeautifulSoup available for parsing html")
  332. names = get_dataset_names()
  333. assert(len(names) > 0)
  334. assert("titanic" in names)
  335. @_network(url="https://github.com/mwaskom/seaborn-data")
  336. def test_load_datasets():
  337. if not BeautifulSoup:
  338. raise nose.SkipTest("No BeautifulSoup available for parsing html")
  339. # Heavy test to verify that we can load all available datasets
  340. for name in get_dataset_names():
  341. # unfortunately @network somehow obscures this generator so it
  342. # does not get in effect, so we need to call explicitly
  343. # yield check_load_dataset, name
  344. check_load_dataset(name)
  345. @_network(url="https://github.com/mwaskom/seaborn-data")
  346. def test_load_cached_datasets():
  347. if not BeautifulSoup:
  348. raise nose.SkipTest("No BeautifulSoup available for parsing html")
  349. # Heavy test to verify that we can load all available datasets
  350. for name in get_dataset_names():
  351. # unfortunately @network somehow obscures this generator so it
  352. # does not get in effect, so we need to call explicitly
  353. # yield check_load_dataset, name
  354. check_load_cached_dataset(name)
  355. def test_relative_luminance():
  356. """Test relative luminance."""
  357. out1 = utils.relative_luminance("white")
  358. assert_equal(out1, 1)
  359. out2 = utils.relative_luminance("#000000")
  360. assert_equal(out2, 0)
  361. out3 = utils.relative_luminance((.25, .5, .75))
  362. nose.tools.assert_almost_equal(out3, 0.201624536)
  363. rgbs = mpl.cm.RdBu(np.linspace(0, 1, 10))
  364. lums1 = [utils.relative_luminance(rgb) for rgb in rgbs]
  365. lums2 = utils.relative_luminance(rgbs)
  366. for lum1, lum2 in zip(lums1, lums2):
  367. nose.tools.assert_almost_equal(lum1, lum2)
  368. def test_remove_na():
  369. a_array = np.array([1, 2, np.nan, 3])
  370. a_array_rm = utils.remove_na(a_array)
  371. npt.assert_array_equal(a_array_rm, np.array([1, 2, 3]))
  372. a_series = pd.Series([1, 2, np.nan, 3])
  373. a_series_rm = utils.remove_na(a_series)
  374. pdt.assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3]))