test_utils.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597
  1. from __future__ import division, absolute_import, print_function
  2. import warnings
  3. import sys
  4. import os
  5. import itertools
  6. import textwrap
  7. import pytest
  8. import weakref
  9. import numpy as np
  10. from numpy.testing import (
  11. assert_equal, assert_array_equal, assert_almost_equal,
  12. assert_array_almost_equal, assert_array_less, build_err_msg, raises,
  13. assert_raises, assert_warns, assert_no_warnings, assert_allclose,
  14. assert_approx_equal, assert_array_almost_equal_nulp, assert_array_max_ulp,
  15. clear_and_catch_warnings, suppress_warnings, assert_string_equal, assert_,
  16. tempdir, temppath, assert_no_gc_cycles, HAS_REFCOUNT
  17. )
  18. from numpy.core.overrides import ARRAY_FUNCTION_ENABLED
  19. class _GenericTest(object):
  20. def _test_equal(self, a, b):
  21. self._assert_func(a, b)
  22. def _test_not_equal(self, a, b):
  23. with assert_raises(AssertionError):
  24. self._assert_func(a, b)
  25. def test_array_rank1_eq(self):
  26. """Test two equal array of rank 1 are found equal."""
  27. a = np.array([1, 2])
  28. b = np.array([1, 2])
  29. self._test_equal(a, b)
  30. def test_array_rank1_noteq(self):
  31. """Test two different array of rank 1 are found not equal."""
  32. a = np.array([1, 2])
  33. b = np.array([2, 2])
  34. self._test_not_equal(a, b)
  35. def test_array_rank2_eq(self):
  36. """Test two equal array of rank 2 are found equal."""
  37. a = np.array([[1, 2], [3, 4]])
  38. b = np.array([[1, 2], [3, 4]])
  39. self._test_equal(a, b)
  40. def test_array_diffshape(self):
  41. """Test two arrays with different shapes are found not equal."""
  42. a = np.array([1, 2])
  43. b = np.array([[1, 2], [1, 2]])
  44. self._test_not_equal(a, b)
  45. def test_objarray(self):
  46. """Test object arrays."""
  47. a = np.array([1, 1], dtype=object)
  48. self._test_equal(a, 1)
  49. def test_array_likes(self):
  50. self._test_equal([1, 2, 3], (1, 2, 3))
  51. class TestArrayEqual(_GenericTest):
  52. def setup(self):
  53. self._assert_func = assert_array_equal
  54. def test_generic_rank1(self):
  55. """Test rank 1 array for all dtypes."""
  56. def foo(t):
  57. a = np.empty(2, t)
  58. a.fill(1)
  59. b = a.copy()
  60. c = a.copy()
  61. c.fill(0)
  62. self._test_equal(a, b)
  63. self._test_not_equal(c, b)
  64. # Test numeric types and object
  65. for t in '?bhilqpBHILQPfdgFDG':
  66. foo(t)
  67. # Test strings
  68. for t in ['S1', 'U1']:
  69. foo(t)
  70. def test_0_ndim_array(self):
  71. x = np.array(473963742225900817127911193656584771)
  72. y = np.array(18535119325151578301457182298393896)
  73. assert_raises(AssertionError, self._assert_func, x, y)
  74. y = x
  75. self._assert_func(x, y)
  76. x = np.array(43)
  77. y = np.array(10)
  78. assert_raises(AssertionError, self._assert_func, x, y)
  79. y = x
  80. self._assert_func(x, y)
  81. def test_generic_rank3(self):
  82. """Test rank 3 array for all dtypes."""
  83. def foo(t):
  84. a = np.empty((4, 2, 3), t)
  85. a.fill(1)
  86. b = a.copy()
  87. c = a.copy()
  88. c.fill(0)
  89. self._test_equal(a, b)
  90. self._test_not_equal(c, b)
  91. # Test numeric types and object
  92. for t in '?bhilqpBHILQPfdgFDG':
  93. foo(t)
  94. # Test strings
  95. for t in ['S1', 'U1']:
  96. foo(t)
  97. def test_nan_array(self):
  98. """Test arrays with nan values in them."""
  99. a = np.array([1, 2, np.nan])
  100. b = np.array([1, 2, np.nan])
  101. self._test_equal(a, b)
  102. c = np.array([1, 2, 3])
  103. self._test_not_equal(c, b)
  104. def test_string_arrays(self):
  105. """Test two arrays with different shapes are found not equal."""
  106. a = np.array(['floupi', 'floupa'])
  107. b = np.array(['floupi', 'floupa'])
  108. self._test_equal(a, b)
  109. c = np.array(['floupipi', 'floupa'])
  110. self._test_not_equal(c, b)
  111. def test_recarrays(self):
  112. """Test record arrays."""
  113. a = np.empty(2, [('floupi', float), ('floupa', float)])
  114. a['floupi'] = [1, 2]
  115. a['floupa'] = [1, 2]
  116. b = a.copy()
  117. self._test_equal(a, b)
  118. c = np.empty(2, [('floupipi', float), ('floupa', float)])
  119. c['floupipi'] = a['floupi'].copy()
  120. c['floupa'] = a['floupa'].copy()
  121. with suppress_warnings() as sup:
  122. l = sup.record(FutureWarning, message="elementwise == ")
  123. self._test_not_equal(c, b)
  124. assert_equal(len(l), 1)
  125. def test_masked_nan_inf(self):
  126. # Regression test for gh-11121
  127. a = np.ma.MaskedArray([3., 4., 6.5], mask=[False, True, False])
  128. b = np.array([3., np.nan, 6.5])
  129. self._test_equal(a, b)
  130. self._test_equal(b, a)
  131. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, False, False])
  132. b = np.array([np.inf, 4., 6.5])
  133. self._test_equal(a, b)
  134. self._test_equal(b, a)
  135. def test_subclass_that_overrides_eq(self):
  136. # While we cannot guarantee testing functions will always work for
  137. # subclasses, the tests should ideally rely only on subclasses having
  138. # comparison operators, not on them being able to store booleans
  139. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  140. class MyArray(np.ndarray):
  141. def __eq__(self, other):
  142. return bool(np.equal(self, other).all())
  143. def __ne__(self, other):
  144. return not self == other
  145. a = np.array([1., 2.]).view(MyArray)
  146. b = np.array([2., 3.]).view(MyArray)
  147. assert_(type(a == a), bool)
  148. assert_(a == a)
  149. assert_(a != b)
  150. self._test_equal(a, a)
  151. self._test_not_equal(a, b)
  152. self._test_not_equal(b, a)
  153. @pytest.mark.skipif(
  154. not ARRAY_FUNCTION_ENABLED, reason='requires __array_function__')
  155. def test_subclass_that_does_not_implement_npall(self):
  156. class MyArray(np.ndarray):
  157. def __array_function__(self, *args, **kwargs):
  158. return NotImplemented
  159. a = np.array([1., 2.]).view(MyArray)
  160. b = np.array([2., 3.]).view(MyArray)
  161. with assert_raises(TypeError):
  162. np.all(a)
  163. self._test_equal(a, a)
  164. self._test_not_equal(a, b)
  165. self._test_not_equal(b, a)
  166. class TestBuildErrorMessage(object):
  167. def test_build_err_msg_defaults(self):
  168. x = np.array([1.00001, 2.00002, 3.00003])
  169. y = np.array([1.00002, 2.00003, 3.00004])
  170. err_msg = 'There is a mismatch'
  171. a = build_err_msg([x, y], err_msg)
  172. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  173. '1.00001, 2.00002, 3.00003])\n DESIRED: array([1.00002, '
  174. '2.00003, 3.00004])')
  175. assert_equal(a, b)
  176. def test_build_err_msg_no_verbose(self):
  177. x = np.array([1.00001, 2.00002, 3.00003])
  178. y = np.array([1.00002, 2.00003, 3.00004])
  179. err_msg = 'There is a mismatch'
  180. a = build_err_msg([x, y], err_msg, verbose=False)
  181. b = '\nItems are not equal: There is a mismatch'
  182. assert_equal(a, b)
  183. def test_build_err_msg_custom_names(self):
  184. x = np.array([1.00001, 2.00002, 3.00003])
  185. y = np.array([1.00002, 2.00003, 3.00004])
  186. err_msg = 'There is a mismatch'
  187. a = build_err_msg([x, y], err_msg, names=('FOO', 'BAR'))
  188. b = ('\nItems are not equal: There is a mismatch\n FOO: array(['
  189. '1.00001, 2.00002, 3.00003])\n BAR: array([1.00002, 2.00003, '
  190. '3.00004])')
  191. assert_equal(a, b)
  192. def test_build_err_msg_custom_precision(self):
  193. x = np.array([1.000000001, 2.00002, 3.00003])
  194. y = np.array([1.000000002, 2.00003, 3.00004])
  195. err_msg = 'There is a mismatch'
  196. a = build_err_msg([x, y], err_msg, precision=10)
  197. b = ('\nItems are not equal: There is a mismatch\n ACTUAL: array(['
  198. '1.000000001, 2.00002 , 3.00003 ])\n DESIRED: array(['
  199. '1.000000002, 2.00003 , 3.00004 ])')
  200. assert_equal(a, b)
  201. class TestEqual(TestArrayEqual):
  202. def setup(self):
  203. self._assert_func = assert_equal
  204. def test_nan_items(self):
  205. self._assert_func(np.nan, np.nan)
  206. self._assert_func([np.nan], [np.nan])
  207. self._test_not_equal(np.nan, [np.nan])
  208. self._test_not_equal(np.nan, 1)
  209. def test_inf_items(self):
  210. self._assert_func(np.inf, np.inf)
  211. self._assert_func([np.inf], [np.inf])
  212. self._test_not_equal(np.inf, [np.inf])
  213. def test_datetime(self):
  214. self._test_equal(
  215. np.datetime64("2017-01-01", "s"),
  216. np.datetime64("2017-01-01", "s")
  217. )
  218. self._test_equal(
  219. np.datetime64("2017-01-01", "s"),
  220. np.datetime64("2017-01-01", "m")
  221. )
  222. # gh-10081
  223. self._test_not_equal(
  224. np.datetime64("2017-01-01", "s"),
  225. np.datetime64("2017-01-02", "s")
  226. )
  227. self._test_not_equal(
  228. np.datetime64("2017-01-01", "s"),
  229. np.datetime64("2017-01-02", "m")
  230. )
  231. def test_nat_items(self):
  232. # not a datetime
  233. nadt_no_unit = np.datetime64("NaT")
  234. nadt_s = np.datetime64("NaT", "s")
  235. nadt_d = np.datetime64("NaT", "ns")
  236. # not a timedelta
  237. natd_no_unit = np.timedelta64("NaT")
  238. natd_s = np.timedelta64("NaT", "s")
  239. natd_d = np.timedelta64("NaT", "ns")
  240. dts = [nadt_no_unit, nadt_s, nadt_d]
  241. tds = [natd_no_unit, natd_s, natd_d]
  242. for a, b in itertools.product(dts, dts):
  243. self._assert_func(a, b)
  244. self._assert_func([a], [b])
  245. self._test_not_equal([a], b)
  246. for a, b in itertools.product(tds, tds):
  247. self._assert_func(a, b)
  248. self._assert_func([a], [b])
  249. self._test_not_equal([a], b)
  250. for a, b in itertools.product(tds, dts):
  251. self._test_not_equal(a, b)
  252. self._test_not_equal(a, [b])
  253. self._test_not_equal([a], [b])
  254. self._test_not_equal([a], np.datetime64("2017-01-01", "s"))
  255. self._test_not_equal([b], np.datetime64("2017-01-01", "s"))
  256. self._test_not_equal([a], np.timedelta64(123, "s"))
  257. self._test_not_equal([b], np.timedelta64(123, "s"))
  258. def test_non_numeric(self):
  259. self._assert_func('ab', 'ab')
  260. self._test_not_equal('ab', 'abb')
  261. def test_complex_item(self):
  262. self._assert_func(complex(1, 2), complex(1, 2))
  263. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  264. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  265. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  266. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  267. def test_negative_zero(self):
  268. self._test_not_equal(np.PZERO, np.NZERO)
  269. def test_complex(self):
  270. x = np.array([complex(1, 2), complex(1, np.nan)])
  271. y = np.array([complex(1, 2), complex(1, 2)])
  272. self._assert_func(x, x)
  273. self._test_not_equal(x, y)
  274. def test_error_message(self):
  275. with pytest.raises(AssertionError) as exc_info:
  276. self._assert_func(np.array([1, 2]), np.array([[1, 2]]))
  277. msg = str(exc_info.value)
  278. msg2 = msg.replace("shapes (2L,), (1L, 2L)", "shapes (2,), (1, 2)")
  279. msg_reference = textwrap.dedent("""\
  280. Arrays are not equal
  281. (shapes (2,), (1, 2) mismatch)
  282. x: array([1, 2])
  283. y: array([[1, 2]])""")
  284. try:
  285. assert_equal(msg, msg_reference)
  286. except AssertionError:
  287. assert_equal(msg2, msg_reference)
  288. def test_object(self):
  289. #gh-12942
  290. import datetime
  291. a = np.array([datetime.datetime(2000, 1, 1),
  292. datetime.datetime(2000, 1, 2)])
  293. self._test_not_equal(a, a[::-1])
  294. class TestArrayAlmostEqual(_GenericTest):
  295. def setup(self):
  296. self._assert_func = assert_array_almost_equal
  297. def test_closeness(self):
  298. # Note that in the course of time we ended up with
  299. # `abs(x - y) < 1.5 * 10**(-decimal)`
  300. # instead of the previously documented
  301. # `abs(x - y) < 0.5 * 10**(-decimal)`
  302. # so this check serves to preserve the wrongness.
  303. # test scalars
  304. self._assert_func(1.499999, 0.0, decimal=0)
  305. assert_raises(AssertionError,
  306. lambda: self._assert_func(1.5, 0.0, decimal=0))
  307. # test arrays
  308. self._assert_func([1.499999], [0.0], decimal=0)
  309. assert_raises(AssertionError,
  310. lambda: self._assert_func([1.5], [0.0], decimal=0))
  311. def test_simple(self):
  312. x = np.array([1234.2222])
  313. y = np.array([1234.2223])
  314. self._assert_func(x, y, decimal=3)
  315. self._assert_func(x, y, decimal=4)
  316. assert_raises(AssertionError,
  317. lambda: self._assert_func(x, y, decimal=5))
  318. def test_nan(self):
  319. anan = np.array([np.nan])
  320. aone = np.array([1])
  321. ainf = np.array([np.inf])
  322. self._assert_func(anan, anan)
  323. assert_raises(AssertionError,
  324. lambda: self._assert_func(anan, aone))
  325. assert_raises(AssertionError,
  326. lambda: self._assert_func(anan, ainf))
  327. assert_raises(AssertionError,
  328. lambda: self._assert_func(ainf, anan))
  329. def test_inf(self):
  330. a = np.array([[1., 2.], [3., 4.]])
  331. b = a.copy()
  332. a[0, 0] = np.inf
  333. assert_raises(AssertionError,
  334. lambda: self._assert_func(a, b))
  335. b[0, 0] = -np.inf
  336. assert_raises(AssertionError,
  337. lambda: self._assert_func(a, b))
  338. def test_subclass(self):
  339. a = np.array([[1., 2.], [3., 4.]])
  340. b = np.ma.masked_array([[1., 2.], [0., 4.]],
  341. [[False, False], [True, False]])
  342. self._assert_func(a, b)
  343. self._assert_func(b, a)
  344. self._assert_func(b, b)
  345. # Test fully masked as well (see gh-11123).
  346. a = np.ma.MaskedArray(3.5, mask=True)
  347. b = np.array([3., 4., 6.5])
  348. self._test_equal(a, b)
  349. self._test_equal(b, a)
  350. a = np.ma.masked
  351. b = np.array([3., 4., 6.5])
  352. self._test_equal(a, b)
  353. self._test_equal(b, a)
  354. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  355. b = np.array([1., 2., 3.])
  356. self._test_equal(a, b)
  357. self._test_equal(b, a)
  358. a = np.ma.MaskedArray([3., 4., 6.5], mask=[True, True, True])
  359. b = np.array(1.)
  360. self._test_equal(a, b)
  361. self._test_equal(b, a)
  362. def test_subclass_that_cannot_be_bool(self):
  363. # While we cannot guarantee testing functions will always work for
  364. # subclasses, the tests should ideally rely only on subclasses having
  365. # comparison operators, not on them being able to store booleans
  366. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  367. class MyArray(np.ndarray):
  368. def __eq__(self, other):
  369. return super(MyArray, self).__eq__(other).view(np.ndarray)
  370. def __lt__(self, other):
  371. return super(MyArray, self).__lt__(other).view(np.ndarray)
  372. def all(self, *args, **kwargs):
  373. raise NotImplementedError
  374. a = np.array([1., 2.]).view(MyArray)
  375. self._assert_func(a, a)
  376. class TestAlmostEqual(_GenericTest):
  377. def setup(self):
  378. self._assert_func = assert_almost_equal
  379. def test_closeness(self):
  380. # Note that in the course of time we ended up with
  381. # `abs(x - y) < 1.5 * 10**(-decimal)`
  382. # instead of the previously documented
  383. # `abs(x - y) < 0.5 * 10**(-decimal)`
  384. # so this check serves to preserve the wrongness.
  385. # test scalars
  386. self._assert_func(1.499999, 0.0, decimal=0)
  387. assert_raises(AssertionError,
  388. lambda: self._assert_func(1.5, 0.0, decimal=0))
  389. # test arrays
  390. self._assert_func([1.499999], [0.0], decimal=0)
  391. assert_raises(AssertionError,
  392. lambda: self._assert_func([1.5], [0.0], decimal=0))
  393. def test_nan_item(self):
  394. self._assert_func(np.nan, np.nan)
  395. assert_raises(AssertionError,
  396. lambda: self._assert_func(np.nan, 1))
  397. assert_raises(AssertionError,
  398. lambda: self._assert_func(np.nan, np.inf))
  399. assert_raises(AssertionError,
  400. lambda: self._assert_func(np.inf, np.nan))
  401. def test_inf_item(self):
  402. self._assert_func(np.inf, np.inf)
  403. self._assert_func(-np.inf, -np.inf)
  404. assert_raises(AssertionError,
  405. lambda: self._assert_func(np.inf, 1))
  406. assert_raises(AssertionError,
  407. lambda: self._assert_func(-np.inf, np.inf))
  408. def test_simple_item(self):
  409. self._test_not_equal(1, 2)
  410. def test_complex_item(self):
  411. self._assert_func(complex(1, 2), complex(1, 2))
  412. self._assert_func(complex(1, np.nan), complex(1, np.nan))
  413. self._assert_func(complex(np.inf, np.nan), complex(np.inf, np.nan))
  414. self._test_not_equal(complex(1, np.nan), complex(1, 2))
  415. self._test_not_equal(complex(np.nan, 1), complex(1, np.nan))
  416. self._test_not_equal(complex(np.nan, np.inf), complex(np.nan, 2))
  417. def test_complex(self):
  418. x = np.array([complex(1, 2), complex(1, np.nan)])
  419. z = np.array([complex(1, 2), complex(np.nan, 1)])
  420. y = np.array([complex(1, 2), complex(1, 2)])
  421. self._assert_func(x, x)
  422. self._test_not_equal(x, y)
  423. self._test_not_equal(x, z)
  424. def test_error_message(self):
  425. """Check the message is formatted correctly for the decimal value.
  426. Also check the message when input includes inf or nan (gh12200)"""
  427. x = np.array([1.00000000001, 2.00000000002, 3.00003])
  428. y = np.array([1.00000000002, 2.00000000003, 3.00004])
  429. # Test with a different amount of decimal digits
  430. with pytest.raises(AssertionError) as exc_info:
  431. self._assert_func(x, y, decimal=12)
  432. msgs = str(exc_info.value).split('\n')
  433. assert_equal(msgs[3], 'Mismatched elements: 3 / 3 (100%)')
  434. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  435. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  436. assert_equal(
  437. msgs[6],
  438. ' x: array([1.00000000001, 2.00000000002, 3.00003 ])')
  439. assert_equal(
  440. msgs[7],
  441. ' y: array([1.00000000002, 2.00000000003, 3.00004 ])')
  442. # With the default value of decimal digits, only the 3rd element
  443. # differs. Note that we only check for the formatting of the arrays
  444. # themselves.
  445. with pytest.raises(AssertionError) as exc_info:
  446. self._assert_func(x, y)
  447. msgs = str(exc_info.value).split('\n')
  448. assert_equal(msgs[3], 'Mismatched elements: 1 / 3 (33.3%)')
  449. assert_equal(msgs[4], 'Max absolute difference: 1.e-05')
  450. assert_equal(msgs[5], 'Max relative difference: 3.33328889e-06')
  451. assert_equal(msgs[6], ' x: array([1. , 2. , 3.00003])')
  452. assert_equal(msgs[7], ' y: array([1. , 2. , 3.00004])')
  453. # Check the error message when input includes inf
  454. x = np.array([np.inf, 0])
  455. y = np.array([np.inf, 1])
  456. with pytest.raises(AssertionError) as exc_info:
  457. self._assert_func(x, y)
  458. msgs = str(exc_info.value).split('\n')
  459. assert_equal(msgs[3], 'Mismatched elements: 1 / 2 (50%)')
  460. assert_equal(msgs[4], 'Max absolute difference: 1.')
  461. assert_equal(msgs[5], 'Max relative difference: 1.')
  462. assert_equal(msgs[6], ' x: array([inf, 0.])')
  463. assert_equal(msgs[7], ' y: array([inf, 1.])')
  464. # Check the error message when dividing by zero
  465. x = np.array([1, 2])
  466. y = np.array([0, 0])
  467. with pytest.raises(AssertionError) as exc_info:
  468. self._assert_func(x, y)
  469. msgs = str(exc_info.value).split('\n')
  470. assert_equal(msgs[3], 'Mismatched elements: 2 / 2 (100%)')
  471. assert_equal(msgs[4], 'Max absolute difference: 2')
  472. assert_equal(msgs[5], 'Max relative difference: inf')
  473. def test_error_message_2(self):
  474. """Check the message is formatted correctly when either x or y is a scalar."""
  475. x = 2
  476. y = np.ones(20)
  477. with pytest.raises(AssertionError) as exc_info:
  478. self._assert_func(x, y)
  479. msgs = str(exc_info.value).split('\n')
  480. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  481. assert_equal(msgs[4], 'Max absolute difference: 1.')
  482. assert_equal(msgs[5], 'Max relative difference: 1.')
  483. y = 2
  484. x = np.ones(20)
  485. with pytest.raises(AssertionError) as exc_info:
  486. self._assert_func(x, y)
  487. msgs = str(exc_info.value).split('\n')
  488. assert_equal(msgs[3], 'Mismatched elements: 20 / 20 (100%)')
  489. assert_equal(msgs[4], 'Max absolute difference: 1.')
  490. assert_equal(msgs[5], 'Max relative difference: 0.5')
  491. def test_subclass_that_cannot_be_bool(self):
  492. # While we cannot guarantee testing functions will always work for
  493. # subclasses, the tests should ideally rely only on subclasses having
  494. # comparison operators, not on them being able to store booleans
  495. # (which, e.g., astropy Quantity cannot usefully do). See gh-8452.
  496. class MyArray(np.ndarray):
  497. def __eq__(self, other):
  498. return super(MyArray, self).__eq__(other).view(np.ndarray)
  499. def __lt__(self, other):
  500. return super(MyArray, self).__lt__(other).view(np.ndarray)
  501. def all(self, *args, **kwargs):
  502. raise NotImplementedError
  503. a = np.array([1., 2.]).view(MyArray)
  504. self._assert_func(a, a)
  505. class TestApproxEqual(object):
  506. def setup(self):
  507. self._assert_func = assert_approx_equal
  508. def test_simple_arrays(self):
  509. x = np.array([1234.22])
  510. y = np.array([1234.23])
  511. self._assert_func(x, y, significant=5)
  512. self._assert_func(x, y, significant=6)
  513. assert_raises(AssertionError,
  514. lambda: self._assert_func(x, y, significant=7))
  515. def test_simple_items(self):
  516. x = 1234.22
  517. y = 1234.23
  518. self._assert_func(x, y, significant=4)
  519. self._assert_func(x, y, significant=5)
  520. self._assert_func(x, y, significant=6)
  521. assert_raises(AssertionError,
  522. lambda: self._assert_func(x, y, significant=7))
  523. def test_nan_array(self):
  524. anan = np.array(np.nan)
  525. aone = np.array(1)
  526. ainf = np.array(np.inf)
  527. self._assert_func(anan, anan)
  528. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  529. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  530. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  531. def test_nan_items(self):
  532. anan = np.array(np.nan)
  533. aone = np.array(1)
  534. ainf = np.array(np.inf)
  535. self._assert_func(anan, anan)
  536. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  537. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  538. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  539. class TestArrayAssertLess(object):
  540. def setup(self):
  541. self._assert_func = assert_array_less
  542. def test_simple_arrays(self):
  543. x = np.array([1.1, 2.2])
  544. y = np.array([1.2, 2.3])
  545. self._assert_func(x, y)
  546. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  547. y = np.array([1.0, 2.3])
  548. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  549. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  550. def test_rank2(self):
  551. x = np.array([[1.1, 2.2], [3.3, 4.4]])
  552. y = np.array([[1.2, 2.3], [3.4, 4.5]])
  553. self._assert_func(x, y)
  554. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  555. y = np.array([[1.0, 2.3], [3.4, 4.5]])
  556. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  557. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  558. def test_rank3(self):
  559. x = np.ones(shape=(2, 2, 2))
  560. y = np.ones(shape=(2, 2, 2))+1
  561. self._assert_func(x, y)
  562. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  563. y[0, 0, 0] = 0
  564. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  565. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  566. def test_simple_items(self):
  567. x = 1.1
  568. y = 2.2
  569. self._assert_func(x, y)
  570. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  571. y = np.array([2.2, 3.3])
  572. self._assert_func(x, y)
  573. assert_raises(AssertionError, lambda: self._assert_func(y, x))
  574. y = np.array([1.0, 3.3])
  575. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  576. def test_nan_noncompare(self):
  577. anan = np.array(np.nan)
  578. aone = np.array(1)
  579. ainf = np.array(np.inf)
  580. self._assert_func(anan, anan)
  581. assert_raises(AssertionError, lambda: self._assert_func(aone, anan))
  582. assert_raises(AssertionError, lambda: self._assert_func(anan, aone))
  583. assert_raises(AssertionError, lambda: self._assert_func(anan, ainf))
  584. assert_raises(AssertionError, lambda: self._assert_func(ainf, anan))
  585. def test_nan_noncompare_array(self):
  586. x = np.array([1.1, 2.2, 3.3])
  587. anan = np.array(np.nan)
  588. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  589. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  590. x = np.array([1.1, 2.2, np.nan])
  591. assert_raises(AssertionError, lambda: self._assert_func(x, anan))
  592. assert_raises(AssertionError, lambda: self._assert_func(anan, x))
  593. y = np.array([1.0, 2.0, np.nan])
  594. self._assert_func(y, x)
  595. assert_raises(AssertionError, lambda: self._assert_func(x, y))
  596. def test_inf_compare(self):
  597. aone = np.array(1)
  598. ainf = np.array(np.inf)
  599. self._assert_func(aone, ainf)
  600. self._assert_func(-ainf, aone)
  601. self._assert_func(-ainf, ainf)
  602. assert_raises(AssertionError, lambda: self._assert_func(ainf, aone))
  603. assert_raises(AssertionError, lambda: self._assert_func(aone, -ainf))
  604. assert_raises(AssertionError, lambda: self._assert_func(ainf, ainf))
  605. assert_raises(AssertionError, lambda: self._assert_func(ainf, -ainf))
  606. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -ainf))
  607. def test_inf_compare_array(self):
  608. x = np.array([1.1, 2.2, np.inf])
  609. ainf = np.array(np.inf)
  610. assert_raises(AssertionError, lambda: self._assert_func(x, ainf))
  611. assert_raises(AssertionError, lambda: self._assert_func(ainf, x))
  612. assert_raises(AssertionError, lambda: self._assert_func(x, -ainf))
  613. assert_raises(AssertionError, lambda: self._assert_func(-x, -ainf))
  614. assert_raises(AssertionError, lambda: self._assert_func(-ainf, -x))
  615. self._assert_func(-ainf, x)
  616. @pytest.mark.skip(reason="The raises decorator depends on Nose")
  617. class TestRaises(object):
  618. def setup(self):
  619. class MyException(Exception):
  620. pass
  621. self.e = MyException
  622. def raises_exception(self, e):
  623. raise e
  624. def does_not_raise_exception(self):
  625. pass
  626. def test_correct_catch(self):
  627. raises(self.e)(self.raises_exception)(self.e) # raises?
  628. def test_wrong_exception(self):
  629. try:
  630. raises(self.e)(self.raises_exception)(RuntimeError) # raises?
  631. except RuntimeError:
  632. return
  633. else:
  634. raise AssertionError("should have caught RuntimeError")
  635. def test_catch_no_raise(self):
  636. try:
  637. raises(self.e)(self.does_not_raise_exception)() # raises?
  638. except AssertionError:
  639. return
  640. else:
  641. raise AssertionError("should have raised an AssertionError")
  642. class TestWarns(object):
  643. def test_warn(self):
  644. def f():
  645. warnings.warn("yo")
  646. return 3
  647. before_filters = sys.modules['warnings'].filters[:]
  648. assert_equal(assert_warns(UserWarning, f), 3)
  649. after_filters = sys.modules['warnings'].filters
  650. assert_raises(AssertionError, assert_no_warnings, f)
  651. assert_equal(assert_no_warnings(lambda x: x, 1), 1)
  652. # Check that the warnings state is unchanged
  653. assert_equal(before_filters, after_filters,
  654. "assert_warns does not preserver warnings state")
  655. def test_context_manager(self):
  656. before_filters = sys.modules['warnings'].filters[:]
  657. with assert_warns(UserWarning):
  658. warnings.warn("yo")
  659. after_filters = sys.modules['warnings'].filters
  660. def no_warnings():
  661. with assert_no_warnings():
  662. warnings.warn("yo")
  663. assert_raises(AssertionError, no_warnings)
  664. assert_equal(before_filters, after_filters,
  665. "assert_warns does not preserver warnings state")
  666. def test_warn_wrong_warning(self):
  667. def f():
  668. warnings.warn("yo", DeprecationWarning)
  669. failed = False
  670. with warnings.catch_warnings():
  671. warnings.simplefilter("error", DeprecationWarning)
  672. try:
  673. # Should raise a DeprecationWarning
  674. assert_warns(UserWarning, f)
  675. failed = True
  676. except DeprecationWarning:
  677. pass
  678. if failed:
  679. raise AssertionError("wrong warning caught by assert_warn")
  680. class TestAssertAllclose(object):
  681. def test_simple(self):
  682. x = 1e-3
  683. y = 1e-9
  684. assert_allclose(x, y, atol=1)
  685. assert_raises(AssertionError, assert_allclose, x, y)
  686. a = np.array([x, y, x, y])
  687. b = np.array([x, y, x, x])
  688. assert_allclose(a, b, atol=1)
  689. assert_raises(AssertionError, assert_allclose, a, b)
  690. b[-1] = y * (1 + 1e-8)
  691. assert_allclose(a, b)
  692. assert_raises(AssertionError, assert_allclose, a, b, rtol=1e-9)
  693. assert_allclose(6, 10, rtol=0.5)
  694. assert_raises(AssertionError, assert_allclose, 10, 6, rtol=0.5)
  695. def test_min_int(self):
  696. a = np.array([np.iinfo(np.int_).min], dtype=np.int_)
  697. # Should not raise:
  698. assert_allclose(a, a)
  699. def test_report_fail_percentage(self):
  700. a = np.array([1, 1, 1, 1])
  701. b = np.array([1, 1, 1, 2])
  702. with pytest.raises(AssertionError) as exc_info:
  703. assert_allclose(a, b)
  704. msg = str(exc_info.value)
  705. assert_('Mismatched elements: 1 / 4 (25%)\n'
  706. 'Max absolute difference: 1\n'
  707. 'Max relative difference: 0.5' in msg)
  708. def test_equal_nan(self):
  709. a = np.array([np.nan])
  710. b = np.array([np.nan])
  711. # Should not raise:
  712. assert_allclose(a, b, equal_nan=True)
  713. def test_not_equal_nan(self):
  714. a = np.array([np.nan])
  715. b = np.array([np.nan])
  716. assert_raises(AssertionError, assert_allclose, a, b, equal_nan=False)
  717. def test_equal_nan_default(self):
  718. # Make sure equal_nan default behavior remains unchanged. (All
  719. # of these functions use assert_array_compare under the hood.)
  720. # None of these should raise.
  721. a = np.array([np.nan])
  722. b = np.array([np.nan])
  723. assert_array_equal(a, b)
  724. assert_array_almost_equal(a, b)
  725. assert_array_less(a, b)
  726. assert_allclose(a, b)
  727. def test_report_max_relative_error(self):
  728. a = np.array([0, 1])
  729. b = np.array([0, 2])
  730. with pytest.raises(AssertionError) as exc_info:
  731. assert_allclose(a, b)
  732. msg = str(exc_info.value)
  733. assert_('Max relative difference: 0.5' in msg)
  734. class TestArrayAlmostEqualNulp(object):
  735. def test_float64_pass(self):
  736. # The number of units of least precision
  737. # In this case, use a few places above the lowest level (ie nulp=1)
  738. nulp = 5
  739. x = np.linspace(-20, 20, 50, dtype=np.float64)
  740. x = 10**x
  741. x = np.r_[-x, x]
  742. # Addition
  743. eps = np.finfo(x.dtype).eps
  744. y = x + x*eps*nulp/2.
  745. assert_array_almost_equal_nulp(x, y, nulp)
  746. # Subtraction
  747. epsneg = np.finfo(x.dtype).epsneg
  748. y = x - x*epsneg*nulp/2.
  749. assert_array_almost_equal_nulp(x, y, nulp)
  750. def test_float64_fail(self):
  751. nulp = 5
  752. x = np.linspace(-20, 20, 50, dtype=np.float64)
  753. x = 10**x
  754. x = np.r_[-x, x]
  755. eps = np.finfo(x.dtype).eps
  756. y = x + x*eps*nulp*2.
  757. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  758. x, y, nulp)
  759. epsneg = np.finfo(x.dtype).epsneg
  760. y = x - x*epsneg*nulp*2.
  761. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  762. x, y, nulp)
  763. def test_float32_pass(self):
  764. nulp = 5
  765. x = np.linspace(-20, 20, 50, dtype=np.float32)
  766. x = 10**x
  767. x = np.r_[-x, x]
  768. eps = np.finfo(x.dtype).eps
  769. y = x + x*eps*nulp/2.
  770. assert_array_almost_equal_nulp(x, y, nulp)
  771. epsneg = np.finfo(x.dtype).epsneg
  772. y = x - x*epsneg*nulp/2.
  773. assert_array_almost_equal_nulp(x, y, nulp)
  774. def test_float32_fail(self):
  775. nulp = 5
  776. x = np.linspace(-20, 20, 50, dtype=np.float32)
  777. x = 10**x
  778. x = np.r_[-x, x]
  779. eps = np.finfo(x.dtype).eps
  780. y = x + x*eps*nulp*2.
  781. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  782. x, y, nulp)
  783. epsneg = np.finfo(x.dtype).epsneg
  784. y = x - x*epsneg*nulp*2.
  785. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  786. x, y, nulp)
  787. def test_float16_pass(self):
  788. nulp = 5
  789. x = np.linspace(-4, 4, 10, dtype=np.float16)
  790. x = 10**x
  791. x = np.r_[-x, x]
  792. eps = np.finfo(x.dtype).eps
  793. y = x + x*eps*nulp/2.
  794. assert_array_almost_equal_nulp(x, y, nulp)
  795. epsneg = np.finfo(x.dtype).epsneg
  796. y = x - x*epsneg*nulp/2.
  797. assert_array_almost_equal_nulp(x, y, nulp)
  798. def test_float16_fail(self):
  799. nulp = 5
  800. x = np.linspace(-4, 4, 10, dtype=np.float16)
  801. x = 10**x
  802. x = np.r_[-x, x]
  803. eps = np.finfo(x.dtype).eps
  804. y = x + x*eps*nulp*2.
  805. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  806. x, y, nulp)
  807. epsneg = np.finfo(x.dtype).epsneg
  808. y = x - x*epsneg*nulp*2.
  809. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  810. x, y, nulp)
  811. def test_complex128_pass(self):
  812. nulp = 5
  813. x = np.linspace(-20, 20, 50, dtype=np.float64)
  814. x = 10**x
  815. x = np.r_[-x, x]
  816. xi = x + x*1j
  817. eps = np.finfo(x.dtype).eps
  818. y = x + x*eps*nulp/2.
  819. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  820. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  821. # The test condition needs to be at least a factor of sqrt(2) smaller
  822. # because the real and imaginary parts both change
  823. y = x + x*eps*nulp/4.
  824. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  825. epsneg = np.finfo(x.dtype).epsneg
  826. y = x - x*epsneg*nulp/2.
  827. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  828. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  829. y = x - x*epsneg*nulp/4.
  830. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  831. def test_complex128_fail(self):
  832. nulp = 5
  833. x = np.linspace(-20, 20, 50, dtype=np.float64)
  834. x = 10**x
  835. x = np.r_[-x, x]
  836. xi = x + x*1j
  837. eps = np.finfo(x.dtype).eps
  838. y = x + x*eps*nulp*2.
  839. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  840. xi, x + y*1j, nulp)
  841. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  842. xi, y + x*1j, nulp)
  843. # The test condition needs to be at least a factor of sqrt(2) smaller
  844. # because the real and imaginary parts both change
  845. y = x + x*eps*nulp
  846. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  847. xi, y + y*1j, nulp)
  848. epsneg = np.finfo(x.dtype).epsneg
  849. y = x - x*epsneg*nulp*2.
  850. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  851. xi, x + y*1j, nulp)
  852. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  853. xi, y + x*1j, nulp)
  854. y = x - x*epsneg*nulp
  855. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  856. xi, y + y*1j, nulp)
  857. def test_complex64_pass(self):
  858. nulp = 5
  859. x = np.linspace(-20, 20, 50, dtype=np.float32)
  860. x = 10**x
  861. x = np.r_[-x, x]
  862. xi = x + x*1j
  863. eps = np.finfo(x.dtype).eps
  864. y = x + x*eps*nulp/2.
  865. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  866. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  867. y = x + x*eps*nulp/4.
  868. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  869. epsneg = np.finfo(x.dtype).epsneg
  870. y = x - x*epsneg*nulp/2.
  871. assert_array_almost_equal_nulp(xi, x + y*1j, nulp)
  872. assert_array_almost_equal_nulp(xi, y + x*1j, nulp)
  873. y = x - x*epsneg*nulp/4.
  874. assert_array_almost_equal_nulp(xi, y + y*1j, nulp)
  875. def test_complex64_fail(self):
  876. nulp = 5
  877. x = np.linspace(-20, 20, 50, dtype=np.float32)
  878. x = 10**x
  879. x = np.r_[-x, x]
  880. xi = x + x*1j
  881. eps = np.finfo(x.dtype).eps
  882. y = x + x*eps*nulp*2.
  883. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  884. xi, x + y*1j, nulp)
  885. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  886. xi, y + x*1j, nulp)
  887. y = x + x*eps*nulp
  888. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  889. xi, y + y*1j, nulp)
  890. epsneg = np.finfo(x.dtype).epsneg
  891. y = x - x*epsneg*nulp*2.
  892. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  893. xi, x + y*1j, nulp)
  894. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  895. xi, y + x*1j, nulp)
  896. y = x - x*epsneg*nulp
  897. assert_raises(AssertionError, assert_array_almost_equal_nulp,
  898. xi, y + y*1j, nulp)
  899. class TestULP(object):
  900. def test_equal(self):
  901. x = np.random.randn(10)
  902. assert_array_max_ulp(x, x, maxulp=0)
  903. def test_single(self):
  904. # Generate 1 + small deviation, check that adding eps gives a few UNL
  905. x = np.ones(10).astype(np.float32)
  906. x += 0.01 * np.random.randn(10).astype(np.float32)
  907. eps = np.finfo(np.float32).eps
  908. assert_array_max_ulp(x, x+eps, maxulp=20)
  909. def test_double(self):
  910. # Generate 1 + small deviation, check that adding eps gives a few UNL
  911. x = np.ones(10).astype(np.float64)
  912. x += 0.01 * np.random.randn(10).astype(np.float64)
  913. eps = np.finfo(np.float64).eps
  914. assert_array_max_ulp(x, x+eps, maxulp=200)
  915. def test_inf(self):
  916. for dt in [np.float32, np.float64]:
  917. inf = np.array([np.inf]).astype(dt)
  918. big = np.array([np.finfo(dt).max])
  919. assert_array_max_ulp(inf, big, maxulp=200)
  920. def test_nan(self):
  921. # Test that nan is 'far' from small, tiny, inf, max and min
  922. for dt in [np.float32, np.float64]:
  923. if dt == np.float32:
  924. maxulp = 1e6
  925. else:
  926. maxulp = 1e12
  927. inf = np.array([np.inf]).astype(dt)
  928. nan = np.array([np.nan]).astype(dt)
  929. big = np.array([np.finfo(dt).max])
  930. tiny = np.array([np.finfo(dt).tiny])
  931. zero = np.array([np.PZERO]).astype(dt)
  932. nzero = np.array([np.NZERO]).astype(dt)
  933. assert_raises(AssertionError,
  934. lambda: assert_array_max_ulp(nan, inf,
  935. maxulp=maxulp))
  936. assert_raises(AssertionError,
  937. lambda: assert_array_max_ulp(nan, big,
  938. maxulp=maxulp))
  939. assert_raises(AssertionError,
  940. lambda: assert_array_max_ulp(nan, tiny,
  941. maxulp=maxulp))
  942. assert_raises(AssertionError,
  943. lambda: assert_array_max_ulp(nan, zero,
  944. maxulp=maxulp))
  945. assert_raises(AssertionError,
  946. lambda: assert_array_max_ulp(nan, nzero,
  947. maxulp=maxulp))
  948. class TestStringEqual(object):
  949. def test_simple(self):
  950. assert_string_equal("hello", "hello")
  951. assert_string_equal("hello\nmultiline", "hello\nmultiline")
  952. with pytest.raises(AssertionError) as exc_info:
  953. assert_string_equal("foo\nbar", "hello\nbar")
  954. msg = str(exc_info.value)
  955. assert_equal(msg, "Differences in strings:\n- foo\n+ hello")
  956. assert_raises(AssertionError,
  957. lambda: assert_string_equal("foo", "hello"))
  958. def test_regex(self):
  959. assert_string_equal("a+*b", "a+*b")
  960. assert_raises(AssertionError,
  961. lambda: assert_string_equal("aaa", "a+b"))
  962. def assert_warn_len_equal(mod, n_in_context, py34=None, py37=None):
  963. try:
  964. mod_warns = mod.__warningregistry__
  965. except AttributeError:
  966. # the lack of a __warningregistry__
  967. # attribute means that no warning has
  968. # occurred; this can be triggered in
  969. # a parallel test scenario, while in
  970. # a serial test scenario an initial
  971. # warning (and therefore the attribute)
  972. # are always created first
  973. mod_warns = {}
  974. num_warns = len(mod_warns)
  975. # Python 3.4 appears to clear any pre-existing warnings of the same type,
  976. # when raising warnings inside a catch_warnings block. So, there is a
  977. # warning generated by the tests within the context manager, but no
  978. # previous warnings.
  979. if 'version' in mod_warns:
  980. # Python 3 adds a 'version' entry to the registry,
  981. # do not count it.
  982. num_warns -= 1
  983. # Behavior of warnings is Python version dependent. Adjust the
  984. # expected result to compensate. In particular, Python 3.7 does
  985. # not make an entry for ignored warnings.
  986. if sys.version_info[:2] >= (3, 7):
  987. if py37 is not None:
  988. n_in_context = py37
  989. elif sys.version_info[:2] >= (3, 4):
  990. if py34 is not None:
  991. n_in_context = py34
  992. assert_equal(num_warns, n_in_context)
  993. def test_warn_len_equal_call_scenarios():
  994. # assert_warn_len_equal is called under
  995. # varying circumstances depending on serial
  996. # vs. parallel test scenarios; this test
  997. # simply aims to probe both code paths and
  998. # check that no assertion is uncaught
  999. # parallel scenario -- no warning issued yet
  1000. class mod(object):
  1001. pass
  1002. mod_inst = mod()
  1003. assert_warn_len_equal(mod=mod_inst,
  1004. n_in_context=0)
  1005. # serial test scenario -- the __warningregistry__
  1006. # attribute should be present
  1007. class mod(object):
  1008. def __init__(self):
  1009. self.__warningregistry__ = {'warning1':1,
  1010. 'warning2':2}
  1011. mod_inst = mod()
  1012. assert_warn_len_equal(mod=mod_inst,
  1013. n_in_context=2)
  1014. def _get_fresh_mod():
  1015. # Get this module, with warning registry empty
  1016. my_mod = sys.modules[__name__]
  1017. try:
  1018. my_mod.__warningregistry__.clear()
  1019. except AttributeError:
  1020. # will not have a __warningregistry__ unless warning has been
  1021. # raised in the module at some point
  1022. pass
  1023. return my_mod
  1024. def test_clear_and_catch_warnings():
  1025. # Initial state of module, no warnings
  1026. my_mod = _get_fresh_mod()
  1027. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1028. with clear_and_catch_warnings(modules=[my_mod]):
  1029. warnings.simplefilter('ignore')
  1030. warnings.warn('Some warning')
  1031. assert_equal(my_mod.__warningregistry__, {})
  1032. # Without specified modules, don't clear warnings during context
  1033. # Python 3.7 catch_warnings doesn't make an entry for 'ignore'.
  1034. with clear_and_catch_warnings():
  1035. warnings.simplefilter('ignore')
  1036. warnings.warn('Some warning')
  1037. assert_warn_len_equal(my_mod, 1, py37=0)
  1038. # Confirm that specifying module keeps old warning, does not add new
  1039. with clear_and_catch_warnings(modules=[my_mod]):
  1040. warnings.simplefilter('ignore')
  1041. warnings.warn('Another warning')
  1042. assert_warn_len_equal(my_mod, 1, py37=0)
  1043. # Another warning, no module spec does add to warnings dict, except on
  1044. # Python 3.4 (see comments in `assert_warn_len_equal`)
  1045. # Python 3.7 catch_warnings doesn't make an entry for 'ignore'.
  1046. with clear_and_catch_warnings():
  1047. warnings.simplefilter('ignore')
  1048. warnings.warn('Another warning')
  1049. assert_warn_len_equal(my_mod, 2, py34=1, py37=0)
  1050. def test_suppress_warnings_module():
  1051. # Initial state of module, no warnings
  1052. my_mod = _get_fresh_mod()
  1053. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1054. def warn_other_module():
  1055. # Apply along axis is implemented in python; stacklevel=2 means
  1056. # we end up inside its module, not ours.
  1057. def warn(arr):
  1058. warnings.warn("Some warning 2", stacklevel=2)
  1059. return arr
  1060. np.apply_along_axis(warn, 0, [0])
  1061. # Test module based warning suppression:
  1062. assert_warn_len_equal(my_mod, 0)
  1063. with suppress_warnings() as sup:
  1064. sup.record(UserWarning)
  1065. # suppress warning from other module (may have .pyc ending),
  1066. # if apply_along_axis is moved, had to be changed.
  1067. sup.filter(module=np.lib.shape_base)
  1068. warnings.warn("Some warning")
  1069. warn_other_module()
  1070. # Check that the suppression did test the file correctly (this module
  1071. # got filtered)
  1072. assert_equal(len(sup.log), 1)
  1073. assert_equal(sup.log[0].message.args[0], "Some warning")
  1074. assert_warn_len_equal(my_mod, 0, py37=0)
  1075. sup = suppress_warnings()
  1076. # Will have to be changed if apply_along_axis is moved:
  1077. sup.filter(module=my_mod)
  1078. with sup:
  1079. warnings.warn('Some warning')
  1080. assert_warn_len_equal(my_mod, 0)
  1081. # And test repeat works:
  1082. sup.filter(module=my_mod)
  1083. with sup:
  1084. warnings.warn('Some warning')
  1085. assert_warn_len_equal(my_mod, 0)
  1086. # Without specified modules, don't clear warnings during context
  1087. # Python 3.7 does not add ignored warnings.
  1088. with suppress_warnings():
  1089. warnings.simplefilter('ignore')
  1090. warnings.warn('Some warning')
  1091. assert_warn_len_equal(my_mod, 1, py37=0)
  1092. def test_suppress_warnings_type():
  1093. # Initial state of module, no warnings
  1094. my_mod = _get_fresh_mod()
  1095. assert_equal(getattr(my_mod, '__warningregistry__', {}), {})
  1096. # Test module based warning suppression:
  1097. with suppress_warnings() as sup:
  1098. sup.filter(UserWarning)
  1099. warnings.warn('Some warning')
  1100. assert_warn_len_equal(my_mod, 0)
  1101. sup = suppress_warnings()
  1102. sup.filter(UserWarning)
  1103. with sup:
  1104. warnings.warn('Some warning')
  1105. assert_warn_len_equal(my_mod, 0)
  1106. # And test repeat works:
  1107. sup.filter(module=my_mod)
  1108. with sup:
  1109. warnings.warn('Some warning')
  1110. assert_warn_len_equal(my_mod, 0)
  1111. # Without specified modules, don't clear warnings during context
  1112. # Python 3.7 does not add ignored warnings.
  1113. with suppress_warnings():
  1114. warnings.simplefilter('ignore')
  1115. warnings.warn('Some warning')
  1116. assert_warn_len_equal(my_mod, 1, py37=0)
  1117. def test_suppress_warnings_decorate_no_record():
  1118. sup = suppress_warnings()
  1119. sup.filter(UserWarning)
  1120. @sup
  1121. def warn(category):
  1122. warnings.warn('Some warning', category)
  1123. with warnings.catch_warnings(record=True) as w:
  1124. warnings.simplefilter("always")
  1125. warn(UserWarning) # should be supppressed
  1126. warn(RuntimeWarning)
  1127. assert_equal(len(w), 1)
  1128. def test_suppress_warnings_record():
  1129. sup = suppress_warnings()
  1130. log1 = sup.record()
  1131. with sup:
  1132. log2 = sup.record(message='Some other warning 2')
  1133. sup.filter(message='Some warning')
  1134. warnings.warn('Some warning')
  1135. warnings.warn('Some other warning')
  1136. warnings.warn('Some other warning 2')
  1137. assert_equal(len(sup.log), 2)
  1138. assert_equal(len(log1), 1)
  1139. assert_equal(len(log2),1)
  1140. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1141. # Do it again, with the same context to see if some warnings survived:
  1142. with sup:
  1143. log2 = sup.record(message='Some other warning 2')
  1144. sup.filter(message='Some warning')
  1145. warnings.warn('Some warning')
  1146. warnings.warn('Some other warning')
  1147. warnings.warn('Some other warning 2')
  1148. assert_equal(len(sup.log), 2)
  1149. assert_equal(len(log1), 1)
  1150. assert_equal(len(log2), 1)
  1151. assert_equal(log2[0].message.args[0], 'Some other warning 2')
  1152. # Test nested:
  1153. with suppress_warnings() as sup:
  1154. sup.record()
  1155. with suppress_warnings() as sup2:
  1156. sup2.record(message='Some warning')
  1157. warnings.warn('Some warning')
  1158. warnings.warn('Some other warning')
  1159. assert_equal(len(sup2.log), 1)
  1160. assert_equal(len(sup.log), 1)
  1161. def test_suppress_warnings_forwarding():
  1162. def warn_other_module():
  1163. # Apply along axis is implemented in python; stacklevel=2 means
  1164. # we end up inside its module, not ours.
  1165. def warn(arr):
  1166. warnings.warn("Some warning", stacklevel=2)
  1167. return arr
  1168. np.apply_along_axis(warn, 0, [0])
  1169. with suppress_warnings() as sup:
  1170. sup.record()
  1171. with suppress_warnings("always"):
  1172. for i in range(2):
  1173. warnings.warn("Some warning")
  1174. assert_equal(len(sup.log), 2)
  1175. with suppress_warnings() as sup:
  1176. sup.record()
  1177. with suppress_warnings("location"):
  1178. for i in range(2):
  1179. warnings.warn("Some warning")
  1180. warnings.warn("Some warning")
  1181. assert_equal(len(sup.log), 2)
  1182. with suppress_warnings() as sup:
  1183. sup.record()
  1184. with suppress_warnings("module"):
  1185. for i in range(2):
  1186. warnings.warn("Some warning")
  1187. warnings.warn("Some warning")
  1188. warn_other_module()
  1189. assert_equal(len(sup.log), 2)
  1190. with suppress_warnings() as sup:
  1191. sup.record()
  1192. with suppress_warnings("once"):
  1193. for i in range(2):
  1194. warnings.warn("Some warning")
  1195. warnings.warn("Some other warning")
  1196. warn_other_module()
  1197. assert_equal(len(sup.log), 2)
  1198. def test_tempdir():
  1199. with tempdir() as tdir:
  1200. fpath = os.path.join(tdir, 'tmp')
  1201. with open(fpath, 'w'):
  1202. pass
  1203. assert_(not os.path.isdir(tdir))
  1204. raised = False
  1205. try:
  1206. with tempdir() as tdir:
  1207. raise ValueError()
  1208. except ValueError:
  1209. raised = True
  1210. assert_(raised)
  1211. assert_(not os.path.isdir(tdir))
  1212. def test_temppath():
  1213. with temppath() as fpath:
  1214. with open(fpath, 'w'):
  1215. pass
  1216. assert_(not os.path.isfile(fpath))
  1217. raised = False
  1218. try:
  1219. with temppath() as fpath:
  1220. raise ValueError()
  1221. except ValueError:
  1222. raised = True
  1223. assert_(raised)
  1224. assert_(not os.path.isfile(fpath))
  1225. class my_cacw(clear_and_catch_warnings):
  1226. class_modules = (sys.modules[__name__],)
  1227. def test_clear_and_catch_warnings_inherit():
  1228. # Test can subclass and add default modules
  1229. my_mod = _get_fresh_mod()
  1230. with my_cacw():
  1231. warnings.simplefilter('ignore')
  1232. warnings.warn('Some warning')
  1233. assert_equal(my_mod.__warningregistry__, {})
  1234. @pytest.mark.skipif(not HAS_REFCOUNT, reason="Python lacks refcounts")
  1235. class TestAssertNoGcCycles(object):
  1236. """ Test assert_no_gc_cycles """
  1237. def test_passes(self):
  1238. def no_cycle():
  1239. b = []
  1240. b.append([])
  1241. return b
  1242. with assert_no_gc_cycles():
  1243. no_cycle()
  1244. assert_no_gc_cycles(no_cycle)
  1245. def test_asserts(self):
  1246. def make_cycle():
  1247. a = []
  1248. a.append(a)
  1249. a.append(a)
  1250. return a
  1251. with assert_raises(AssertionError):
  1252. with assert_no_gc_cycles():
  1253. make_cycle()
  1254. with assert_raises(AssertionError):
  1255. assert_no_gc_cycles(make_cycle)
  1256. @pytest.mark.slow
  1257. def test_fails(self):
  1258. """
  1259. Test that in cases where the garbage cannot be collected, we raise an
  1260. error, instead of hanging forever trying to clear it.
  1261. """
  1262. class ReferenceCycleInDel(object):
  1263. """
  1264. An object that not only contains a reference cycle, but creates new
  1265. cycles whenever it's garbage-collected and its __del__ runs
  1266. """
  1267. make_cycle = True
  1268. def __init__(self):
  1269. self.cycle = self
  1270. def __del__(self):
  1271. # break the current cycle so that `self` can be freed
  1272. self.cycle = None
  1273. if ReferenceCycleInDel.make_cycle:
  1274. # but create a new one so that the garbage collector has more
  1275. # work to do.
  1276. ReferenceCycleInDel()
  1277. try:
  1278. w = weakref.ref(ReferenceCycleInDel())
  1279. try:
  1280. with assert_raises(RuntimeError):
  1281. # this will be unable to get a baseline empty garbage
  1282. assert_no_gc_cycles(lambda: None)
  1283. except AssertionError:
  1284. # the above test is only necessary if the GC actually tried to free
  1285. # our object anyway, which python 2.7 does not.
  1286. if w() is not None:
  1287. pytest.skip("GC does not call __del__ on cyclic objects")
  1288. raise
  1289. finally:
  1290. # make sure that we stop creating reference cycles
  1291. ReferenceCycleInDel.make_cycle = False