test_regression.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628
  1. import numpy as np
  2. import matplotlib as mpl
  3. import matplotlib.pyplot as plt
  4. import pandas as pd
  5. import pytest
  6. import nose.tools as nt
  7. import numpy.testing as npt
  8. try:
  9. import pandas.testing as pdt
  10. except ImportError:
  11. import pandas.util.testing as pdt
  12. try:
  13. import statsmodels.regression.linear_model as smlm
  14. _no_statsmodels = False
  15. except ImportError:
  16. _no_statsmodels = True
  17. from .. import regression as lm
  18. from ..palettes import color_palette
  19. rs = np.random.RandomState(0)
  20. class TestLinearPlotter(object):
  21. rs = np.random.RandomState(77)
  22. df = pd.DataFrame(dict(x=rs.normal(size=60),
  23. d=rs.randint(-2, 3, 60),
  24. y=rs.gamma(4, size=60),
  25. s=np.tile(list("abcdefghij"), 6)))
  26. df["z"] = df.y + rs.randn(60)
  27. df["y_na"] = df.y.copy()
  28. df.loc[[10, 20, 30], 'y_na'] = np.nan
  29. def test_establish_variables_from_frame(self):
  30. p = lm._LinearPlotter()
  31. p.establish_variables(self.df, x="x", y="y")
  32. pdt.assert_series_equal(p.x, self.df.x)
  33. pdt.assert_series_equal(p.y, self.df.y)
  34. pdt.assert_frame_equal(p.data, self.df)
  35. def test_establish_variables_from_series(self):
  36. p = lm._LinearPlotter()
  37. p.establish_variables(None, x=self.df.x, y=self.df.y)
  38. pdt.assert_series_equal(p.x, self.df.x)
  39. pdt.assert_series_equal(p.y, self.df.y)
  40. nt.assert_is(p.data, None)
  41. def test_establish_variables_from_array(self):
  42. p = lm._LinearPlotter()
  43. p.establish_variables(None,
  44. x=self.df.x.values,
  45. y=self.df.y.values)
  46. npt.assert_array_equal(p.x, self.df.x)
  47. npt.assert_array_equal(p.y, self.df.y)
  48. nt.assert_is(p.data, None)
  49. def test_establish_variables_from_lists(self):
  50. p = lm._LinearPlotter()
  51. p.establish_variables(None,
  52. x=self.df.x.values.tolist(),
  53. y=self.df.y.values.tolist())
  54. npt.assert_array_equal(p.x, self.df.x)
  55. npt.assert_array_equal(p.y, self.df.y)
  56. nt.assert_is(p.data, None)
  57. def test_establish_variables_from_mix(self):
  58. p = lm._LinearPlotter()
  59. p.establish_variables(self.df, x="x", y=self.df.y)
  60. pdt.assert_series_equal(p.x, self.df.x)
  61. pdt.assert_series_equal(p.y, self.df.y)
  62. pdt.assert_frame_equal(p.data, self.df)
  63. def test_establish_variables_from_bad(self):
  64. p = lm._LinearPlotter()
  65. with nt.assert_raises(ValueError):
  66. p.establish_variables(None, x="x", y=self.df.y)
  67. def test_dropna(self):
  68. p = lm._LinearPlotter()
  69. p.establish_variables(self.df, x="x", y_na="y_na")
  70. pdt.assert_series_equal(p.x, self.df.x)
  71. pdt.assert_series_equal(p.y_na, self.df.y_na)
  72. p.dropna("x", "y_na")
  73. mask = self.df.y_na.notnull()
  74. pdt.assert_series_equal(p.x, self.df.x[mask])
  75. pdt.assert_series_equal(p.y_na, self.df.y_na[mask])
  76. class TestRegressionPlotter(object):
  77. rs = np.random.RandomState(49)
  78. grid = np.linspace(-3, 3, 30)
  79. n_boot = 100
  80. bins_numeric = 3
  81. bins_given = [-1, 0, 1]
  82. df = pd.DataFrame(dict(x=rs.normal(size=60),
  83. d=rs.randint(-2, 3, 60),
  84. y=rs.gamma(4, size=60),
  85. s=np.tile(list(range(6)), 10)))
  86. df["z"] = df.y + rs.randn(60)
  87. df["y_na"] = df.y.copy()
  88. bw_err = rs.randn(6)[df.s.values] * 2
  89. df.y += bw_err
  90. p = 1 / (1 + np.exp(-(df.x * 2 + rs.randn(60))))
  91. df["c"] = [rs.binomial(1, p_i) for p_i in p]
  92. df.loc[[10, 20, 30], 'y_na'] = np.nan
  93. def test_variables_from_frame(self):
  94. p = lm._RegressionPlotter("x", "y", data=self.df, units="s")
  95. pdt.assert_series_equal(p.x, self.df.x)
  96. pdt.assert_series_equal(p.y, self.df.y)
  97. pdt.assert_series_equal(p.units, self.df.s)
  98. pdt.assert_frame_equal(p.data, self.df)
  99. def test_variables_from_series(self):
  100. p = lm._RegressionPlotter(self.df.x, self.df.y, units=self.df.s)
  101. npt.assert_array_equal(p.x, self.df.x)
  102. npt.assert_array_equal(p.y, self.df.y)
  103. npt.assert_array_equal(p.units, self.df.s)
  104. nt.assert_is(p.data, None)
  105. def test_variables_from_mix(self):
  106. p = lm._RegressionPlotter("x", self.df.y + 1, data=self.df)
  107. npt.assert_array_equal(p.x, self.df.x)
  108. npt.assert_array_equal(p.y, self.df.y + 1)
  109. pdt.assert_frame_equal(p.data, self.df)
  110. def test_variables_must_be_1d(self):
  111. array_2d = np.random.randn(20, 2)
  112. array_1d = np.random.randn(20)
  113. with pytest.raises(ValueError):
  114. lm._RegressionPlotter(array_2d, array_1d)
  115. with pytest.raises(ValueError):
  116. lm._RegressionPlotter(array_1d, array_2d)
  117. def test_dropna(self):
  118. p = lm._RegressionPlotter("x", "y_na", data=self.df)
  119. nt.assert_equal(len(p.x), pd.notnull(self.df.y_na).sum())
  120. p = lm._RegressionPlotter("x", "y_na", data=self.df, dropna=False)
  121. nt.assert_equal(len(p.x), len(self.df.y_na))
  122. @pytest.mark.parametrize("x,y",
  123. [([1.5], [2]),
  124. (np.array([1.5]), np.array([2])),
  125. (pd.Series(1.5), pd.Series(2))])
  126. def test_singleton(self, x, y):
  127. p = lm._RegressionPlotter(x, y)
  128. assert not p.fit_reg
  129. def test_ci(self):
  130. p = lm._RegressionPlotter("x", "y", data=self.df, ci=95)
  131. nt.assert_equal(p.ci, 95)
  132. nt.assert_equal(p.x_ci, 95)
  133. p = lm._RegressionPlotter("x", "y", data=self.df, ci=95, x_ci=68)
  134. nt.assert_equal(p.ci, 95)
  135. nt.assert_equal(p.x_ci, 68)
  136. p = lm._RegressionPlotter("x", "y", data=self.df, ci=95, x_ci="sd")
  137. nt.assert_equal(p.ci, 95)
  138. nt.assert_equal(p.x_ci, "sd")
  139. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  140. def test_fast_regression(self):
  141. p = lm._RegressionPlotter("x", "y", data=self.df, n_boot=self.n_boot)
  142. # Fit with the "fast" function, which just does linear algebra
  143. yhat_fast, _ = p.fit_fast(self.grid)
  144. # Fit using the statsmodels function with an OLS model
  145. yhat_smod, _ = p.fit_statsmodels(self.grid, smlm.OLS)
  146. # Compare the vector of y_hat values
  147. npt.assert_array_almost_equal(yhat_fast, yhat_smod)
  148. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  149. def test_regress_poly(self):
  150. p = lm._RegressionPlotter("x", "y", data=self.df, n_boot=self.n_boot)
  151. # Fit an first-order polynomial
  152. yhat_poly, _ = p.fit_poly(self.grid, 1)
  153. # Fit using the statsmodels function with an OLS model
  154. yhat_smod, _ = p.fit_statsmodels(self.grid, smlm.OLS)
  155. # Compare the vector of y_hat values
  156. npt.assert_array_almost_equal(yhat_poly, yhat_smod)
  157. def test_regress_logx(self):
  158. x = np.arange(1, 10)
  159. y = np.arange(1, 10)
  160. grid = np.linspace(1, 10, 100)
  161. p = lm._RegressionPlotter(x, y, n_boot=self.n_boot)
  162. yhat_lin, _ = p.fit_fast(grid)
  163. yhat_log, _ = p.fit_logx(grid)
  164. nt.assert_greater(yhat_lin[0], yhat_log[0])
  165. nt.assert_greater(yhat_log[20], yhat_lin[20])
  166. nt.assert_greater(yhat_lin[90], yhat_log[90])
  167. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  168. def test_regress_n_boot(self):
  169. p = lm._RegressionPlotter("x", "y", data=self.df, n_boot=self.n_boot)
  170. # Fast (linear algebra) version
  171. _, boots_fast = p.fit_fast(self.grid)
  172. npt.assert_equal(boots_fast.shape, (self.n_boot, self.grid.size))
  173. # Slower (np.polyfit) version
  174. _, boots_poly = p.fit_poly(self.grid, 1)
  175. npt.assert_equal(boots_poly.shape, (self.n_boot, self.grid.size))
  176. # Slowest (statsmodels) version
  177. _, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
  178. npt.assert_equal(boots_smod.shape, (self.n_boot, self.grid.size))
  179. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  180. def test_regress_without_bootstrap(self):
  181. p = lm._RegressionPlotter("x", "y", data=self.df,
  182. n_boot=self.n_boot, ci=None)
  183. # Fast (linear algebra) version
  184. _, boots_fast = p.fit_fast(self.grid)
  185. nt.assert_is(boots_fast, None)
  186. # Slower (np.polyfit) version
  187. _, boots_poly = p.fit_poly(self.grid, 1)
  188. nt.assert_is(boots_poly, None)
  189. # Slowest (statsmodels) version
  190. _, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
  191. nt.assert_is(boots_smod, None)
  192. def test_regress_bootstrap_seed(self):
  193. seed = 200
  194. p1 = lm._RegressionPlotter("x", "y", data=self.df,
  195. n_boot=self.n_boot, seed=seed)
  196. p2 = lm._RegressionPlotter("x", "y", data=self.df,
  197. n_boot=self.n_boot, seed=seed)
  198. _, boots1 = p1.fit_fast(self.grid)
  199. _, boots2 = p2.fit_fast(self.grid)
  200. npt.assert_array_equal(boots1, boots2)
  201. def test_numeric_bins(self):
  202. p = lm._RegressionPlotter(self.df.x, self.df.y)
  203. x_binned, bins = p.bin_predictor(self.bins_numeric)
  204. npt.assert_equal(len(bins), self.bins_numeric)
  205. npt.assert_array_equal(np.unique(x_binned), bins)
  206. def test_provided_bins(self):
  207. p = lm._RegressionPlotter(self.df.x, self.df.y)
  208. x_binned, bins = p.bin_predictor(self.bins_given)
  209. npt.assert_array_equal(np.unique(x_binned), self.bins_given)
  210. def test_bin_results(self):
  211. p = lm._RegressionPlotter(self.df.x, self.df.y)
  212. x_binned, bins = p.bin_predictor(self.bins_given)
  213. nt.assert_greater(self.df.x[x_binned == 0].min(),
  214. self.df.x[x_binned == -1].max())
  215. nt.assert_greater(self.df.x[x_binned == 1].min(),
  216. self.df.x[x_binned == 0].max())
  217. def test_scatter_data(self):
  218. p = lm._RegressionPlotter(self.df.x, self.df.y)
  219. x, y = p.scatter_data
  220. npt.assert_array_equal(x, self.df.x)
  221. npt.assert_array_equal(y, self.df.y)
  222. p = lm._RegressionPlotter(self.df.d, self.df.y)
  223. x, y = p.scatter_data
  224. npt.assert_array_equal(x, self.df.d)
  225. npt.assert_array_equal(y, self.df.y)
  226. p = lm._RegressionPlotter(self.df.d, self.df.y, x_jitter=.1)
  227. x, y = p.scatter_data
  228. nt.assert_true((x != self.df.d).any())
  229. npt.assert_array_less(np.abs(self.df.d - x), np.repeat(.1, len(x)))
  230. npt.assert_array_equal(y, self.df.y)
  231. p = lm._RegressionPlotter(self.df.d, self.df.y, y_jitter=.05)
  232. x, y = p.scatter_data
  233. npt.assert_array_equal(x, self.df.d)
  234. npt.assert_array_less(np.abs(self.df.y - y), np.repeat(.1, len(y)))
  235. def test_estimate_data(self):
  236. p = lm._RegressionPlotter(self.df.d, self.df.y, x_estimator=np.mean)
  237. x, y, ci = p.estimate_data
  238. npt.assert_array_equal(x, np.sort(np.unique(self.df.d)))
  239. npt.assert_array_almost_equal(y, self.df.groupby("d").y.mean())
  240. npt.assert_array_less(np.array(ci)[:, 0], y)
  241. npt.assert_array_less(y, np.array(ci)[:, 1])
  242. def test_estimate_cis(self):
  243. seed = 123
  244. p = lm._RegressionPlotter(self.df.d, self.df.y,
  245. x_estimator=np.mean, ci=95, seed=seed)
  246. _, _, ci_big = p.estimate_data
  247. p = lm._RegressionPlotter(self.df.d, self.df.y,
  248. x_estimator=np.mean, ci=50, seed=seed)
  249. _, _, ci_wee = p.estimate_data
  250. npt.assert_array_less(np.diff(ci_wee), np.diff(ci_big))
  251. p = lm._RegressionPlotter(self.df.d, self.df.y,
  252. x_estimator=np.mean, ci=None)
  253. _, _, ci_nil = p.estimate_data
  254. npt.assert_array_equal(ci_nil, [None] * len(ci_nil))
  255. def test_estimate_units(self):
  256. # Seed the RNG locally
  257. seed = 345
  258. p = lm._RegressionPlotter("x", "y", data=self.df,
  259. units="s", seed=seed, x_bins=3)
  260. _, _, ci_big = p.estimate_data
  261. ci_big = np.diff(ci_big, axis=1)
  262. p = lm._RegressionPlotter("x", "y", data=self.df, seed=seed, x_bins=3)
  263. _, _, ci_wee = p.estimate_data
  264. ci_wee = np.diff(ci_wee, axis=1)
  265. npt.assert_array_less(ci_wee, ci_big)
  266. def test_partial(self):
  267. x = self.rs.randn(100)
  268. y = x + self.rs.randn(100)
  269. z = x + self.rs.randn(100)
  270. p = lm._RegressionPlotter(y, z)
  271. _, r_orig = np.corrcoef(p.x, p.y)[0]
  272. p = lm._RegressionPlotter(y, z, y_partial=x)
  273. _, r_semipartial = np.corrcoef(p.x, p.y)[0]
  274. nt.assert_less(r_semipartial, r_orig)
  275. p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
  276. _, r_partial = np.corrcoef(p.x, p.y)[0]
  277. nt.assert_less(r_partial, r_orig)
  278. x = pd.Series(x)
  279. y = pd.Series(y)
  280. p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
  281. _, r_partial = np.corrcoef(p.x, p.y)[0]
  282. nt.assert_less(r_partial, r_orig)
  283. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  284. def test_logistic_regression(self):
  285. p = lm._RegressionPlotter("x", "c", data=self.df,
  286. logistic=True, n_boot=self.n_boot)
  287. _, yhat, _ = p.fit_regression(x_range=(-3, 3))
  288. npt.assert_array_less(yhat, 1)
  289. npt.assert_array_less(0, yhat)
  290. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  291. def test_logistic_perfect_separation(self):
  292. y = self.df.x > self.df.x.mean()
  293. p = lm._RegressionPlotter("x", y, data=self.df,
  294. logistic=True, n_boot=10)
  295. with np.errstate(all="ignore"):
  296. _, yhat, _ = p.fit_regression(x_range=(-3, 3))
  297. nt.assert_true(np.isnan(yhat).all())
  298. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  299. def test_robust_regression(self):
  300. p_ols = lm._RegressionPlotter("x", "y", data=self.df,
  301. n_boot=self.n_boot)
  302. _, ols_yhat, _ = p_ols.fit_regression(x_range=(-3, 3))
  303. p_robust = lm._RegressionPlotter("x", "y", data=self.df,
  304. robust=True, n_boot=self.n_boot)
  305. _, robust_yhat, _ = p_robust.fit_regression(x_range=(-3, 3))
  306. nt.assert_equal(len(ols_yhat), len(robust_yhat))
  307. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  308. def test_lowess_regression(self):
  309. p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True)
  310. grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3))
  311. nt.assert_equal(len(grid), len(yhat))
  312. nt.assert_is(err_bands, None)
  313. def test_regression_options(self):
  314. with nt.assert_raises(ValueError):
  315. lm._RegressionPlotter("x", "y", data=self.df,
  316. lowess=True, order=2)
  317. with nt.assert_raises(ValueError):
  318. lm._RegressionPlotter("x", "y", data=self.df,
  319. lowess=True, logistic=True)
  320. def test_regression_limits(self):
  321. f, ax = plt.subplots()
  322. ax.scatter(self.df.x, self.df.y)
  323. p = lm._RegressionPlotter("x", "y", data=self.df)
  324. grid, _, _ = p.fit_regression(ax)
  325. xlim = ax.get_xlim()
  326. nt.assert_equal(grid.min(), xlim[0])
  327. nt.assert_equal(grid.max(), xlim[1])
  328. p = lm._RegressionPlotter("x", "y", data=self.df, truncate=True)
  329. grid, _, _ = p.fit_regression()
  330. nt.assert_equal(grid.min(), self.df.x.min())
  331. nt.assert_equal(grid.max(), self.df.x.max())
  332. class TestRegressionPlots(object):
  333. rs = np.random.RandomState(56)
  334. df = pd.DataFrame(dict(x=rs.randn(90),
  335. y=rs.randn(90) + 5,
  336. z=rs.randint(0, 1, 90),
  337. g=np.repeat(list("abc"), 30),
  338. h=np.tile(list("xy"), 45),
  339. u=np.tile(np.arange(6), 15)))
  340. bw_err = rs.randn(6)[df.u.values]
  341. df.y += bw_err
  342. def test_regplot_basic(self):
  343. f, ax = plt.subplots()
  344. lm.regplot("x", "y", self.df)
  345. nt.assert_equal(len(ax.lines), 1)
  346. nt.assert_equal(len(ax.collections), 2)
  347. x, y = ax.collections[0].get_offsets().T
  348. npt.assert_array_equal(x, self.df.x)
  349. npt.assert_array_equal(y, self.df.y)
  350. def test_regplot_selective(self):
  351. f, ax = plt.subplots()
  352. ax = lm.regplot("x", "y", self.df, scatter=False, ax=ax)
  353. nt.assert_equal(len(ax.lines), 1)
  354. nt.assert_equal(len(ax.collections), 1)
  355. ax.clear()
  356. f, ax = plt.subplots()
  357. ax = lm.regplot("x", "y", self.df, fit_reg=False)
  358. nt.assert_equal(len(ax.lines), 0)
  359. nt.assert_equal(len(ax.collections), 1)
  360. ax.clear()
  361. f, ax = plt.subplots()
  362. ax = lm.regplot("x", "y", self.df, ci=None)
  363. nt.assert_equal(len(ax.lines), 1)
  364. nt.assert_equal(len(ax.collections), 1)
  365. ax.clear()
  366. def test_regplot_scatter_kws_alpha(self):
  367. f, ax = plt.subplots()
  368. color = np.array([[0.3, 0.8, 0.5, 0.5]])
  369. ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color})
  370. nt.assert_is(ax.collections[0]._alpha, None)
  371. nt.assert_equal(ax.collections[0]._facecolors[0, 3], 0.5)
  372. f, ax = plt.subplots()
  373. color = np.array([[0.3, 0.8, 0.5]])
  374. ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color})
  375. nt.assert_equal(ax.collections[0]._alpha, 0.8)
  376. f, ax = plt.subplots()
  377. color = np.array([[0.3, 0.8, 0.5]])
  378. ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color,
  379. 'alpha': 0.4})
  380. nt.assert_equal(ax.collections[0]._alpha, 0.4)
  381. f, ax = plt.subplots()
  382. color = 'r'
  383. ax = lm.regplot("x", "y", self.df, scatter_kws={'color': color})
  384. nt.assert_equal(ax.collections[0]._alpha, 0.8)
  385. def test_regplot_binned(self):
  386. ax = lm.regplot("x", "y", self.df, x_bins=5)
  387. nt.assert_equal(len(ax.lines), 6)
  388. nt.assert_equal(len(ax.collections), 2)
  389. def test_lmplot_basic(self):
  390. g = lm.lmplot("x", "y", self.df)
  391. ax = g.axes[0, 0]
  392. nt.assert_equal(len(ax.lines), 1)
  393. nt.assert_equal(len(ax.collections), 2)
  394. x, y = ax.collections[0].get_offsets().T
  395. npt.assert_array_equal(x, self.df.x)
  396. npt.assert_array_equal(y, self.df.y)
  397. def test_lmplot_hue(self):
  398. g = lm.lmplot("x", "y", data=self.df, hue="h")
  399. ax = g.axes[0, 0]
  400. nt.assert_equal(len(ax.lines), 2)
  401. nt.assert_equal(len(ax.collections), 4)
  402. def test_lmplot_markers(self):
  403. g1 = lm.lmplot("x", "y", data=self.df, hue="h", markers="s")
  404. nt.assert_equal(g1.hue_kws, {"marker": ["s", "s"]})
  405. g2 = lm.lmplot("x", "y", data=self.df, hue="h", markers=["o", "s"])
  406. nt.assert_equal(g2.hue_kws, {"marker": ["o", "s"]})
  407. with nt.assert_raises(ValueError):
  408. lm.lmplot("x", "y", data=self.df, hue="h", markers=["o", "s", "d"])
  409. def test_lmplot_marker_linewidths(self):
  410. g = lm.lmplot("x", "y", data=self.df, hue="h",
  411. fit_reg=False, markers=["o", "+"])
  412. c = g.axes[0, 0].collections
  413. nt.assert_equal(c[1].get_linewidths()[0],
  414. mpl.rcParams["lines.linewidth"])
  415. def test_lmplot_facets(self):
  416. g = lm.lmplot("x", "y", data=self.df, row="g", col="h")
  417. nt.assert_equal(g.axes.shape, (3, 2))
  418. g = lm.lmplot("x", "y", data=self.df, col="u", col_wrap=4)
  419. nt.assert_equal(g.axes.shape, (6,))
  420. g = lm.lmplot("x", "y", data=self.df, hue="h", col="u")
  421. nt.assert_equal(g.axes.shape, (1, 6))
  422. def test_lmplot_hue_col_nolegend(self):
  423. g = lm.lmplot("x", "y", data=self.df, col="h", hue="h")
  424. nt.assert_is(g._legend, None)
  425. def test_lmplot_scatter_kws(self):
  426. g = lm.lmplot("x", "y", hue="h", data=self.df, ci=None)
  427. red_scatter, blue_scatter = g.axes[0, 0].collections
  428. red, blue = color_palette(n_colors=2)
  429. npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3])
  430. npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3])
  431. def test_residplot(self):
  432. x, y = self.df.x, self.df.y
  433. ax = lm.residplot(x, y)
  434. resid = y - np.polyval(np.polyfit(x, y, 1), x)
  435. x_plot, y_plot = ax.collections[0].get_offsets().T
  436. npt.assert_array_equal(x, x_plot)
  437. npt.assert_array_almost_equal(resid, y_plot)
  438. @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
  439. def test_residplot_lowess(self):
  440. ax = lm.residplot("x", "y", self.df, lowess=True)
  441. nt.assert_equal(len(ax.lines), 2)
  442. x, y = ax.lines[1].get_xydata().T
  443. npt.assert_array_equal(x, np.sort(self.df.x))
  444. def test_three_point_colors(self):
  445. x, y = np.random.randn(2, 3)
  446. ax = lm.regplot(x, y, color=(1, 0, 0))
  447. color = ax.collections[0].get_facecolors()
  448. npt.assert_almost_equal(color[0, :3],
  449. (1, 0, 0))
  450. def test_regplot_xlim(self):
  451. f, ax = plt.subplots()
  452. x, y1, y2 = np.random.randn(3, 50)
  453. lm.regplot(x, y1, truncate=False)
  454. lm.regplot(x, y2, truncate=False)
  455. line1, line2 = ax.lines
  456. assert np.array_equal(line1.get_xdata(), line2.get_xdata())