common.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. """ common utilities """
  2. import itertools
  3. from warnings import catch_warnings
  4. import numpy as np
  5. from pandas.core.dtypes.common import is_scalar
  6. from pandas import DataFrame, Float64Index, MultiIndex, Series, UInt64Index, date_range
  7. import pandas._testing as tm
  8. def _mklbl(prefix, n):
  9. return ["{prefix}{i}".format(prefix=prefix, i=i) for i in range(n)]
  10. def _axify(obj, key, axis):
  11. # create a tuple accessor
  12. axes = [slice(None)] * obj.ndim
  13. axes[axis] = key
  14. return tuple(axes)
  15. class Base:
  16. """ indexing comprehensive base class """
  17. _kinds = {"series", "frame"}
  18. _typs = {
  19. "ints",
  20. "uints",
  21. "labels",
  22. "mixed",
  23. "ts",
  24. "floats",
  25. "empty",
  26. "ts_rev",
  27. "multi",
  28. }
  29. def setup_method(self, method):
  30. self.series_ints = Series(np.random.rand(4), index=np.arange(0, 8, 2))
  31. self.frame_ints = DataFrame(
  32. np.random.randn(4, 4), index=np.arange(0, 8, 2), columns=np.arange(0, 12, 3)
  33. )
  34. self.series_uints = Series(
  35. np.random.rand(4), index=UInt64Index(np.arange(0, 8, 2))
  36. )
  37. self.frame_uints = DataFrame(
  38. np.random.randn(4, 4),
  39. index=UInt64Index(range(0, 8, 2)),
  40. columns=UInt64Index(range(0, 12, 3)),
  41. )
  42. self.series_floats = Series(
  43. np.random.rand(4), index=Float64Index(range(0, 8, 2))
  44. )
  45. self.frame_floats = DataFrame(
  46. np.random.randn(4, 4),
  47. index=Float64Index(range(0, 8, 2)),
  48. columns=Float64Index(range(0, 12, 3)),
  49. )
  50. m_idces = [
  51. MultiIndex.from_product([[1, 2], [3, 4]]),
  52. MultiIndex.from_product([[5, 6], [7, 8]]),
  53. MultiIndex.from_product([[9, 10], [11, 12]]),
  54. ]
  55. self.series_multi = Series(np.random.rand(4), index=m_idces[0])
  56. self.frame_multi = DataFrame(
  57. np.random.randn(4, 4), index=m_idces[0], columns=m_idces[1]
  58. )
  59. self.series_labels = Series(np.random.randn(4), index=list("abcd"))
  60. self.frame_labels = DataFrame(
  61. np.random.randn(4, 4), index=list("abcd"), columns=list("ABCD")
  62. )
  63. self.series_mixed = Series(np.random.randn(4), index=[2, 4, "null", 8])
  64. self.frame_mixed = DataFrame(np.random.randn(4, 4), index=[2, 4, "null", 8])
  65. self.series_ts = Series(
  66. np.random.randn(4), index=date_range("20130101", periods=4)
  67. )
  68. self.frame_ts = DataFrame(
  69. np.random.randn(4, 4), index=date_range("20130101", periods=4)
  70. )
  71. dates_rev = date_range("20130101", periods=4).sort_values(ascending=False)
  72. self.series_ts_rev = Series(np.random.randn(4), index=dates_rev)
  73. self.frame_ts_rev = DataFrame(np.random.randn(4, 4), index=dates_rev)
  74. self.frame_empty = DataFrame()
  75. self.series_empty = Series(dtype=object)
  76. # form agglomerates
  77. for kind in self._kinds:
  78. d = dict()
  79. for typ in self._typs:
  80. d[typ] = getattr(self, "{kind}_{typ}".format(kind=kind, typ=typ))
  81. setattr(self, kind, d)
  82. def generate_indices(self, f, values=False):
  83. """ generate the indices
  84. if values is True , use the axis values
  85. is False, use the range
  86. """
  87. axes = f.axes
  88. if values:
  89. axes = (list(range(len(ax))) for ax in axes)
  90. return itertools.product(*axes)
  91. def get_result(self, obj, method, key, axis):
  92. """ return the result for this obj with this key and this axis """
  93. if isinstance(key, dict):
  94. key = key[axis]
  95. # use an artificial conversion to map the key as integers to the labels
  96. # so ix can work for comparisons
  97. if method == "indexer":
  98. method = "ix"
  99. key = obj._get_axis(axis)[key]
  100. # in case we actually want 0 index slicing
  101. with catch_warnings(record=True):
  102. try:
  103. xp = getattr(obj, method).__getitem__(_axify(obj, key, axis))
  104. except AttributeError:
  105. xp = getattr(obj, method).__getitem__(key)
  106. return xp
  107. def get_value(self, name, f, i, values=False):
  108. """ return the value for the location i """
  109. # check against values
  110. if values:
  111. return f.values[i]
  112. elif name == "iat":
  113. return f.iloc[i]
  114. else:
  115. assert name == "at"
  116. return f.loc[i]
  117. def check_values(self, f, func, values=False):
  118. if f is None:
  119. return
  120. axes = f.axes
  121. indicies = itertools.product(*axes)
  122. for i in indicies:
  123. result = getattr(f, func)[i]
  124. # check against values
  125. if values:
  126. expected = f.values[i]
  127. else:
  128. expected = f
  129. for a in reversed(i):
  130. expected = expected.__getitem__(a)
  131. tm.assert_almost_equal(result, expected)
  132. def check_result(
  133. self, method1, key1, method2, key2, typs=None, axes=None, fails=None,
  134. ):
  135. def _eq(axis, obj, key1, key2):
  136. """ compare equal for these 2 keys """
  137. if axis > obj.ndim - 1:
  138. return
  139. try:
  140. rs = getattr(obj, method1).__getitem__(_axify(obj, key1, axis))
  141. try:
  142. xp = self.get_result(obj=obj, method=method2, key=key2, axis=axis)
  143. except (KeyError, IndexError):
  144. # TODO: why is this allowed?
  145. return
  146. if is_scalar(rs) and is_scalar(xp):
  147. assert rs == xp
  148. else:
  149. tm.assert_equal(rs, xp)
  150. except (IndexError, TypeError, KeyError) as detail:
  151. # if we are in fails, the ok, otherwise raise it
  152. if fails is not None:
  153. if isinstance(detail, fails):
  154. result = f"ok ({type(detail).__name__})"
  155. return
  156. result = type(detail).__name__
  157. raise AssertionError(result, detail)
  158. if typs is None:
  159. typs = self._typs
  160. if axes is None:
  161. axes = [0, 1]
  162. elif not isinstance(axes, (tuple, list)):
  163. assert isinstance(axes, int)
  164. axes = [axes]
  165. # check
  166. for kind in self._kinds:
  167. d = getattr(self, kind)
  168. for ax in axes:
  169. for typ in typs:
  170. if typ not in self._typs:
  171. continue
  172. obj = d[typ]
  173. _eq(axis=ax, obj=obj, key1=key1, key2=key2)