test_axisgrid.py 49 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474
  1. import warnings
  2. import numpy as np
  3. import pandas as pd
  4. from scipy import stats
  5. import matplotlib as mpl
  6. import matplotlib.pyplot as plt
  7. import pytest
  8. import nose.tools as nt
  9. import numpy.testing as npt
  10. try:
  11. import pandas.testing as tm
  12. except ImportError:
  13. import pandas.util.testing as tm
  14. from .. import axisgrid as ag
  15. from .. import rcmod
  16. from ..palettes import color_palette
  17. from ..distributions import kdeplot, _freedman_diaconis_bins
  18. from ..categorical import pointplot
  19. from ..utils import categorical_order
  20. rs = np.random.RandomState(0)
  21. class TestFacetGrid(object):
  22. df = pd.DataFrame(dict(x=rs.normal(size=60),
  23. y=rs.gamma(4, size=60),
  24. a=np.repeat(list("abc"), 20),
  25. b=np.tile(list("mn"), 30),
  26. c=np.tile(list("tuv"), 20),
  27. d=np.tile(list("abcdefghijkl"), 5)))
  28. def test_self_data(self):
  29. g = ag.FacetGrid(self.df)
  30. nt.assert_is(g.data, self.df)
  31. def test_self_fig(self):
  32. g = ag.FacetGrid(self.df)
  33. nt.assert_is_instance(g.fig, plt.Figure)
  34. def test_self_axes(self):
  35. g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
  36. for ax in g.axes.flat:
  37. nt.assert_is_instance(ax, plt.Axes)
  38. def test_axes_array_size(self):
  39. g1 = ag.FacetGrid(self.df)
  40. nt.assert_equal(g1.axes.shape, (1, 1))
  41. g2 = ag.FacetGrid(self.df, row="a")
  42. nt.assert_equal(g2.axes.shape, (3, 1))
  43. g3 = ag.FacetGrid(self.df, col="b")
  44. nt.assert_equal(g3.axes.shape, (1, 2))
  45. g4 = ag.FacetGrid(self.df, hue="c")
  46. nt.assert_equal(g4.axes.shape, (1, 1))
  47. g5 = ag.FacetGrid(self.df, row="a", col="b", hue="c")
  48. nt.assert_equal(g5.axes.shape, (3, 2))
  49. for ax in g5.axes.flat:
  50. nt.assert_is_instance(ax, plt.Axes)
  51. def test_single_axes(self):
  52. g1 = ag.FacetGrid(self.df)
  53. nt.assert_is_instance(g1.ax, plt.Axes)
  54. g2 = ag.FacetGrid(self.df, row="a")
  55. with nt.assert_raises(AttributeError):
  56. g2.ax
  57. g3 = ag.FacetGrid(self.df, col="a")
  58. with nt.assert_raises(AttributeError):
  59. g3.ax
  60. g4 = ag.FacetGrid(self.df, col="a", row="b")
  61. with nt.assert_raises(AttributeError):
  62. g4.ax
  63. def test_col_wrap(self):
  64. n = len(self.df.d.unique())
  65. g = ag.FacetGrid(self.df, col="d")
  66. assert g.axes.shape == (1, n)
  67. assert g.facet_axis(0, 8) is g.axes[0, 8]
  68. g_wrap = ag.FacetGrid(self.df, col="d", col_wrap=4)
  69. assert g_wrap.axes.shape == (n,)
  70. assert g_wrap.facet_axis(0, 8) is g_wrap.axes[8]
  71. assert g_wrap._ncol == 4
  72. assert g_wrap._nrow == (n / 4)
  73. with pytest.raises(ValueError):
  74. g = ag.FacetGrid(self.df, row="b", col="d", col_wrap=4)
  75. df = self.df.copy()
  76. df.loc[df.d == "j"] = np.nan
  77. g_missing = ag.FacetGrid(df, col="d")
  78. assert g_missing.axes.shape == (1, n - 1)
  79. g_missing_wrap = ag.FacetGrid(df, col="d", col_wrap=4)
  80. assert g_missing_wrap.axes.shape == (n - 1,)
  81. g = ag.FacetGrid(self.df, col="d", col_wrap=1)
  82. assert len(list(g.facet_data())) == n
  83. def test_normal_axes(self):
  84. null = np.empty(0, object).flat
  85. g = ag.FacetGrid(self.df)
  86. npt.assert_array_equal(g._bottom_axes, g.axes.flat)
  87. npt.assert_array_equal(g._not_bottom_axes, null)
  88. npt.assert_array_equal(g._left_axes, g.axes.flat)
  89. npt.assert_array_equal(g._not_left_axes, null)
  90. npt.assert_array_equal(g._inner_axes, null)
  91. g = ag.FacetGrid(self.df, col="c")
  92. npt.assert_array_equal(g._bottom_axes, g.axes.flat)
  93. npt.assert_array_equal(g._not_bottom_axes, null)
  94. npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
  95. npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
  96. npt.assert_array_equal(g._inner_axes, null)
  97. g = ag.FacetGrid(self.df, row="c")
  98. npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
  99. npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
  100. npt.assert_array_equal(g._left_axes, g.axes.flat)
  101. npt.assert_array_equal(g._not_left_axes, null)
  102. npt.assert_array_equal(g._inner_axes, null)
  103. g = ag.FacetGrid(self.df, col="a", row="c")
  104. npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
  105. npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
  106. npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
  107. npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
  108. npt.assert_array_equal(g._inner_axes, g.axes[:-1, 1:].flat)
  109. def test_wrapped_axes(self):
  110. null = np.empty(0, object).flat
  111. g = ag.FacetGrid(self.df, col="a", col_wrap=2)
  112. npt.assert_array_equal(g._bottom_axes,
  113. g.axes[np.array([1, 2])].flat)
  114. npt.assert_array_equal(g._not_bottom_axes, g.axes[:1].flat)
  115. npt.assert_array_equal(g._left_axes, g.axes[np.array([0, 2])].flat)
  116. npt.assert_array_equal(g._not_left_axes, g.axes[np.array([1])].flat)
  117. npt.assert_array_equal(g._inner_axes, null)
  118. def test_figure_size(self):
  119. g = ag.FacetGrid(self.df, row="a", col="b")
  120. npt.assert_array_equal(g.fig.get_size_inches(), (6, 9))
  121. g = ag.FacetGrid(self.df, row="a", col="b", height=6)
  122. npt.assert_array_equal(g.fig.get_size_inches(), (12, 18))
  123. g = ag.FacetGrid(self.df, col="c", height=4, aspect=.5)
  124. npt.assert_array_equal(g.fig.get_size_inches(), (6, 4))
  125. def test_figure_size_with_legend(self):
  126. g1 = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5)
  127. npt.assert_array_equal(g1.fig.get_size_inches(), (6, 4))
  128. g1.add_legend()
  129. nt.assert_greater(g1.fig.get_size_inches()[0], 6)
  130. g2 = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5,
  131. legend_out=False)
  132. npt.assert_array_equal(g2.fig.get_size_inches(), (6, 4))
  133. g2.add_legend()
  134. npt.assert_array_equal(g2.fig.get_size_inches(), (6, 4))
  135. def test_legend_data(self):
  136. g1 = ag.FacetGrid(self.df, hue="a")
  137. g1.map(plt.plot, "x", "y")
  138. g1.add_legend()
  139. palette = color_palette(n_colors=3)
  140. nt.assert_equal(g1._legend.get_title().get_text(), "a")
  141. a_levels = sorted(self.df.a.unique())
  142. lines = g1._legend.get_lines()
  143. nt.assert_equal(len(lines), len(a_levels))
  144. for line, hue in zip(lines, palette):
  145. nt.assert_equal(line.get_color(), hue)
  146. labels = g1._legend.get_texts()
  147. nt.assert_equal(len(labels), len(a_levels))
  148. for label, level in zip(labels, a_levels):
  149. nt.assert_equal(label.get_text(), level)
  150. def test_legend_data_missing_level(self):
  151. g1 = ag.FacetGrid(self.df, hue="a", hue_order=list("azbc"))
  152. g1.map(plt.plot, "x", "y")
  153. g1.add_legend()
  154. b, g, r, p = color_palette(n_colors=4)
  155. palette = [b, r, p]
  156. nt.assert_equal(g1._legend.get_title().get_text(), "a")
  157. a_levels = sorted(self.df.a.unique())
  158. lines = g1._legend.get_lines()
  159. nt.assert_equal(len(lines), len(a_levels))
  160. for line, hue in zip(lines, palette):
  161. nt.assert_equal(line.get_color(), hue)
  162. labels = g1._legend.get_texts()
  163. nt.assert_equal(len(labels), 4)
  164. for label, level in zip(labels, list("azbc")):
  165. nt.assert_equal(label.get_text(), level)
  166. def test_get_boolean_legend_data(self):
  167. self.df["b_bool"] = self.df.b == "m"
  168. g1 = ag.FacetGrid(self.df, hue="b_bool")
  169. g1.map(plt.plot, "x", "y")
  170. g1.add_legend()
  171. palette = color_palette(n_colors=2)
  172. nt.assert_equal(g1._legend.get_title().get_text(), "b_bool")
  173. b_levels = list(map(str, categorical_order(self.df.b_bool)))
  174. lines = g1._legend.get_lines()
  175. nt.assert_equal(len(lines), len(b_levels))
  176. for line, hue in zip(lines, palette):
  177. nt.assert_equal(line.get_color(), hue)
  178. labels = g1._legend.get_texts()
  179. nt.assert_equal(len(labels), len(b_levels))
  180. for label, level in zip(labels, b_levels):
  181. nt.assert_equal(label.get_text(), level)
  182. def test_legend_tuples(self):
  183. g = ag.FacetGrid(self.df, hue="a")
  184. g.map(plt.plot, "x", "y")
  185. handles, labels = g.ax.get_legend_handles_labels()
  186. label_tuples = [("", l) for l in labels]
  187. legend_data = dict(zip(label_tuples, handles))
  188. g.add_legend(legend_data, label_tuples)
  189. for entry, label in zip(g._legend.get_texts(), labels):
  190. assert entry.get_text() == label
  191. def test_legend_options(self):
  192. g1 = ag.FacetGrid(self.df, hue="b")
  193. g1.map(plt.plot, "x", "y")
  194. g1.add_legend()
  195. def test_legendout_with_colwrap(self):
  196. g = ag.FacetGrid(self.df, col="d", hue='b',
  197. col_wrap=4, legend_out=False)
  198. g.map(plt.plot, "x", "y", linewidth=3)
  199. g.add_legend()
  200. def test_subplot_kws(self):
  201. g = ag.FacetGrid(self.df, despine=False,
  202. subplot_kws=dict(projection="polar"))
  203. for ax in g.axes.flat:
  204. nt.assert_true("PolarAxesSubplot" in str(type(ax)))
  205. def test_gridspec_kws(self):
  206. ratios = [3, 1, 2]
  207. gskws = dict(width_ratios=ratios)
  208. g = ag.FacetGrid(self.df, col='c', row='a', gridspec_kws=gskws)
  209. for ax in g.axes.flat:
  210. ax.set_xticks([])
  211. ax.set_yticks([])
  212. g.fig.tight_layout()
  213. for (l, m, r) in g.axes:
  214. assert l.get_position().width > m.get_position().width
  215. assert r.get_position().width > m.get_position().width
  216. def test_gridspec_kws_col_wrap(self):
  217. ratios = [3, 1, 2, 1, 1]
  218. gskws = dict(width_ratios=ratios)
  219. with warnings.catch_warnings():
  220. warnings.resetwarnings()
  221. warnings.simplefilter("always")
  222. npt.assert_warns(UserWarning, ag.FacetGrid, self.df, col='d',
  223. col_wrap=5, gridspec_kws=gskws)
  224. def test_data_generator(self):
  225. g = ag.FacetGrid(self.df, row="a")
  226. d = list(g.facet_data())
  227. nt.assert_equal(len(d), 3)
  228. tup, data = d[0]
  229. nt.assert_equal(tup, (0, 0, 0))
  230. nt.assert_true((data["a"] == "a").all())
  231. tup, data = d[1]
  232. nt.assert_equal(tup, (1, 0, 0))
  233. nt.assert_true((data["a"] == "b").all())
  234. g = ag.FacetGrid(self.df, row="a", col="b")
  235. d = list(g.facet_data())
  236. nt.assert_equal(len(d), 6)
  237. tup, data = d[0]
  238. nt.assert_equal(tup, (0, 0, 0))
  239. nt.assert_true((data["a"] == "a").all())
  240. nt.assert_true((data["b"] == "m").all())
  241. tup, data = d[1]
  242. nt.assert_equal(tup, (0, 1, 0))
  243. nt.assert_true((data["a"] == "a").all())
  244. nt.assert_true((data["b"] == "n").all())
  245. tup, data = d[2]
  246. nt.assert_equal(tup, (1, 0, 0))
  247. nt.assert_true((data["a"] == "b").all())
  248. nt.assert_true((data["b"] == "m").all())
  249. g = ag.FacetGrid(self.df, hue="c")
  250. d = list(g.facet_data())
  251. nt.assert_equal(len(d), 3)
  252. tup, data = d[1]
  253. nt.assert_equal(tup, (0, 0, 1))
  254. nt.assert_true((data["c"] == "u").all())
  255. def test_map(self):
  256. g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
  257. g.map(plt.plot, "x", "y", linewidth=3)
  258. lines = g.axes[0, 0].lines
  259. nt.assert_equal(len(lines), 3)
  260. line1, _, _ = lines
  261. nt.assert_equal(line1.get_linewidth(), 3)
  262. x, y = line1.get_data()
  263. mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
  264. npt.assert_array_equal(x, self.df.x[mask])
  265. npt.assert_array_equal(y, self.df.y[mask])
  266. def test_map_dataframe(self):
  267. g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
  268. def plot(x, y, data=None, **kws):
  269. plt.plot(data[x], data[y], **kws)
  270. g.map_dataframe(plot, "x", "y", linestyle="--")
  271. lines = g.axes[0, 0].lines
  272. nt.assert_equal(len(lines), 3)
  273. line1, _, _ = lines
  274. nt.assert_equal(line1.get_linestyle(), "--")
  275. x, y = line1.get_data()
  276. mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
  277. npt.assert_array_equal(x, self.df.x[mask])
  278. npt.assert_array_equal(y, self.df.y[mask])
  279. def test_set(self):
  280. g = ag.FacetGrid(self.df, row="a", col="b")
  281. xlim = (-2, 5)
  282. ylim = (3, 6)
  283. xticks = [-2, 0, 3, 5]
  284. yticks = [3, 4.5, 6]
  285. g.set(xlim=xlim, ylim=ylim, xticks=xticks, yticks=yticks)
  286. for ax in g.axes.flat:
  287. npt.assert_array_equal(ax.get_xlim(), xlim)
  288. npt.assert_array_equal(ax.get_ylim(), ylim)
  289. npt.assert_array_equal(ax.get_xticks(), xticks)
  290. npt.assert_array_equal(ax.get_yticks(), yticks)
  291. def test_set_titles(self):
  292. g = ag.FacetGrid(self.df, row="a", col="b")
  293. g.map(plt.plot, "x", "y")
  294. # Test the default titles
  295. nt.assert_equal(g.axes[0, 0].get_title(), "a = a | b = m")
  296. nt.assert_equal(g.axes[0, 1].get_title(), "a = a | b = n")
  297. nt.assert_equal(g.axes[1, 0].get_title(), "a = b | b = m")
  298. # Test a provided title
  299. g.set_titles("{row_var} == {row_name} \\/ {col_var} == {col_name}")
  300. nt.assert_equal(g.axes[0, 0].get_title(), "a == a \\/ b == m")
  301. nt.assert_equal(g.axes[0, 1].get_title(), "a == a \\/ b == n")
  302. nt.assert_equal(g.axes[1, 0].get_title(), "a == b \\/ b == m")
  303. # Test a single row
  304. g = ag.FacetGrid(self.df, col="b")
  305. g.map(plt.plot, "x", "y")
  306. # Test the default titles
  307. nt.assert_equal(g.axes[0, 0].get_title(), "b = m")
  308. nt.assert_equal(g.axes[0, 1].get_title(), "b = n")
  309. # test with dropna=False
  310. g = ag.FacetGrid(self.df, col="b", hue="b", dropna=False)
  311. g.map(plt.plot, 'x', 'y')
  312. def test_set_titles_margin_titles(self):
  313. g = ag.FacetGrid(self.df, row="a", col="b", margin_titles=True)
  314. g.map(plt.plot, "x", "y")
  315. # Test the default titles
  316. nt.assert_equal(g.axes[0, 0].get_title(), "b = m")
  317. nt.assert_equal(g.axes[0, 1].get_title(), "b = n")
  318. nt.assert_equal(g.axes[1, 0].get_title(), "")
  319. # Test the row "titles"
  320. nt.assert_equal(g.axes[0, 1].texts[0].get_text(), "a = a")
  321. nt.assert_equal(g.axes[1, 1].texts[0].get_text(), "a = b")
  322. # Test a provided title
  323. g.set_titles(col_template="{col_var} == {col_name}")
  324. nt.assert_equal(g.axes[0, 0].get_title(), "b == m")
  325. nt.assert_equal(g.axes[0, 1].get_title(), "b == n")
  326. nt.assert_equal(g.axes[1, 0].get_title(), "")
  327. def test_set_ticklabels(self):
  328. g = ag.FacetGrid(self.df, row="a", col="b")
  329. g.map(plt.plot, "x", "y")
  330. xlab = [l.get_text() + "h" for l in g.axes[1, 0].get_xticklabels()]
  331. ylab = [l.get_text() for l in g.axes[1, 0].get_yticklabels()]
  332. g.set_xticklabels(xlab)
  333. g.set_yticklabels(ylab)
  334. got_x = [l.get_text() for l in g.axes[1, 1].get_xticklabels()]
  335. got_y = [l.get_text() for l in g.axes[0, 0].get_yticklabels()]
  336. npt.assert_array_equal(got_x, xlab)
  337. npt.assert_array_equal(got_y, ylab)
  338. x, y = np.arange(10), np.arange(10)
  339. df = pd.DataFrame(np.c_[x, y], columns=["x", "y"])
  340. g = ag.FacetGrid(df).map(pointplot, "x", "y", order=x)
  341. g.set_xticklabels(step=2)
  342. got_x = [int(l.get_text()) for l in g.axes[0, 0].get_xticklabels()]
  343. npt.assert_array_equal(x[::2], got_x)
  344. g = ag.FacetGrid(self.df, col="d", col_wrap=5)
  345. g.map(plt.plot, "x", "y")
  346. g.set_xticklabels(rotation=45)
  347. g.set_yticklabels(rotation=75)
  348. for ax in g._bottom_axes:
  349. for l in ax.get_xticklabels():
  350. nt.assert_equal(l.get_rotation(), 45)
  351. for ax in g._left_axes:
  352. for l in ax.get_yticklabels():
  353. nt.assert_equal(l.get_rotation(), 75)
  354. def test_set_axis_labels(self):
  355. g = ag.FacetGrid(self.df, row="a", col="b")
  356. g.map(plt.plot, "x", "y")
  357. xlab = 'xx'
  358. ylab = 'yy'
  359. g.set_axis_labels(xlab, ylab)
  360. got_x = [ax.get_xlabel() for ax in g.axes[-1, :]]
  361. got_y = [ax.get_ylabel() for ax in g.axes[:, 0]]
  362. npt.assert_array_equal(got_x, xlab)
  363. npt.assert_array_equal(got_y, ylab)
  364. def test_axis_lims(self):
  365. g = ag.FacetGrid(self.df, row="a", col="b", xlim=(0, 4), ylim=(-2, 3))
  366. nt.assert_equal(g.axes[0, 0].get_xlim(), (0, 4))
  367. nt.assert_equal(g.axes[0, 0].get_ylim(), (-2, 3))
  368. def test_data_orders(self):
  369. g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
  370. nt.assert_equal(g.row_names, list("abc"))
  371. nt.assert_equal(g.col_names, list("mn"))
  372. nt.assert_equal(g.hue_names, list("tuv"))
  373. nt.assert_equal(g.axes.shape, (3, 2))
  374. g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
  375. row_order=list("bca"),
  376. col_order=list("nm"),
  377. hue_order=list("vtu"))
  378. nt.assert_equal(g.row_names, list("bca"))
  379. nt.assert_equal(g.col_names, list("nm"))
  380. nt.assert_equal(g.hue_names, list("vtu"))
  381. nt.assert_equal(g.axes.shape, (3, 2))
  382. g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
  383. row_order=list("bcda"),
  384. col_order=list("nom"),
  385. hue_order=list("qvtu"))
  386. nt.assert_equal(g.row_names, list("bcda"))
  387. nt.assert_equal(g.col_names, list("nom"))
  388. nt.assert_equal(g.hue_names, list("qvtu"))
  389. nt.assert_equal(g.axes.shape, (4, 3))
  390. def test_palette(self):
  391. rcmod.set()
  392. g = ag.FacetGrid(self.df, hue="c")
  393. assert g._colors == color_palette(n_colors=len(self.df.c.unique()))
  394. g = ag.FacetGrid(self.df, hue="d")
  395. assert g._colors == color_palette("husl", len(self.df.d.unique()))
  396. g = ag.FacetGrid(self.df, hue="c", palette="Set2")
  397. assert g._colors == color_palette("Set2", len(self.df.c.unique()))
  398. dict_pal = dict(t="red", u="green", v="blue")
  399. list_pal = color_palette(["red", "green", "blue"], 3)
  400. g = ag.FacetGrid(self.df, hue="c", palette=dict_pal)
  401. assert g._colors == list_pal
  402. list_pal = color_palette(["green", "blue", "red"], 3)
  403. g = ag.FacetGrid(self.df, hue="c", hue_order=list("uvt"),
  404. palette=dict_pal)
  405. assert g._colors == list_pal
  406. def test_hue_kws(self):
  407. kws = dict(marker=["o", "s", "D"])
  408. g = ag.FacetGrid(self.df, hue="c", hue_kws=kws)
  409. g.map(plt.plot, "x", "y")
  410. for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
  411. nt.assert_equal(line.get_marker(), marker)
  412. def test_dropna(self):
  413. df = self.df.copy()
  414. hasna = pd.Series(np.tile(np.arange(6), 10), dtype=np.float)
  415. hasna[hasna == 5] = np.nan
  416. df["hasna"] = hasna
  417. g = ag.FacetGrid(df, dropna=False, row="hasna")
  418. nt.assert_equal(g._not_na.sum(), 60)
  419. g = ag.FacetGrid(df, dropna=True, row="hasna")
  420. nt.assert_equal(g._not_na.sum(), 50)
  421. def test_categorical_column_missing_categories(self):
  422. df = self.df.copy()
  423. df['a'] = df['a'].astype('category')
  424. g = ag.FacetGrid(df[df['a'] == 'a'], col="a", col_wrap=1)
  425. nt.assert_equal(g.axes.shape, (len(df['a'].cat.categories),))
  426. def test_categorical_warning(self):
  427. g = ag.FacetGrid(self.df, col="b")
  428. with warnings.catch_warnings():
  429. warnings.resetwarnings()
  430. warnings.simplefilter("always")
  431. npt.assert_warns(UserWarning, g.map, pointplot, "b", "x")
  432. class TestPairGrid(object):
  433. rs = np.random.RandomState(sum(map(ord, "PairGrid")))
  434. df = pd.DataFrame(dict(x=rs.normal(size=60),
  435. y=rs.randint(0, 4, size=(60)),
  436. z=rs.gamma(3, size=60),
  437. a=np.repeat(list("abc"), 20),
  438. b=np.repeat(list("abcdefghijkl"), 5)))
  439. def test_self_data(self):
  440. g = ag.PairGrid(self.df)
  441. nt.assert_is(g.data, self.df)
  442. def test_ignore_datelike_data(self):
  443. df = self.df.copy()
  444. df['date'] = pd.date_range('2010-01-01', periods=len(df), freq='d')
  445. result = ag.PairGrid(self.df).data
  446. expected = df.drop('date', axis=1)
  447. tm.assert_frame_equal(result, expected)
  448. def test_self_fig(self):
  449. g = ag.PairGrid(self.df)
  450. nt.assert_is_instance(g.fig, plt.Figure)
  451. def test_self_axes(self):
  452. g = ag.PairGrid(self.df)
  453. for ax in g.axes.flat:
  454. nt.assert_is_instance(ax, plt.Axes)
  455. def test_default_axes(self):
  456. g = ag.PairGrid(self.df)
  457. nt.assert_equal(g.axes.shape, (3, 3))
  458. nt.assert_equal(g.x_vars, ["x", "y", "z"])
  459. nt.assert_equal(g.y_vars, ["x", "y", "z"])
  460. nt.assert_true(g.square_grid)
  461. def test_specific_square_axes(self):
  462. vars = ["z", "x"]
  463. g = ag.PairGrid(self.df, vars=vars)
  464. nt.assert_equal(g.axes.shape, (len(vars), len(vars)))
  465. nt.assert_equal(g.x_vars, vars)
  466. nt.assert_equal(g.y_vars, vars)
  467. nt.assert_true(g.square_grid)
  468. def test_remove_hue_from_default(self):
  469. hue = "z"
  470. g = ag.PairGrid(self.df, hue=hue)
  471. assert hue not in g.x_vars
  472. assert hue not in g.y_vars
  473. vars = ["x", "y", "z"]
  474. g = ag.PairGrid(self.df, hue=hue, vars=vars)
  475. assert hue in g.x_vars
  476. assert hue in g.y_vars
  477. def test_specific_nonsquare_axes(self):
  478. x_vars = ["x", "y"]
  479. y_vars = ["z", "y", "x"]
  480. g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  481. nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars)))
  482. nt.assert_equal(g.x_vars, x_vars)
  483. nt.assert_equal(g.y_vars, y_vars)
  484. nt.assert_true(not g.square_grid)
  485. x_vars = ["x", "y"]
  486. y_vars = "z"
  487. g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  488. nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars)))
  489. nt.assert_equal(g.x_vars, list(x_vars))
  490. nt.assert_equal(g.y_vars, list(y_vars))
  491. nt.assert_true(not g.square_grid)
  492. def test_specific_square_axes_with_array(self):
  493. vars = np.array(["z", "x"])
  494. g = ag.PairGrid(self.df, vars=vars)
  495. nt.assert_equal(g.axes.shape, (len(vars), len(vars)))
  496. nt.assert_equal(g.x_vars, list(vars))
  497. nt.assert_equal(g.y_vars, list(vars))
  498. nt.assert_true(g.square_grid)
  499. def test_specific_nonsquare_axes_with_array(self):
  500. x_vars = np.array(["x", "y"])
  501. y_vars = np.array(["z", "y", "x"])
  502. g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  503. nt.assert_equal(g.axes.shape, (len(y_vars), len(x_vars)))
  504. nt.assert_equal(g.x_vars, list(x_vars))
  505. nt.assert_equal(g.y_vars, list(y_vars))
  506. nt.assert_true(not g.square_grid)
  507. def test_corner(self):
  508. plot_vars = ["x", "y", "z"]
  509. g1 = ag.PairGrid(self.df, vars=plot_vars, corner=True)
  510. corner_size = sum([i + 1 for i in range(len(plot_vars))])
  511. assert len(g1.fig.axes) == corner_size
  512. g1.map_diag(plt.hist)
  513. assert len(g1.fig.axes) == (corner_size + len(plot_vars))
  514. for ax in np.diag(g1.axes):
  515. assert not ax.yaxis.get_visible()
  516. assert not g1.axes[0, 0].get_ylabel()
  517. def test_size(self):
  518. g1 = ag.PairGrid(self.df, height=3)
  519. npt.assert_array_equal(g1.fig.get_size_inches(), (9, 9))
  520. g2 = ag.PairGrid(self.df, height=4, aspect=.5)
  521. npt.assert_array_equal(g2.fig.get_size_inches(), (6, 12))
  522. g3 = ag.PairGrid(self.df, y_vars=["z"], x_vars=["x", "y"],
  523. height=2, aspect=2)
  524. npt.assert_array_equal(g3.fig.get_size_inches(), (8, 2))
  525. def test_map(self):
  526. vars = ["x", "y", "z"]
  527. g1 = ag.PairGrid(self.df)
  528. g1.map(plt.scatter)
  529. for i, axes_i in enumerate(g1.axes):
  530. for j, ax in enumerate(axes_i):
  531. x_in = self.df[vars[j]]
  532. y_in = self.df[vars[i]]
  533. x_out, y_out = ax.collections[0].get_offsets().T
  534. npt.assert_array_equal(x_in, x_out)
  535. npt.assert_array_equal(y_in, y_out)
  536. g2 = ag.PairGrid(self.df, "a")
  537. g2.map(plt.scatter)
  538. for i, axes_i in enumerate(g2.axes):
  539. for j, ax in enumerate(axes_i):
  540. x_in = self.df[vars[j]]
  541. y_in = self.df[vars[i]]
  542. for k, k_level in enumerate(self.df.a.unique()):
  543. x_in_k = x_in[self.df.a == k_level]
  544. y_in_k = y_in[self.df.a == k_level]
  545. x_out, y_out = ax.collections[k].get_offsets().T
  546. npt.assert_array_equal(x_in_k, x_out)
  547. npt.assert_array_equal(y_in_k, y_out)
  548. def test_map_nonsquare(self):
  549. x_vars = ["x"]
  550. y_vars = ["y", "z"]
  551. g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  552. g.map(plt.scatter)
  553. x_in = self.df.x
  554. for i, i_var in enumerate(y_vars):
  555. ax = g.axes[i, 0]
  556. y_in = self.df[i_var]
  557. x_out, y_out = ax.collections[0].get_offsets().T
  558. npt.assert_array_equal(x_in, x_out)
  559. npt.assert_array_equal(y_in, y_out)
  560. def test_map_lower(self):
  561. vars = ["x", "y", "z"]
  562. g = ag.PairGrid(self.df)
  563. g.map_lower(plt.scatter)
  564. for i, j in zip(*np.tril_indices_from(g.axes, -1)):
  565. ax = g.axes[i, j]
  566. x_in = self.df[vars[j]]
  567. y_in = self.df[vars[i]]
  568. x_out, y_out = ax.collections[0].get_offsets().T
  569. npt.assert_array_equal(x_in, x_out)
  570. npt.assert_array_equal(y_in, y_out)
  571. for i, j in zip(*np.triu_indices_from(g.axes)):
  572. ax = g.axes[i, j]
  573. nt.assert_equal(len(ax.collections), 0)
  574. def test_map_upper(self):
  575. vars = ["x", "y", "z"]
  576. g = ag.PairGrid(self.df)
  577. g.map_upper(plt.scatter)
  578. for i, j in zip(*np.triu_indices_from(g.axes, 1)):
  579. ax = g.axes[i, j]
  580. x_in = self.df[vars[j]]
  581. y_in = self.df[vars[i]]
  582. x_out, y_out = ax.collections[0].get_offsets().T
  583. npt.assert_array_equal(x_in, x_out)
  584. npt.assert_array_equal(y_in, y_out)
  585. for i, j in zip(*np.tril_indices_from(g.axes)):
  586. ax = g.axes[i, j]
  587. nt.assert_equal(len(ax.collections), 0)
  588. def test_map_diag(self):
  589. g1 = ag.PairGrid(self.df)
  590. g1.map_diag(plt.hist)
  591. for var, ax in zip(g1.diag_vars, g1.diag_axes):
  592. nt.assert_equal(len(ax.patches), 10)
  593. assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
  594. g2 = ag.PairGrid(self.df, hue="a")
  595. g2.map_diag(plt.hist)
  596. for ax in g2.diag_axes:
  597. nt.assert_equal(len(ax.patches), 30)
  598. g3 = ag.PairGrid(self.df, hue="a")
  599. g3.map_diag(plt.hist, histtype='step')
  600. for ax in g3.diag_axes:
  601. for ptch in ax.patches:
  602. nt.assert_equal(ptch.fill, False)
  603. def test_map_diag_rectangular(self):
  604. x_vars = ["x", "y"]
  605. y_vars = ["x", "y", "z"]
  606. g1 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  607. g1.map_diag(plt.hist)
  608. assert set(g1.diag_vars) == (set(x_vars) & set(y_vars))
  609. for var, ax in zip(g1.diag_vars, g1.diag_axes):
  610. nt.assert_equal(len(ax.patches), 10)
  611. assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
  612. for i, ax in enumerate(np.diag(g1.axes)):
  613. assert ax.bbox.bounds == g1.diag_axes[i].bbox.bounds
  614. g2 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars, hue="a")
  615. g2.map_diag(plt.hist)
  616. assert set(g2.diag_vars) == (set(x_vars) & set(y_vars))
  617. for ax in g2.diag_axes:
  618. nt.assert_equal(len(ax.patches), 30)
  619. x_vars = ["x", "y", "z"]
  620. y_vars = ["x", "y"]
  621. g3 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
  622. g3.map_diag(plt.hist)
  623. assert set(g3.diag_vars) == (set(x_vars) & set(y_vars))
  624. for var, ax in zip(g3.diag_vars, g3.diag_axes):
  625. nt.assert_equal(len(ax.patches), 10)
  626. assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()
  627. for i, ax in enumerate(np.diag(g3.axes)):
  628. assert ax.bbox.bounds == g3.diag_axes[i].bbox.bounds
  629. def test_map_diag_color(self):
  630. color = "red"
  631. rgb_color = mpl.colors.colorConverter.to_rgba(color)
  632. g1 = ag.PairGrid(self.df)
  633. g1.map_diag(plt.hist, color=color)
  634. for ax in g1.diag_axes:
  635. for patch in ax.patches:
  636. assert patch.get_facecolor() == rgb_color
  637. g2 = ag.PairGrid(self.df)
  638. g2.map_diag(kdeplot, color='red')
  639. for ax in g2.diag_axes:
  640. for line in ax.lines:
  641. assert line.get_color() == color
  642. def test_map_diag_palette(self):
  643. pal = color_palette(n_colors=len(self.df.a.unique()))
  644. g = ag.PairGrid(self.df, hue="a")
  645. g.map_diag(kdeplot)
  646. for ax in g.diag_axes:
  647. for line, color in zip(ax.lines, pal):
  648. assert line.get_color() == color
  649. def test_map_diag_and_offdiag(self):
  650. vars = ["x", "y", "z"]
  651. g = ag.PairGrid(self.df)
  652. g.map_offdiag(plt.scatter)
  653. g.map_diag(plt.hist)
  654. for ax in g.diag_axes:
  655. nt.assert_equal(len(ax.patches), 10)
  656. for i, j in zip(*np.triu_indices_from(g.axes, 1)):
  657. ax = g.axes[i, j]
  658. x_in = self.df[vars[j]]
  659. y_in = self.df[vars[i]]
  660. x_out, y_out = ax.collections[0].get_offsets().T
  661. npt.assert_array_equal(x_in, x_out)
  662. npt.assert_array_equal(y_in, y_out)
  663. for i, j in zip(*np.tril_indices_from(g.axes, -1)):
  664. ax = g.axes[i, j]
  665. x_in = self.df[vars[j]]
  666. y_in = self.df[vars[i]]
  667. x_out, y_out = ax.collections[0].get_offsets().T
  668. npt.assert_array_equal(x_in, x_out)
  669. npt.assert_array_equal(y_in, y_out)
  670. for i, j in zip(*np.diag_indices_from(g.axes)):
  671. ax = g.axes[i, j]
  672. nt.assert_equal(len(ax.collections), 0)
  673. def test_diag_sharey(self):
  674. g = ag.PairGrid(self.df, diag_sharey=True)
  675. g.map_diag(kdeplot)
  676. for ax in g.diag_axes[1:]:
  677. assert ax.get_ylim() == g.diag_axes[0].get_ylim()
  678. def test_palette(self):
  679. rcmod.set()
  680. g = ag.PairGrid(self.df, hue="a")
  681. assert g.palette == color_palette(n_colors=len(self.df.a.unique()))
  682. g = ag.PairGrid(self.df, hue="b")
  683. assert g.palette == color_palette("husl", len(self.df.b.unique()))
  684. g = ag.PairGrid(self.df, hue="a", palette="Set2")
  685. assert g.palette == color_palette("Set2", len(self.df.a.unique()))
  686. dict_pal = dict(a="red", b="green", c="blue")
  687. list_pal = color_palette(["red", "green", "blue"])
  688. g = ag.PairGrid(self.df, hue="a", palette=dict_pal)
  689. assert g.palette == list_pal
  690. list_pal = color_palette(["blue", "red", "green"])
  691. g = ag.PairGrid(self.df, hue="a", hue_order=list("cab"),
  692. palette=dict_pal)
  693. assert g.palette == list_pal
  694. def test_hue_kws(self):
  695. kws = dict(marker=["o", "s", "d", "+"])
  696. g = ag.PairGrid(self.df, hue="a", hue_kws=kws)
  697. g.map(plt.plot)
  698. for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
  699. nt.assert_equal(line.get_marker(), marker)
  700. g = ag.PairGrid(self.df, hue="a", hue_kws=kws,
  701. hue_order=list("dcab"))
  702. g.map(plt.plot)
  703. for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
  704. nt.assert_equal(line.get_marker(), marker)
  705. def test_hue_order(self):
  706. order = list("dcab")
  707. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  708. g.map(plt.plot)
  709. for line, level in zip(g.axes[1, 0].lines, order):
  710. x, y = line.get_xydata().T
  711. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  712. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
  713. plt.close("all")
  714. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  715. g.map_diag(plt.plot)
  716. for line, level in zip(g.axes[0, 0].lines, order):
  717. x, y = line.get_xydata().T
  718. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  719. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
  720. plt.close("all")
  721. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  722. g.map_lower(plt.plot)
  723. for line, level in zip(g.axes[1, 0].lines, order):
  724. x, y = line.get_xydata().T
  725. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  726. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
  727. plt.close("all")
  728. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  729. g.map_upper(plt.plot)
  730. for line, level in zip(g.axes[0, 1].lines, order):
  731. x, y = line.get_xydata().T
  732. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
  733. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
  734. plt.close("all")
  735. def test_hue_order_missing_level(self):
  736. order = list("dcaeb")
  737. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  738. g.map(plt.plot)
  739. for line, level in zip(g.axes[1, 0].lines, order):
  740. x, y = line.get_xydata().T
  741. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  742. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
  743. plt.close("all")
  744. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  745. g.map_diag(plt.plot)
  746. for line, level in zip(g.axes[0, 0].lines, order):
  747. x, y = line.get_xydata().T
  748. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  749. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
  750. plt.close("all")
  751. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  752. g.map_lower(plt.plot)
  753. for line, level in zip(g.axes[1, 0].lines, order):
  754. x, y = line.get_xydata().T
  755. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
  756. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])
  757. plt.close("all")
  758. g = ag.PairGrid(self.df, hue="a", hue_order=order)
  759. g.map_upper(plt.plot)
  760. for line, level in zip(g.axes[0, 1].lines, order):
  761. x, y = line.get_xydata().T
  762. npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
  763. npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])
  764. plt.close("all")
  765. def test_nondefault_index(self):
  766. df = self.df.copy().set_index("b")
  767. plot_vars = ["x", "y", "z"]
  768. g1 = ag.PairGrid(df)
  769. g1.map(plt.scatter)
  770. for i, axes_i in enumerate(g1.axes):
  771. for j, ax in enumerate(axes_i):
  772. x_in = self.df[plot_vars[j]]
  773. y_in = self.df[plot_vars[i]]
  774. x_out, y_out = ax.collections[0].get_offsets().T
  775. npt.assert_array_equal(x_in, x_out)
  776. npt.assert_array_equal(y_in, y_out)
  777. g2 = ag.PairGrid(df, "a")
  778. g2.map(plt.scatter)
  779. for i, axes_i in enumerate(g2.axes):
  780. for j, ax in enumerate(axes_i):
  781. x_in = self.df[plot_vars[j]]
  782. y_in = self.df[plot_vars[i]]
  783. for k, k_level in enumerate(self.df.a.unique()):
  784. x_in_k = x_in[self.df.a == k_level]
  785. y_in_k = y_in[self.df.a == k_level]
  786. x_out, y_out = ax.collections[k].get_offsets().T
  787. npt.assert_array_equal(x_in_k, x_out)
  788. npt.assert_array_equal(y_in_k, y_out)
  789. def test_dropna(self):
  790. df = self.df.copy()
  791. n_null = 20
  792. df.loc[np.arange(n_null), "x"] = np.nan
  793. plot_vars = ["x", "y", "z"]
  794. g1 = ag.PairGrid(df, vars=plot_vars, dropna=True)
  795. g1.map(plt.scatter)
  796. for i, axes_i in enumerate(g1.axes):
  797. for j, ax in enumerate(axes_i):
  798. x_in = df[plot_vars[j]]
  799. y_in = df[plot_vars[i]]
  800. x_out, y_out = ax.collections[0].get_offsets().T
  801. n_valid = (x_in * y_in).notnull().sum()
  802. assert n_valid == len(x_out)
  803. assert n_valid == len(y_out)
  804. def test_pairplot(self):
  805. vars = ["x", "y", "z"]
  806. g = ag.pairplot(self.df)
  807. for ax in g.diag_axes:
  808. assert len(ax.patches) > 1
  809. for i, j in zip(*np.triu_indices_from(g.axes, 1)):
  810. ax = g.axes[i, j]
  811. x_in = self.df[vars[j]]
  812. y_in = self.df[vars[i]]
  813. x_out, y_out = ax.collections[0].get_offsets().T
  814. npt.assert_array_equal(x_in, x_out)
  815. npt.assert_array_equal(y_in, y_out)
  816. for i, j in zip(*np.tril_indices_from(g.axes, -1)):
  817. ax = g.axes[i, j]
  818. x_in = self.df[vars[j]]
  819. y_in = self.df[vars[i]]
  820. x_out, y_out = ax.collections[0].get_offsets().T
  821. npt.assert_array_equal(x_in, x_out)
  822. npt.assert_array_equal(y_in, y_out)
  823. for i, j in zip(*np.diag_indices_from(g.axes)):
  824. ax = g.axes[i, j]
  825. nt.assert_equal(len(ax.collections), 0)
  826. g = ag.pairplot(self.df, hue="a")
  827. n = len(self.df.a.unique())
  828. for ax in g.diag_axes:
  829. assert len(ax.lines) == n
  830. assert len(ax.collections) == n
  831. def test_pairplot_reg(self):
  832. vars = ["x", "y", "z"]
  833. g = ag.pairplot(self.df, diag_kind="hist", kind="reg")
  834. for ax in g.diag_axes:
  835. nt.assert_equal(len(ax.patches), 10)
  836. for i, j in zip(*np.triu_indices_from(g.axes, 1)):
  837. ax = g.axes[i, j]
  838. x_in = self.df[vars[j]]
  839. y_in = self.df[vars[i]]
  840. x_out, y_out = ax.collections[0].get_offsets().T
  841. npt.assert_array_equal(x_in, x_out)
  842. npt.assert_array_equal(y_in, y_out)
  843. nt.assert_equal(len(ax.lines), 1)
  844. nt.assert_equal(len(ax.collections), 2)
  845. for i, j in zip(*np.tril_indices_from(g.axes, -1)):
  846. ax = g.axes[i, j]
  847. x_in = self.df[vars[j]]
  848. y_in = self.df[vars[i]]
  849. x_out, y_out = ax.collections[0].get_offsets().T
  850. npt.assert_array_equal(x_in, x_out)
  851. npt.assert_array_equal(y_in, y_out)
  852. nt.assert_equal(len(ax.lines), 1)
  853. nt.assert_equal(len(ax.collections), 2)
  854. for i, j in zip(*np.diag_indices_from(g.axes)):
  855. ax = g.axes[i, j]
  856. nt.assert_equal(len(ax.collections), 0)
  857. def test_pairplot_kde(self):
  858. vars = ["x", "y", "z"]
  859. g = ag.pairplot(self.df, diag_kind="kde")
  860. for ax in g.diag_axes:
  861. nt.assert_equal(len(ax.lines), 1)
  862. for i, j in zip(*np.triu_indices_from(g.axes, 1)):
  863. ax = g.axes[i, j]
  864. x_in = self.df[vars[j]]
  865. y_in = self.df[vars[i]]
  866. x_out, y_out = ax.collections[0].get_offsets().T
  867. npt.assert_array_equal(x_in, x_out)
  868. npt.assert_array_equal(y_in, y_out)
  869. for i, j in zip(*np.tril_indices_from(g.axes, -1)):
  870. ax = g.axes[i, j]
  871. x_in = self.df[vars[j]]
  872. y_in = self.df[vars[i]]
  873. x_out, y_out = ax.collections[0].get_offsets().T
  874. npt.assert_array_equal(x_in, x_out)
  875. npt.assert_array_equal(y_in, y_out)
  876. for i, j in zip(*np.diag_indices_from(g.axes)):
  877. ax = g.axes[i, j]
  878. nt.assert_equal(len(ax.collections), 0)
  879. def test_pairplot_markers(self):
  880. vars = ["x", "y", "z"]
  881. markers = ["o", "x", "s"]
  882. g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers)
  883. assert g.hue_kws["marker"] == markers
  884. plt.close("all")
  885. with pytest.raises(ValueError):
  886. g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers[:-2])
  887. class TestJointGrid(object):
  888. rs = np.random.RandomState(sum(map(ord, "JointGrid")))
  889. x = rs.randn(100)
  890. y = rs.randn(100)
  891. x_na = x.copy()
  892. x_na[10] = np.nan
  893. x_na[20] = np.nan
  894. data = pd.DataFrame(dict(x=x, y=y, x_na=x_na))
  895. def test_margin_grid_from_lists(self):
  896. g = ag.JointGrid(self.x.tolist(), self.y.tolist())
  897. npt.assert_array_equal(g.x, self.x)
  898. npt.assert_array_equal(g.y, self.y)
  899. def test_margin_grid_from_arrays(self):
  900. g = ag.JointGrid(self.x, self.y)
  901. npt.assert_array_equal(g.x, self.x)
  902. npt.assert_array_equal(g.y, self.y)
  903. def test_margin_grid_from_series(self):
  904. g = ag.JointGrid(self.data.x, self.data.y)
  905. npt.assert_array_equal(g.x, self.x)
  906. npt.assert_array_equal(g.y, self.y)
  907. def test_margin_grid_from_dataframe(self):
  908. g = ag.JointGrid("x", "y", self.data)
  909. npt.assert_array_equal(g.x, self.x)
  910. npt.assert_array_equal(g.y, self.y)
  911. def test_margin_grid_from_dataframe_bad_variable(self):
  912. with nt.assert_raises(ValueError):
  913. ag.JointGrid("x", "bad_column", self.data)
  914. def test_margin_grid_axis_labels(self):
  915. g = ag.JointGrid("x", "y", self.data)
  916. xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
  917. nt.assert_equal(xlabel, "x")
  918. nt.assert_equal(ylabel, "y")
  919. g.set_axis_labels("x variable", "y variable")
  920. xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
  921. nt.assert_equal(xlabel, "x variable")
  922. nt.assert_equal(ylabel, "y variable")
  923. def test_dropna(self):
  924. g = ag.JointGrid("x_na", "y", self.data, dropna=False)
  925. nt.assert_equal(len(g.x), len(self.x_na))
  926. g = ag.JointGrid("x_na", "y", self.data, dropna=True)
  927. nt.assert_equal(len(g.x), pd.notnull(self.x_na).sum())
  928. def test_axlims(self):
  929. lim = (-3, 3)
  930. g = ag.JointGrid("x", "y", self.data, xlim=lim, ylim=lim)
  931. nt.assert_equal(g.ax_joint.get_xlim(), lim)
  932. nt.assert_equal(g.ax_joint.get_ylim(), lim)
  933. nt.assert_equal(g.ax_marg_x.get_xlim(), lim)
  934. nt.assert_equal(g.ax_marg_y.get_ylim(), lim)
  935. def test_marginal_ticks(self):
  936. g = ag.JointGrid("x", "y", self.data)
  937. nt.assert_true(~len(g.ax_marg_x.get_xticks()))
  938. nt.assert_true(~len(g.ax_marg_y.get_yticks()))
  939. def test_bivariate_plot(self):
  940. g = ag.JointGrid("x", "y", self.data)
  941. g.plot_joint(plt.plot)
  942. x, y = g.ax_joint.lines[0].get_xydata().T
  943. npt.assert_array_equal(x, self.x)
  944. npt.assert_array_equal(y, self.y)
  945. def test_univariate_plot(self):
  946. g = ag.JointGrid("x", "x", self.data)
  947. g.plot_marginals(kdeplot)
  948. _, y1 = g.ax_marg_x.lines[0].get_xydata().T
  949. y2, _ = g.ax_marg_y.lines[0].get_xydata().T
  950. npt.assert_array_equal(y1, y2)
  951. def test_plot(self):
  952. g = ag.JointGrid("x", "x", self.data)
  953. g.plot(plt.plot, kdeplot)
  954. x, y = g.ax_joint.lines[0].get_xydata().T
  955. npt.assert_array_equal(x, self.x)
  956. npt.assert_array_equal(y, self.x)
  957. _, y1 = g.ax_marg_x.lines[0].get_xydata().T
  958. y2, _ = g.ax_marg_y.lines[0].get_xydata().T
  959. npt.assert_array_equal(y1, y2)
  960. def test_annotate(self):
  961. g = ag.JointGrid("x", "y", self.data)
  962. rp = stats.pearsonr(self.x, self.y)
  963. with pytest.warns(UserWarning):
  964. g.annotate(stats.pearsonr)
  965. annotation = g.ax_joint.legend_.texts[0].get_text()
  966. nt.assert_equal(annotation, "pearsonr = %.2g; p = %.2g" % rp)
  967. with pytest.warns(UserWarning):
  968. g.annotate(stats.pearsonr, stat="correlation")
  969. annotation = g.ax_joint.legend_.texts[0].get_text()
  970. nt.assert_equal(annotation, "correlation = %.2g; p = %.2g" % rp)
  971. def rsquared(x, y):
  972. return stats.pearsonr(x, y)[0] ** 2
  973. r2 = rsquared(self.x, self.y)
  974. with pytest.warns(UserWarning):
  975. g.annotate(rsquared)
  976. annotation = g.ax_joint.legend_.texts[0].get_text()
  977. nt.assert_equal(annotation, "rsquared = %.2g" % r2)
  978. template = "{stat} = {val:.3g} (p = {p:.3g})"
  979. with pytest.warns(UserWarning):
  980. g.annotate(stats.pearsonr, template=template)
  981. annotation = g.ax_joint.legend_.texts[0].get_text()
  982. nt.assert_equal(annotation, template.format(stat="pearsonr",
  983. val=rp[0], p=rp[1]))
  984. def test_space(self):
  985. g = ag.JointGrid("x", "y", self.data, space=0)
  986. joint_bounds = g.ax_joint.bbox.bounds
  987. marg_x_bounds = g.ax_marg_x.bbox.bounds
  988. marg_y_bounds = g.ax_marg_y.bbox.bounds
  989. nt.assert_equal(joint_bounds[2], marg_x_bounds[2])
  990. nt.assert_equal(joint_bounds[3], marg_y_bounds[3])
  991. class TestJointPlot(object):
  992. rs = np.random.RandomState(sum(map(ord, "jointplot")))
  993. x = rs.randn(100)
  994. y = rs.randn(100)
  995. data = pd.DataFrame(dict(x=x, y=y))
  996. def test_scatter(self):
  997. g = ag.jointplot("x", "y", self.data)
  998. nt.assert_equal(len(g.ax_joint.collections), 1)
  999. x, y = g.ax_joint.collections[0].get_offsets().T
  1000. npt.assert_array_equal(self.x, x)
  1001. npt.assert_array_equal(self.y, y)
  1002. x_bins = _freedman_diaconis_bins(self.x)
  1003. nt.assert_equal(len(g.ax_marg_x.patches), x_bins)
  1004. y_bins = _freedman_diaconis_bins(self.y)
  1005. nt.assert_equal(len(g.ax_marg_y.patches), y_bins)
  1006. def test_reg(self):
  1007. g = ag.jointplot("x", "y", self.data, kind="reg")
  1008. nt.assert_equal(len(g.ax_joint.collections), 2)
  1009. x, y = g.ax_joint.collections[0].get_offsets().T
  1010. npt.assert_array_equal(self.x, x)
  1011. npt.assert_array_equal(self.y, y)
  1012. x_bins = _freedman_diaconis_bins(self.x)
  1013. nt.assert_equal(len(g.ax_marg_x.patches), x_bins)
  1014. y_bins = _freedman_diaconis_bins(self.y)
  1015. nt.assert_equal(len(g.ax_marg_y.patches), y_bins)
  1016. nt.assert_equal(len(g.ax_joint.lines), 1)
  1017. nt.assert_equal(len(g.ax_marg_x.lines), 1)
  1018. nt.assert_equal(len(g.ax_marg_y.lines), 1)
  1019. def test_resid(self):
  1020. g = ag.jointplot("x", "y", self.data, kind="resid")
  1021. nt.assert_equal(len(g.ax_joint.collections), 1)
  1022. nt.assert_equal(len(g.ax_joint.lines), 1)
  1023. nt.assert_equal(len(g.ax_marg_x.lines), 0)
  1024. nt.assert_equal(len(g.ax_marg_y.lines), 1)
  1025. def test_hex(self):
  1026. g = ag.jointplot("x", "y", self.data, kind="hex")
  1027. nt.assert_equal(len(g.ax_joint.collections), 1)
  1028. x_bins = _freedman_diaconis_bins(self.x)
  1029. nt.assert_equal(len(g.ax_marg_x.patches), x_bins)
  1030. y_bins = _freedman_diaconis_bins(self.y)
  1031. nt.assert_equal(len(g.ax_marg_y.patches), y_bins)
  1032. def test_kde(self):
  1033. g = ag.jointplot("x", "y", self.data, kind="kde")
  1034. nt.assert_true(len(g.ax_joint.collections) > 0)
  1035. nt.assert_equal(len(g.ax_marg_x.collections), 1)
  1036. nt.assert_equal(len(g.ax_marg_y.collections), 1)
  1037. nt.assert_equal(len(g.ax_marg_x.lines), 1)
  1038. nt.assert_equal(len(g.ax_marg_y.lines), 1)
  1039. def test_color(self):
  1040. g = ag.jointplot("x", "y", self.data, color="purple")
  1041. purple = mpl.colors.colorConverter.to_rgb("purple")
  1042. scatter_color = g.ax_joint.collections[0].get_facecolor()[0, :3]
  1043. nt.assert_equal(tuple(scatter_color), purple)
  1044. hist_color = g.ax_marg_x.patches[0].get_facecolor()[:3]
  1045. nt.assert_equal(hist_color, purple)
  1046. def test_annotation(self):
  1047. with pytest.warns(UserWarning):
  1048. g = ag.jointplot("x", "y", self.data, stat_func=stats.pearsonr)
  1049. nt.assert_equal(len(g.ax_joint.legend_.get_texts()), 1)
  1050. g = ag.jointplot("x", "y", self.data, stat_func=None)
  1051. nt.assert_is(g.ax_joint.legend_, None)
  1052. def test_hex_customise(self):
  1053. # test that default gridsize can be overridden
  1054. g = ag.jointplot("x", "y", self.data, kind="hex",
  1055. joint_kws=dict(gridsize=5))
  1056. nt.assert_equal(len(g.ax_joint.collections), 1)
  1057. a = g.ax_joint.collections[0].get_array()
  1058. nt.assert_equal(28, a.shape[0]) # 28 hexagons expected for gridsize 5
  1059. def test_bad_kind(self):
  1060. with nt.assert_raises(ValueError):
  1061. ag.jointplot("x", "y", self.data, kind="not_a_kind")
  1062. def test_leaky_dict(self):
  1063. # Validate input dicts are unchanged by jointplot plotting function
  1064. for kwarg in ("joint_kws", "marginal_kws", "annot_kws"):
  1065. for kind in ("hex", "kde", "resid", "reg", "scatter"):
  1066. empty_dict = {}
  1067. ag.jointplot("x", "y", self.data, kind=kind,
  1068. **{kwarg: empty_dict})
  1069. assert empty_dict == {}