test_misc.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408
  1. # coding: utf-8
  2. """ Test cases for misc plot functions """
  3. import numpy as np
  4. from numpy import random
  5. from numpy.random import randn
  6. import pytest
  7. import pandas.util._test_decorators as td
  8. from pandas import DataFrame, Series
  9. import pandas._testing as tm
  10. from pandas.tests.plotting.common import TestPlotBase, _check_plot_works
  11. import pandas.plotting as plotting
  12. @td.skip_if_mpl
  13. def test_import_error_message():
  14. # GH-19810
  15. df = DataFrame({"A": [1, 2]})
  16. with pytest.raises(ImportError, match="matplotlib is required for plotting"):
  17. df.plot()
  18. def test_get_accessor_args():
  19. func = plotting._core.PlotAccessor._get_call_args
  20. msg = "Called plot accessor for type list, expected Series or DataFrame"
  21. with pytest.raises(TypeError, match=msg):
  22. func(backend_name="", data=[], args=[], kwargs={})
  23. msg = "should not be called with positional arguments"
  24. with pytest.raises(TypeError, match=msg):
  25. func(backend_name="", data=Series(dtype=object), args=["line", None], kwargs={})
  26. x, y, kind, kwargs = func(
  27. backend_name="",
  28. data=DataFrame(),
  29. args=["x"],
  30. kwargs={"y": "y", "kind": "bar", "grid": False},
  31. )
  32. assert x == "x"
  33. assert y == "y"
  34. assert kind == "bar"
  35. assert kwargs == {"grid": False}
  36. x, y, kind, kwargs = func(
  37. backend_name="pandas.plotting._matplotlib",
  38. data=Series(dtype=object),
  39. args=[],
  40. kwargs={},
  41. )
  42. assert x is None
  43. assert y is None
  44. assert kind == "line"
  45. assert len(kwargs) == 22
  46. @td.skip_if_no_mpl
  47. class TestSeriesPlots(TestPlotBase):
  48. def setup_method(self, method):
  49. TestPlotBase.setup_method(self, method)
  50. import matplotlib as mpl
  51. mpl.rcdefaults()
  52. self.ts = tm.makeTimeSeries()
  53. self.ts.name = "ts"
  54. @pytest.mark.slow
  55. def test_autocorrelation_plot(self):
  56. from pandas.plotting import autocorrelation_plot
  57. _check_plot_works(autocorrelation_plot, series=self.ts)
  58. _check_plot_works(autocorrelation_plot, series=self.ts.values)
  59. ax = autocorrelation_plot(self.ts, label="Test")
  60. self._check_legend_labels(ax, labels=["Test"])
  61. @pytest.mark.slow
  62. def test_lag_plot(self):
  63. from pandas.plotting import lag_plot
  64. _check_plot_works(lag_plot, series=self.ts)
  65. _check_plot_works(lag_plot, series=self.ts, lag=5)
  66. @pytest.mark.slow
  67. def test_bootstrap_plot(self):
  68. from pandas.plotting import bootstrap_plot
  69. _check_plot_works(bootstrap_plot, series=self.ts, size=10)
  70. @td.skip_if_no_mpl
  71. class TestDataFramePlots(TestPlotBase):
  72. @td.skip_if_no_scipy
  73. def test_scatter_matrix_axis(self):
  74. scatter_matrix = plotting.scatter_matrix
  75. with tm.RNGContext(42):
  76. df = DataFrame(randn(100, 3))
  77. # we are plotting multiples on a sub-plot
  78. with tm.assert_produces_warning(UserWarning):
  79. axes = _check_plot_works(
  80. scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
  81. )
  82. axes0_labels = axes[0][0].yaxis.get_majorticklabels()
  83. # GH 5662
  84. expected = ["-2", "0", "2"]
  85. self._check_text_labels(axes0_labels, expected)
  86. self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  87. df[0] = (df[0] - 2) / 3
  88. # we are plotting multiples on a sub-plot
  89. with tm.assert_produces_warning(UserWarning):
  90. axes = _check_plot_works(
  91. scatter_matrix, filterwarnings="always", frame=df, range_padding=0.1
  92. )
  93. axes0_labels = axes[0][0].yaxis.get_majorticklabels()
  94. expected = ["-1.0", "-0.5", "0.0"]
  95. self._check_text_labels(axes0_labels, expected)
  96. self._check_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  97. @pytest.mark.slow
  98. def test_andrews_curves(self, iris):
  99. from pandas.plotting import andrews_curves
  100. from matplotlib import cm
  101. df = iris
  102. _check_plot_works(andrews_curves, frame=df, class_column="Name")
  103. rgba = ("#556270", "#4ECDC4", "#C7F464")
  104. ax = _check_plot_works(
  105. andrews_curves, frame=df, class_column="Name", color=rgba
  106. )
  107. self._check_colors(
  108. ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
  109. )
  110. cnames = ["dodgerblue", "aquamarine", "seagreen"]
  111. ax = _check_plot_works(
  112. andrews_curves, frame=df, class_column="Name", color=cnames
  113. )
  114. self._check_colors(
  115. ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
  116. )
  117. ax = _check_plot_works(
  118. andrews_curves, frame=df, class_column="Name", colormap=cm.jet
  119. )
  120. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  121. self._check_colors(
  122. ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
  123. )
  124. length = 10
  125. df = DataFrame(
  126. {
  127. "A": random.rand(length),
  128. "B": random.rand(length),
  129. "C": random.rand(length),
  130. "Name": ["A"] * length,
  131. }
  132. )
  133. _check_plot_works(andrews_curves, frame=df, class_column="Name")
  134. rgba = ("#556270", "#4ECDC4", "#C7F464")
  135. ax = _check_plot_works(
  136. andrews_curves, frame=df, class_column="Name", color=rgba
  137. )
  138. self._check_colors(
  139. ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
  140. )
  141. cnames = ["dodgerblue", "aquamarine", "seagreen"]
  142. ax = _check_plot_works(
  143. andrews_curves, frame=df, class_column="Name", color=cnames
  144. )
  145. self._check_colors(
  146. ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
  147. )
  148. ax = _check_plot_works(
  149. andrews_curves, frame=df, class_column="Name", colormap=cm.jet
  150. )
  151. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  152. self._check_colors(
  153. ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
  154. )
  155. colors = ["b", "g", "r"]
  156. df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
  157. ax = andrews_curves(df, "Name", color=colors)
  158. handles, labels = ax.get_legend_handles_labels()
  159. self._check_colors(handles, linecolors=colors)
  160. @pytest.mark.slow
  161. def test_parallel_coordinates(self, iris):
  162. from pandas.plotting import parallel_coordinates
  163. from matplotlib import cm
  164. df = iris
  165. ax = _check_plot_works(parallel_coordinates, frame=df, class_column="Name")
  166. nlines = len(ax.get_lines())
  167. nxticks = len(ax.xaxis.get_ticklabels())
  168. rgba = ("#556270", "#4ECDC4", "#C7F464")
  169. ax = _check_plot_works(
  170. parallel_coordinates, frame=df, class_column="Name", color=rgba
  171. )
  172. self._check_colors(
  173. ax.get_lines()[:10], linecolors=rgba, mapping=df["Name"][:10]
  174. )
  175. cnames = ["dodgerblue", "aquamarine", "seagreen"]
  176. ax = _check_plot_works(
  177. parallel_coordinates, frame=df, class_column="Name", color=cnames
  178. )
  179. self._check_colors(
  180. ax.get_lines()[:10], linecolors=cnames, mapping=df["Name"][:10]
  181. )
  182. ax = _check_plot_works(
  183. parallel_coordinates, frame=df, class_column="Name", colormap=cm.jet
  184. )
  185. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  186. self._check_colors(
  187. ax.get_lines()[:10], linecolors=cmaps, mapping=df["Name"][:10]
  188. )
  189. ax = _check_plot_works(
  190. parallel_coordinates, frame=df, class_column="Name", axvlines=False
  191. )
  192. assert len(ax.get_lines()) == (nlines - nxticks)
  193. colors = ["b", "g", "r"]
  194. df = DataFrame({"A": [1, 2, 3], "B": [1, 2, 3], "C": [1, 2, 3], "Name": colors})
  195. ax = parallel_coordinates(df, "Name", color=colors)
  196. handles, labels = ax.get_legend_handles_labels()
  197. self._check_colors(handles, linecolors=colors)
  198. # not sure if this is indicative of a problem
  199. @pytest.mark.filterwarnings("ignore:Attempting to set:UserWarning")
  200. def test_parallel_coordinates_with_sorted_labels(self):
  201. """ For #15908 """
  202. from pandas.plotting import parallel_coordinates
  203. df = DataFrame(
  204. {
  205. "feat": list(range(30)),
  206. "class": [2 for _ in range(10)]
  207. + [3 for _ in range(10)]
  208. + [1 for _ in range(10)],
  209. }
  210. )
  211. ax = parallel_coordinates(df, "class", sort_labels=True)
  212. polylines, labels = ax.get_legend_handles_labels()
  213. color_label_tuples = zip(
  214. [polyline.get_color() for polyline in polylines], labels
  215. )
  216. ordered_color_label_tuples = sorted(color_label_tuples, key=lambda x: x[1])
  217. prev_next_tupels = zip(
  218. list(ordered_color_label_tuples[0:-1]), list(ordered_color_label_tuples[1:])
  219. )
  220. for prev, nxt in prev_next_tupels:
  221. # labels and colors are ordered strictly increasing
  222. assert prev[1] < nxt[1] and prev[0] < nxt[0]
  223. @pytest.mark.slow
  224. def test_radviz(self, iris):
  225. from pandas.plotting import radviz
  226. from matplotlib import cm
  227. df = iris
  228. _check_plot_works(radviz, frame=df, class_column="Name")
  229. rgba = ("#556270", "#4ECDC4", "#C7F464")
  230. ax = _check_plot_works(radviz, frame=df, class_column="Name", color=rgba)
  231. # skip Circle drawn as ticks
  232. patches = [p for p in ax.patches[:20] if p.get_label() != ""]
  233. self._check_colors(patches[:10], facecolors=rgba, mapping=df["Name"][:10])
  234. cnames = ["dodgerblue", "aquamarine", "seagreen"]
  235. _check_plot_works(radviz, frame=df, class_column="Name", color=cnames)
  236. patches = [p for p in ax.patches[:20] if p.get_label() != ""]
  237. self._check_colors(patches, facecolors=cnames, mapping=df["Name"][:10])
  238. _check_plot_works(radviz, frame=df, class_column="Name", colormap=cm.jet)
  239. cmaps = [cm.jet(n) for n in np.linspace(0, 1, df["Name"].nunique())]
  240. patches = [p for p in ax.patches[:20] if p.get_label() != ""]
  241. self._check_colors(patches, facecolors=cmaps, mapping=df["Name"][:10])
  242. colors = [[0.0, 0.0, 1.0, 1.0], [0.0, 0.5, 1.0, 1.0], [1.0, 0.0, 0.0, 1.0]]
  243. df = DataFrame(
  244. {"A": [1, 2, 3], "B": [2, 1, 3], "C": [3, 2, 1], "Name": ["b", "g", "r"]}
  245. )
  246. ax = radviz(df, "Name", color=colors)
  247. handles, labels = ax.get_legend_handles_labels()
  248. self._check_colors(handles, facecolors=colors)
  249. @pytest.mark.slow
  250. def test_subplot_titles(self, iris):
  251. df = iris.drop("Name", axis=1).head()
  252. # Use the column names as the subplot titles
  253. title = list(df.columns)
  254. # Case len(title) == len(df)
  255. plot = df.plot(subplots=True, title=title)
  256. assert [p.get_title() for p in plot] == title
  257. # Case len(title) > len(df)
  258. msg = (
  259. "The length of `title` must equal the number of columns if"
  260. " using `title` of type `list` and `subplots=True`"
  261. )
  262. with pytest.raises(ValueError, match=msg):
  263. df.plot(subplots=True, title=title + ["kittens > puppies"])
  264. # Case len(title) < len(df)
  265. with pytest.raises(ValueError, match=msg):
  266. df.plot(subplots=True, title=title[:2])
  267. # Case subplots=False and title is of type list
  268. msg = (
  269. "Using `title` of type `list` is not supported unless"
  270. " `subplots=True` is passed"
  271. )
  272. with pytest.raises(ValueError, match=msg):
  273. df.plot(subplots=False, title=title)
  274. # Case df with 3 numeric columns but layout of (2,2)
  275. plot = df.drop("SepalWidth", axis=1).plot(
  276. subplots=True, layout=(2, 2), title=title[:-1]
  277. )
  278. title_list = [ax.get_title() for sublist in plot for ax in sublist]
  279. assert title_list == title[:3] + [""]
  280. def test_get_standard_colors_random_seed(self):
  281. # GH17525
  282. df = DataFrame(np.zeros((10, 10)))
  283. # Make sure that the random seed isn't reset by _get_standard_colors
  284. plotting.parallel_coordinates(df, 0)
  285. rand1 = random.random()
  286. plotting.parallel_coordinates(df, 0)
  287. rand2 = random.random()
  288. assert rand1 != rand2
  289. # Make sure it produces the same colors every time it's called
  290. from pandas.plotting._matplotlib.style import _get_standard_colors
  291. color1 = _get_standard_colors(1, color_type="random")
  292. color2 = _get_standard_colors(1, color_type="random")
  293. assert color1 == color2
  294. def test_get_standard_colors_default_num_colors(self):
  295. from pandas.plotting._matplotlib.style import _get_standard_colors
  296. # Make sure the default color_types returns the specified amount
  297. color1 = _get_standard_colors(1, color_type="default")
  298. color2 = _get_standard_colors(9, color_type="default")
  299. color3 = _get_standard_colors(20, color_type="default")
  300. assert len(color1) == 1
  301. assert len(color2) == 9
  302. assert len(color3) == 20
  303. def test_plot_single_color(self):
  304. # Example from #20585. All 3 bars should have the same color
  305. df = DataFrame(
  306. {
  307. "account-start": ["2017-02-03", "2017-03-03", "2017-01-01"],
  308. "client": ["Alice Anders", "Bob Baker", "Charlie Chaplin"],
  309. "balance": [-1432.32, 10.43, 30000.00],
  310. "db-id": [1234, 2424, 251],
  311. "proxy-id": [525, 1525, 2542],
  312. "rank": [52, 525, 32],
  313. }
  314. )
  315. ax = df.client.value_counts().plot.bar()
  316. colors = [rect.get_facecolor() for rect in ax.get_children()[0:3]]
  317. assert all(color == colors[0] for color in colors)
  318. def test_get_standard_colors_no_appending(self):
  319. # GH20726
  320. # Make sure not to add more colors so that matplotlib can cycle
  321. # correctly.
  322. from matplotlib import cm
  323. from pandas.plotting._matplotlib.style import _get_standard_colors
  324. color_before = cm.gnuplot(range(5))
  325. color_after = _get_standard_colors(1, color=color_before)
  326. assert len(color_after) == len(color_before)
  327. df = DataFrame(np.random.randn(48, 4), columns=list("ABCD"))
  328. color_list = cm.gnuplot(np.linspace(0, 1, 16))
  329. p = df.A.plot.bar(figsize=(16, 7), color=color_list)
  330. assert p.patches[1].get_facecolor() == p.patches[17].get_facecolor()