test_palettes.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. import colorsys
  2. import numpy as np
  3. import matplotlib as mpl
  4. import pytest
  5. import nose.tools as nt
  6. import numpy.testing as npt
  7. from .. import palettes, utils, rcmod
  8. from ..external import husl
  9. from ..colors import xkcd_rgb, crayons
  10. class TestColorPalettes(object):
  11. def test_current_palette(self):
  12. pal = palettes.color_palette(["red", "blue", "green"])
  13. rcmod.set_palette(pal)
  14. assert pal == utils.get_color_cycle()
  15. rcmod.set()
  16. def test_palette_context(self):
  17. default_pal = palettes.color_palette()
  18. context_pal = palettes.color_palette("muted")
  19. with palettes.color_palette(context_pal):
  20. nt.assert_equal(utils.get_color_cycle(), context_pal)
  21. nt.assert_equal(utils.get_color_cycle(), default_pal)
  22. def test_big_palette_context(self):
  23. original_pal = palettes.color_palette("deep", n_colors=8)
  24. context_pal = palettes.color_palette("husl", 10)
  25. rcmod.set_palette(original_pal)
  26. with palettes.color_palette(context_pal, 10):
  27. nt.assert_equal(utils.get_color_cycle(), context_pal)
  28. nt.assert_equal(utils.get_color_cycle(), original_pal)
  29. # Reset default
  30. rcmod.set()
  31. def test_palette_size(self):
  32. pal = palettes.color_palette("deep")
  33. assert len(pal) == palettes.QUAL_PALETTE_SIZES["deep"]
  34. pal = palettes.color_palette("pastel6")
  35. assert len(pal) == palettes.QUAL_PALETTE_SIZES["pastel6"]
  36. pal = palettes.color_palette("Set3")
  37. assert len(pal) == palettes.QUAL_PALETTE_SIZES["Set3"]
  38. pal = palettes.color_palette("husl")
  39. assert len(pal) == 6
  40. pal = palettes.color_palette("Greens")
  41. assert len(pal) == 6
  42. def test_seaborn_palettes(self):
  43. pals = "deep", "muted", "pastel", "bright", "dark", "colorblind"
  44. for name in pals:
  45. full = palettes.color_palette(name, 10).as_hex()
  46. short = palettes.color_palette(name + "6", 6).as_hex()
  47. b, _, g, r, m, _, _, _, y, c = full
  48. assert [b, g, r, m, y, c] == list(short)
  49. def test_hls_palette(self):
  50. hls_pal1 = palettes.hls_palette()
  51. hls_pal2 = palettes.color_palette("hls")
  52. npt.assert_array_equal(hls_pal1, hls_pal2)
  53. def test_husl_palette(self):
  54. husl_pal1 = palettes.husl_palette()
  55. husl_pal2 = palettes.color_palette("husl")
  56. npt.assert_array_equal(husl_pal1, husl_pal2)
  57. def test_mpl_palette(self):
  58. mpl_pal1 = palettes.mpl_palette("Reds")
  59. mpl_pal2 = palettes.color_palette("Reds")
  60. npt.assert_array_equal(mpl_pal1, mpl_pal2)
  61. def test_mpl_dark_palette(self):
  62. mpl_pal1 = palettes.mpl_palette("Blues_d")
  63. mpl_pal2 = palettes.color_palette("Blues_d")
  64. npt.assert_array_equal(mpl_pal1, mpl_pal2)
  65. def test_bad_palette_name(self):
  66. with nt.assert_raises(ValueError):
  67. palettes.color_palette("IAmNotAPalette")
  68. def test_terrible_palette_name(self):
  69. with nt.assert_raises(ValueError):
  70. palettes.color_palette("jet")
  71. def test_bad_palette_colors(self):
  72. pal = ["red", "blue", "iamnotacolor"]
  73. with nt.assert_raises(ValueError):
  74. palettes.color_palette(pal)
  75. def test_palette_desat(self):
  76. pal1 = palettes.husl_palette(6)
  77. pal1 = [utils.desaturate(c, .5) for c in pal1]
  78. pal2 = palettes.color_palette("husl", desat=.5)
  79. npt.assert_array_equal(pal1, pal2)
  80. def test_palette_is_list_of_tuples(self):
  81. pal_in = np.array(["red", "blue", "green"])
  82. pal_out = palettes.color_palette(pal_in, 3)
  83. nt.assert_is_instance(pal_out, list)
  84. nt.assert_is_instance(pal_out[0], tuple)
  85. nt.assert_is_instance(pal_out[0][0], float)
  86. nt.assert_equal(len(pal_out[0]), 3)
  87. def test_palette_cycles(self):
  88. deep = palettes.color_palette("deep6")
  89. double_deep = palettes.color_palette("deep6", 12)
  90. nt.assert_equal(double_deep, deep + deep)
  91. def test_hls_values(self):
  92. pal1 = palettes.hls_palette(6, h=0)
  93. pal2 = palettes.hls_palette(6, h=.5)
  94. pal2 = pal2[3:] + pal2[:3]
  95. npt.assert_array_almost_equal(pal1, pal2)
  96. pal_dark = palettes.hls_palette(5, l=.2) # noqa
  97. pal_bright = palettes.hls_palette(5, l=.8) # noqa
  98. npt.assert_array_less(list(map(sum, pal_dark)),
  99. list(map(sum, pal_bright)))
  100. pal_flat = palettes.hls_palette(5, s=.1)
  101. pal_bold = palettes.hls_palette(5, s=.9)
  102. npt.assert_array_less(list(map(np.std, pal_flat)),
  103. list(map(np.std, pal_bold)))
  104. def test_husl_values(self):
  105. pal1 = palettes.husl_palette(6, h=0)
  106. pal2 = palettes.husl_palette(6, h=.5)
  107. pal2 = pal2[3:] + pal2[:3]
  108. npt.assert_array_almost_equal(pal1, pal2)
  109. pal_dark = palettes.husl_palette(5, l=.2) # noqa
  110. pal_bright = palettes.husl_palette(5, l=.8) # noqa
  111. npt.assert_array_less(list(map(sum, pal_dark)),
  112. list(map(sum, pal_bright)))
  113. pal_flat = palettes.husl_palette(5, s=.1)
  114. pal_bold = palettes.husl_palette(5, s=.9)
  115. npt.assert_array_less(list(map(np.std, pal_flat)),
  116. list(map(np.std, pal_bold)))
  117. def test_cbrewer_qual(self):
  118. pal_short = palettes.mpl_palette("Set1", 4)
  119. pal_long = palettes.mpl_palette("Set1", 6)
  120. nt.assert_equal(pal_short, pal_long[:4])
  121. pal_full = palettes.mpl_palette("Set2", 8)
  122. pal_long = palettes.mpl_palette("Set2", 10)
  123. nt.assert_equal(pal_full, pal_long[:8])
  124. def test_mpl_reversal(self):
  125. pal_forward = palettes.mpl_palette("BuPu", 6)
  126. pal_reverse = palettes.mpl_palette("BuPu_r", 6)
  127. npt.assert_array_almost_equal(pal_forward, pal_reverse[::-1])
  128. def test_rgb_from_hls(self):
  129. color = .5, .8, .4
  130. rgb_got = palettes._color_to_rgb(color, "hls")
  131. rgb_want = colorsys.hls_to_rgb(*color)
  132. nt.assert_equal(rgb_got, rgb_want)
  133. def test_rgb_from_husl(self):
  134. color = 120, 50, 40
  135. rgb_got = palettes._color_to_rgb(color, "husl")
  136. rgb_want = tuple(husl.husl_to_rgb(*color))
  137. assert rgb_got == rgb_want
  138. for h in range(0, 360):
  139. color = h, 100, 100
  140. rgb = palettes._color_to_rgb(color, "husl")
  141. assert min(rgb) >= 0
  142. assert max(rgb) <= 1
  143. def test_rgb_from_xkcd(self):
  144. color = "dull red"
  145. rgb_got = palettes._color_to_rgb(color, "xkcd")
  146. rgb_want = xkcd_rgb[color]
  147. nt.assert_equal(rgb_got, rgb_want)
  148. def test_light_palette(self):
  149. pal_forward = palettes.light_palette("red")
  150. pal_reverse = palettes.light_palette("red", reverse=True)
  151. assert np.allclose(pal_forward, pal_reverse[::-1])
  152. red = mpl.colors.colorConverter.to_rgb("red")
  153. nt.assert_equal(pal_forward[-1], red)
  154. pal_cmap = palettes.light_palette("blue", as_cmap=True)
  155. nt.assert_is_instance(pal_cmap, mpl.colors.LinearSegmentedColormap)
  156. def test_dark_palette(self):
  157. pal_forward = palettes.dark_palette("red")
  158. pal_reverse = palettes.dark_palette("red", reverse=True)
  159. assert np.allclose(pal_forward, pal_reverse[::-1])
  160. red = mpl.colors.colorConverter.to_rgb("red")
  161. assert pal_forward[-1] == red
  162. pal_cmap = palettes.dark_palette("blue", as_cmap=True)
  163. assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
  164. def test_diverging_palette(self):
  165. h_neg, h_pos = 100, 200
  166. sat, lum = 70, 50
  167. args = h_neg, h_pos, sat, lum
  168. n = 12
  169. pal = palettes.diverging_palette(*args, n=n)
  170. neg_pal = palettes.light_palette((h_neg, sat, lum), int(n // 2),
  171. input="husl")
  172. pos_pal = palettes.light_palette((h_pos, sat, lum), int(n // 2),
  173. input="husl")
  174. assert len(pal) == n
  175. assert pal[0] == neg_pal[-1]
  176. assert pal[-1] == pos_pal[-1]
  177. pal_dark = palettes.diverging_palette(*args, n=n, center="dark")
  178. assert np.mean(pal[int(n / 2)]) > np.mean(pal_dark[int(n / 2)])
  179. pal_cmap = palettes.diverging_palette(*args, as_cmap=True)
  180. assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)
  181. def test_blend_palette(self):
  182. colors = ["red", "yellow", "white"]
  183. pal_cmap = palettes.blend_palette(colors, as_cmap=True)
  184. nt.assert_is_instance(pal_cmap, mpl.colors.LinearSegmentedColormap)
  185. def test_cubehelix_against_matplotlib(self):
  186. x = np.linspace(0, 1, 8)
  187. mpl_pal = mpl.cm.cubehelix(x)[:, :3].tolist()
  188. sns_pal = palettes.cubehelix_palette(8, start=0.5, rot=-1.5, hue=1,
  189. dark=0, light=1, reverse=True)
  190. nt.assert_list_equal(sns_pal, mpl_pal)
  191. def test_cubehelix_n_colors(self):
  192. for n in [3, 5, 8]:
  193. pal = palettes.cubehelix_palette(n)
  194. nt.assert_equal(len(pal), n)
  195. def test_cubehelix_reverse(self):
  196. pal_forward = palettes.cubehelix_palette()
  197. pal_reverse = palettes.cubehelix_palette(reverse=True)
  198. nt.assert_list_equal(pal_forward, pal_reverse[::-1])
  199. def test_cubehelix_cmap(self):
  200. cmap = palettes.cubehelix_palette(as_cmap=True)
  201. nt.assert_is_instance(cmap, mpl.colors.ListedColormap)
  202. pal = palettes.cubehelix_palette()
  203. x = np.linspace(0, 1, 6)
  204. npt.assert_array_equal(cmap(x)[:, :3], pal)
  205. cmap_rev = palettes.cubehelix_palette(as_cmap=True, reverse=True)
  206. x = np.linspace(0, 1, 6)
  207. pal_forward = cmap(x).tolist()
  208. pal_reverse = cmap_rev(x[::-1]).tolist()
  209. nt.assert_list_equal(pal_forward, pal_reverse)
  210. def test_cubehelix_code(self):
  211. color_palette = palettes.color_palette
  212. cubehelix_palette = palettes.cubehelix_palette
  213. pal1 = color_palette("ch:", 8)
  214. pal2 = color_palette(cubehelix_palette(8))
  215. assert pal1 == pal2
  216. pal1 = color_palette("ch:.5, -.25,hue = .5,light=.75", 8)
  217. pal2 = color_palette(cubehelix_palette(8, .5, -.25, hue=.5, light=.75))
  218. assert pal1 == pal2
  219. pal1 = color_palette("ch:h=1,r=.5", 9)
  220. pal2 = color_palette(cubehelix_palette(9, hue=1, rot=.5))
  221. assert pal1 == pal2
  222. pal1 = color_palette("ch:_r", 6)
  223. pal2 = color_palette(cubehelix_palette(6, reverse=True))
  224. assert pal1 == pal2
  225. def test_xkcd_palette(self):
  226. names = list(xkcd_rgb.keys())[10:15]
  227. colors = palettes.xkcd_palette(names)
  228. for name, color in zip(names, colors):
  229. as_hex = mpl.colors.rgb2hex(color)
  230. nt.assert_equal(as_hex, xkcd_rgb[name])
  231. def test_crayon_palette(self):
  232. names = list(crayons.keys())[10:15]
  233. colors = palettes.crayon_palette(names)
  234. for name, color in zip(names, colors):
  235. as_hex = mpl.colors.rgb2hex(color)
  236. nt.assert_equal(as_hex, crayons[name].lower())
  237. def test_color_codes(self):
  238. palettes.set_color_codes("deep")
  239. colors = palettes.color_palette("deep6") + [".1"]
  240. for code, color in zip("bgrmyck", colors):
  241. rgb_want = mpl.colors.colorConverter.to_rgb(color)
  242. rgb_got = mpl.colors.colorConverter.to_rgb(code)
  243. nt.assert_equal(rgb_want, rgb_got)
  244. palettes.set_color_codes("reset")
  245. with pytest.raises(ValueError):
  246. palettes.set_color_codes("Set1")
  247. def test_as_hex(self):
  248. pal = palettes.color_palette("deep")
  249. for rgb, hex in zip(pal, pal.as_hex()):
  250. nt.assert_equal(mpl.colors.rgb2hex(rgb), hex)
  251. def test_preserved_palette_length(self):
  252. pal_in = palettes.color_palette("Set1", 10)
  253. pal_out = palettes.color_palette(pal_in)
  254. nt.assert_equal(pal_in, pal_out)