matrix.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391
  1. """Functions to visualize matrices of data."""
  2. import itertools
  3. import warnings
  4. import matplotlib as mpl
  5. from matplotlib.collections import LineCollection
  6. import matplotlib.pyplot as plt
  7. from matplotlib import gridspec
  8. import numpy as np
  9. import pandas as pd
  10. from scipy.cluster import hierarchy
  11. from . import cm
  12. from .axisgrid import Grid
  13. from .utils import (despine, axis_ticklabels_overlap, relative_luminance,
  14. to_utf8)
  15. __all__ = ["heatmap", "clustermap"]
  16. def _index_to_label(index):
  17. """Convert a pandas index or multiindex to an axis label."""
  18. if isinstance(index, pd.MultiIndex):
  19. return "-".join(map(to_utf8, index.names))
  20. else:
  21. return index.name
  22. def _index_to_ticklabels(index):
  23. """Convert a pandas index or multiindex into ticklabels."""
  24. if isinstance(index, pd.MultiIndex):
  25. return ["-".join(map(to_utf8, i)) for i in index.values]
  26. else:
  27. return index.values
  28. def _convert_colors(colors):
  29. """Convert either a list of colors or nested lists of colors to RGB."""
  30. to_rgb = mpl.colors.colorConverter.to_rgb
  31. if isinstance(colors, pd.DataFrame):
  32. # Convert dataframe
  33. return pd.DataFrame({col: colors[col].map(to_rgb)
  34. for col in colors})
  35. elif isinstance(colors, pd.Series):
  36. return colors.map(to_rgb)
  37. else:
  38. try:
  39. to_rgb(colors[0])
  40. # If this works, there is only one level of colors
  41. return list(map(to_rgb, colors))
  42. except ValueError:
  43. # If we get here, we have nested lists
  44. return [list(map(to_rgb, l)) for l in colors]
  45. def _matrix_mask(data, mask):
  46. """Ensure that data and mask are compatabile and add missing values.
  47. Values will be plotted for cells where ``mask`` is ``False``.
  48. ``data`` is expected to be a DataFrame; ``mask`` can be an array or
  49. a DataFrame.
  50. """
  51. if mask is None:
  52. mask = np.zeros(data.shape, np.bool)
  53. if isinstance(mask, np.ndarray):
  54. # For array masks, ensure that shape matches data then convert
  55. if mask.shape != data.shape:
  56. raise ValueError("Mask must have the same shape as data.")
  57. mask = pd.DataFrame(mask,
  58. index=data.index,
  59. columns=data.columns,
  60. dtype=np.bool)
  61. elif isinstance(mask, pd.DataFrame):
  62. # For DataFrame masks, ensure that semantic labels match data
  63. if not mask.index.equals(data.index) \
  64. and mask.columns.equals(data.columns):
  65. err = "Mask must have the same index and columns as data."
  66. raise ValueError(err)
  67. # Add any cells with missing data to the mask
  68. # This works around an issue where `plt.pcolormesh` doesn't represent
  69. # missing data properly
  70. mask = mask | pd.isnull(data)
  71. return mask
  72. class _HeatMapper(object):
  73. """Draw a heatmap plot of a matrix with nice labels and colormaps."""
  74. def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
  75. annot_kws, cbar, cbar_kws,
  76. xticklabels=True, yticklabels=True, mask=None):
  77. """Initialize the plotting object."""
  78. # We always want to have a DataFrame with semantic information
  79. # and an ndarray to pass to matplotlib
  80. if isinstance(data, pd.DataFrame):
  81. plot_data = data.values
  82. else:
  83. plot_data = np.asarray(data)
  84. data = pd.DataFrame(plot_data)
  85. # Validate the mask and convet to DataFrame
  86. mask = _matrix_mask(data, mask)
  87. plot_data = np.ma.masked_where(np.asarray(mask), plot_data)
  88. # Get good names for the rows and columns
  89. xtickevery = 1
  90. if isinstance(xticklabels, int):
  91. xtickevery = xticklabels
  92. xticklabels = _index_to_ticklabels(data.columns)
  93. elif xticklabels is True:
  94. xticklabels = _index_to_ticklabels(data.columns)
  95. elif xticklabels is False:
  96. xticklabels = []
  97. ytickevery = 1
  98. if isinstance(yticklabels, int):
  99. ytickevery = yticklabels
  100. yticklabels = _index_to_ticklabels(data.index)
  101. elif yticklabels is True:
  102. yticklabels = _index_to_ticklabels(data.index)
  103. elif yticklabels is False:
  104. yticklabels = []
  105. # Get the positions and used label for the ticks
  106. nx, ny = data.T.shape
  107. if not len(xticklabels):
  108. self.xticks = []
  109. self.xticklabels = []
  110. elif isinstance(xticklabels, str) and xticklabels == "auto":
  111. self.xticks = "auto"
  112. self.xticklabels = _index_to_ticklabels(data.columns)
  113. else:
  114. self.xticks, self.xticklabels = self._skip_ticks(xticklabels,
  115. xtickevery)
  116. if not len(yticklabels):
  117. self.yticks = []
  118. self.yticklabels = []
  119. elif isinstance(yticklabels, str) and yticklabels == "auto":
  120. self.yticks = "auto"
  121. self.yticklabels = _index_to_ticklabels(data.index)
  122. else:
  123. self.yticks, self.yticklabels = self._skip_ticks(yticklabels,
  124. ytickevery)
  125. # Get good names for the axis labels
  126. xlabel = _index_to_label(data.columns)
  127. ylabel = _index_to_label(data.index)
  128. self.xlabel = xlabel if xlabel is not None else ""
  129. self.ylabel = ylabel if ylabel is not None else ""
  130. # Determine good default values for the colormapping
  131. self._determine_cmap_params(plot_data, vmin, vmax,
  132. cmap, center, robust)
  133. # Sort out the annotations
  134. if annot is None or annot is False:
  135. annot = False
  136. annot_data = None
  137. else:
  138. if isinstance(annot, bool):
  139. annot_data = plot_data
  140. else:
  141. annot_data = np.asarray(annot)
  142. if annot_data.shape != plot_data.shape:
  143. err = "`data` and `annot` must have same shape."
  144. raise ValueError(err)
  145. annot = True
  146. # Save other attributes to the object
  147. self.data = data
  148. self.plot_data = plot_data
  149. self.annot = annot
  150. self.annot_data = annot_data
  151. self.fmt = fmt
  152. self.annot_kws = {} if annot_kws is None else annot_kws.copy()
  153. self.cbar = cbar
  154. self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy()
  155. def _determine_cmap_params(self, plot_data, vmin, vmax,
  156. cmap, center, robust):
  157. """Use some heuristics to set good defaults for colorbar and range."""
  158. # plot_data is a np.ma.array instance
  159. calc_data = plot_data.filled(np.nan)
  160. if vmin is None:
  161. if robust:
  162. vmin = np.nanpercentile(calc_data, 2)
  163. else:
  164. vmin = np.nanmin(calc_data)
  165. if vmax is None:
  166. if robust:
  167. vmax = np.nanpercentile(calc_data, 98)
  168. else:
  169. vmax = np.nanmax(calc_data)
  170. self.vmin, self.vmax = vmin, vmax
  171. # Choose default colormaps if not provided
  172. if cmap is None:
  173. if center is None:
  174. self.cmap = cm.rocket
  175. else:
  176. self.cmap = cm.icefire
  177. elif isinstance(cmap, str):
  178. self.cmap = mpl.cm.get_cmap(cmap)
  179. elif isinstance(cmap, list):
  180. self.cmap = mpl.colors.ListedColormap(cmap)
  181. else:
  182. self.cmap = cmap
  183. # Recenter a divergent colormap
  184. if center is not None:
  185. # Copy bad values
  186. # in mpl<3.2 only masked values are honored with "bad" color spec
  187. # (see https://github.com/matplotlib/matplotlib/pull/14257)
  188. bad = self.cmap(np.ma.masked_invalid([np.nan]))[0]
  189. # under/over values are set for sure when cmap extremes
  190. # do not map to the same color as +-inf
  191. under = self.cmap(-np.inf)
  192. over = self.cmap(np.inf)
  193. under_set = under != self.cmap(0)
  194. over_set = over != self.cmap(self.cmap.N - 1)
  195. vrange = max(vmax - center, center - vmin)
  196. normlize = mpl.colors.Normalize(center - vrange, center + vrange)
  197. cmin, cmax = normlize([vmin, vmax])
  198. cc = np.linspace(cmin, cmax, 256)
  199. self.cmap = mpl.colors.ListedColormap(self.cmap(cc))
  200. self.cmap.set_bad(bad)
  201. if under_set:
  202. self.cmap.set_under(under)
  203. if over_set:
  204. self.cmap.set_over(over)
  205. def _annotate_heatmap(self, ax, mesh):
  206. """Add textual labels with the value in each cell."""
  207. mesh.update_scalarmappable()
  208. height, width = self.annot_data.shape
  209. xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)
  210. for x, y, m, color, val in zip(xpos.flat, ypos.flat,
  211. mesh.get_array(), mesh.get_facecolors(),
  212. self.annot_data.flat):
  213. if m is not np.ma.masked:
  214. lum = relative_luminance(color)
  215. text_color = ".15" if lum > .408 else "w"
  216. annotation = ("{:" + self.fmt + "}").format(val)
  217. text_kwargs = dict(color=text_color, ha="center", va="center")
  218. text_kwargs.update(self.annot_kws)
  219. ax.text(x, y, annotation, **text_kwargs)
  220. def _skip_ticks(self, labels, tickevery):
  221. """Return ticks and labels at evenly spaced intervals."""
  222. n = len(labels)
  223. if tickevery == 0:
  224. ticks, labels = [], []
  225. elif tickevery == 1:
  226. ticks, labels = np.arange(n) + .5, labels
  227. else:
  228. start, end, step = 0, n, tickevery
  229. ticks = np.arange(start, end, step) + .5
  230. labels = labels[start:end:step]
  231. return ticks, labels
  232. def _auto_ticks(self, ax, labels, axis):
  233. """Determine ticks and ticklabels that minimize overlap."""
  234. transform = ax.figure.dpi_scale_trans.inverted()
  235. bbox = ax.get_window_extent().transformed(transform)
  236. size = [bbox.width, bbox.height][axis]
  237. axis = [ax.xaxis, ax.yaxis][axis]
  238. tick, = axis.set_ticks([0])
  239. fontsize = tick.label1.get_size()
  240. max_ticks = int(size // (fontsize / 72))
  241. if max_ticks < 1:
  242. return [], []
  243. tick_every = len(labels) // max_ticks + 1
  244. tick_every = 1 if tick_every == 0 else tick_every
  245. ticks, labels = self._skip_ticks(labels, tick_every)
  246. return ticks, labels
  247. def plot(self, ax, cax, kws):
  248. """Draw the heatmap on the provided Axes."""
  249. # Remove all the Axes spines
  250. despine(ax=ax, left=True, bottom=True)
  251. # Draw the heatmap
  252. mesh = ax.pcolormesh(self.plot_data, vmin=self.vmin, vmax=self.vmax,
  253. cmap=self.cmap, **kws)
  254. # Set the axis limits
  255. ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))
  256. # Invert the y axis to show the plot in matrix form
  257. ax.invert_yaxis()
  258. # Possibly add a colorbar
  259. if self.cbar:
  260. cb = ax.figure.colorbar(mesh, cax, ax, **self.cbar_kws)
  261. cb.outline.set_linewidth(0)
  262. # If rasterized is passed to pcolormesh, also rasterize the
  263. # colorbar to avoid white lines on the PDF rendering
  264. if kws.get('rasterized', False):
  265. cb.solids.set_rasterized(True)
  266. # Add row and column labels
  267. if isinstance(self.xticks, str) and self.xticks == "auto":
  268. xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
  269. else:
  270. xticks, xticklabels = self.xticks, self.xticklabels
  271. if isinstance(self.yticks, str) and self.yticks == "auto":
  272. yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
  273. else:
  274. yticks, yticklabels = self.yticks, self.yticklabels
  275. ax.set(xticks=xticks, yticks=yticks)
  276. xtl = ax.set_xticklabels(xticklabels)
  277. ytl = ax.set_yticklabels(yticklabels, rotation="vertical")
  278. # Possibly rotate them if they overlap
  279. if hasattr(ax.figure.canvas, "get_renderer"):
  280. ax.figure.draw(ax.figure.canvas.get_renderer())
  281. if axis_ticklabels_overlap(xtl):
  282. plt.setp(xtl, rotation="vertical")
  283. if axis_ticklabels_overlap(ytl):
  284. plt.setp(ytl, rotation="horizontal")
  285. # Add the axis labels
  286. ax.set(xlabel=self.xlabel, ylabel=self.ylabel)
  287. # Annotate the cells with the formatted values
  288. if self.annot:
  289. self._annotate_heatmap(ax, mesh)
  290. def heatmap(data, vmin=None, vmax=None, cmap=None, center=None, robust=False,
  291. annot=None, fmt=".2g", annot_kws=None,
  292. linewidths=0, linecolor="white",
  293. cbar=True, cbar_kws=None, cbar_ax=None,
  294. square=False, xticklabels="auto", yticklabels="auto",
  295. mask=None, ax=None, **kwargs):
  296. """Plot rectangular data as a color-encoded matrix.
  297. This is an Axes-level function and will draw the heatmap into the
  298. currently-active Axes if none is provided to the ``ax`` argument. Part of
  299. this Axes space will be taken and used to plot a colormap, unless ``cbar``
  300. is False or a separate Axes is provided to ``cbar_ax``.
  301. Parameters
  302. ----------
  303. data : rectangular dataset
  304. 2D dataset that can be coerced into an ndarray. If a Pandas DataFrame
  305. is provided, the index/column information will be used to label the
  306. columns and rows.
  307. vmin, vmax : floats, optional
  308. Values to anchor the colormap, otherwise they are inferred from the
  309. data and other keyword arguments.
  310. cmap : matplotlib colormap name or object, or list of colors, optional
  311. The mapping from data values to color space. If not provided, the
  312. default will depend on whether ``center`` is set.
  313. center : float, optional
  314. The value at which to center the colormap when plotting divergant data.
  315. Using this parameter will change the default ``cmap`` if none is
  316. specified.
  317. robust : bool, optional
  318. If True and ``vmin`` or ``vmax`` are absent, the colormap range is
  319. computed with robust quantiles instead of the extreme values.
  320. annot : bool or rectangular dataset, optional
  321. If True, write the data value in each cell. If an array-like with the
  322. same shape as ``data``, then use this to annotate the heatmap instead
  323. of the data. Note that DataFrames will match on position, not index.
  324. fmt : string, optional
  325. String formatting code to use when adding annotations.
  326. annot_kws : dict of key, value mappings, optional
  327. Keyword arguments for ``ax.text`` when ``annot`` is True.
  328. linewidths : float, optional
  329. Width of the lines that will divide each cell.
  330. linecolor : color, optional
  331. Color of the lines that will divide each cell.
  332. cbar : boolean, optional
  333. Whether to draw a colorbar.
  334. cbar_kws : dict of key, value mappings, optional
  335. Keyword arguments for `fig.colorbar`.
  336. cbar_ax : matplotlib Axes, optional
  337. Axes in which to draw the colorbar, otherwise take space from the
  338. main Axes.
  339. square : boolean, optional
  340. If True, set the Axes aspect to "equal" so each cell will be
  341. square-shaped.
  342. xticklabels, yticklabels : "auto", bool, list-like, or int, optional
  343. If True, plot the column names of the dataframe. If False, don't plot
  344. the column names. If list-like, plot these alternate labels as the
  345. xticklabels. If an integer, use the column names but plot only every
  346. n label. If "auto", try to densely plot non-overlapping labels.
  347. mask : boolean array or DataFrame, optional
  348. If passed, data will not be shown in cells where ``mask`` is True.
  349. Cells with missing values are automatically masked.
  350. ax : matplotlib Axes, optional
  351. Axes in which to draw the plot, otherwise use the currently-active
  352. Axes.
  353. kwargs : other keyword arguments
  354. All other keyword arguments are passed to
  355. :func:`matplotlib.axes.Axes.pcolormesh`.
  356. Returns
  357. -------
  358. ax : matplotlib Axes
  359. Axes object with the heatmap.
  360. See also
  361. --------
  362. clustermap : Plot a matrix using hierachical clustering to arrange the
  363. rows and columns.
  364. Examples
  365. --------
  366. Plot a heatmap for a numpy array:
  367. .. plot::
  368. :context: close-figs
  369. >>> import numpy as np; np.random.seed(0)
  370. >>> import seaborn as sns; sns.set()
  371. >>> uniform_data = np.random.rand(10, 12)
  372. >>> ax = sns.heatmap(uniform_data)
  373. Change the limits of the colormap:
  374. .. plot::
  375. :context: close-figs
  376. >>> ax = sns.heatmap(uniform_data, vmin=0, vmax=1)
  377. Plot a heatmap for data centered on 0 with a diverging colormap:
  378. .. plot::
  379. :context: close-figs
  380. >>> normal_data = np.random.randn(10, 12)
  381. >>> ax = sns.heatmap(normal_data, center=0)
  382. Plot a dataframe with meaningful row and column labels:
  383. .. plot::
  384. :context: close-figs
  385. >>> flights = sns.load_dataset("flights")
  386. >>> flights = flights.pivot("month", "year", "passengers")
  387. >>> ax = sns.heatmap(flights)
  388. Annotate each cell with the numeric value using integer formatting:
  389. .. plot::
  390. :context: close-figs
  391. >>> ax = sns.heatmap(flights, annot=True, fmt="d")
  392. Add lines between each cell:
  393. .. plot::
  394. :context: close-figs
  395. >>> ax = sns.heatmap(flights, linewidths=.5)
  396. Use a different colormap:
  397. .. plot::
  398. :context: close-figs
  399. >>> ax = sns.heatmap(flights, cmap="YlGnBu")
  400. Center the colormap at a specific value:
  401. .. plot::
  402. :context: close-figs
  403. >>> ax = sns.heatmap(flights, center=flights.loc["January", 1955])
  404. Plot every other column label and don't plot row labels:
  405. .. plot::
  406. :context: close-figs
  407. >>> data = np.random.randn(50, 20)
  408. >>> ax = sns.heatmap(data, xticklabels=2, yticklabels=False)
  409. Don't draw a colorbar:
  410. .. plot::
  411. :context: close-figs
  412. >>> ax = sns.heatmap(flights, cbar=False)
  413. Use different axes for the colorbar:
  414. .. plot::
  415. :context: close-figs
  416. >>> grid_kws = {"height_ratios": (.9, .05), "hspace": .3}
  417. >>> f, (ax, cbar_ax) = plt.subplots(2, gridspec_kw=grid_kws)
  418. >>> ax = sns.heatmap(flights, ax=ax,
  419. ... cbar_ax=cbar_ax,
  420. ... cbar_kws={"orientation": "horizontal"})
  421. Use a mask to plot only part of a matrix
  422. .. plot::
  423. :context: close-figs
  424. >>> corr = np.corrcoef(np.random.randn(10, 200))
  425. >>> mask = np.zeros_like(corr)
  426. >>> mask[np.triu_indices_from(mask)] = True
  427. >>> with sns.axes_style("white"):
  428. ... f, ax = plt.subplots(figsize=(7, 5))
  429. ... ax = sns.heatmap(corr, mask=mask, vmax=.3, square=True)
  430. """
  431. # Initialize the plotter object
  432. plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
  433. annot_kws, cbar, cbar_kws, xticklabels,
  434. yticklabels, mask)
  435. # Add the pcolormesh kwargs here
  436. kwargs["linewidths"] = linewidths
  437. kwargs["edgecolor"] = linecolor
  438. # Draw the plot and return the Axes
  439. if ax is None:
  440. ax = plt.gca()
  441. if square:
  442. ax.set_aspect("equal")
  443. plotter.plot(ax, cbar_ax, kwargs)
  444. return ax
  445. class _DendrogramPlotter(object):
  446. """Object for drawing tree of similarities between data rows/columns"""
  447. def __init__(self, data, linkage, metric, method, axis, label, rotate):
  448. """Plot a dendrogram of the relationships between the columns of data
  449. Parameters
  450. ----------
  451. data : pandas.DataFrame
  452. Rectangular data
  453. """
  454. self.axis = axis
  455. if self.axis == 1:
  456. data = data.T
  457. if isinstance(data, pd.DataFrame):
  458. array = data.values
  459. else:
  460. array = np.asarray(data)
  461. data = pd.DataFrame(array)
  462. self.array = array
  463. self.data = data
  464. self.shape = self.data.shape
  465. self.metric = metric
  466. self.method = method
  467. self.axis = axis
  468. self.label = label
  469. self.rotate = rotate
  470. if linkage is None:
  471. self.linkage = self.calculated_linkage
  472. else:
  473. self.linkage = linkage
  474. self.dendrogram = self.calculate_dendrogram()
  475. # Dendrogram ends are always at multiples of 5, who knows why
  476. ticks = 10 * np.arange(self.data.shape[0]) + 5
  477. if self.label:
  478. ticklabels = _index_to_ticklabels(self.data.index)
  479. ticklabels = [ticklabels[i] for i in self.reordered_ind]
  480. if self.rotate:
  481. self.xticks = []
  482. self.yticks = ticks
  483. self.xticklabels = []
  484. self.yticklabels = ticklabels
  485. self.ylabel = _index_to_label(self.data.index)
  486. self.xlabel = ''
  487. else:
  488. self.xticks = ticks
  489. self.yticks = []
  490. self.xticklabels = ticklabels
  491. self.yticklabels = []
  492. self.ylabel = ''
  493. self.xlabel = _index_to_label(self.data.index)
  494. else:
  495. self.xticks, self.yticks = [], []
  496. self.yticklabels, self.xticklabels = [], []
  497. self.xlabel, self.ylabel = '', ''
  498. self.dependent_coord = self.dendrogram['dcoord']
  499. self.independent_coord = self.dendrogram['icoord']
  500. def _calculate_linkage_scipy(self):
  501. linkage = hierarchy.linkage(self.array, method=self.method,
  502. metric=self.metric)
  503. return linkage
  504. def _calculate_linkage_fastcluster(self):
  505. import fastcluster
  506. # Fastcluster has a memory-saving vectorized version, but only
  507. # with certain linkage methods, and mostly with euclidean metric
  508. # vector_methods = ('single', 'centroid', 'median', 'ward')
  509. euclidean_methods = ('centroid', 'median', 'ward')
  510. euclidean = self.metric == 'euclidean' and self.method in \
  511. euclidean_methods
  512. if euclidean or self.method == 'single':
  513. return fastcluster.linkage_vector(self.array,
  514. method=self.method,
  515. metric=self.metric)
  516. else:
  517. linkage = fastcluster.linkage(self.array, method=self.method,
  518. metric=self.metric)
  519. return linkage
  520. @property
  521. def calculated_linkage(self):
  522. try:
  523. return self._calculate_linkage_fastcluster()
  524. except ImportError:
  525. if np.product(self.shape) >= 10000:
  526. msg = ("Clustering large matrix with scipy. Installing "
  527. "`fastcluster` may give better performance.")
  528. warnings.warn(msg)
  529. return self._calculate_linkage_scipy()
  530. def calculate_dendrogram(self):
  531. """Calculates a dendrogram based on the linkage matrix
  532. Made a separate function, not a property because don't want to
  533. recalculate the dendrogram every time it is accessed.
  534. Returns
  535. -------
  536. dendrogram : dict
  537. Dendrogram dictionary as returned by scipy.cluster.hierarchy
  538. .dendrogram. The important key-value pairing is
  539. "reordered_ind" which indicates the re-ordering of the matrix
  540. """
  541. return hierarchy.dendrogram(self.linkage, no_plot=True,
  542. color_threshold=-np.inf)
  543. @property
  544. def reordered_ind(self):
  545. """Indices of the matrix, reordered by the dendrogram"""
  546. return self.dendrogram['leaves']
  547. def plot(self, ax, tree_kws):
  548. """Plots a dendrogram of the similarities between data on the axes
  549. Parameters
  550. ----------
  551. ax : matplotlib.axes.Axes
  552. Axes object upon which the dendrogram is plotted
  553. """
  554. tree_kws = {} if tree_kws is None else tree_kws.copy()
  555. tree_kws.setdefault("linewidths", .5)
  556. tree_kws.setdefault("colors", ".2")
  557. if self.rotate and self.axis == 0:
  558. coords = zip(self.dependent_coord, self.independent_coord)
  559. else:
  560. coords = zip(self.independent_coord, self.dependent_coord)
  561. lines = LineCollection([list(zip(x, y)) for x, y in coords],
  562. **tree_kws)
  563. ax.add_collection(lines)
  564. number_of_leaves = len(self.reordered_ind)
  565. max_dependent_coord = max(map(max, self.dependent_coord))
  566. if self.rotate:
  567. ax.yaxis.set_ticks_position('right')
  568. # Constants 10 and 1.05 come from
  569. # `scipy.cluster.hierarchy._plot_dendrogram`
  570. ax.set_ylim(0, number_of_leaves * 10)
  571. ax.set_xlim(0, max_dependent_coord * 1.05)
  572. ax.invert_xaxis()
  573. ax.invert_yaxis()
  574. else:
  575. # Constants 10 and 1.05 come from
  576. # `scipy.cluster.hierarchy._plot_dendrogram`
  577. ax.set_xlim(0, number_of_leaves * 10)
  578. ax.set_ylim(0, max_dependent_coord * 1.05)
  579. despine(ax=ax, bottom=True, left=True)
  580. ax.set(xticks=self.xticks, yticks=self.yticks,
  581. xlabel=self.xlabel, ylabel=self.ylabel)
  582. xtl = ax.set_xticklabels(self.xticklabels)
  583. ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical')
  584. # Force a draw of the plot to avoid matplotlib window error
  585. if hasattr(ax.figure.canvas, "get_renderer"):
  586. ax.figure.draw(ax.figure.canvas.get_renderer())
  587. if len(ytl) > 0 and axis_ticklabels_overlap(ytl):
  588. plt.setp(ytl, rotation="horizontal")
  589. if len(xtl) > 0 and axis_ticklabels_overlap(xtl):
  590. plt.setp(xtl, rotation="vertical")
  591. return self
  592. def dendrogram(data, linkage=None, axis=1, label=True, metric='euclidean',
  593. method='average', rotate=False, tree_kws=None, ax=None):
  594. """Draw a tree diagram of relationships within a matrix
  595. Parameters
  596. ----------
  597. data : pandas.DataFrame
  598. Rectangular data
  599. linkage : numpy.array, optional
  600. Linkage matrix
  601. axis : int, optional
  602. Which axis to use to calculate linkage. 0 is rows, 1 is columns.
  603. label : bool, optional
  604. If True, label the dendrogram at leaves with column or row names
  605. metric : str, optional
  606. Distance metric. Anything valid for scipy.spatial.distance.pdist
  607. method : str, optional
  608. Linkage method to use. Anything valid for
  609. scipy.cluster.hierarchy.linkage
  610. rotate : bool, optional
  611. When plotting the matrix, whether to rotate it 90 degrees
  612. counter-clockwise, so the leaves face right
  613. tree_kws : dict, optional
  614. Keyword arguments for the ``matplotlib.collections.LineCollection``
  615. that is used for plotting the lines of the dendrogram tree.
  616. ax : matplotlib axis, optional
  617. Axis to plot on, otherwise uses current axis
  618. Returns
  619. -------
  620. dendrogramplotter : _DendrogramPlotter
  621. A Dendrogram plotter object.
  622. Notes
  623. -----
  624. Access the reordered dendrogram indices with
  625. dendrogramplotter.reordered_ind
  626. """
  627. plotter = _DendrogramPlotter(data, linkage=linkage, axis=axis,
  628. metric=metric, method=method,
  629. label=label, rotate=rotate)
  630. if ax is None:
  631. ax = plt.gca()
  632. return plotter.plot(ax=ax, tree_kws=tree_kws)
  633. class ClusterGrid(Grid):
  634. def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,
  635. figsize=None, row_colors=None, col_colors=None, mask=None,
  636. dendrogram_ratio=None, colors_ratio=None, cbar_pos=None):
  637. """Grid object for organizing clustered heatmap input on to axes"""
  638. if isinstance(data, pd.DataFrame):
  639. self.data = data
  640. else:
  641. self.data = pd.DataFrame(data)
  642. self.data2d = self.format_data(self.data, pivot_kws, z_score,
  643. standard_scale)
  644. self.mask = _matrix_mask(self.data2d, mask)
  645. self.fig = plt.figure(figsize=figsize)
  646. self.row_colors, self.row_color_labels = \
  647. self._preprocess_colors(data, row_colors, axis=0)
  648. self.col_colors, self.col_color_labels = \
  649. self._preprocess_colors(data, col_colors, axis=1)
  650. try:
  651. row_dendrogram_ratio, col_dendrogram_ratio = dendrogram_ratio
  652. except TypeError:
  653. row_dendrogram_ratio = col_dendrogram_ratio = dendrogram_ratio
  654. try:
  655. row_colors_ratio, col_colors_ratio = colors_ratio
  656. except TypeError:
  657. row_colors_ratio = col_colors_ratio = colors_ratio
  658. width_ratios = self.dim_ratios(self.row_colors,
  659. row_dendrogram_ratio,
  660. row_colors_ratio)
  661. height_ratios = self.dim_ratios(self.col_colors,
  662. col_dendrogram_ratio,
  663. col_colors_ratio)
  664. nrows = 2 if self.col_colors is None else 3
  665. ncols = 2 if self.row_colors is None else 3
  666. self.gs = gridspec.GridSpec(nrows, ncols,
  667. width_ratios=width_ratios,
  668. height_ratios=height_ratios)
  669. self.ax_row_dendrogram = self.fig.add_subplot(self.gs[-1, 0])
  670. self.ax_col_dendrogram = self.fig.add_subplot(self.gs[0, -1])
  671. self.ax_row_dendrogram.set_axis_off()
  672. self.ax_col_dendrogram.set_axis_off()
  673. self.ax_row_colors = None
  674. self.ax_col_colors = None
  675. if self.row_colors is not None:
  676. self.ax_row_colors = self.fig.add_subplot(
  677. self.gs[-1, 1])
  678. if self.col_colors is not None:
  679. self.ax_col_colors = self.fig.add_subplot(
  680. self.gs[1, -1])
  681. self.ax_heatmap = self.fig.add_subplot(self.gs[-1, -1])
  682. if cbar_pos is None:
  683. self.ax_cbar = self.cax = None
  684. else:
  685. # Initialize the colorbar axes in the gridspec so that tight_layout
  686. # works. We will move it where it belongs later. This is a hack.
  687. self.ax_cbar = self.fig.add_subplot(self.gs[0, 0])
  688. self.cax = self.ax_cbar # Backwards compatability
  689. self.cbar_pos = cbar_pos
  690. self.dendrogram_row = None
  691. self.dendrogram_col = None
  692. def _preprocess_colors(self, data, colors, axis):
  693. """Preprocess {row/col}_colors to extract labels and convert colors."""
  694. labels = None
  695. if colors is not None:
  696. if isinstance(colors, (pd.DataFrame, pd.Series)):
  697. # Ensure colors match data indices
  698. if axis == 0:
  699. colors = colors.reindex(data.index)
  700. else:
  701. colors = colors.reindex(data.columns)
  702. # Replace na's with background color
  703. # TODO We should set these to transparent instead
  704. colors = colors.fillna('white')
  705. # Extract color values and labels from frame/series
  706. if isinstance(colors, pd.DataFrame):
  707. labels = list(colors.columns)
  708. colors = colors.T.values
  709. else:
  710. if colors.name is None:
  711. labels = [""]
  712. else:
  713. labels = [colors.name]
  714. colors = colors.values
  715. colors = _convert_colors(colors)
  716. return colors, labels
  717. def format_data(self, data, pivot_kws, z_score=None,
  718. standard_scale=None):
  719. """Extract variables from data or use directly."""
  720. # Either the data is already in 2d matrix format, or need to do a pivot
  721. if pivot_kws is not None:
  722. data2d = data.pivot(**pivot_kws)
  723. else:
  724. data2d = data
  725. if z_score is not None and standard_scale is not None:
  726. raise ValueError(
  727. 'Cannot perform both z-scoring and standard-scaling on data')
  728. if z_score is not None:
  729. data2d = self.z_score(data2d, z_score)
  730. if standard_scale is not None:
  731. data2d = self.standard_scale(data2d, standard_scale)
  732. return data2d
  733. @staticmethod
  734. def z_score(data2d, axis=1):
  735. """Standarize the mean and variance of the data axis
  736. Parameters
  737. ----------
  738. data2d : pandas.DataFrame
  739. Data to normalize
  740. axis : int
  741. Which axis to normalize across. If 0, normalize across rows, if 1,
  742. normalize across columns.
  743. Returns
  744. -------
  745. normalized : pandas.DataFrame
  746. Noramlized data with a mean of 0 and variance of 1 across the
  747. specified axis.
  748. """
  749. if axis == 1:
  750. z_scored = data2d
  751. else:
  752. z_scored = data2d.T
  753. z_scored = (z_scored - z_scored.mean()) / z_scored.std()
  754. if axis == 1:
  755. return z_scored
  756. else:
  757. return z_scored.T
  758. @staticmethod
  759. def standard_scale(data2d, axis=1):
  760. """Divide the data by the difference between the max and min
  761. Parameters
  762. ----------
  763. data2d : pandas.DataFrame
  764. Data to normalize
  765. axis : int
  766. Which axis to normalize across. If 0, normalize across rows, if 1,
  767. normalize across columns.
  768. vmin : int
  769. If 0, then subtract the minimum of the data before dividing by
  770. the range.
  771. Returns
  772. -------
  773. standardized : pandas.DataFrame
  774. Noramlized data with a mean of 0 and variance of 1 across the
  775. specified axis.
  776. """
  777. # Normalize these values to range from 0 to 1
  778. if axis == 1:
  779. standardized = data2d
  780. else:
  781. standardized = data2d.T
  782. subtract = standardized.min()
  783. standardized = (standardized - subtract) / (
  784. standardized.max() - standardized.min())
  785. if axis == 1:
  786. return standardized
  787. else:
  788. return standardized.T
  789. def dim_ratios(self, colors, dendrogram_ratio, colors_ratio):
  790. """Get the proportions of the figure taken up by each axes."""
  791. ratios = [dendrogram_ratio]
  792. if colors is not None:
  793. # Colors are encoded as rgb, so ther is an extra dimention
  794. if np.ndim(colors) > 2:
  795. n_colors = len(colors)
  796. else:
  797. n_colors = 1
  798. ratios += [n_colors * colors_ratio]
  799. # Add the ratio for the heatmap itself
  800. ratios.append(1 - sum(ratios))
  801. return ratios
  802. @staticmethod
  803. def color_list_to_matrix_and_cmap(colors, ind, axis=0):
  804. """Turns a list of colors into a numpy matrix and matplotlib colormap
  805. These arguments can now be plotted using heatmap(matrix, cmap)
  806. and the provided colors will be plotted.
  807. Parameters
  808. ----------
  809. colors : list of matplotlib colors
  810. Colors to label the rows or columns of a dataframe.
  811. ind : list of ints
  812. Ordering of the rows or columns, to reorder the original colors
  813. by the clustered dendrogram order
  814. axis : int
  815. Which axis this is labeling
  816. Returns
  817. -------
  818. matrix : numpy.array
  819. A numpy array of integer values, where each corresponds to a color
  820. from the originally provided list of colors
  821. cmap : matplotlib.colors.ListedColormap
  822. """
  823. # check for nested lists/color palettes.
  824. # Will fail if matplotlib color is list not tuple
  825. if any(issubclass(type(x), list) for x in colors):
  826. all_colors = set(itertools.chain(*colors))
  827. n = len(colors)
  828. m = len(colors[0])
  829. else:
  830. all_colors = set(colors)
  831. n = 1
  832. m = len(colors)
  833. colors = [colors]
  834. color_to_value = dict((col, i) for i, col in enumerate(all_colors))
  835. matrix = np.array([color_to_value[c]
  836. for color in colors for c in color])
  837. shape = (n, m)
  838. matrix = matrix.reshape(shape)
  839. matrix = matrix[:, ind]
  840. if axis == 0:
  841. # row-side:
  842. matrix = matrix.T
  843. cmap = mpl.colors.ListedColormap(all_colors)
  844. return matrix, cmap
  845. def savefig(self, *args, **kwargs):
  846. if 'bbox_inches' not in kwargs:
  847. kwargs['bbox_inches'] = 'tight'
  848. self.fig.savefig(*args, **kwargs)
  849. def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
  850. row_linkage, col_linkage, tree_kws):
  851. # Plot the row dendrogram
  852. if row_cluster:
  853. self.dendrogram_row = dendrogram(
  854. self.data2d, metric=metric, method=method, label=False, axis=0,
  855. ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage,
  856. tree_kws=tree_kws
  857. )
  858. else:
  859. self.ax_row_dendrogram.set_xticks([])
  860. self.ax_row_dendrogram.set_yticks([])
  861. # PLot the column dendrogram
  862. if col_cluster:
  863. self.dendrogram_col = dendrogram(
  864. self.data2d, metric=metric, method=method, label=False,
  865. axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage,
  866. tree_kws=tree_kws
  867. )
  868. else:
  869. self.ax_col_dendrogram.set_xticks([])
  870. self.ax_col_dendrogram.set_yticks([])
  871. despine(ax=self.ax_row_dendrogram, bottom=True, left=True)
  872. despine(ax=self.ax_col_dendrogram, bottom=True, left=True)
  873. def plot_colors(self, xind, yind, **kws):
  874. """Plots color labels between the dendrogram and the heatmap
  875. Parameters
  876. ----------
  877. heatmap_kws : dict
  878. Keyword arguments heatmap
  879. """
  880. # Remove any custom colormap and centering
  881. # TODO this code has consistently caused problems when we
  882. # have missed kwargs that need to be excluded that it might
  883. # be better to rewrite *in*clusively.
  884. kws = kws.copy()
  885. kws.pop('cmap', None)
  886. kws.pop('norm', None)
  887. kws.pop('center', None)
  888. kws.pop('annot', None)
  889. kws.pop('vmin', None)
  890. kws.pop('vmax', None)
  891. kws.pop('robust', None)
  892. kws.pop('xticklabels', None)
  893. kws.pop('yticklabels', None)
  894. # Plot the row colors
  895. if self.row_colors is not None:
  896. matrix, cmap = self.color_list_to_matrix_and_cmap(
  897. self.row_colors, yind, axis=0)
  898. # Get row_color labels
  899. if self.row_color_labels is not None:
  900. row_color_labels = self.row_color_labels
  901. else:
  902. row_color_labels = False
  903. heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_row_colors,
  904. xticklabels=row_color_labels, yticklabels=False, **kws)
  905. # Adjust rotation of labels
  906. if row_color_labels is not False:
  907. plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90)
  908. else:
  909. despine(self.ax_row_colors, left=True, bottom=True)
  910. # Plot the column colors
  911. if self.col_colors is not None:
  912. matrix, cmap = self.color_list_to_matrix_and_cmap(
  913. self.col_colors, xind, axis=1)
  914. # Get col_color labels
  915. if self.col_color_labels is not None:
  916. col_color_labels = self.col_color_labels
  917. else:
  918. col_color_labels = False
  919. heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_col_colors,
  920. xticklabels=False, yticklabels=col_color_labels, **kws)
  921. # Adjust rotation of labels, place on right side
  922. if col_color_labels is not False:
  923. self.ax_col_colors.yaxis.tick_right()
  924. plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0)
  925. else:
  926. despine(self.ax_col_colors, left=True, bottom=True)
  927. def plot_matrix(self, colorbar_kws, xind, yind, **kws):
  928. self.data2d = self.data2d.iloc[yind, xind]
  929. self.mask = self.mask.iloc[yind, xind]
  930. # Try to reorganize specified tick labels, if provided
  931. xtl = kws.pop("xticklabels", "auto")
  932. try:
  933. xtl = np.asarray(xtl)[xind]
  934. except (TypeError, IndexError):
  935. pass
  936. ytl = kws.pop("yticklabels", "auto")
  937. try:
  938. ytl = np.asarray(ytl)[yind]
  939. except (TypeError, IndexError):
  940. pass
  941. # Reorganize the annotations to match the heatmap
  942. annot = kws.pop("annot", None)
  943. if annot is None:
  944. pass
  945. else:
  946. if isinstance(annot, bool):
  947. annot_data = self.data2d
  948. else:
  949. annot_data = np.asarray(annot)
  950. if annot_data.shape != self.data2d.shape:
  951. err = "`data` and `annot` must have same shape."
  952. raise ValueError(err)
  953. annot_data = annot_data[yind][:, xind]
  954. annot = annot_data
  955. # Setting ax_cbar=None in clustermap call implies no colorbar
  956. kws.setdefault("cbar", self.ax_cbar is not None)
  957. heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.ax_cbar,
  958. cbar_kws=colorbar_kws, mask=self.mask,
  959. xticklabels=xtl, yticklabels=ytl, annot=annot, **kws)
  960. ytl = self.ax_heatmap.get_yticklabels()
  961. ytl_rot = None if not ytl else ytl[0].get_rotation()
  962. self.ax_heatmap.yaxis.set_ticks_position('right')
  963. self.ax_heatmap.yaxis.set_label_position('right')
  964. if ytl_rot is not None:
  965. ytl = self.ax_heatmap.get_yticklabels()
  966. plt.setp(ytl, rotation=ytl_rot)
  967. tight_params = dict(h_pad=.02, w_pad=.02)
  968. if self.ax_cbar is None:
  969. self.fig.tight_layout(**tight_params)
  970. else:
  971. # Turn the colorbar axes off for tight layout so that its
  972. # ticks don't interfere with the rest of the plot layout.
  973. # Then move it.
  974. self.ax_cbar.set_axis_off()
  975. self.fig.tight_layout(**tight_params)
  976. self.ax_cbar.set_axis_on()
  977. self.ax_cbar.set_position(self.cbar_pos)
  978. def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster,
  979. row_linkage, col_linkage, tree_kws, **kws):
  980. # heatmap square=True sets the aspect ratio on the axes, but that is
  981. # not compatible with the multi-axes layout of clustergrid
  982. if kws.get("square", False):
  983. msg = "``square=True`` ignored in clustermap"
  984. warnings.warn(msg)
  985. kws.pop("square")
  986. colorbar_kws = {} if colorbar_kws is None else colorbar_kws
  987. self.plot_dendrograms(row_cluster, col_cluster, metric, method,
  988. row_linkage=row_linkage, col_linkage=col_linkage,
  989. tree_kws=tree_kws)
  990. try:
  991. xind = self.dendrogram_col.reordered_ind
  992. except AttributeError:
  993. xind = np.arange(self.data2d.shape[1])
  994. try:
  995. yind = self.dendrogram_row.reordered_ind
  996. except AttributeError:
  997. yind = np.arange(self.data2d.shape[0])
  998. self.plot_colors(xind, yind, **kws)
  999. self.plot_matrix(colorbar_kws, xind, yind, **kws)
  1000. return self
  1001. def clustermap(data, pivot_kws=None, method='average', metric='euclidean',
  1002. z_score=None, standard_scale=None, figsize=(10, 10),
  1003. cbar_kws=None, row_cluster=True, col_cluster=True,
  1004. row_linkage=None, col_linkage=None,
  1005. row_colors=None, col_colors=None, mask=None,
  1006. dendrogram_ratio=.2, colors_ratio=0.03,
  1007. cbar_pos=(.02, .8, .05, .18), tree_kws=None,
  1008. **kwargs):
  1009. """Plot a matrix dataset as a hierarchically-clustered heatmap.
  1010. Parameters
  1011. ----------
  1012. data: 2D array-like
  1013. Rectangular data for clustering. Cannot contain NAs.
  1014. pivot_kws : dict, optional
  1015. If `data` is a tidy dataframe, can provide keyword arguments for
  1016. pivot to create a rectangular dataframe.
  1017. method : str, optional
  1018. Linkage method to use for calculating clusters.
  1019. See scipy.cluster.hierarchy.linkage documentation for more information:
  1020. https://docs.scipy.org/doc/scipy/reference/generated/scipy.cluster.hierarchy.linkage.html
  1021. metric : str, optional
  1022. Distance metric to use for the data. See
  1023. scipy.spatial.distance.pdist documentation for more options
  1024. https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.distance.pdist.html
  1025. To use different metrics (or methods) for rows and columns, you may
  1026. construct each linkage matrix yourself and provide them as
  1027. {row,col}_linkage.
  1028. z_score : int or None, optional
  1029. Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores
  1030. for the rows or the columns. Z scores are: z = (x - mean)/std, so
  1031. values in each row (column) will get the mean of the row (column)
  1032. subtracted, then divided by the standard deviation of the row (column).
  1033. This ensures that each row (column) has mean of 0 and variance of 1.
  1034. standard_scale : int or None, optional
  1035. Either 0 (rows) or 1 (columns). Whether or not to standardize that
  1036. dimension, meaning for each row or column, subtract the minimum and
  1037. divide each by its maximum.
  1038. figsize: (width, height), optional
  1039. Overall size of the figure.
  1040. cbar_kws : dict, optional
  1041. Keyword arguments to pass to ``cbar_kws`` in ``heatmap``, e.g. to
  1042. add a label to the colorbar.
  1043. {row,col}_cluster : bool, optional
  1044. If True, cluster the {rows, columns}.
  1045. {row,col}_linkage : numpy.array, optional
  1046. Precomputed linkage matrix for the rows or columns. See
  1047. scipy.cluster.hierarchy.linkage for specific formats.
  1048. {row,col}_colors : list-like or pandas DataFrame/Series, optional
  1049. List of colors to label for either the rows or columns. Useful to
  1050. evaluate whether samples within a group are clustered together. Can
  1051. use nested lists or DataFrame for multiple color levels of labeling.
  1052. If given as a DataFrame or Series, labels for the colors are extracted
  1053. from the DataFrames column names or from the name of the Series.
  1054. DataFrame/Series colors are also matched to the data by their
  1055. index, ensuring colors are drawn in the correct order.
  1056. mask : boolean array or DataFrame, optional
  1057. If passed, data will not be shown in cells where ``mask`` is True.
  1058. Cells with missing values are automatically masked. Only used for
  1059. visualizing, not for calculating.
  1060. {dendrogram,colors}_ratio: float, or pair of floats, optional
  1061. Proportion of the figure size devoted to the two marginal elements. If
  1062. a pair is given, they correspond to (row, col) ratios.
  1063. cbar_pos : (left, bottom, width, height), optional
  1064. Position of the colorbar axes in the figure. Setting to ``None`` will
  1065. disable the colorbar.
  1066. tree_kws : dict, optional
  1067. Parameters for the :class:`matplotlib.collections.LineCollection`
  1068. that is used to plot the lines of the dendrogram tree.
  1069. kwargs : other keyword arguments
  1070. All other keyword arguments are passed to :func:`heatmap`
  1071. Returns
  1072. -------
  1073. clustergrid : ClusterGrid
  1074. A ClusterGrid instance.
  1075. Notes
  1076. -----
  1077. The returned object has a ``savefig`` method that should be used if you
  1078. want to save the figure object without clipping the dendrograms.
  1079. To access the reordered row indices, use:
  1080. ``clustergrid.dendrogram_row.reordered_ind``
  1081. Column indices, use:
  1082. ``clustergrid.dendrogram_col.reordered_ind``
  1083. Examples
  1084. --------
  1085. Plot a clustered heatmap:
  1086. .. plot::
  1087. :context: close-figs
  1088. >>> import seaborn as sns; sns.set(color_codes=True)
  1089. >>> iris = sns.load_dataset("iris")
  1090. >>> species = iris.pop("species")
  1091. >>> g = sns.clustermap(iris)
  1092. Change the size and layout of the figure:
  1093. .. plot::
  1094. :context: close-figs
  1095. >>> g = sns.clustermap(iris,
  1096. ... figsize=(7, 5),
  1097. ... row_cluster=False,
  1098. ... dendrogram_ratio=(.1, .2),
  1099. ... cbar_pos=(0, .2, .03, .4))
  1100. Add colored labels to identify observations:
  1101. .. plot::
  1102. :context: close-figs
  1103. >>> lut = dict(zip(species.unique(), "rbg"))
  1104. >>> row_colors = species.map(lut)
  1105. >>> g = sns.clustermap(iris, row_colors=row_colors)
  1106. Use a different colormap and adjust the limits of the color range:
  1107. .. plot::
  1108. :context: close-figs
  1109. >>> g = sns.clustermap(iris, cmap="mako", vmin=0, vmax=10)
  1110. Use a different similarity metric:
  1111. .. plot::
  1112. :context: close-figs
  1113. >>> g = sns.clustermap(iris, metric="correlation")
  1114. Use a different clustering method:
  1115. .. plot::
  1116. :context: close-figs
  1117. >>> g = sns.clustermap(iris, method="single")
  1118. Standardize the data within the columns:
  1119. .. plot::
  1120. :context: close-figs
  1121. >>> g = sns.clustermap(iris, standard_scale=1)
  1122. Normalize the data within the rows:
  1123. .. plot::
  1124. :context: close-figs
  1125. >>> g = sns.clustermap(iris, z_score=0, cmap="vlag")
  1126. """
  1127. plotter = ClusterGrid(data, pivot_kws=pivot_kws, figsize=figsize,
  1128. row_colors=row_colors, col_colors=col_colors,
  1129. z_score=z_score, standard_scale=standard_scale,
  1130. mask=mask, dendrogram_ratio=dendrogram_ratio,
  1131. colors_ratio=colors_ratio, cbar_pos=cbar_pos)
  1132. return plotter.plot(metric=metric, method=method,
  1133. colorbar_kws=cbar_kws,
  1134. row_cluster=row_cluster, col_cluster=col_cluster,
  1135. row_linkage=row_linkage, col_linkage=col_linkage,
  1136. tree_kws=tree_kws, **kwargs)