hist.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. import numpy as np
  2. from pandas.core.dtypes.common import is_integer, is_list_like
  3. from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass
  4. from pandas.core.dtypes.missing import isna, remove_na_arraylike
  5. import pandas.core.common as com
  6. from pandas.io.formats.printing import pprint_thing
  7. from pandas.plotting._matplotlib.core import LinePlot, MPLPlot
  8. from pandas.plotting._matplotlib.tools import _flatten, _set_ticks_props, _subplots
  9. class HistPlot(LinePlot):
  10. _kind = "hist"
  11. def __init__(self, data, bins=10, bottom=0, **kwargs):
  12. self.bins = bins # use mpl default
  13. self.bottom = bottom
  14. # Do not call LinePlot.__init__ which may fill nan
  15. MPLPlot.__init__(self, data, **kwargs)
  16. def _args_adjust(self):
  17. if is_integer(self.bins):
  18. # create common bin edge
  19. values = self.data._convert(datetime=True)._get_numeric_data()
  20. values = np.ravel(values)
  21. values = values[~isna(values)]
  22. _, self.bins = np.histogram(
  23. values,
  24. bins=self.bins,
  25. range=self.kwds.get("range", None),
  26. weights=self.kwds.get("weights", None),
  27. )
  28. if is_list_like(self.bottom):
  29. self.bottom = np.array(self.bottom)
  30. @classmethod
  31. def _plot(
  32. cls,
  33. ax,
  34. y,
  35. style=None,
  36. bins=None,
  37. bottom=0,
  38. column_num=0,
  39. stacking_id=None,
  40. **kwds,
  41. ):
  42. if column_num == 0:
  43. cls._initialize_stacker(ax, stacking_id, len(bins) - 1)
  44. y = y[~isna(y)]
  45. base = np.zeros(len(bins) - 1)
  46. bottom = bottom + cls._get_stacked_values(ax, stacking_id, base, kwds["label"])
  47. # ignore style
  48. n, bins, patches = ax.hist(y, bins=bins, bottom=bottom, **kwds)
  49. cls._update_stacker(ax, stacking_id, n)
  50. return patches
  51. def _make_plot(self):
  52. colors = self._get_colors()
  53. stacking_id = self._get_stacking_id()
  54. for i, (label, y) in enumerate(self._iter_data()):
  55. ax = self._get_ax(i)
  56. kwds = self.kwds.copy()
  57. label = pprint_thing(label)
  58. kwds["label"] = label
  59. style, kwds = self._apply_style_colors(colors, kwds, i, label)
  60. if style is not None:
  61. kwds["style"] = style
  62. kwds = self._make_plot_keywords(kwds, y)
  63. artists = self._plot(ax, y, column_num=i, stacking_id=stacking_id, **kwds)
  64. self._add_legend_handle(artists[0], label, index=i)
  65. def _make_plot_keywords(self, kwds, y):
  66. """merge BoxPlot/KdePlot properties to passed kwds"""
  67. # y is required for KdePlot
  68. kwds["bottom"] = self.bottom
  69. kwds["bins"] = self.bins
  70. return kwds
  71. def _post_plot_logic(self, ax, data):
  72. if self.orientation == "horizontal":
  73. ax.set_xlabel("Frequency")
  74. else:
  75. ax.set_ylabel("Frequency")
  76. @property
  77. def orientation(self):
  78. if self.kwds.get("orientation", None) == "horizontal":
  79. return "horizontal"
  80. else:
  81. return "vertical"
  82. class KdePlot(HistPlot):
  83. _kind = "kde"
  84. orientation = "vertical"
  85. def __init__(self, data, bw_method=None, ind=None, **kwargs):
  86. MPLPlot.__init__(self, data, **kwargs)
  87. self.bw_method = bw_method
  88. self.ind = ind
  89. def _args_adjust(self):
  90. pass
  91. def _get_ind(self, y):
  92. if self.ind is None:
  93. # np.nanmax() and np.nanmin() ignores the missing values
  94. sample_range = np.nanmax(y) - np.nanmin(y)
  95. ind = np.linspace(
  96. np.nanmin(y) - 0.5 * sample_range,
  97. np.nanmax(y) + 0.5 * sample_range,
  98. 1000,
  99. )
  100. elif is_integer(self.ind):
  101. sample_range = np.nanmax(y) - np.nanmin(y)
  102. ind = np.linspace(
  103. np.nanmin(y) - 0.5 * sample_range,
  104. np.nanmax(y) + 0.5 * sample_range,
  105. self.ind,
  106. )
  107. else:
  108. ind = self.ind
  109. return ind
  110. @classmethod
  111. def _plot(
  112. cls,
  113. ax,
  114. y,
  115. style=None,
  116. bw_method=None,
  117. ind=None,
  118. column_num=None,
  119. stacking_id=None,
  120. **kwds,
  121. ):
  122. from scipy.stats import gaussian_kde
  123. y = remove_na_arraylike(y)
  124. gkde = gaussian_kde(y, bw_method=bw_method)
  125. y = gkde.evaluate(ind)
  126. lines = MPLPlot._plot(ax, ind, y, style=style, **kwds)
  127. return lines
  128. def _make_plot_keywords(self, kwds, y):
  129. kwds["bw_method"] = self.bw_method
  130. kwds["ind"] = self._get_ind(y)
  131. return kwds
  132. def _post_plot_logic(self, ax, data):
  133. ax.set_ylabel("Density")
  134. def _grouped_plot(
  135. plotf,
  136. data,
  137. column=None,
  138. by=None,
  139. numeric_only=True,
  140. figsize=None,
  141. sharex=True,
  142. sharey=True,
  143. layout=None,
  144. rot=0,
  145. ax=None,
  146. **kwargs,
  147. ):
  148. if figsize == "default":
  149. # allowed to specify mpl default with 'default'
  150. raise ValueError(
  151. "figsize='default' is no longer supported. "
  152. "Specify figure size by tuple instead"
  153. )
  154. grouped = data.groupby(by)
  155. if column is not None:
  156. grouped = grouped[column]
  157. naxes = len(grouped)
  158. fig, axes = _subplots(
  159. naxes=naxes, figsize=figsize, sharex=sharex, sharey=sharey, ax=ax, layout=layout
  160. )
  161. _axes = _flatten(axes)
  162. for i, (key, group) in enumerate(grouped):
  163. ax = _axes[i]
  164. if numeric_only and isinstance(group, ABCDataFrame):
  165. group = group._get_numeric_data()
  166. plotf(group, ax, **kwargs)
  167. ax.set_title(pprint_thing(key))
  168. return fig, axes
  169. def _grouped_hist(
  170. data,
  171. column=None,
  172. by=None,
  173. ax=None,
  174. bins=50,
  175. figsize=None,
  176. layout=None,
  177. sharex=False,
  178. sharey=False,
  179. rot=90,
  180. grid=True,
  181. xlabelsize=None,
  182. xrot=None,
  183. ylabelsize=None,
  184. yrot=None,
  185. **kwargs,
  186. ):
  187. """
  188. Grouped histogram
  189. Parameters
  190. ----------
  191. data : Series/DataFrame
  192. column : object, optional
  193. by : object, optional
  194. ax : axes, optional
  195. bins : int, default 50
  196. figsize : tuple, optional
  197. layout : optional
  198. sharex : bool, default False
  199. sharey : bool, default False
  200. rot : int, default 90
  201. grid : bool, default True
  202. kwargs : dict, keyword arguments passed to matplotlib.Axes.hist
  203. Returns
  204. -------
  205. collection of Matplotlib Axes
  206. """
  207. def plot_group(group, ax):
  208. ax.hist(group.dropna().values, bins=bins, **kwargs)
  209. if xrot is None:
  210. xrot = rot
  211. fig, axes = _grouped_plot(
  212. plot_group,
  213. data,
  214. column=column,
  215. by=by,
  216. sharex=sharex,
  217. sharey=sharey,
  218. ax=ax,
  219. figsize=figsize,
  220. layout=layout,
  221. rot=rot,
  222. )
  223. _set_ticks_props(
  224. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  225. )
  226. fig.subplots_adjust(
  227. bottom=0.15, top=0.9, left=0.1, right=0.9, hspace=0.5, wspace=0.3
  228. )
  229. return axes
  230. def hist_series(
  231. self,
  232. by=None,
  233. ax=None,
  234. grid=True,
  235. xlabelsize=None,
  236. xrot=None,
  237. ylabelsize=None,
  238. yrot=None,
  239. figsize=None,
  240. bins=10,
  241. **kwds,
  242. ):
  243. import matplotlib.pyplot as plt
  244. if by is None:
  245. if kwds.get("layout", None) is not None:
  246. raise ValueError("The 'layout' keyword is not supported when 'by' is None")
  247. # hack until the plotting interface is a bit more unified
  248. fig = kwds.pop(
  249. "figure", plt.gcf() if plt.get_fignums() else plt.figure(figsize=figsize)
  250. )
  251. if figsize is not None and tuple(figsize) != tuple(fig.get_size_inches()):
  252. fig.set_size_inches(*figsize, forward=True)
  253. if ax is None:
  254. ax = fig.gca()
  255. elif ax.get_figure() != fig:
  256. raise AssertionError("passed axis not bound to passed figure")
  257. values = self.dropna().values
  258. ax.hist(values, bins=bins, **kwds)
  259. ax.grid(grid)
  260. axes = np.array([ax])
  261. _set_ticks_props(
  262. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  263. )
  264. else:
  265. if "figure" in kwds:
  266. raise ValueError(
  267. "Cannot pass 'figure' when using the "
  268. "'by' argument, since a new 'Figure' instance "
  269. "will be created"
  270. )
  271. axes = _grouped_hist(
  272. self,
  273. by=by,
  274. ax=ax,
  275. grid=grid,
  276. figsize=figsize,
  277. bins=bins,
  278. xlabelsize=xlabelsize,
  279. xrot=xrot,
  280. ylabelsize=ylabelsize,
  281. yrot=yrot,
  282. **kwds,
  283. )
  284. if hasattr(axes, "ndim"):
  285. if axes.ndim == 1 and len(axes) == 1:
  286. return axes[0]
  287. return axes
  288. def hist_frame(
  289. data,
  290. column=None,
  291. by=None,
  292. grid=True,
  293. xlabelsize=None,
  294. xrot=None,
  295. ylabelsize=None,
  296. yrot=None,
  297. ax=None,
  298. sharex=False,
  299. sharey=False,
  300. figsize=None,
  301. layout=None,
  302. bins=10,
  303. **kwds,
  304. ):
  305. if by is not None:
  306. axes = _grouped_hist(
  307. data,
  308. column=column,
  309. by=by,
  310. ax=ax,
  311. grid=grid,
  312. figsize=figsize,
  313. sharex=sharex,
  314. sharey=sharey,
  315. layout=layout,
  316. bins=bins,
  317. xlabelsize=xlabelsize,
  318. xrot=xrot,
  319. ylabelsize=ylabelsize,
  320. yrot=yrot,
  321. **kwds,
  322. )
  323. return axes
  324. if column is not None:
  325. if not isinstance(column, (list, np.ndarray, ABCIndexClass)):
  326. column = [column]
  327. data = data[column]
  328. data = data._get_numeric_data()
  329. naxes = len(data.columns)
  330. if naxes == 0:
  331. raise ValueError("hist method requires numerical columns, nothing to plot.")
  332. fig, axes = _subplots(
  333. naxes=naxes,
  334. ax=ax,
  335. squeeze=False,
  336. sharex=sharex,
  337. sharey=sharey,
  338. figsize=figsize,
  339. layout=layout,
  340. )
  341. _axes = _flatten(axes)
  342. for i, col in enumerate(com.try_sort(data.columns)):
  343. ax = _axes[i]
  344. ax.hist(data[col].dropna().values, bins=bins, **kwds)
  345. ax.set_title(col)
  346. ax.grid(grid)
  347. _set_ticks_props(
  348. axes, xlabelsize=xlabelsize, xrot=xrot, ylabelsize=ylabelsize, yrot=yrot
  349. )
  350. fig.subplots_adjust(wspace=0.3, hspace=0.3)
  351. return axes