misc.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. import random
  2. import matplotlib.lines as mlines
  3. import matplotlib.patches as patches
  4. import numpy as np
  5. from pandas.core.dtypes.missing import notna
  6. from pandas.io.formats.printing import pprint_thing
  7. from pandas.plotting._matplotlib.style import _get_standard_colors
  8. from pandas.plotting._matplotlib.tools import _set_ticks_props, _subplots
  9. def scatter_matrix(
  10. frame,
  11. alpha=0.5,
  12. figsize=None,
  13. ax=None,
  14. grid=False,
  15. diagonal="hist",
  16. marker=".",
  17. density_kwds=None,
  18. hist_kwds=None,
  19. range_padding=0.05,
  20. **kwds,
  21. ):
  22. df = frame._get_numeric_data()
  23. n = df.columns.size
  24. naxes = n * n
  25. fig, axes = _subplots(naxes=naxes, figsize=figsize, ax=ax, squeeze=False)
  26. # no gaps between subplots
  27. fig.subplots_adjust(wspace=0, hspace=0)
  28. mask = notna(df)
  29. marker = _get_marker_compat(marker)
  30. hist_kwds = hist_kwds or {}
  31. density_kwds = density_kwds or {}
  32. # GH 14855
  33. kwds.setdefault("edgecolors", "none")
  34. boundaries_list = []
  35. for a in df.columns:
  36. values = df[a].values[mask[a].values]
  37. rmin_, rmax_ = np.min(values), np.max(values)
  38. rdelta_ext = (rmax_ - rmin_) * range_padding / 2.0
  39. boundaries_list.append((rmin_ - rdelta_ext, rmax_ + rdelta_ext))
  40. for i, a in enumerate(df.columns):
  41. for j, b in enumerate(df.columns):
  42. ax = axes[i, j]
  43. if i == j:
  44. values = df[a].values[mask[a].values]
  45. # Deal with the diagonal by drawing a histogram there.
  46. if diagonal == "hist":
  47. ax.hist(values, **hist_kwds)
  48. elif diagonal in ("kde", "density"):
  49. from scipy.stats import gaussian_kde
  50. y = values
  51. gkde = gaussian_kde(y)
  52. ind = np.linspace(y.min(), y.max(), 1000)
  53. ax.plot(ind, gkde.evaluate(ind), **density_kwds)
  54. ax.set_xlim(boundaries_list[i])
  55. else:
  56. common = (mask[a] & mask[b]).values
  57. ax.scatter(
  58. df[b][common], df[a][common], marker=marker, alpha=alpha, **kwds
  59. )
  60. ax.set_xlim(boundaries_list[j])
  61. ax.set_ylim(boundaries_list[i])
  62. ax.set_xlabel(b)
  63. ax.set_ylabel(a)
  64. if j != 0:
  65. ax.yaxis.set_visible(False)
  66. if i != n - 1:
  67. ax.xaxis.set_visible(False)
  68. if len(df.columns) > 1:
  69. lim1 = boundaries_list[0]
  70. locs = axes[0][1].yaxis.get_majorticklocs()
  71. locs = locs[(lim1[0] <= locs) & (locs <= lim1[1])]
  72. adj = (locs - lim1[0]) / (lim1[1] - lim1[0])
  73. lim0 = axes[0][0].get_ylim()
  74. adj = adj * (lim0[1] - lim0[0]) + lim0[0]
  75. axes[0][0].yaxis.set_ticks(adj)
  76. if np.all(locs == locs.astype(int)):
  77. # if all ticks are int
  78. locs = locs.astype(int)
  79. axes[0][0].yaxis.set_ticklabels(locs)
  80. _set_ticks_props(axes, xlabelsize=8, xrot=90, ylabelsize=8, yrot=0)
  81. return axes
  82. def _get_marker_compat(marker):
  83. if marker not in mlines.lineMarkers:
  84. return "o"
  85. return marker
  86. def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
  87. import matplotlib.pyplot as plt
  88. def normalize(series):
  89. a = min(series)
  90. b = max(series)
  91. return (series - a) / (b - a)
  92. n = len(frame)
  93. classes = frame[class_column].drop_duplicates()
  94. class_col = frame[class_column]
  95. df = frame.drop(class_column, axis=1).apply(normalize)
  96. if ax is None:
  97. ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
  98. to_plot = {}
  99. colors = _get_standard_colors(
  100. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  101. )
  102. for kls in classes:
  103. to_plot[kls] = [[], []]
  104. m = len(frame.columns) - 1
  105. s = np.array(
  106. [
  107. (np.cos(t), np.sin(t))
  108. for t in [2.0 * np.pi * (i / float(m)) for i in range(m)]
  109. ]
  110. )
  111. for i in range(n):
  112. row = df.iloc[i].values
  113. row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
  114. y = (s * row_).sum(axis=0) / row.sum()
  115. kls = class_col.iat[i]
  116. to_plot[kls][0].append(y[0])
  117. to_plot[kls][1].append(y[1])
  118. for i, kls in enumerate(classes):
  119. ax.scatter(
  120. to_plot[kls][0],
  121. to_plot[kls][1],
  122. color=colors[i],
  123. label=pprint_thing(kls),
  124. **kwds,
  125. )
  126. ax.legend()
  127. ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor="none"))
  128. for xy, name in zip(s, df.columns):
  129. ax.add_patch(patches.Circle(xy, radius=0.025, facecolor="gray"))
  130. if xy[0] < 0.0 and xy[1] < 0.0:
  131. ax.text(
  132. xy[0] - 0.025, xy[1] - 0.025, name, ha="right", va="top", size="small"
  133. )
  134. elif xy[0] < 0.0 and xy[1] >= 0.0:
  135. ax.text(
  136. xy[0] - 0.025,
  137. xy[1] + 0.025,
  138. name,
  139. ha="right",
  140. va="bottom",
  141. size="small",
  142. )
  143. elif xy[0] >= 0.0 and xy[1] < 0.0:
  144. ax.text(
  145. xy[0] + 0.025, xy[1] - 0.025, name, ha="left", va="top", size="small"
  146. )
  147. elif xy[0] >= 0.0 and xy[1] >= 0.0:
  148. ax.text(
  149. xy[0] + 0.025, xy[1] + 0.025, name, ha="left", va="bottom", size="small"
  150. )
  151. ax.axis("equal")
  152. return ax
  153. def andrews_curves(
  154. frame, class_column, ax=None, samples=200, color=None, colormap=None, **kwds
  155. ):
  156. import matplotlib.pyplot as plt
  157. def function(amplitudes):
  158. def f(t):
  159. x1 = amplitudes[0]
  160. result = x1 / np.sqrt(2.0)
  161. # Take the rest of the coefficients and resize them
  162. # appropriately. Take a copy of amplitudes as otherwise numpy
  163. # deletes the element from amplitudes itself.
  164. coeffs = np.delete(np.copy(amplitudes), 0)
  165. coeffs.resize(int((coeffs.size + 1) / 2), 2)
  166. # Generate the harmonics and arguments for the sin and cos
  167. # functions.
  168. harmonics = np.arange(0, coeffs.shape[0]) + 1
  169. trig_args = np.outer(harmonics, t)
  170. result += np.sum(
  171. coeffs[:, 0, np.newaxis] * np.sin(trig_args)
  172. + coeffs[:, 1, np.newaxis] * np.cos(trig_args),
  173. axis=0,
  174. )
  175. return result
  176. return f
  177. n = len(frame)
  178. class_col = frame[class_column]
  179. classes = frame[class_column].drop_duplicates()
  180. df = frame.drop(class_column, axis=1)
  181. t = np.linspace(-np.pi, np.pi, samples)
  182. used_legends = set()
  183. color_values = _get_standard_colors(
  184. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  185. )
  186. colors = dict(zip(classes, color_values))
  187. if ax is None:
  188. ax = plt.gca(xlim=(-np.pi, np.pi))
  189. for i in range(n):
  190. row = df.iloc[i].values
  191. f = function(row)
  192. y = f(t)
  193. kls = class_col.iat[i]
  194. label = pprint_thing(kls)
  195. if label not in used_legends:
  196. used_legends.add(label)
  197. ax.plot(t, y, color=colors[kls], label=label, **kwds)
  198. else:
  199. ax.plot(t, y, color=colors[kls], **kwds)
  200. ax.legend(loc="upper right")
  201. ax.grid()
  202. return ax
  203. def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
  204. import matplotlib.pyplot as plt
  205. # random.sample(ndarray, int) fails on python 3.3, sigh
  206. data = list(series.values)
  207. samplings = [random.sample(data, size) for _ in range(samples)]
  208. means = np.array([np.mean(sampling) for sampling in samplings])
  209. medians = np.array([np.median(sampling) for sampling in samplings])
  210. midranges = np.array(
  211. [(min(sampling) + max(sampling)) * 0.5 for sampling in samplings]
  212. )
  213. if fig is None:
  214. fig = plt.figure()
  215. x = list(range(samples))
  216. axes = []
  217. ax1 = fig.add_subplot(2, 3, 1)
  218. ax1.set_xlabel("Sample")
  219. axes.append(ax1)
  220. ax1.plot(x, means, **kwds)
  221. ax2 = fig.add_subplot(2, 3, 2)
  222. ax2.set_xlabel("Sample")
  223. axes.append(ax2)
  224. ax2.plot(x, medians, **kwds)
  225. ax3 = fig.add_subplot(2, 3, 3)
  226. ax3.set_xlabel("Sample")
  227. axes.append(ax3)
  228. ax3.plot(x, midranges, **kwds)
  229. ax4 = fig.add_subplot(2, 3, 4)
  230. ax4.set_xlabel("Mean")
  231. axes.append(ax4)
  232. ax4.hist(means, **kwds)
  233. ax5 = fig.add_subplot(2, 3, 5)
  234. ax5.set_xlabel("Median")
  235. axes.append(ax5)
  236. ax5.hist(medians, **kwds)
  237. ax6 = fig.add_subplot(2, 3, 6)
  238. ax6.set_xlabel("Midrange")
  239. axes.append(ax6)
  240. ax6.hist(midranges, **kwds)
  241. for axis in axes:
  242. plt.setp(axis.get_xticklabels(), fontsize=8)
  243. plt.setp(axis.get_yticklabels(), fontsize=8)
  244. return fig
  245. def parallel_coordinates(
  246. frame,
  247. class_column,
  248. cols=None,
  249. ax=None,
  250. color=None,
  251. use_columns=False,
  252. xticks=None,
  253. colormap=None,
  254. axvlines=True,
  255. axvlines_kwds=None,
  256. sort_labels=False,
  257. **kwds,
  258. ):
  259. import matplotlib.pyplot as plt
  260. if axvlines_kwds is None:
  261. axvlines_kwds = {"linewidth": 1, "color": "black"}
  262. n = len(frame)
  263. classes = frame[class_column].drop_duplicates()
  264. class_col = frame[class_column]
  265. if cols is None:
  266. df = frame.drop(class_column, axis=1)
  267. else:
  268. df = frame[cols]
  269. used_legends = set()
  270. ncols = len(df.columns)
  271. # determine values to use for xticks
  272. if use_columns is True:
  273. if not np.all(np.isreal(list(df.columns))):
  274. raise ValueError("Columns must be numeric to be used as xticks")
  275. x = df.columns
  276. elif xticks is not None:
  277. if not np.all(np.isreal(xticks)):
  278. raise ValueError("xticks specified must be numeric")
  279. elif len(xticks) != ncols:
  280. raise ValueError("Length of xticks must match number of columns")
  281. x = xticks
  282. else:
  283. x = list(range(ncols))
  284. if ax is None:
  285. ax = plt.gca()
  286. color_values = _get_standard_colors(
  287. num_colors=len(classes), colormap=colormap, color_type="random", color=color
  288. )
  289. if sort_labels:
  290. classes = sorted(classes)
  291. color_values = sorted(color_values)
  292. colors = dict(zip(classes, color_values))
  293. for i in range(n):
  294. y = df.iloc[i].values
  295. kls = class_col.iat[i]
  296. label = pprint_thing(kls)
  297. if label not in used_legends:
  298. used_legends.add(label)
  299. ax.plot(x, y, color=colors[kls], label=label, **kwds)
  300. else:
  301. ax.plot(x, y, color=colors[kls], **kwds)
  302. if axvlines:
  303. for i in x:
  304. ax.axvline(i, **axvlines_kwds)
  305. ax.set_xticks(x)
  306. ax.set_xticklabels(df.columns)
  307. ax.set_xlim(x[0], x[-1])
  308. ax.legend(loc="upper right")
  309. ax.grid()
  310. return ax
  311. def lag_plot(series, lag=1, ax=None, **kwds):
  312. # workaround because `c='b'` is hardcoded in matplotlibs scatter method
  313. import matplotlib.pyplot as plt
  314. kwds.setdefault("c", plt.rcParams["patch.facecolor"])
  315. data = series.values
  316. y1 = data[:-lag]
  317. y2 = data[lag:]
  318. if ax is None:
  319. ax = plt.gca()
  320. ax.set_xlabel("y(t)")
  321. ax.set_ylabel(f"y(t + {lag})")
  322. ax.scatter(y1, y2, **kwds)
  323. return ax
  324. def autocorrelation_plot(series, ax=None, **kwds):
  325. import matplotlib.pyplot as plt
  326. n = len(series)
  327. data = np.asarray(series)
  328. if ax is None:
  329. ax = plt.gca(xlim=(1, n), ylim=(-1.0, 1.0))
  330. mean = np.mean(data)
  331. c0 = np.sum((data - mean) ** 2) / float(n)
  332. def r(h):
  333. return ((data[: n - h] - mean) * (data[h:] - mean)).sum() / float(n) / c0
  334. x = np.arange(n) + 1
  335. y = [r(loc) for loc in x]
  336. z95 = 1.959963984540054
  337. z99 = 2.5758293035489004
  338. ax.axhline(y=z99 / np.sqrt(n), linestyle="--", color="grey")
  339. ax.axhline(y=z95 / np.sqrt(n), color="grey")
  340. ax.axhline(y=0.0, color="black")
  341. ax.axhline(y=-z95 / np.sqrt(n), color="grey")
  342. ax.axhline(y=-z99 / np.sqrt(n), linestyle="--", color="grey")
  343. ax.set_xlabel("Lag")
  344. ax.set_ylabel("Autocorrelation")
  345. ax.plot(x, y, **kwds)
  346. if "label" in kwds:
  347. ax.legend()
  348. ax.grid()
  349. return ax