test_matrix.py 47 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285
  1. import itertools
  2. import tempfile
  3. import copy
  4. import numpy as np
  5. import matplotlib as mpl
  6. import matplotlib.pyplot as plt
  7. import pandas as pd
  8. from scipy.spatial import distance
  9. from scipy.cluster import hierarchy
  10. import nose.tools as nt
  11. import numpy.testing as npt
  12. try:
  13. import pandas.testing as pdt
  14. except ImportError:
  15. import pandas.util.testing as pdt
  16. import pytest
  17. from .. import matrix as mat
  18. from .. import color_palette
  19. try:
  20. import fastcluster
  21. assert fastcluster
  22. _no_fastcluster = False
  23. except ImportError:
  24. _no_fastcluster = True
  25. class TestHeatmap(object):
  26. rs = np.random.RandomState(sum(map(ord, "heatmap")))
  27. x_norm = rs.randn(4, 8)
  28. letters = pd.Series(["A", "B", "C", "D"], name="letters")
  29. df_norm = pd.DataFrame(x_norm, index=letters)
  30. x_unif = rs.rand(20, 13)
  31. df_unif = pd.DataFrame(x_unif)
  32. default_kws = dict(vmin=None, vmax=None, cmap=None, center=None,
  33. robust=False, annot=False, fmt=".2f", annot_kws=None,
  34. cbar=True, cbar_kws=None, mask=None)
  35. def test_ndarray_input(self):
  36. p = mat._HeatMapper(self.x_norm, **self.default_kws)
  37. npt.assert_array_equal(p.plot_data, self.x_norm)
  38. pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))
  39. npt.assert_array_equal(p.xticklabels, np.arange(8))
  40. npt.assert_array_equal(p.yticklabels, np.arange(4))
  41. nt.assert_equal(p.xlabel, "")
  42. nt.assert_equal(p.ylabel, "")
  43. def test_df_input(self):
  44. p = mat._HeatMapper(self.df_norm, **self.default_kws)
  45. npt.assert_array_equal(p.plot_data, self.x_norm)
  46. pdt.assert_frame_equal(p.data, self.df_norm)
  47. npt.assert_array_equal(p.xticklabels, np.arange(8))
  48. npt.assert_array_equal(p.yticklabels, self.letters.values)
  49. nt.assert_equal(p.xlabel, "")
  50. nt.assert_equal(p.ylabel, "letters")
  51. def test_df_multindex_input(self):
  52. df = self.df_norm.copy()
  53. index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
  54. ("C", 3), ("D", 4)],
  55. names=["letter", "number"])
  56. index.name = "letter-number"
  57. df.index = index
  58. p = mat._HeatMapper(df, **self.default_kws)
  59. combined_tick_labels = ["A-1", "B-2", "C-3", "D-4"]
  60. npt.assert_array_equal(p.yticklabels, combined_tick_labels)
  61. nt.assert_equal(p.ylabel, "letter-number")
  62. p = mat._HeatMapper(df.T, **self.default_kws)
  63. npt.assert_array_equal(p.xticklabels, combined_tick_labels)
  64. nt.assert_equal(p.xlabel, "letter-number")
  65. def test_mask_input(self):
  66. kws = self.default_kws.copy()
  67. mask = self.x_norm > 0
  68. kws['mask'] = mask
  69. p = mat._HeatMapper(self.x_norm, **kws)
  70. plot_data = np.ma.masked_where(mask, self.x_norm)
  71. npt.assert_array_equal(p.plot_data, plot_data)
  72. def test_mask_limits(self):
  73. """Make sure masked cells are not used to calculate extremes"""
  74. kws = self.default_kws.copy()
  75. mask = self.x_norm > 0
  76. kws['mask'] = mask
  77. p = mat._HeatMapper(self.x_norm, **kws)
  78. assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
  79. assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
  80. mask = self.x_norm < 0
  81. kws['mask'] = mask
  82. p = mat._HeatMapper(self.x_norm, **kws)
  83. assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
  84. assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
  85. def test_default_vlims(self):
  86. p = mat._HeatMapper(self.df_unif, **self.default_kws)
  87. nt.assert_equal(p.vmin, self.x_unif.min())
  88. nt.assert_equal(p.vmax, self.x_unif.max())
  89. def test_robust_vlims(self):
  90. kws = self.default_kws.copy()
  91. kws["robust"] = True
  92. p = mat._HeatMapper(self.df_unif, **kws)
  93. nt.assert_equal(p.vmin, np.percentile(self.x_unif, 2))
  94. nt.assert_equal(p.vmax, np.percentile(self.x_unif, 98))
  95. def test_custom_sequential_vlims(self):
  96. kws = self.default_kws.copy()
  97. kws["vmin"] = 0
  98. kws["vmax"] = 1
  99. p = mat._HeatMapper(self.df_unif, **kws)
  100. nt.assert_equal(p.vmin, 0)
  101. nt.assert_equal(p.vmax, 1)
  102. def test_custom_diverging_vlims(self):
  103. kws = self.default_kws.copy()
  104. kws["vmin"] = -4
  105. kws["vmax"] = 5
  106. kws["center"] = 0
  107. p = mat._HeatMapper(self.df_norm, **kws)
  108. nt.assert_equal(p.vmin, -4)
  109. nt.assert_equal(p.vmax, 5)
  110. def test_array_with_nans(self):
  111. x1 = self.rs.rand(10, 10)
  112. nulls = np.zeros(10) * np.nan
  113. x2 = np.c_[x1, nulls]
  114. m1 = mat._HeatMapper(x1, **self.default_kws)
  115. m2 = mat._HeatMapper(x2, **self.default_kws)
  116. nt.assert_equal(m1.vmin, m2.vmin)
  117. nt.assert_equal(m1.vmax, m2.vmax)
  118. def test_mask(self):
  119. df = pd.DataFrame(data={'a': [1, 1, 1],
  120. 'b': [2, np.nan, 2],
  121. 'c': [3, 3, np.nan]})
  122. kws = self.default_kws.copy()
  123. kws["mask"] = np.isnan(df.values)
  124. m = mat._HeatMapper(df, **kws)
  125. npt.assert_array_equal(np.isnan(m.plot_data.data),
  126. m.plot_data.mask)
  127. def test_custom_cmap(self):
  128. kws = self.default_kws.copy()
  129. kws["cmap"] = "BuGn"
  130. p = mat._HeatMapper(self.df_unif, **kws)
  131. nt.assert_equal(p.cmap, mpl.cm.BuGn)
  132. def test_centered_vlims(self):
  133. kws = self.default_kws.copy()
  134. kws["center"] = .5
  135. p = mat._HeatMapper(self.df_unif, **kws)
  136. nt.assert_equal(p.vmin, self.df_unif.values.min())
  137. nt.assert_equal(p.vmax, self.df_unif.values.max())
  138. def test_default_colors(self):
  139. vals = np.linspace(.2, 1, 9)
  140. cmap = mpl.cm.binary
  141. ax = mat.heatmap([vals], cmap=cmap)
  142. fc = ax.collections[0].get_facecolors()
  143. cvals = np.linspace(0, 1, 9)
  144. npt.assert_array_almost_equal(fc, cmap(cvals), 2)
  145. def test_custom_vlim_colors(self):
  146. vals = np.linspace(.2, 1, 9)
  147. cmap = mpl.cm.binary
  148. ax = mat.heatmap([vals], vmin=0, cmap=cmap)
  149. fc = ax.collections[0].get_facecolors()
  150. npt.assert_array_almost_equal(fc, cmap(vals), 2)
  151. def test_custom_center_colors(self):
  152. vals = np.linspace(.2, 1, 9)
  153. cmap = mpl.cm.binary
  154. ax = mat.heatmap([vals], center=.5, cmap=cmap)
  155. fc = ax.collections[0].get_facecolors()
  156. npt.assert_array_almost_equal(fc, cmap(vals), 2)
  157. def test_cmap_with_properties(self):
  158. kws = self.default_kws.copy()
  159. cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
  160. cmap.set_bad("red")
  161. kws["cmap"] = cmap
  162. hm = mat._HeatMapper(self.df_unif, **kws)
  163. npt.assert_array_equal(
  164. cmap(np.ma.masked_invalid([np.nan])),
  165. hm.cmap(np.ma.masked_invalid([np.nan])))
  166. kws["center"] = 0.5
  167. hm = mat._HeatMapper(self.df_unif, **kws)
  168. npt.assert_array_equal(
  169. cmap(np.ma.masked_invalid([np.nan])),
  170. hm.cmap(np.ma.masked_invalid([np.nan])))
  171. kws = self.default_kws.copy()
  172. cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
  173. cmap.set_under("red")
  174. kws["cmap"] = cmap
  175. hm = mat._HeatMapper(self.df_unif, **kws)
  176. npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
  177. kws["center"] = .5
  178. hm = mat._HeatMapper(self.df_unif, **kws)
  179. npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
  180. kws = self.default_kws.copy()
  181. cmap = copy.copy(mpl.cm.get_cmap("BrBG"))
  182. cmap.set_over("red")
  183. kws["cmap"] = cmap
  184. hm = mat._HeatMapper(self.df_unif, **kws)
  185. npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))
  186. kws["center"] = .5
  187. hm = mat._HeatMapper(self.df_unif, **kws)
  188. npt.assert_array_equal(cmap(np.inf), hm.cmap(np.inf))
  189. def test_tickabels_off(self):
  190. kws = self.default_kws.copy()
  191. kws['xticklabels'] = False
  192. kws['yticklabels'] = False
  193. p = mat._HeatMapper(self.df_norm, **kws)
  194. nt.assert_equal(p.xticklabels, [])
  195. nt.assert_equal(p.yticklabels, [])
  196. def test_custom_ticklabels(self):
  197. kws = self.default_kws.copy()
  198. xticklabels = list('iheartheatmaps'[:self.df_norm.shape[1]])
  199. yticklabels = list('heatmapsarecool'[:self.df_norm.shape[0]])
  200. kws['xticklabels'] = xticklabels
  201. kws['yticklabels'] = yticklabels
  202. p = mat._HeatMapper(self.df_norm, **kws)
  203. nt.assert_equal(p.xticklabels, xticklabels)
  204. nt.assert_equal(p.yticklabels, yticklabels)
  205. def test_custom_ticklabel_interval(self):
  206. kws = self.default_kws.copy()
  207. xstep, ystep = 2, 3
  208. kws['xticklabels'] = xstep
  209. kws['yticklabels'] = ystep
  210. p = mat._HeatMapper(self.df_norm, **kws)
  211. nx, ny = self.df_norm.T.shape
  212. npt.assert_array_equal(p.xticks, np.arange(0, nx, xstep) + .5)
  213. npt.assert_array_equal(p.yticks, np.arange(0, ny, ystep) + .5)
  214. npt.assert_array_equal(p.xticklabels,
  215. self.df_norm.columns[0:nx:xstep])
  216. npt.assert_array_equal(p.yticklabels,
  217. self.df_norm.index[0:ny:ystep])
  218. def test_heatmap_annotation(self):
  219. ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
  220. annot_kws={"fontsize": 14})
  221. for val, text in zip(self.x_norm.flat, ax.texts):
  222. nt.assert_equal(text.get_text(), "{:.1f}".format(val))
  223. nt.assert_equal(text.get_fontsize(), 14)
  224. def test_heatmap_annotation_overwrite_kws(self):
  225. annot_kws = dict(color="0.3", va="bottom", ha="left")
  226. ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
  227. annot_kws=annot_kws)
  228. for text in ax.texts:
  229. nt.assert_equal(text.get_color(), "0.3")
  230. nt.assert_equal(text.get_ha(), "left")
  231. nt.assert_equal(text.get_va(), "bottom")
  232. def test_heatmap_annotation_with_mask(self):
  233. df = pd.DataFrame(data={'a': [1, 1, 1],
  234. 'b': [2, np.nan, 2],
  235. 'c': [3, 3, np.nan]})
  236. mask = np.isnan(df.values)
  237. df_masked = np.ma.masked_where(mask, df)
  238. ax = mat.heatmap(df, annot=True, fmt='.1f', mask=mask)
  239. nt.assert_equal(len(df_masked.compressed()), len(ax.texts))
  240. for val, text in zip(df_masked.compressed(), ax.texts):
  241. nt.assert_equal("{:.1f}".format(val), text.get_text())
  242. def test_heatmap_annotation_mesh_colors(self):
  243. ax = mat.heatmap(self.df_norm, annot=True)
  244. mesh = ax.collections[0]
  245. nt.assert_equal(len(mesh.get_facecolors()), self.df_norm.values.size)
  246. plt.close("all")
  247. def test_heatmap_annotation_other_data(self):
  248. annot_data = self.df_norm + 10
  249. ax = mat.heatmap(self.df_norm, annot=annot_data, fmt=".1f",
  250. annot_kws={"fontsize": 14})
  251. for val, text in zip(annot_data.values.flat, ax.texts):
  252. nt.assert_equal(text.get_text(), "{:.1f}".format(val))
  253. nt.assert_equal(text.get_fontsize(), 14)
  254. def test_heatmap_annotation_with_limited_ticklabels(self):
  255. ax = mat.heatmap(self.df_norm, fmt=".2f", annot=True,
  256. xticklabels=False, yticklabels=False)
  257. for val, text in zip(self.x_norm.flat, ax.texts):
  258. nt.assert_equal(text.get_text(), "{:.2f}".format(val))
  259. def test_heatmap_cbar(self):
  260. f = plt.figure()
  261. mat.heatmap(self.df_norm)
  262. nt.assert_equal(len(f.axes), 2)
  263. plt.close(f)
  264. f = plt.figure()
  265. mat.heatmap(self.df_norm, cbar=False)
  266. nt.assert_equal(len(f.axes), 1)
  267. plt.close(f)
  268. f, (ax1, ax2) = plt.subplots(2)
  269. mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2)
  270. nt.assert_equal(len(f.axes), 2)
  271. plt.close(f)
  272. @pytest.mark.xfail(mpl.__version__ == "3.1.1",
  273. reason="matplotlib 3.1.1 bug")
  274. def test_heatmap_axes(self):
  275. ax = mat.heatmap(self.df_norm)
  276. xtl = [int(l.get_text()) for l in ax.get_xticklabels()]
  277. nt.assert_equal(xtl, list(self.df_norm.columns))
  278. ytl = [l.get_text() for l in ax.get_yticklabels()]
  279. nt.assert_equal(ytl, list(self.df_norm.index))
  280. nt.assert_equal(ax.get_xlabel(), "")
  281. nt.assert_equal(ax.get_ylabel(), "letters")
  282. nt.assert_equal(ax.get_xlim(), (0, 8))
  283. nt.assert_equal(ax.get_ylim(), (4, 0))
  284. def test_heatmap_ticklabel_rotation(self):
  285. f, ax = plt.subplots(figsize=(2, 2))
  286. mat.heatmap(self.df_norm, xticklabels=1, yticklabels=1, ax=ax)
  287. for t in ax.get_xticklabels():
  288. nt.assert_equal(t.get_rotation(), 0)
  289. for t in ax.get_yticklabels():
  290. nt.assert_equal(t.get_rotation(), 90)
  291. plt.close(f)
  292. df = self.df_norm.copy()
  293. df.columns = [str(c) * 10 for c in df.columns]
  294. df.index = [i * 10 for i in df.index]
  295. f, ax = plt.subplots(figsize=(2, 2))
  296. mat.heatmap(df, xticklabels=1, yticklabels=1, ax=ax)
  297. for t in ax.get_xticklabels():
  298. nt.assert_equal(t.get_rotation(), 90)
  299. for t in ax.get_yticklabels():
  300. nt.assert_equal(t.get_rotation(), 0)
  301. plt.close(f)
  302. def test_heatmap_inner_lines(self):
  303. c = (0, 0, 1, 1)
  304. ax = mat.heatmap(self.df_norm, linewidths=2, linecolor=c)
  305. mesh = ax.collections[0]
  306. nt.assert_equal(mesh.get_linewidths()[0], 2)
  307. nt.assert_equal(tuple(mesh.get_edgecolor()[0]), c)
  308. def test_square_aspect(self):
  309. ax = mat.heatmap(self.df_norm, square=True)
  310. nt.assert_equal(ax.get_aspect(), "equal")
  311. def test_mask_validation(self):
  312. mask = mat._matrix_mask(self.df_norm, None)
  313. nt.assert_equal(mask.shape, self.df_norm.shape)
  314. nt.assert_equal(mask.values.sum(), 0)
  315. with nt.assert_raises(ValueError):
  316. bad_array_mask = self.rs.randn(3, 6) > 0
  317. mat._matrix_mask(self.df_norm, bad_array_mask)
  318. with nt.assert_raises(ValueError):
  319. bad_df_mask = pd.DataFrame(self.rs.randn(4, 8) > 0)
  320. mat._matrix_mask(self.df_norm, bad_df_mask)
  321. def test_missing_data_mask(self):
  322. data = pd.DataFrame(np.arange(4, dtype=np.float).reshape(2, 2))
  323. data.loc[0, 0] = np.nan
  324. mask = mat._matrix_mask(data, None)
  325. npt.assert_array_equal(mask, [[True, False], [False, False]])
  326. mask_in = np.array([[False, True], [False, False]])
  327. mask_out = mat._matrix_mask(data, mask_in)
  328. npt.assert_array_equal(mask_out, [[True, True], [False, False]])
  329. def test_cbar_ticks(self):
  330. f, (ax1, ax2) = plt.subplots(2)
  331. mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2,
  332. cbar_kws=dict(drawedges=True))
  333. assert len(ax2.collections) == 2
  334. class TestDendrogram(object):
  335. rs = np.random.RandomState(sum(map(ord, "dendrogram")))
  336. x_norm = rs.randn(4, 8) + np.arange(8)
  337. x_norm = (x_norm.T + np.arange(4)).T
  338. letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
  339. name="letters")
  340. df_norm = pd.DataFrame(x_norm, columns=letters)
  341. try:
  342. import fastcluster
  343. x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
  344. metric='euclidean',
  345. method='single')
  346. except ImportError:
  347. x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
  348. x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
  349. x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
  350. color_threshold=-np.inf)
  351. x_norm_leaves = x_norm_dendrogram['leaves']
  352. df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])
  353. default_kws = dict(linkage=None, metric='euclidean', method='single',
  354. axis=1, label=True, rotate=False)
  355. def test_ndarray_input(self):
  356. p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)
  357. npt.assert_array_equal(p.array.T, self.x_norm)
  358. pdt.assert_frame_equal(p.data.T, pd.DataFrame(self.x_norm))
  359. npt.assert_array_equal(p.linkage, self.x_norm_linkage)
  360. nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram)
  361. npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)
  362. npt.assert_array_equal(p.xticklabels, self.x_norm_leaves)
  363. npt.assert_array_equal(p.yticklabels, [])
  364. nt.assert_equal(p.xlabel, None)
  365. nt.assert_equal(p.ylabel, '')
  366. def test_df_input(self):
  367. p = mat._DendrogramPlotter(self.df_norm, **self.default_kws)
  368. npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
  369. pdt.assert_frame_equal(p.data.T, self.df_norm)
  370. npt.assert_array_equal(p.linkage, self.x_norm_linkage)
  371. nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram)
  372. npt.assert_array_equal(p.xticklabels,
  373. np.asarray(self.df_norm.columns)[
  374. self.x_norm_leaves])
  375. npt.assert_array_equal(p.yticklabels, [])
  376. nt.assert_equal(p.xlabel, 'letters')
  377. nt.assert_equal(p.ylabel, '')
  378. def test_df_multindex_input(self):
  379. df = self.df_norm.copy()
  380. index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
  381. ("C", 3), ("D", 4)],
  382. names=["letter", "number"])
  383. index.name = "letter-number"
  384. df.index = index
  385. kws = self.default_kws.copy()
  386. kws['label'] = True
  387. p = mat._DendrogramPlotter(df.T, **kws)
  388. xticklabels = ["A-1", "B-2", "C-3", "D-4"]
  389. xticklabels = [xticklabels[i] for i in p.reordered_ind]
  390. npt.assert_array_equal(p.xticklabels, xticklabels)
  391. npt.assert_array_equal(p.yticklabels, [])
  392. nt.assert_equal(p.xlabel, "letter-number")
  393. def test_axis0_input(self):
  394. kws = self.default_kws.copy()
  395. kws['axis'] = 0
  396. p = mat._DendrogramPlotter(self.df_norm.T, **kws)
  397. npt.assert_array_equal(p.array, np.asarray(self.df_norm.T))
  398. pdt.assert_frame_equal(p.data, self.df_norm.T)
  399. npt.assert_array_equal(p.linkage, self.x_norm_linkage)
  400. nt.assert_dict_equal(p.dendrogram, self.x_norm_dendrogram)
  401. npt.assert_array_equal(p.xticklabels, self.df_norm_leaves)
  402. npt.assert_array_equal(p.yticklabels, [])
  403. nt.assert_equal(p.xlabel, 'letters')
  404. nt.assert_equal(p.ylabel, '')
  405. def test_rotate_input(self):
  406. kws = self.default_kws.copy()
  407. kws['rotate'] = True
  408. p = mat._DendrogramPlotter(self.df_norm, **kws)
  409. npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
  410. pdt.assert_frame_equal(p.data.T, self.df_norm)
  411. npt.assert_array_equal(p.xticklabels, [])
  412. npt.assert_array_equal(p.yticklabels, self.df_norm_leaves)
  413. nt.assert_equal(p.xlabel, '')
  414. nt.assert_equal(p.ylabel, 'letters')
  415. def test_rotate_axis0_input(self):
  416. kws = self.default_kws.copy()
  417. kws['rotate'] = True
  418. kws['axis'] = 0
  419. p = mat._DendrogramPlotter(self.df_norm.T, **kws)
  420. npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)
  421. def test_custom_linkage(self):
  422. kws = self.default_kws.copy()
  423. try:
  424. import fastcluster
  425. linkage = fastcluster.linkage_vector(self.x_norm, method='single',
  426. metric='euclidean')
  427. except ImportError:
  428. d = distance.pdist(self.x_norm, metric='euclidean')
  429. linkage = hierarchy.linkage(d, method='single')
  430. dendrogram = hierarchy.dendrogram(linkage, no_plot=True,
  431. color_threshold=-np.inf)
  432. kws['linkage'] = linkage
  433. p = mat._DendrogramPlotter(self.df_norm, **kws)
  434. npt.assert_array_equal(p.linkage, linkage)
  435. nt.assert_dict_equal(p.dendrogram, dendrogram)
  436. def test_label_false(self):
  437. kws = self.default_kws.copy()
  438. kws['label'] = False
  439. p = mat._DendrogramPlotter(self.df_norm, **kws)
  440. nt.assert_equal(p.xticks, [])
  441. nt.assert_equal(p.yticks, [])
  442. nt.assert_equal(p.xticklabels, [])
  443. nt.assert_equal(p.yticklabels, [])
  444. nt.assert_equal(p.xlabel, "")
  445. nt.assert_equal(p.ylabel, "")
  446. def test_linkage_scipy(self):
  447. p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)
  448. scipy_linkage = p._calculate_linkage_scipy()
  449. from scipy.spatial import distance
  450. from scipy.cluster import hierarchy
  451. dists = distance.pdist(self.x_norm.T,
  452. metric=self.default_kws['metric'])
  453. linkage = hierarchy.linkage(dists, method=self.default_kws['method'])
  454. npt.assert_array_equal(scipy_linkage, linkage)
  455. @pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
  456. def test_fastcluster_other_method(self):
  457. import fastcluster
  458. kws = self.default_kws.copy()
  459. kws['method'] = 'average'
  460. linkage = fastcluster.linkage(self.x_norm.T, method='average',
  461. metric='euclidean')
  462. p = mat._DendrogramPlotter(self.x_norm, **kws)
  463. npt.assert_array_equal(p.linkage, linkage)
  464. @pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
  465. def test_fastcluster_non_euclidean(self):
  466. import fastcluster
  467. kws = self.default_kws.copy()
  468. kws['metric'] = 'cosine'
  469. kws['method'] = 'average'
  470. linkage = fastcluster.linkage(self.x_norm.T, method=kws['method'],
  471. metric=kws['metric'])
  472. p = mat._DendrogramPlotter(self.x_norm, **kws)
  473. npt.assert_array_equal(p.linkage, linkage)
  474. def test_dendrogram_plot(self):
  475. d = mat.dendrogram(self.x_norm, **self.default_kws)
  476. ax = plt.gca()
  477. xlim = ax.get_xlim()
  478. # 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
  479. xmax = len(d.reordered_ind) * 10
  480. nt.assert_equal(xlim[0], 0)
  481. nt.assert_equal(xlim[1], xmax)
  482. nt.assert_equal(len(ax.collections[0].get_paths()),
  483. len(d.dependent_coord))
  484. @pytest.mark.xfail(mpl.__version__ == "3.1.1",
  485. reason="matplotlib 3.1.1 bug")
  486. def test_dendrogram_rotate(self):
  487. kws = self.default_kws.copy()
  488. kws['rotate'] = True
  489. d = mat.dendrogram(self.x_norm, **kws)
  490. ax = plt.gca()
  491. ylim = ax.get_ylim()
  492. # 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
  493. ymax = len(d.reordered_ind) * 10
  494. # Since y axis is inverted, ylim is (80, 0)
  495. # and therefore not (0, 80) as usual:
  496. nt.assert_equal(ylim[1], 0)
  497. nt.assert_equal(ylim[0], ymax)
  498. def test_dendrogram_ticklabel_rotation(self):
  499. f, ax = plt.subplots(figsize=(2, 2))
  500. mat.dendrogram(self.df_norm, ax=ax)
  501. for t in ax.get_xticklabels():
  502. nt.assert_equal(t.get_rotation(), 0)
  503. plt.close(f)
  504. df = self.df_norm.copy()
  505. df.columns = [str(c) * 10 for c in df.columns]
  506. df.index = [i * 10 for i in df.index]
  507. f, ax = plt.subplots(figsize=(2, 2))
  508. mat.dendrogram(df, ax=ax)
  509. for t in ax.get_xticklabels():
  510. nt.assert_equal(t.get_rotation(), 90)
  511. plt.close(f)
  512. f, ax = plt.subplots(figsize=(2, 2))
  513. mat.dendrogram(df.T, axis=0, rotate=True)
  514. for t in ax.get_yticklabels():
  515. nt.assert_equal(t.get_rotation(), 0)
  516. plt.close(f)
  517. class TestClustermap(object):
  518. rs = np.random.RandomState(sum(map(ord, "clustermap")))
  519. x_norm = rs.randn(4, 8) + np.arange(8)
  520. x_norm = (x_norm.T + np.arange(4)).T
  521. letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
  522. name="letters")
  523. df_norm = pd.DataFrame(x_norm, columns=letters)
  524. try:
  525. import fastcluster
  526. x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
  527. metric='euclidean',
  528. method='single')
  529. except ImportError:
  530. x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
  531. x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
  532. x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
  533. color_threshold=-np.inf)
  534. x_norm_leaves = x_norm_dendrogram['leaves']
  535. df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])
  536. default_kws = dict(pivot_kws=None, z_score=None, standard_scale=None,
  537. figsize=(10, 10), row_colors=None, col_colors=None,
  538. dendrogram_ratio=.2, colors_ratio=.03,
  539. cbar_pos=(0, .8, .05, .2))
  540. default_plot_kws = dict(metric='euclidean', method='average',
  541. colorbar_kws=None,
  542. row_cluster=True, col_cluster=True,
  543. row_linkage=None, col_linkage=None,
  544. tree_kws=None)
  545. row_colors = color_palette('Set2', df_norm.shape[0])
  546. col_colors = color_palette('Dark2', df_norm.shape[1])
  547. def test_ndarray_input(self):
  548. cm = mat.ClusterGrid(self.x_norm, **self.default_kws)
  549. pdt.assert_frame_equal(cm.data, pd.DataFrame(self.x_norm))
  550. nt.assert_equal(len(cm.fig.axes), 4)
  551. nt.assert_equal(cm.ax_row_colors, None)
  552. nt.assert_equal(cm.ax_col_colors, None)
  553. def test_df_input(self):
  554. cm = mat.ClusterGrid(self.df_norm, **self.default_kws)
  555. pdt.assert_frame_equal(cm.data, self.df_norm)
  556. def test_corr_df_input(self):
  557. df = self.df_norm.corr()
  558. cg = mat.ClusterGrid(df, **self.default_kws)
  559. cg.plot(**self.default_plot_kws)
  560. diag = cg.data2d.values[np.diag_indices_from(cg.data2d)]
  561. npt.assert_array_equal(diag, np.ones(cg.data2d.shape[0]))
  562. def test_pivot_input(self):
  563. df_norm = self.df_norm.copy()
  564. df_norm.index.name = 'numbers'
  565. df_long = pd.melt(df_norm.reset_index(), var_name='letters',
  566. id_vars='numbers')
  567. kws = self.default_kws.copy()
  568. kws['pivot_kws'] = dict(index='numbers', columns='letters',
  569. values='value')
  570. cm = mat.ClusterGrid(df_long, **kws)
  571. pdt.assert_frame_equal(cm.data2d, df_norm)
  572. def test_colors_input(self):
  573. kws = self.default_kws.copy()
  574. kws['row_colors'] = self.row_colors
  575. kws['col_colors'] = self.col_colors
  576. cm = mat.ClusterGrid(self.df_norm, **kws)
  577. npt.assert_array_equal(cm.row_colors, self.row_colors)
  578. npt.assert_array_equal(cm.col_colors, self.col_colors)
  579. nt.assert_equal(len(cm.fig.axes), 6)
  580. def test_nested_colors_input(self):
  581. kws = self.default_kws.copy()
  582. row_colors = [self.row_colors, self.row_colors]
  583. col_colors = [self.col_colors, self.col_colors]
  584. kws['row_colors'] = row_colors
  585. kws['col_colors'] = col_colors
  586. cm = mat.ClusterGrid(self.df_norm, **kws)
  587. npt.assert_array_equal(cm.row_colors, row_colors)
  588. npt.assert_array_equal(cm.col_colors, col_colors)
  589. nt.assert_equal(len(cm.fig.axes), 6)
  590. def test_colors_input_custom_cmap(self):
  591. kws = self.default_kws.copy()
  592. kws['cmap'] = mpl.cm.PRGn
  593. kws['row_colors'] = self.row_colors
  594. kws['col_colors'] = self.col_colors
  595. cm = mat.clustermap(self.df_norm, **kws)
  596. npt.assert_array_equal(cm.row_colors, self.row_colors)
  597. npt.assert_array_equal(cm.col_colors, self.col_colors)
  598. nt.assert_equal(len(cm.fig.axes), 6)
  599. def test_z_score(self):
  600. df = self.df_norm.copy()
  601. df = (df - df.mean()) / df.std()
  602. kws = self.default_kws.copy()
  603. kws['z_score'] = 1
  604. cm = mat.ClusterGrid(self.df_norm, **kws)
  605. pdt.assert_frame_equal(cm.data2d, df)
  606. def test_z_score_axis0(self):
  607. df = self.df_norm.copy()
  608. df = df.T
  609. df = (df - df.mean()) / df.std()
  610. df = df.T
  611. kws = self.default_kws.copy()
  612. kws['z_score'] = 0
  613. cm = mat.ClusterGrid(self.df_norm, **kws)
  614. pdt.assert_frame_equal(cm.data2d, df)
  615. def test_standard_scale(self):
  616. df = self.df_norm.copy()
  617. df = (df - df.min()) / (df.max() - df.min())
  618. kws = self.default_kws.copy()
  619. kws['standard_scale'] = 1
  620. cm = mat.ClusterGrid(self.df_norm, **kws)
  621. pdt.assert_frame_equal(cm.data2d, df)
  622. def test_standard_scale_axis0(self):
  623. df = self.df_norm.copy()
  624. df = df.T
  625. df = (df - df.min()) / (df.max() - df.min())
  626. df = df.T
  627. kws = self.default_kws.copy()
  628. kws['standard_scale'] = 0
  629. cm = mat.ClusterGrid(self.df_norm, **kws)
  630. pdt.assert_frame_equal(cm.data2d, df)
  631. def test_z_score_standard_scale(self):
  632. kws = self.default_kws.copy()
  633. kws['z_score'] = True
  634. kws['standard_scale'] = True
  635. with nt.assert_raises(ValueError):
  636. mat.ClusterGrid(self.df_norm, **kws)
  637. def test_color_list_to_matrix_and_cmap(self):
  638. matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
  639. self.col_colors, self.x_norm_leaves)
  640. colors_set = set(self.col_colors)
  641. col_to_value = dict((col, i) for i, col in enumerate(colors_set))
  642. matrix_test = np.array([col_to_value[col] for col in
  643. self.col_colors])[self.x_norm_leaves]
  644. shape = len(self.col_colors), 1
  645. matrix_test = matrix_test.reshape(shape)
  646. cmap_test = mpl.colors.ListedColormap(colors_set)
  647. npt.assert_array_equal(matrix, matrix_test)
  648. npt.assert_array_equal(cmap.colors, cmap_test.colors)
  649. def test_nested_color_list_to_matrix_and_cmap(self):
  650. colors = [self.col_colors, self.col_colors]
  651. matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
  652. colors, self.x_norm_leaves)
  653. all_colors = set(itertools.chain(*colors))
  654. color_to_value = dict((col, i) for i, col in enumerate(all_colors))
  655. matrix_test = np.array(
  656. [color_to_value[c] for color in colors for c in color])
  657. shape = len(colors), len(colors[0])
  658. matrix_test = matrix_test.reshape(shape)
  659. matrix_test = matrix_test[:, self.x_norm_leaves]
  660. matrix_test = matrix_test.T
  661. cmap_test = mpl.colors.ListedColormap(all_colors)
  662. npt.assert_array_equal(matrix, matrix_test)
  663. npt.assert_array_equal(cmap.colors, cmap_test.colors)
  664. def test_color_list_to_matrix_and_cmap_axis1(self):
  665. matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
  666. self.col_colors, self.x_norm_leaves, axis=1)
  667. colors_set = set(self.col_colors)
  668. col_to_value = dict((col, i) for i, col in enumerate(colors_set))
  669. matrix_test = np.array([col_to_value[col] for col in
  670. self.col_colors])[self.x_norm_leaves]
  671. shape = 1, len(self.col_colors)
  672. matrix_test = matrix_test.reshape(shape)
  673. cmap_test = mpl.colors.ListedColormap(colors_set)
  674. npt.assert_array_equal(matrix, matrix_test)
  675. npt.assert_array_equal(cmap.colors, cmap_test.colors)
  676. def test_savefig(self):
  677. # Not sure if this is the right way to test....
  678. cm = mat.ClusterGrid(self.df_norm, **self.default_kws)
  679. cm.plot(**self.default_plot_kws)
  680. cm.savefig(tempfile.NamedTemporaryFile(), format='png')
  681. def test_plot_dendrograms(self):
  682. cm = mat.clustermap(self.df_norm, **self.default_kws)
  683. nt.assert_equal(len(cm.ax_row_dendrogram.collections[0].get_paths()),
  684. len(cm.dendrogram_row.independent_coord))
  685. nt.assert_equal(len(cm.ax_col_dendrogram.collections[0].get_paths()),
  686. len(cm.dendrogram_col.independent_coord))
  687. data2d = self.df_norm.iloc[cm.dendrogram_row.reordered_ind,
  688. cm.dendrogram_col.reordered_ind]
  689. pdt.assert_frame_equal(cm.data2d, data2d)
  690. def test_cluster_false(self):
  691. kws = self.default_kws.copy()
  692. kws['row_cluster'] = False
  693. kws['col_cluster'] = False
  694. cm = mat.clustermap(self.df_norm, **kws)
  695. nt.assert_equal(len(cm.ax_row_dendrogram.lines), 0)
  696. nt.assert_equal(len(cm.ax_col_dendrogram.lines), 0)
  697. nt.assert_equal(len(cm.ax_row_dendrogram.get_xticks()), 0)
  698. nt.assert_equal(len(cm.ax_row_dendrogram.get_yticks()), 0)
  699. nt.assert_equal(len(cm.ax_col_dendrogram.get_xticks()), 0)
  700. nt.assert_equal(len(cm.ax_col_dendrogram.get_yticks()), 0)
  701. pdt.assert_frame_equal(cm.data2d, self.df_norm)
  702. def test_row_col_colors(self):
  703. kws = self.default_kws.copy()
  704. kws['row_colors'] = self.row_colors
  705. kws['col_colors'] = self.col_colors
  706. cm = mat.clustermap(self.df_norm, **kws)
  707. nt.assert_equal(len(cm.ax_row_colors.collections), 1)
  708. nt.assert_equal(len(cm.ax_col_colors.collections), 1)
  709. def test_cluster_false_row_col_colors(self):
  710. kws = self.default_kws.copy()
  711. kws['row_cluster'] = False
  712. kws['col_cluster'] = False
  713. kws['row_colors'] = self.row_colors
  714. kws['col_colors'] = self.col_colors
  715. cm = mat.clustermap(self.df_norm, **kws)
  716. nt.assert_equal(len(cm.ax_row_dendrogram.lines), 0)
  717. nt.assert_equal(len(cm.ax_col_dendrogram.lines), 0)
  718. nt.assert_equal(len(cm.ax_row_dendrogram.get_xticks()), 0)
  719. nt.assert_equal(len(cm.ax_row_dendrogram.get_yticks()), 0)
  720. nt.assert_equal(len(cm.ax_col_dendrogram.get_xticks()), 0)
  721. nt.assert_equal(len(cm.ax_col_dendrogram.get_yticks()), 0)
  722. nt.assert_equal(len(cm.ax_row_colors.collections), 1)
  723. nt.assert_equal(len(cm.ax_col_colors.collections), 1)
  724. pdt.assert_frame_equal(cm.data2d, self.df_norm)
  725. def test_row_col_colors_df(self):
  726. kws = self.default_kws.copy()
  727. kws['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
  728. 'row_2': list(self.row_colors)},
  729. index=self.df_norm.index,
  730. columns=['row_1', 'row_2'])
  731. kws['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
  732. 'col_2': list(self.col_colors)},
  733. index=self.df_norm.columns,
  734. columns=['col_1', 'col_2'])
  735. cm = mat.clustermap(self.df_norm, **kws)
  736. row_labels = [l.get_text() for l in
  737. cm.ax_row_colors.get_xticklabels()]
  738. nt.assert_equal(cm.row_color_labels, ['row_1', 'row_2'])
  739. nt.assert_equal(row_labels, cm.row_color_labels)
  740. col_labels = [l.get_text() for l in
  741. cm.ax_col_colors.get_yticklabels()]
  742. nt.assert_equal(cm.col_color_labels, ['col_1', 'col_2'])
  743. nt.assert_equal(col_labels, cm.col_color_labels)
  744. def test_row_col_colors_df_shuffled(self):
  745. # Tests if colors are properly matched, even if given in wrong order
  746. m, n = self.df_norm.shape
  747. shuffled_inds = [self.df_norm.index[i] for i in
  748. list(range(0, m, 2)) + list(range(1, m, 2))]
  749. shuffled_cols = [self.df_norm.columns[i] for i in
  750. list(range(0, n, 2)) + list(range(1, n, 2))]
  751. kws = self.default_kws.copy()
  752. row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
  753. index=self.df_norm.index)
  754. kws['row_colors'] = row_colors.loc[shuffled_inds]
  755. col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
  756. index=self.df_norm.columns)
  757. kws['col_colors'] = col_colors.loc[shuffled_cols]
  758. cm = mat.clustermap(self.df_norm, **kws)
  759. nt.assert_equal(list(cm.col_colors)[0], list(self.col_colors))
  760. nt.assert_equal(list(cm.row_colors)[0], list(self.row_colors))
  761. def test_row_col_colors_df_missing(self):
  762. kws = self.default_kws.copy()
  763. row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
  764. index=self.df_norm.index)
  765. kws['row_colors'] = row_colors.drop(self.df_norm.index[0])
  766. col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
  767. index=self.df_norm.columns)
  768. kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])
  769. cm = mat.clustermap(self.df_norm, **kws)
  770. nt.assert_equal(list(cm.col_colors)[0],
  771. [(1.0, 1.0, 1.0)] + list(self.col_colors[1:]))
  772. nt.assert_equal(list(cm.row_colors)[0],
  773. [(1.0, 1.0, 1.0)] + list(self.row_colors[1:]))
  774. def test_row_col_colors_df_one_axis(self):
  775. # Test case with only row annotation.
  776. kws1 = self.default_kws.copy()
  777. kws1['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
  778. 'row_2': list(self.row_colors)},
  779. index=self.df_norm.index,
  780. columns=['row_1', 'row_2'])
  781. cm1 = mat.clustermap(self.df_norm, **kws1)
  782. row_labels = [l.get_text() for l in
  783. cm1.ax_row_colors.get_xticklabels()]
  784. nt.assert_equal(cm1.row_color_labels, ['row_1', 'row_2'])
  785. nt.assert_equal(row_labels, cm1.row_color_labels)
  786. # Test case with onl col annotation.
  787. kws2 = self.default_kws.copy()
  788. kws2['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
  789. 'col_2': list(self.col_colors)},
  790. index=self.df_norm.columns,
  791. columns=['col_1', 'col_2'])
  792. cm2 = mat.clustermap(self.df_norm, **kws2)
  793. col_labels = [l.get_text() for l in
  794. cm2.ax_col_colors.get_yticklabels()]
  795. nt.assert_equal(cm2.col_color_labels, ['col_1', 'col_2'])
  796. nt.assert_equal(col_labels, cm2.col_color_labels)
  797. def test_row_col_colors_series(self):
  798. kws = self.default_kws.copy()
  799. kws['row_colors'] = pd.Series(list(self.row_colors), name='row_annot',
  800. index=self.df_norm.index)
  801. kws['col_colors'] = pd.Series(list(self.col_colors), name='col_annot',
  802. index=self.df_norm.columns)
  803. cm = mat.clustermap(self.df_norm, **kws)
  804. row_labels = [l.get_text() for l in
  805. cm.ax_row_colors.get_xticklabels()]
  806. nt.assert_equal(cm.row_color_labels, ['row_annot'])
  807. nt.assert_equal(row_labels, cm.row_color_labels)
  808. col_labels = [l.get_text() for l in
  809. cm.ax_col_colors.get_yticklabels()]
  810. nt.assert_equal(cm.col_color_labels, ['col_annot'])
  811. nt.assert_equal(col_labels, cm.col_color_labels)
  812. def test_row_col_colors_series_shuffled(self):
  813. # Tests if colors are properly matched, even if given in wrong order
  814. m, n = self.df_norm.shape
  815. shuffled_inds = [self.df_norm.index[i] for i in
  816. list(range(0, m, 2)) + list(range(1, m, 2))]
  817. shuffled_cols = [self.df_norm.columns[i] for i in
  818. list(range(0, n, 2)) + list(range(1, n, 2))]
  819. kws = self.default_kws.copy()
  820. row_colors = pd.Series(list(self.row_colors), name='row_annot',
  821. index=self.df_norm.index)
  822. kws['row_colors'] = row_colors.loc[shuffled_inds]
  823. col_colors = pd.Series(list(self.col_colors), name='col_annot',
  824. index=self.df_norm.columns)
  825. kws['col_colors'] = col_colors.loc[shuffled_cols]
  826. cm = mat.clustermap(self.df_norm, **kws)
  827. nt.assert_equal(list(cm.col_colors), list(self.col_colors))
  828. nt.assert_equal(list(cm.row_colors), list(self.row_colors))
  829. def test_row_col_colors_series_missing(self):
  830. kws = self.default_kws.copy()
  831. row_colors = pd.Series(list(self.row_colors), name='row_annot',
  832. index=self.df_norm.index)
  833. kws['row_colors'] = row_colors.drop(self.df_norm.index[0])
  834. col_colors = pd.Series(list(self.col_colors), name='col_annot',
  835. index=self.df_norm.columns)
  836. kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])
  837. cm = mat.clustermap(self.df_norm, **kws)
  838. nt.assert_equal(list(cm.col_colors),
  839. [(1.0, 1.0, 1.0)] + list(self.col_colors[1:]))
  840. nt.assert_equal(list(cm.row_colors),
  841. [(1.0, 1.0, 1.0)] + list(self.row_colors[1:]))
  842. def test_row_col_colors_ignore_heatmap_kwargs(self):
  843. g = mat.clustermap(self.rs.uniform(0, 200, self.df_norm.shape),
  844. row_colors=self.row_colors,
  845. col_colors=self.col_colors,
  846. cmap="Spectral",
  847. norm=mpl.colors.LogNorm(),
  848. vmax=100)
  849. assert np.array_equal(
  850. np.array(self.row_colors)[g.dendrogram_row.reordered_ind],
  851. g.ax_row_colors.collections[0].get_facecolors()[:, :3]
  852. )
  853. assert np.array_equal(
  854. np.array(self.col_colors)[g.dendrogram_col.reordered_ind],
  855. g.ax_col_colors.collections[0].get_facecolors()[:, :3]
  856. )
  857. def test_mask_reorganization(self):
  858. kws = self.default_kws.copy()
  859. kws["mask"] = self.df_norm > 0
  860. g = mat.clustermap(self.df_norm, **kws)
  861. npt.assert_array_equal(g.data2d.index, g.mask.index)
  862. npt.assert_array_equal(g.data2d.columns, g.mask.columns)
  863. npt.assert_array_equal(g.mask.index,
  864. self.df_norm.index[
  865. g.dendrogram_row.reordered_ind])
  866. npt.assert_array_equal(g.mask.columns,
  867. self.df_norm.columns[
  868. g.dendrogram_col.reordered_ind])
  869. def test_ticklabel_reorganization(self):
  870. kws = self.default_kws.copy()
  871. xtl = np.arange(self.df_norm.shape[1])
  872. kws["xticklabels"] = list(xtl)
  873. ytl = self.letters.loc[:self.df_norm.shape[0]]
  874. kws["yticklabels"] = ytl
  875. g = mat.clustermap(self.df_norm, **kws)
  876. xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()]
  877. ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()]
  878. xtl_want = xtl[g.dendrogram_col.reordered_ind].astype("<U1")
  879. ytl_want = ytl[g.dendrogram_row.reordered_ind].astype("<U1")
  880. npt.assert_array_equal(xtl_actual, xtl_want)
  881. npt.assert_array_equal(ytl_actual, ytl_want)
  882. def test_noticklabels(self):
  883. kws = self.default_kws.copy()
  884. kws["xticklabels"] = False
  885. kws["yticklabels"] = False
  886. g = mat.clustermap(self.df_norm, **kws)
  887. xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()]
  888. ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()]
  889. nt.assert_equal(xtl_actual, [])
  890. nt.assert_equal(ytl_actual, [])
  891. def test_size_ratios(self):
  892. # The way that wspace/hspace work in GridSpec, the mapping from input
  893. # ratio to actual width/height of each axes is complicated, so this
  894. # test is just going to assert comparative relationships
  895. kws1 = self.default_kws.copy()
  896. kws1.update(dendrogram_ratio=.2, colors_ratio=.03,
  897. col_colors=self.col_colors, row_colors=self.row_colors)
  898. kws2 = kws1.copy()
  899. kws2.update(dendrogram_ratio=.3, colors_ratio=.05)
  900. g1 = mat.clustermap(self.df_norm, **kws1)
  901. g2 = mat.clustermap(self.df_norm, **kws2)
  902. assert (g2.ax_col_dendrogram.get_position().height
  903. > g1.ax_col_dendrogram.get_position().height)
  904. assert (g2.ax_col_colors.get_position().height
  905. > g1.ax_col_colors.get_position().height)
  906. assert (g2.ax_heatmap.get_position().height
  907. < g1.ax_heatmap.get_position().height)
  908. assert (g2.ax_row_dendrogram.get_position().width
  909. > g1.ax_row_dendrogram.get_position().width)
  910. assert (g2.ax_row_colors.get_position().width
  911. > g1.ax_row_colors.get_position().width)
  912. assert (g2.ax_heatmap.get_position().width
  913. < g1.ax_heatmap.get_position().width)
  914. kws1 = self.default_kws.copy()
  915. kws1.update(col_colors=self.col_colors)
  916. kws2 = kws1.copy()
  917. kws2.update(col_colors=[self.col_colors, self.col_colors])
  918. g1 = mat.clustermap(self.df_norm, **kws1)
  919. g2 = mat.clustermap(self.df_norm, **kws2)
  920. assert (g2.ax_col_colors.get_position().height
  921. > g1.ax_col_colors.get_position().height)
  922. kws1 = self.default_kws.copy()
  923. kws1.update(dendrogram_ratio=(.2, .2))
  924. kws2 = kws1.copy()
  925. kws2.update(dendrogram_ratio=(.2, .3))
  926. g1 = mat.clustermap(self.df_norm, **kws1)
  927. g2 = mat.clustermap(self.df_norm, **kws2)
  928. assert (g2.ax_row_dendrogram.get_position().width
  929. == g1.ax_row_dendrogram.get_position().width)
  930. assert (g2.ax_col_dendrogram.get_position().height
  931. > g1.ax_col_dendrogram.get_position().height)
  932. def test_cbar_pos(self):
  933. kws = self.default_kws.copy()
  934. kws["cbar_pos"] = (.2, .1, .4, .3)
  935. g = mat.clustermap(self.df_norm, **kws)
  936. pos = g.ax_cbar.get_position()
  937. assert pytest.approx(tuple(pos.p0)) == kws["cbar_pos"][:2]
  938. assert pytest.approx(pos.width) == kws["cbar_pos"][2]
  939. assert pytest.approx(pos.height) == kws["cbar_pos"][3]
  940. kws["cbar_pos"] = None
  941. g = mat.clustermap(self.df_norm, **kws)
  942. assert g.ax_cbar is None
  943. def test_square_warning(self):
  944. kws = self.default_kws.copy()
  945. g1 = mat.clustermap(self.df_norm, **kws)
  946. with pytest.warns(UserWarning):
  947. kws["square"] = True
  948. g2 = mat.clustermap(self.df_norm, **kws)
  949. g1_shape = g1.ax_heatmap.get_position().get_points()
  950. g2_shape = g2.ax_heatmap.get_position().get_points()
  951. assert np.array_equal(g1_shape, g2_shape)
  952. def test_clustermap_annotation(self):
  953. g = mat.clustermap(self.df_norm, annot=True, fmt=".1f")
  954. for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
  955. assert text.get_text() == "{:.1f}".format(val)
  956. g = mat.clustermap(self.df_norm, annot=self.df_norm, fmt=".1f")
  957. for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
  958. assert text.get_text() == "{:.1f}".format(val)
  959. def test_tree_kws(self):
  960. rgb = (1, .5, .2)
  961. g = mat.clustermap(self.df_norm, tree_kws=dict(color=rgb))
  962. for ax in [g.ax_col_dendrogram, g.ax_row_dendrogram]:
  963. tree, = ax.collections
  964. assert tuple(tree.get_color().squeeze())[:3] == rgb