test_algorithms.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. import numpy as np
  2. import numpy.random as npr
  3. import pytest
  4. from numpy.testing import assert_array_equal
  5. from distutils.version import LooseVersion
  6. from .. import algorithms as algo
  7. @pytest.fixture
  8. def random():
  9. np.random.seed(sum(map(ord, "test_algorithms")))
  10. def test_bootstrap(random):
  11. """Test that bootstrapping gives the right answer in dumb cases."""
  12. a_ones = np.ones(10)
  13. n_boot = 5
  14. out1 = algo.bootstrap(a_ones, n_boot=n_boot)
  15. assert_array_equal(out1, np.ones(n_boot))
  16. out2 = algo.bootstrap(a_ones, n_boot=n_boot, func=np.median)
  17. assert_array_equal(out2, np.ones(n_boot))
  18. def test_bootstrap_length(random):
  19. """Test that we get a bootstrap array of the right shape."""
  20. a_norm = np.random.randn(1000)
  21. out = algo.bootstrap(a_norm)
  22. assert len(out) == 10000
  23. n_boot = 100
  24. out = algo.bootstrap(a_norm, n_boot=n_boot)
  25. assert len(out) == n_boot
  26. def test_bootstrap_range(random):
  27. """Test that boostrapping a random array stays within the right range."""
  28. a_norm = np.random.randn(1000)
  29. amin, amax = a_norm.min(), a_norm.max()
  30. out = algo.bootstrap(a_norm)
  31. assert amin <= out.min()
  32. assert amax >= out.max()
  33. def test_bootstrap_multiarg(random):
  34. """Test that bootstrap works with multiple input arrays."""
  35. x = np.vstack([[1, 10] for i in range(10)])
  36. y = np.vstack([[5, 5] for i in range(10)])
  37. def f(x, y):
  38. return np.vstack((x, y)).max(axis=0)
  39. out_actual = algo.bootstrap(x, y, n_boot=2, func=f)
  40. out_wanted = np.array([[5, 10], [5, 10]])
  41. assert_array_equal(out_actual, out_wanted)
  42. def test_bootstrap_axis(random):
  43. """Test axis kwarg to bootstrap function."""
  44. x = np.random.randn(10, 20)
  45. n_boot = 100
  46. out_default = algo.bootstrap(x, n_boot=n_boot)
  47. assert out_default.shape == (n_boot,)
  48. out_axis = algo.bootstrap(x, n_boot=n_boot, axis=0)
  49. assert out_axis.shape, (n_boot, x.shape[1])
  50. def test_bootstrap_seed(random):
  51. """Test that we can get reproducible resamples by seeding the RNG."""
  52. data = np.random.randn(50)
  53. seed = 42
  54. boots1 = algo.bootstrap(data, seed=seed)
  55. boots2 = algo.bootstrap(data, seed=seed)
  56. assert_array_equal(boots1, boots2)
  57. def test_bootstrap_ols(random):
  58. """Test bootstrap of OLS model fit."""
  59. def ols_fit(X, y):
  60. XtXinv = np.linalg.inv(np.dot(X.T, X))
  61. return XtXinv.dot(X.T).dot(y)
  62. X = np.column_stack((np.random.randn(50, 4), np.ones(50)))
  63. w = [2, 4, 0, 3, 5]
  64. y_noisy = np.dot(X, w) + np.random.randn(50) * 20
  65. y_lownoise = np.dot(X, w) + np.random.randn(50)
  66. n_boot = 500
  67. w_boot_noisy = algo.bootstrap(X, y_noisy,
  68. n_boot=n_boot,
  69. func=ols_fit)
  70. w_boot_lownoise = algo.bootstrap(X, y_lownoise,
  71. n_boot=n_boot,
  72. func=ols_fit)
  73. assert w_boot_noisy.shape == (n_boot, 5)
  74. assert w_boot_lownoise.shape == (n_boot, 5)
  75. assert w_boot_noisy.std() > w_boot_lownoise.std()
  76. def test_bootstrap_units(random):
  77. """Test that results make sense when passing unit IDs to bootstrap."""
  78. data = np.random.randn(50)
  79. ids = np.repeat(range(10), 5)
  80. bwerr = np.random.normal(0, 2, 10)
  81. bwerr = bwerr[ids]
  82. data_rm = data + bwerr
  83. seed = 77
  84. boots_orig = algo.bootstrap(data_rm, seed=seed)
  85. boots_rm = algo.bootstrap(data_rm, units=ids, seed=seed)
  86. assert boots_rm.std() > boots_orig.std()
  87. def test_bootstrap_arglength():
  88. """Test that different length args raise ValueError."""
  89. with pytest.raises(ValueError):
  90. algo.bootstrap(np.arange(5), np.arange(10))
  91. def test_bootstrap_string_func():
  92. """Test that named numpy methods are the same as the numpy function."""
  93. x = np.random.randn(100)
  94. res_a = algo.bootstrap(x, func="mean", seed=0)
  95. res_b = algo.bootstrap(x, func=np.mean, seed=0)
  96. assert np.array_equal(res_a, res_b)
  97. res_a = algo.bootstrap(x, func="std", seed=0)
  98. res_b = algo.bootstrap(x, func=np.std, seed=0)
  99. assert np.array_equal(res_a, res_b)
  100. with pytest.raises(AttributeError):
  101. algo.bootstrap(x, func="not_a_method_name")
  102. def test_bootstrap_reproducibility(random):
  103. """Test that bootstrapping uses the internal random state."""
  104. data = np.random.randn(50)
  105. boots1 = algo.bootstrap(data, seed=100)
  106. boots2 = algo.bootstrap(data, seed=100)
  107. assert_array_equal(boots1, boots2)
  108. with pytest.warns(UserWarning):
  109. # Deprecatd, remove when removing random_seed
  110. boots1 = algo.bootstrap(data, random_seed=100)
  111. boots2 = algo.bootstrap(data, random_seed=100)
  112. assert_array_equal(boots1, boots2)
  113. @pytest.mark.skipif(LooseVersion(np.__version__) < "1.17",
  114. reason="Tests new numpy random functionality")
  115. def test_seed_new():
  116. # Can't use pytest parametrize because tests will fail where the new
  117. # Generator object and related function are not defined
  118. test_bank = [
  119. (None, None, npr.Generator, False),
  120. (npr.RandomState(0), npr.RandomState(0), npr.RandomState, True),
  121. (npr.RandomState(0), npr.RandomState(1), npr.RandomState, False),
  122. (npr.default_rng(1), npr.default_rng(1), npr.Generator, True),
  123. (npr.default_rng(1), npr.default_rng(2), npr.Generator, False),
  124. (npr.SeedSequence(10), npr.SeedSequence(10), npr.Generator, True),
  125. (npr.SeedSequence(10), npr.SeedSequence(20), npr.Generator, False),
  126. (100, 100, npr.Generator, True),
  127. (100, 200, npr.Generator, False),
  128. ]
  129. for seed1, seed2, rng_class, match in test_bank:
  130. rng1 = algo._handle_random_seed(seed1)
  131. rng2 = algo._handle_random_seed(seed2)
  132. assert isinstance(rng1, rng_class)
  133. assert isinstance(rng2, rng_class)
  134. assert (rng1.uniform() == rng2.uniform()) == match
  135. @pytest.mark.skipif(LooseVersion(np.__version__) >= "1.17",
  136. reason="Tests old numpy random functionality")
  137. @pytest.mark.parametrize("seed1, seed2, match", [
  138. (None, None, False),
  139. (npr.RandomState(0), npr.RandomState(0), True),
  140. (npr.RandomState(0), npr.RandomState(1), False),
  141. (100, 100, True),
  142. (100, 200, False),
  143. ])
  144. def test_seed_old(seed1, seed2, match):
  145. rng1 = algo._handle_random_seed(seed1)
  146. rng2 = algo._handle_random_seed(seed2)
  147. assert isinstance(rng1, np.random.RandomState)
  148. assert isinstance(rng2, np.random.RandomState)
  149. assert (rng1.uniform() == rng2.uniform()) == match
  150. @pytest.mark.skipif(LooseVersion(np.__version__) >= "1.17",
  151. reason="Tests old numpy random functionality")
  152. def test_bad_seed_old():
  153. with pytest.raises(ValueError):
  154. algo._handle_random_seed("not_a_random_seed")