algorithms.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. """Algorithms to support fitting routines in seaborn plotting functions."""
  2. import numbers
  3. import numpy as np
  4. import warnings
  5. def bootstrap(*args, **kwargs):
  6. """Resample one or more arrays with replacement and store aggregate values.
  7. Positional arguments are a sequence of arrays to bootstrap along the first
  8. axis and pass to a summary function.
  9. Keyword arguments:
  10. n_boot : int, default 10000
  11. Number of iterations
  12. axis : int, default None
  13. Will pass axis to ``func`` as a keyword argument.
  14. units : array, default None
  15. Array of sampling unit IDs. When used the bootstrap resamples units
  16. and then observations within units instead of individual
  17. datapoints.
  18. func : string or callable, default np.mean
  19. Function to call on the args that are passed in. If string, tries
  20. to use as named method on numpy array.
  21. seed : Generator | SeedSequence | RandomState | int | None
  22. Seed for the random number generator; useful if you want
  23. reproducible resamples.
  24. Returns
  25. -------
  26. boot_dist: array
  27. array of bootstrapped statistic values
  28. """
  29. # Ensure list of arrays are same length
  30. if len(np.unique(list(map(len, args)))) > 1:
  31. raise ValueError("All input arrays must have the same length")
  32. n = len(args[0])
  33. # Default keyword arguments
  34. n_boot = kwargs.get("n_boot", 10000)
  35. func = kwargs.get("func", np.mean)
  36. axis = kwargs.get("axis", None)
  37. units = kwargs.get("units", None)
  38. random_seed = kwargs.get("random_seed", None)
  39. if random_seed is not None:
  40. msg = "`random_seed` has been renamed to `seed` and will be removed"
  41. warnings.warn(msg)
  42. seed = kwargs.get("seed", random_seed)
  43. if axis is None:
  44. func_kwargs = dict()
  45. else:
  46. func_kwargs = dict(axis=axis)
  47. # Initialize the resampler
  48. rng = _handle_random_seed(seed)
  49. # Coerce to arrays
  50. args = list(map(np.asarray, args))
  51. if units is not None:
  52. units = np.asarray(units)
  53. # Allow for a function that is the name of a method on an array
  54. if isinstance(func, str):
  55. def f(x):
  56. return getattr(x, func)()
  57. else:
  58. f = func
  59. # Handle numpy changes
  60. try:
  61. integers = rng.integers
  62. except AttributeError:
  63. integers = rng.randint
  64. # Do the bootstrap
  65. if units is not None:
  66. return _structured_bootstrap(args, n_boot, units, f,
  67. func_kwargs, integers)
  68. boot_dist = []
  69. for i in range(int(n_boot)):
  70. resampler = integers(0, n, n, dtype=np.intp) # intp is indexing dtype
  71. sample = [a.take(resampler, axis=0) for a in args]
  72. boot_dist.append(f(*sample, **func_kwargs))
  73. return np.array(boot_dist)
  74. def _structured_bootstrap(args, n_boot, units, func, func_kwargs, integers):
  75. """Resample units instead of datapoints."""
  76. unique_units = np.unique(units)
  77. n_units = len(unique_units)
  78. args = [[a[units == unit] for unit in unique_units] for a in args]
  79. boot_dist = []
  80. for i in range(int(n_boot)):
  81. resampler = integers(0, n_units, n_units, dtype=np.intp)
  82. sample = [np.take(a, resampler, axis=0) for a in args]
  83. lengths = map(len, sample[0])
  84. resampler = [integers(0, n, n, dtype=np.intp) for n in lengths]
  85. sample = [[c.take(r, axis=0) for c, r in zip(a, resampler)]
  86. for a in sample]
  87. sample = list(map(np.concatenate, sample))
  88. boot_dist.append(func(*sample, **func_kwargs))
  89. return np.array(boot_dist)
  90. def _handle_random_seed(seed=None):
  91. """Given a seed in one of many formats, return a random number generator.
  92. Generalizes across the numpy 1.17 changes, preferring newer functionality.
  93. """
  94. if isinstance(seed, np.random.RandomState):
  95. rng = seed
  96. else:
  97. try:
  98. # General interface for seeding on numpy >= 1.17
  99. rng = np.random.default_rng(seed)
  100. except AttributeError:
  101. # We are on numpy < 1.17, handle options ourselves
  102. if isinstance(seed, (numbers.Integral, np.integer)):
  103. rng = np.random.RandomState(seed)
  104. elif seed is None:
  105. rng = np.random.RandomState()
  106. else:
  107. err = "{} cannot be used to seed the randomn number generator"
  108. raise ValueError(err.format(seed))
  109. return rng