boxplot.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. from collections import namedtuple
  2. import warnings
  3. from matplotlib.artist import setp
  4. import numpy as np
  5. from pandas.core.dtypes.common import is_dict_like
  6. from pandas.core.dtypes.generic import ABCSeries
  7. from pandas.core.dtypes.missing import remove_na_arraylike
  8. import pandas as pd
  9. from pandas.io.formats.printing import pprint_thing
  10. from pandas.plotting._matplotlib.core import LinePlot, MPLPlot
  11. from pandas.plotting._matplotlib.style import _get_standard_colors
  12. from pandas.plotting._matplotlib.tools import _flatten, _subplots
  13. class BoxPlot(LinePlot):
  14. _kind = "box"
  15. _layout_type = "horizontal"
  16. _valid_return_types = (None, "axes", "dict", "both")
  17. # namedtuple to hold results
  18. BP = namedtuple("Boxplot", ["ax", "lines"])
  19. def __init__(self, data, return_type="axes", **kwargs):
  20. # Do not call LinePlot.__init__ which may fill nan
  21. if return_type not in self._valid_return_types:
  22. raise ValueError("return_type must be {None, 'axes', 'dict', 'both'}")
  23. self.return_type = return_type
  24. MPLPlot.__init__(self, data, **kwargs)
  25. def _args_adjust(self):
  26. if self.subplots:
  27. # Disable label ax sharing. Otherwise, all subplots shows last
  28. # column label
  29. if self.orientation == "vertical":
  30. self.sharex = False
  31. else:
  32. self.sharey = False
  33. @classmethod
  34. def _plot(cls, ax, y, column_num=None, return_type="axes", **kwds):
  35. if y.ndim == 2:
  36. y = [remove_na_arraylike(v) for v in y]
  37. # Boxplot fails with empty arrays, so need to add a NaN
  38. # if any cols are empty
  39. # GH 8181
  40. y = [v if v.size > 0 else np.array([np.nan]) for v in y]
  41. else:
  42. y = remove_na_arraylike(y)
  43. bp = ax.boxplot(y, **kwds)
  44. if return_type == "dict":
  45. return bp, bp
  46. elif return_type == "both":
  47. return cls.BP(ax=ax, lines=bp), bp
  48. else:
  49. return ax, bp
  50. def _validate_color_args(self):
  51. if "color" in self.kwds:
  52. if self.colormap is not None:
  53. warnings.warn(
  54. "'color' and 'colormap' cannot be used "
  55. "simultaneously. Using 'color'"
  56. )
  57. self.color = self.kwds.pop("color")
  58. if isinstance(self.color, dict):
  59. valid_keys = ["boxes", "whiskers", "medians", "caps"]
  60. for key, values in self.color.items():
  61. if key not in valid_keys:
  62. raise ValueError(
  63. f"color dict contains invalid key '{key}'. "
  64. f"The key must be either {valid_keys}"
  65. )
  66. else:
  67. self.color = None
  68. # get standard colors for default
  69. colors = _get_standard_colors(num_colors=3, colormap=self.colormap, color=None)
  70. # use 2 colors by default, for box/whisker and median
  71. # flier colors isn't needed here
  72. # because it can be specified by ``sym`` kw
  73. self._boxes_c = colors[0]
  74. self._whiskers_c = colors[0]
  75. self._medians_c = colors[2]
  76. self._caps_c = "k" # mpl default
  77. def _get_colors(self, num_colors=None, color_kwds="color"):
  78. pass
  79. def maybe_color_bp(self, bp):
  80. if isinstance(self.color, dict):
  81. boxes = self.color.get("boxes", self._boxes_c)
  82. whiskers = self.color.get("whiskers", self._whiskers_c)
  83. medians = self.color.get("medians", self._medians_c)
  84. caps = self.color.get("caps", self._caps_c)
  85. else:
  86. # Other types are forwarded to matplotlib
  87. # If None, use default colors
  88. boxes = self.color or self._boxes_c
  89. whiskers = self.color or self._whiskers_c
  90. medians = self.color or self._medians_c
  91. caps = self.color or self._caps_c
  92. setp(bp["boxes"], color=boxes, alpha=1)
  93. setp(bp["whiskers"], color=whiskers, alpha=1)
  94. setp(bp["medians"], color=medians, alpha=1)
  95. setp(bp["caps"], color=caps, alpha=1)
  96. def _make_plot(self):
  97. if self.subplots:
  98. self._return_obj = pd.Series(dtype=object)
  99. for i, (label, y) in enumerate(self._iter_data()):
  100. ax = self._get_ax(i)
  101. kwds = self.kwds.copy()
  102. ret, bp = self._plot(
  103. ax, y, column_num=i, return_type=self.return_type, **kwds
  104. )
  105. self.maybe_color_bp(bp)
  106. self._return_obj[label] = ret
  107. label = [pprint_thing(label)]
  108. self._set_ticklabels(ax, label)
  109. else:
  110. y = self.data.values.T
  111. ax = self._get_ax(0)
  112. kwds = self.kwds.copy()
  113. ret, bp = self._plot(
  114. ax, y, column_num=0, return_type=self.return_type, **kwds
  115. )
  116. self.maybe_color_bp(bp)
  117. self._return_obj = ret
  118. labels = [l for l, _ in self._iter_data()]
  119. labels = [pprint_thing(l) for l in labels]
  120. if not self.use_index:
  121. labels = [pprint_thing(key) for key in range(len(labels))]
  122. self._set_ticklabels(ax, labels)
  123. def _set_ticklabels(self, ax, labels):
  124. if self.orientation == "vertical":
  125. ax.set_xticklabels(labels)
  126. else:
  127. ax.set_yticklabels(labels)
  128. def _make_legend(self):
  129. pass
  130. def _post_plot_logic(self, ax, data):
  131. pass
  132. @property
  133. def orientation(self):
  134. if self.kwds.get("vert", True):
  135. return "vertical"
  136. else:
  137. return "horizontal"
  138. @property
  139. def result(self):
  140. if self.return_type is None:
  141. return super().result
  142. else:
  143. return self._return_obj
  144. def _grouped_plot_by_column(
  145. plotf,
  146. data,
  147. columns=None,
  148. by=None,
  149. numeric_only=True,
  150. grid=False,
  151. figsize=None,
  152. ax=None,
  153. layout=None,
  154. return_type=None,
  155. **kwargs,
  156. ):
  157. grouped = data.groupby(by)
  158. if columns is None:
  159. if not isinstance(by, (list, tuple)):
  160. by = [by]
  161. columns = data._get_numeric_data().columns.difference(by)
  162. naxes = len(columns)
  163. fig, axes = _subplots(
  164. naxes=naxes, sharex=True, sharey=True, figsize=figsize, ax=ax, layout=layout
  165. )
  166. _axes = _flatten(axes)
  167. ax_values = []
  168. for i, col in enumerate(columns):
  169. ax = _axes[i]
  170. gp_col = grouped[col]
  171. keys, values = zip(*gp_col)
  172. re_plotf = plotf(keys, values, ax, **kwargs)
  173. ax.set_title(col)
  174. ax.set_xlabel(pprint_thing(by))
  175. ax_values.append(re_plotf)
  176. ax.grid(grid)
  177. result = pd.Series(ax_values, index=columns)
  178. # Return axes in multiplot case, maybe revisit later # 985
  179. if return_type is None:
  180. result = axes
  181. byline = by[0] if len(by) == 1 else by
  182. fig.suptitle(f"Boxplot grouped by {byline}")
  183. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  184. return result
  185. def boxplot(
  186. data,
  187. column=None,
  188. by=None,
  189. ax=None,
  190. fontsize=None,
  191. rot=0,
  192. grid=True,
  193. figsize=None,
  194. layout=None,
  195. return_type=None,
  196. **kwds,
  197. ):
  198. import matplotlib.pyplot as plt
  199. # validate return_type:
  200. if return_type not in BoxPlot._valid_return_types:
  201. raise ValueError("return_type must be {'axes', 'dict', 'both'}")
  202. if isinstance(data, ABCSeries):
  203. data = data.to_frame("x")
  204. column = "x"
  205. def _get_colors():
  206. # num_colors=3 is required as method maybe_color_bp takes the colors
  207. # in positions 0 and 2.
  208. # if colors not provided, use same defaults as DataFrame.plot.box
  209. result = _get_standard_colors(num_colors=3)
  210. result = np.take(result, [0, 0, 2])
  211. result = np.append(result, "k")
  212. colors = kwds.pop("color", None)
  213. if colors:
  214. if is_dict_like(colors):
  215. # replace colors in result array with user-specified colors
  216. # taken from the colors dict parameter
  217. # "boxes" value placed in position 0, "whiskers" in 1, etc.
  218. valid_keys = ["boxes", "whiskers", "medians", "caps"]
  219. key_to_index = dict(zip(valid_keys, range(4)))
  220. for key, value in colors.items():
  221. if key in valid_keys:
  222. result[key_to_index[key]] = value
  223. else:
  224. raise ValueError(
  225. f"color dict contains invalid key '{key}'. "
  226. f"The key must be either {valid_keys}"
  227. )
  228. else:
  229. result.fill(colors)
  230. return result
  231. def maybe_color_bp(bp):
  232. setp(bp["boxes"], color=colors[0], alpha=1)
  233. setp(bp["whiskers"], color=colors[1], alpha=1)
  234. setp(bp["medians"], color=colors[2], alpha=1)
  235. setp(bp["caps"], color=colors[3], alpha=1)
  236. def plot_group(keys, values, ax):
  237. keys = [pprint_thing(x) for x in keys]
  238. values = [np.asarray(remove_na_arraylike(v)) for v in values]
  239. bp = ax.boxplot(values, **kwds)
  240. if fontsize is not None:
  241. ax.tick_params(axis="both", labelsize=fontsize)
  242. if kwds.get("vert", 1):
  243. ax.set_xticklabels(keys, rotation=rot)
  244. else:
  245. ax.set_yticklabels(keys, rotation=rot)
  246. maybe_color_bp(bp)
  247. # Return axes in multiplot case, maybe revisit later # 985
  248. if return_type == "dict":
  249. return bp
  250. elif return_type == "both":
  251. return BoxPlot.BP(ax=ax, lines=bp)
  252. else:
  253. return ax
  254. colors = _get_colors()
  255. if column is None:
  256. columns = None
  257. else:
  258. if isinstance(column, (list, tuple)):
  259. columns = column
  260. else:
  261. columns = [column]
  262. if by is not None:
  263. # Prefer array return type for 2-D plots to match the subplot layout
  264. # https://github.com/pandas-dev/pandas/pull/12216#issuecomment-241175580
  265. result = _grouped_plot_by_column(
  266. plot_group,
  267. data,
  268. columns=columns,
  269. by=by,
  270. grid=grid,
  271. figsize=figsize,
  272. ax=ax,
  273. layout=layout,
  274. return_type=return_type,
  275. )
  276. else:
  277. if return_type is None:
  278. return_type = "axes"
  279. if layout is not None:
  280. raise ValueError("The 'layout' keyword is not supported when 'by' is None")
  281. if ax is None:
  282. rc = {"figure.figsize": figsize} if figsize is not None else {}
  283. with plt.rc_context(rc):
  284. ax = plt.gca()
  285. data = data._get_numeric_data()
  286. if columns is None:
  287. columns = data.columns
  288. else:
  289. data = data[columns]
  290. result = plot_group(columns, data.values.T, ax)
  291. ax.grid(grid)
  292. return result
  293. def boxplot_frame(
  294. self,
  295. column=None,
  296. by=None,
  297. ax=None,
  298. fontsize=None,
  299. rot=0,
  300. grid=True,
  301. figsize=None,
  302. layout=None,
  303. return_type=None,
  304. **kwds,
  305. ):
  306. import matplotlib.pyplot as plt
  307. ax = boxplot(
  308. self,
  309. column=column,
  310. by=by,
  311. ax=ax,
  312. fontsize=fontsize,
  313. grid=grid,
  314. rot=rot,
  315. figsize=figsize,
  316. layout=layout,
  317. return_type=return_type,
  318. **kwds,
  319. )
  320. plt.draw_if_interactive()
  321. return ax
  322. def boxplot_frame_groupby(
  323. grouped,
  324. subplots=True,
  325. column=None,
  326. fontsize=None,
  327. rot=0,
  328. grid=True,
  329. ax=None,
  330. figsize=None,
  331. layout=None,
  332. sharex=False,
  333. sharey=True,
  334. **kwds,
  335. ):
  336. if subplots is True:
  337. naxes = len(grouped)
  338. fig, axes = _subplots(
  339. naxes=naxes,
  340. squeeze=False,
  341. ax=ax,
  342. sharex=sharex,
  343. sharey=sharey,
  344. figsize=figsize,
  345. layout=layout,
  346. )
  347. axes = _flatten(axes)
  348. ret = pd.Series(dtype=object)
  349. for (key, group), ax in zip(grouped, axes):
  350. d = group.boxplot(
  351. ax=ax, column=column, fontsize=fontsize, rot=rot, grid=grid, **kwds
  352. )
  353. ax.set_title(pprint_thing(key))
  354. ret.loc[key] = d
  355. fig.subplots_adjust(bottom=0.15, top=0.9, left=0.1, right=0.9, wspace=0.2)
  356. else:
  357. keys, frames = zip(*grouped)
  358. if grouped.axis == 0:
  359. df = pd.concat(frames, keys=keys, axis=1)
  360. else:
  361. if len(frames) > 1:
  362. df = frames[0].join(frames[1::])
  363. else:
  364. df = frames[0]
  365. ret = df.boxplot(
  366. column=column,
  367. fontsize=fontsize,
  368. rot=rot,
  369. grid=grid,
  370. ax=ax,
  371. figsize=figsize,
  372. layout=layout,
  373. **kwds,
  374. )
  375. return ret