test_collections.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720
  1. """
  2. Tests specific to the collections module.
  3. """
  4. import io
  5. import platform
  6. import numpy as np
  7. from numpy.testing import assert_array_equal, assert_array_almost_equal
  8. import pytest
  9. import matplotlib.pyplot as plt
  10. import matplotlib.collections as mcollections
  11. import matplotlib.transforms as mtransforms
  12. from matplotlib.collections import Collection, LineCollection, EventCollection
  13. from matplotlib.testing.decorators import image_comparison
  14. def generate_EventCollection_plot():
  15. '''
  16. generate the initial collection and plot it
  17. '''
  18. positions = np.array([0., 1., 2., 3., 5., 8., 13., 21.])
  19. extra_positions = np.array([34., 55., 89.])
  20. orientation = 'horizontal'
  21. lineoffset = 1
  22. linelength = .5
  23. linewidth = 2
  24. color = [1, 0, 0, 1]
  25. linestyle = 'solid'
  26. antialiased = True
  27. coll = EventCollection(positions,
  28. orientation=orientation,
  29. lineoffset=lineoffset,
  30. linelength=linelength,
  31. linewidth=linewidth,
  32. color=color,
  33. linestyle=linestyle,
  34. antialiased=antialiased
  35. )
  36. fig = plt.figure()
  37. splt = fig.add_subplot(1, 1, 1)
  38. splt.add_collection(coll)
  39. splt.set_title('EventCollection: default')
  40. props = {'positions': positions,
  41. 'extra_positions': extra_positions,
  42. 'orientation': orientation,
  43. 'lineoffset': lineoffset,
  44. 'linelength': linelength,
  45. 'linewidth': linewidth,
  46. 'color': color,
  47. 'linestyle': linestyle,
  48. 'antialiased': antialiased
  49. }
  50. splt.set_xlim(-1, 22)
  51. splt.set_ylim(0, 2)
  52. return splt, coll, props
  53. @image_comparison(['EventCollection_plot__default'])
  54. def test__EventCollection__get_segments():
  55. '''
  56. check to make sure the default segments have the correct coordinates
  57. '''
  58. _, coll, props = generate_EventCollection_plot()
  59. check_segments(coll,
  60. props['positions'],
  61. props['linelength'],
  62. props['lineoffset'],
  63. props['orientation'])
  64. def test__EventCollection__get_positions():
  65. '''
  66. check to make sure the default positions match the input positions
  67. '''
  68. _, coll, props = generate_EventCollection_plot()
  69. np.testing.assert_array_equal(props['positions'], coll.get_positions())
  70. def test__EventCollection__get_orientation():
  71. '''
  72. check to make sure the default orientation matches the input
  73. orientation
  74. '''
  75. _, coll, props = generate_EventCollection_plot()
  76. assert props['orientation'] == coll.get_orientation()
  77. def test__EventCollection__is_horizontal():
  78. '''
  79. check to make sure the default orientation matches the input
  80. orientation
  81. '''
  82. _, coll, _ = generate_EventCollection_plot()
  83. assert coll.is_horizontal()
  84. def test__EventCollection__get_linelength():
  85. '''
  86. check to make sure the default linelength matches the input linelength
  87. '''
  88. _, coll, props = generate_EventCollection_plot()
  89. assert props['linelength'] == coll.get_linelength()
  90. def test__EventCollection__get_lineoffset():
  91. '''
  92. check to make sure the default lineoffset matches the input lineoffset
  93. '''
  94. _, coll, props = generate_EventCollection_plot()
  95. assert props['lineoffset'] == coll.get_lineoffset()
  96. def test__EventCollection__get_linestyle():
  97. '''
  98. check to make sure the default linestyle matches the input linestyle
  99. '''
  100. _, coll, _ = generate_EventCollection_plot()
  101. assert coll.get_linestyle() == [(None, None)]
  102. def test__EventCollection__get_color():
  103. '''
  104. check to make sure the default color matches the input color
  105. '''
  106. _, coll, props = generate_EventCollection_plot()
  107. np.testing.assert_array_equal(props['color'], coll.get_color())
  108. check_allprop_array(coll.get_colors(), props['color'])
  109. @image_comparison(['EventCollection_plot__set_positions'])
  110. def test__EventCollection__set_positions():
  111. '''
  112. check to make sure set_positions works properly
  113. '''
  114. splt, coll, props = generate_EventCollection_plot()
  115. new_positions = np.hstack([props['positions'], props['extra_positions']])
  116. coll.set_positions(new_positions)
  117. np.testing.assert_array_equal(new_positions, coll.get_positions())
  118. check_segments(coll, new_positions,
  119. props['linelength'],
  120. props['lineoffset'],
  121. props['orientation'])
  122. splt.set_title('EventCollection: set_positions')
  123. splt.set_xlim(-1, 90)
  124. @image_comparison(['EventCollection_plot__add_positions'])
  125. def test__EventCollection__add_positions():
  126. '''
  127. check to make sure add_positions works properly
  128. '''
  129. splt, coll, props = generate_EventCollection_plot()
  130. new_positions = np.hstack([props['positions'],
  131. props['extra_positions'][0]])
  132. coll.add_positions(props['extra_positions'][0])
  133. np.testing.assert_array_equal(new_positions, coll.get_positions())
  134. check_segments(coll,
  135. new_positions,
  136. props['linelength'],
  137. props['lineoffset'],
  138. props['orientation'])
  139. splt.set_title('EventCollection: add_positions')
  140. splt.set_xlim(-1, 35)
  141. @image_comparison(['EventCollection_plot__append_positions'])
  142. def test__EventCollection__append_positions():
  143. '''
  144. check to make sure append_positions works properly
  145. '''
  146. splt, coll, props = generate_EventCollection_plot()
  147. new_positions = np.hstack([props['positions'],
  148. props['extra_positions'][2]])
  149. coll.append_positions(props['extra_positions'][2])
  150. np.testing.assert_array_equal(new_positions, coll.get_positions())
  151. check_segments(coll,
  152. new_positions,
  153. props['linelength'],
  154. props['lineoffset'],
  155. props['orientation'])
  156. splt.set_title('EventCollection: append_positions')
  157. splt.set_xlim(-1, 90)
  158. @image_comparison(['EventCollection_plot__extend_positions'])
  159. def test__EventCollection__extend_positions():
  160. '''
  161. check to make sure extend_positions works properly
  162. '''
  163. splt, coll, props = generate_EventCollection_plot()
  164. new_positions = np.hstack([props['positions'],
  165. props['extra_positions'][1:]])
  166. coll.extend_positions(props['extra_positions'][1:])
  167. np.testing.assert_array_equal(new_positions, coll.get_positions())
  168. check_segments(coll,
  169. new_positions,
  170. props['linelength'],
  171. props['lineoffset'],
  172. props['orientation'])
  173. splt.set_title('EventCollection: extend_positions')
  174. splt.set_xlim(-1, 90)
  175. @image_comparison(['EventCollection_plot__switch_orientation'])
  176. def test__EventCollection__switch_orientation():
  177. '''
  178. check to make sure switch_orientation works properly
  179. '''
  180. splt, coll, props = generate_EventCollection_plot()
  181. new_orientation = 'vertical'
  182. coll.switch_orientation()
  183. assert new_orientation == coll.get_orientation()
  184. assert not coll.is_horizontal()
  185. new_positions = coll.get_positions()
  186. check_segments(coll,
  187. new_positions,
  188. props['linelength'],
  189. props['lineoffset'], new_orientation)
  190. splt.set_title('EventCollection: switch_orientation')
  191. splt.set_ylim(-1, 22)
  192. splt.set_xlim(0, 2)
  193. @image_comparison(['EventCollection_plot__switch_orientation__2x'])
  194. def test__EventCollection__switch_orientation_2x():
  195. '''
  196. check to make sure calling switch_orientation twice sets the
  197. orientation back to the default
  198. '''
  199. splt, coll, props = generate_EventCollection_plot()
  200. coll.switch_orientation()
  201. coll.switch_orientation()
  202. new_positions = coll.get_positions()
  203. assert props['orientation'] == coll.get_orientation()
  204. assert coll.is_horizontal()
  205. np.testing.assert_array_equal(props['positions'], new_positions)
  206. check_segments(coll,
  207. new_positions,
  208. props['linelength'],
  209. props['lineoffset'],
  210. props['orientation'])
  211. splt.set_title('EventCollection: switch_orientation 2x')
  212. @image_comparison(['EventCollection_plot__set_orientation'])
  213. def test__EventCollection__set_orientation():
  214. '''
  215. check to make sure set_orientation works properly
  216. '''
  217. splt, coll, props = generate_EventCollection_plot()
  218. new_orientation = 'vertical'
  219. coll.set_orientation(new_orientation)
  220. assert new_orientation == coll.get_orientation()
  221. assert not coll.is_horizontal()
  222. check_segments(coll,
  223. props['positions'],
  224. props['linelength'],
  225. props['lineoffset'],
  226. new_orientation)
  227. splt.set_title('EventCollection: set_orientation')
  228. splt.set_ylim(-1, 22)
  229. splt.set_xlim(0, 2)
  230. @image_comparison(['EventCollection_plot__set_linelength'])
  231. def test__EventCollection__set_linelength():
  232. '''
  233. check to make sure set_linelength works properly
  234. '''
  235. splt, coll, props = generate_EventCollection_plot()
  236. new_linelength = 15
  237. coll.set_linelength(new_linelength)
  238. assert new_linelength == coll.get_linelength()
  239. check_segments(coll,
  240. props['positions'],
  241. new_linelength,
  242. props['lineoffset'],
  243. props['orientation'])
  244. splt.set_title('EventCollection: set_linelength')
  245. splt.set_ylim(-20, 20)
  246. @image_comparison(['EventCollection_plot__set_lineoffset'])
  247. def test__EventCollection__set_lineoffset():
  248. '''
  249. check to make sure set_lineoffset works properly
  250. '''
  251. splt, coll, props = generate_EventCollection_plot()
  252. new_lineoffset = -5.
  253. coll.set_lineoffset(new_lineoffset)
  254. assert new_lineoffset == coll.get_lineoffset()
  255. check_segments(coll,
  256. props['positions'],
  257. props['linelength'],
  258. new_lineoffset,
  259. props['orientation'])
  260. splt.set_title('EventCollection: set_lineoffset')
  261. splt.set_ylim(-6, -4)
  262. @image_comparison(['EventCollection_plot__set_linestyle'])
  263. def test__EventCollection__set_linestyle():
  264. '''
  265. check to make sure set_linestyle works properly
  266. '''
  267. splt, coll, _ = generate_EventCollection_plot()
  268. new_linestyle = 'dashed'
  269. coll.set_linestyle(new_linestyle)
  270. assert coll.get_linestyle() == [(0, (6.0, 6.0))]
  271. splt.set_title('EventCollection: set_linestyle')
  272. @image_comparison(['EventCollection_plot__set_ls_dash'], remove_text=True)
  273. def test__EventCollection__set_linestyle_single_dash():
  274. '''
  275. check to make sure set_linestyle accepts a single dash pattern
  276. '''
  277. splt, coll, _ = generate_EventCollection_plot()
  278. new_linestyle = (0, (6., 6.))
  279. coll.set_linestyle(new_linestyle)
  280. assert coll.get_linestyle() == [(0, (6.0, 6.0))]
  281. splt.set_title('EventCollection: set_linestyle')
  282. @image_comparison(['EventCollection_plot__set_linewidth'])
  283. def test__EventCollection__set_linewidth():
  284. '''
  285. check to make sure set_linestyle works properly
  286. '''
  287. splt, coll, _ = generate_EventCollection_plot()
  288. new_linewidth = 5
  289. coll.set_linewidth(new_linewidth)
  290. assert coll.get_linewidth() == new_linewidth
  291. splt.set_title('EventCollection: set_linewidth')
  292. @image_comparison(['EventCollection_plot__set_color'])
  293. def test__EventCollection__set_color():
  294. '''
  295. check to make sure set_color works properly
  296. '''
  297. splt, coll, _ = generate_EventCollection_plot()
  298. new_color = np.array([0, 1, 1, 1])
  299. coll.set_color(new_color)
  300. np.testing.assert_array_equal(new_color, coll.get_color())
  301. check_allprop_array(coll.get_colors(), new_color)
  302. splt.set_title('EventCollection: set_color')
  303. def check_segments(coll, positions, linelength, lineoffset, orientation):
  304. '''
  305. check to make sure all values in the segment are correct, given a
  306. particular set of inputs
  307. note: this is not a test, it is used by tests
  308. '''
  309. segments = coll.get_segments()
  310. if (orientation.lower() == 'horizontal'
  311. or orientation.lower() == 'none' or orientation is None):
  312. # if horizontal, the position in is in the y-axis
  313. pos1 = 1
  314. pos2 = 0
  315. elif orientation.lower() == 'vertical':
  316. # if vertical, the position in is in the x-axis
  317. pos1 = 0
  318. pos2 = 1
  319. else:
  320. raise ValueError("orientation must be 'horizontal' or 'vertical'")
  321. # test to make sure each segment is correct
  322. for i, segment in enumerate(segments):
  323. assert segment[0, pos1] == lineoffset + linelength / 2
  324. assert segment[1, pos1] == lineoffset - linelength / 2
  325. assert segment[0, pos2] == positions[i]
  326. assert segment[1, pos2] == positions[i]
  327. def check_allprop_array(values, target):
  328. '''
  329. check to make sure all values match the given target if arrays
  330. note: this is not a test, it is used by tests
  331. '''
  332. for value in values:
  333. np.testing.assert_array_equal(value, target)
  334. def test_null_collection_datalim():
  335. col = mcollections.PathCollection([])
  336. col_data_lim = col.get_datalim(mtransforms.IdentityTransform())
  337. assert_array_equal(col_data_lim.get_points(),
  338. mtransforms.Bbox.null().get_points())
  339. def test_add_collection():
  340. # Test if data limits are unchanged by adding an empty collection.
  341. # GitHub issue #1490, pull #1497.
  342. plt.figure()
  343. ax = plt.axes()
  344. coll = ax.scatter([0, 1], [0, 1])
  345. ax.add_collection(coll)
  346. bounds = ax.dataLim.bounds
  347. coll = ax.scatter([], [])
  348. assert ax.dataLim.bounds == bounds
  349. def test_quiver_limits():
  350. ax = plt.axes()
  351. x, y = np.arange(8), np.arange(10)
  352. u = v = np.linspace(0, 10, 80).reshape(10, 8)
  353. q = plt.quiver(x, y, u, v)
  354. assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.)
  355. plt.figure()
  356. ax = plt.axes()
  357. x = np.linspace(-5, 10, 20)
  358. y = np.linspace(-2, 4, 10)
  359. y, x = np.meshgrid(y, x)
  360. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  361. plt.quiver(x, y, np.sin(x), np.cos(y), transform=trans)
  362. assert ax.dataLim.bounds == (20.0, 30.0, 15.0, 6.0)
  363. def test_barb_limits():
  364. ax = plt.axes()
  365. x = np.linspace(-5, 10, 20)
  366. y = np.linspace(-2, 4, 10)
  367. y, x = np.meshgrid(y, x)
  368. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  369. plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
  370. # The calculated bounds are approximately the bounds of the original data,
  371. # this is because the entire path is taken into account when updating the
  372. # datalim.
  373. assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
  374. decimal=1)
  375. @image_comparison(['EllipseCollection_test_image.png'], remove_text=True,
  376. tol={'aarch64': 0.02}.get(platform.machine(), 0.0))
  377. def test_EllipseCollection():
  378. # Test basic functionality
  379. fig, ax = plt.subplots()
  380. x = np.arange(4)
  381. y = np.arange(3)
  382. X, Y = np.meshgrid(x, y)
  383. XY = np.vstack((X.ravel(), Y.ravel())).T
  384. ww = X / x[-1]
  385. hh = Y / y[-1]
  386. aa = np.ones_like(ww) * 20 # first axis is 20 degrees CCW from x axis
  387. ec = mcollections.EllipseCollection(ww, hh, aa,
  388. units='x',
  389. offsets=XY,
  390. transOffset=ax.transData,
  391. facecolors='none')
  392. ax.add_collection(ec)
  393. ax.autoscale_view()
  394. @image_comparison(['polycollection_close.png'], remove_text=True)
  395. def test_polycollection_close():
  396. from mpl_toolkits.mplot3d import Axes3D
  397. vertsQuad = [
  398. [[0., 0.], [0., 1.], [1., 1.], [1., 0.]],
  399. [[0., 1.], [2., 3.], [2., 2.], [1., 1.]],
  400. [[2., 2.], [2., 3.], [4., 1.], [3., 1.]],
  401. [[3., 0.], [3., 1.], [4., 1.], [4., 0.]]]
  402. fig = plt.figure()
  403. ax = Axes3D(fig)
  404. colors = ['r', 'g', 'b', 'y', 'k']
  405. zpos = list(range(5))
  406. poly = mcollections.PolyCollection(
  407. vertsQuad * len(zpos), linewidth=0.25)
  408. poly.set_alpha(0.7)
  409. # need to have a z-value for *each* polygon = element!
  410. zs = []
  411. cs = []
  412. for z, c in zip(zpos, colors):
  413. zs.extend([z] * len(vertsQuad))
  414. cs.extend([c] * len(vertsQuad))
  415. poly.set_color(cs)
  416. ax.add_collection3d(poly, zs=zs, zdir='y')
  417. # axis limit settings:
  418. ax.set_xlim3d(0, 4)
  419. ax.set_zlim3d(0, 3)
  420. ax.set_ylim3d(0, 4)
  421. @image_comparison(['regularpolycollection_rotate.png'], remove_text=True)
  422. def test_regularpolycollection_rotate():
  423. xx, yy = np.mgrid[:10, :10]
  424. xy_points = np.transpose([xx.flatten(), yy.flatten()])
  425. rotations = np.linspace(0, 2*np.pi, len(xy_points))
  426. fig, ax = plt.subplots()
  427. for xy, alpha in zip(xy_points, rotations):
  428. col = mcollections.RegularPolyCollection(
  429. 4, sizes=(100,), rotation=alpha,
  430. offsets=[xy], transOffset=ax.transData)
  431. ax.add_collection(col, autolim=True)
  432. ax.autoscale_view()
  433. @image_comparison(['regularpolycollection_scale.png'], remove_text=True)
  434. def test_regularpolycollection_scale():
  435. # See issue #3860
  436. class SquareCollection(mcollections.RegularPolyCollection):
  437. def __init__(self, **kwargs):
  438. super().__init__(4, rotation=np.pi/4., **kwargs)
  439. def get_transform(self):
  440. """Return transform scaling circle areas to data space."""
  441. ax = self.axes
  442. pts2pixels = 72.0 / ax.figure.dpi
  443. scale_x = pts2pixels * ax.bbox.width / ax.viewLim.width
  444. scale_y = pts2pixels * ax.bbox.height / ax.viewLim.height
  445. return mtransforms.Affine2D().scale(scale_x, scale_y)
  446. fig, ax = plt.subplots()
  447. xy = [(0, 0)]
  448. # Unit square has a half-diagonal of `1 / sqrt(2)`, so `pi * r**2`
  449. # equals...
  450. circle_areas = [np.pi / 2]
  451. squares = SquareCollection(sizes=circle_areas, offsets=xy,
  452. transOffset=ax.transData)
  453. ax.add_collection(squares, autolim=True)
  454. ax.axis([-1, 1, -1, 1])
  455. def test_picking():
  456. fig, ax = plt.subplots()
  457. col = ax.scatter([0], [0], [1000], picker=True)
  458. fig.savefig(io.BytesIO(), dpi=fig.dpi)
  459. class MouseEvent:
  460. pass
  461. event = MouseEvent()
  462. event.x = 325
  463. event.y = 240
  464. found, indices = col.contains(event)
  465. assert found
  466. assert_array_equal(indices['ind'], [0])
  467. def test_linestyle_single_dashes():
  468. plt.scatter([0, 1, 2], [0, 1, 2], linestyle=(0., [2., 2.]))
  469. plt.draw()
  470. @image_comparison(['size_in_xy.png'], remove_text=True)
  471. def test_size_in_xy():
  472. fig, ax = plt.subplots()
  473. widths, heights, angles = (10, 10), 10, 0
  474. widths = 10, 10
  475. coords = [(10, 10), (15, 15)]
  476. e = mcollections.EllipseCollection(
  477. widths, heights, angles,
  478. units='xy',
  479. offsets=coords,
  480. transOffset=ax.transData)
  481. ax.add_collection(e)
  482. ax.set_xlim(0, 30)
  483. ax.set_ylim(0, 30)
  484. def test_pandas_indexing(pd):
  485. # Should not fail break when faced with a
  486. # non-zero indexed series
  487. index = [11, 12, 13]
  488. ec = fc = pd.Series(['red', 'blue', 'green'], index=index)
  489. lw = pd.Series([1, 2, 3], index=index)
  490. ls = pd.Series(['solid', 'dashed', 'dashdot'], index=index)
  491. aa = pd.Series([True, False, True], index=index)
  492. Collection(edgecolors=ec)
  493. Collection(facecolors=fc)
  494. Collection(linewidths=lw)
  495. Collection(linestyles=ls)
  496. Collection(antialiaseds=aa)
  497. @pytest.mark.style('default')
  498. def test_lslw_bcast():
  499. col = mcollections.PathCollection([])
  500. col.set_linestyles(['-', '-'])
  501. col.set_linewidths([1, 2, 3])
  502. assert col.get_linestyles() == [(None, None)] * 6
  503. assert col.get_linewidths() == [1, 2, 3] * 2
  504. col.set_linestyles(['-', '-', '-'])
  505. assert col.get_linestyles() == [(None, None)] * 3
  506. assert (col.get_linewidths() == [1, 2, 3]).all()
  507. @pytest.mark.style('default')
  508. def test_capstyle():
  509. col = mcollections.PathCollection([], capstyle='round')
  510. assert col.get_capstyle() == 'round'
  511. col.set_capstyle('butt')
  512. assert col.get_capstyle() == 'butt'
  513. @pytest.mark.style('default')
  514. def test_joinstyle():
  515. col = mcollections.PathCollection([], joinstyle='round')
  516. assert col.get_joinstyle() == 'round'
  517. col.set_joinstyle('miter')
  518. assert col.get_joinstyle() == 'miter'
  519. @image_comparison(['cap_and_joinstyle.png'])
  520. def test_cap_and_joinstyle_image():
  521. fig = plt.figure()
  522. ax = fig.add_subplot(1, 1, 1)
  523. ax.set_xlim([-0.5, 1.5])
  524. ax.set_ylim([-0.5, 2.5])
  525. x = np.array([0.0, 1.0, 0.5])
  526. ys = np.array([[0.0], [0.5], [1.0]]) + np.array([[0.0, 0.0, 1.0]])
  527. segs = np.zeros((3, 3, 2))
  528. segs[:, :, 0] = x
  529. segs[:, :, 1] = ys
  530. line_segments = LineCollection(segs, linewidth=[10, 15, 20])
  531. line_segments.set_capstyle("round")
  532. line_segments.set_joinstyle("miter")
  533. ax.add_collection(line_segments)
  534. ax.set_title('Line collection with customized caps and joinstyle')
  535. @image_comparison(['scatter_post_alpha.png'],
  536. remove_text=True, style='default')
  537. def test_scatter_post_alpha():
  538. fig, ax = plt.subplots()
  539. sc = ax.scatter(range(5), range(5), c=range(5))
  540. # this needs to be here to update internal state
  541. fig.canvas.draw()
  542. sc.set_alpha(.1)
  543. def test_pathcollection_legend_elements():
  544. np.random.seed(19680801)
  545. x, y = np.random.rand(2, 10)
  546. y = np.random.rand(10)
  547. c = np.random.randint(0, 5, size=10)
  548. s = np.random.randint(10, 300, size=10)
  549. fig, ax = plt.subplots()
  550. sc = ax.scatter(x, y, c=c, s=s, cmap="jet", marker="o", linewidths=0)
  551. h, l = sc.legend_elements(fmt="{x:g}")
  552. assert len(h) == 5
  553. assert_array_equal(np.array(l).astype(float), np.arange(5))
  554. colors = np.array([line.get_color() for line in h])
  555. colors2 = sc.cmap(np.arange(5)/4)
  556. assert_array_equal(colors, colors2)
  557. l1 = ax.legend(h, l, loc=1)
  558. h2, lab2 = sc.legend_elements(num=9)
  559. assert len(h2) == 9
  560. l2 = ax.legend(h2, lab2, loc=2)
  561. h, l = sc.legend_elements(prop="sizes", alpha=0.5, color="red")
  562. alpha = np.array([line.get_alpha() for line in h])
  563. assert_array_equal(alpha, 0.5)
  564. color = np.array([line.get_markerfacecolor() for line in h])
  565. assert_array_equal(color, "red")
  566. l3 = ax.legend(h, l, loc=4)
  567. h, l = sc.legend_elements(prop="sizes", num=4, fmt="{x:.2f}",
  568. func=lambda x: 2*x)
  569. actsizes = [line.get_markersize() for line in h]
  570. labeledsizes = np.sqrt(np.array(l).astype(float)/2)
  571. assert_array_almost_equal(actsizes, labeledsizes)
  572. l4 = ax.legend(h, l, loc=3)
  573. import matplotlib.ticker as mticker
  574. loc = mticker.MaxNLocator(nbins=9, min_n_ticks=9-1,
  575. steps=[1, 2, 2.5, 3, 5, 6, 8, 10])
  576. h5, lab5 = sc.legend_elements(num=loc)
  577. assert len(h2) == len(h5)
  578. levels = [-1, 0, 55.4, 260]
  579. h6, lab6 = sc.legend_elements(num=levels, prop="sizes", fmt="{x:g}")
  580. assert_array_equal(np.array(lab6).astype(float), levels[2:])
  581. for l in [l1, l2, l3, l4]:
  582. ax.add_artist(l)
  583. fig.canvas.draw()
  584. def test_EventCollection_nosort():
  585. # Check that EventCollection doesn't modify input in place
  586. arr = np.array([3, 2, 1, 10])
  587. coll = EventCollection(arr)
  588. np.testing.assert_array_equal(arr, np.array([3, 2, 1, 10]))