test_setops.py 8.3 KB


  1. from datetime import datetime, timedelta
  2. import numpy as np
  3. import pytest
  4. from pandas import Index, Int64Index, RangeIndex
  5. import pandas._testing as tm
  6. class TestRangeIndexSetOps:
  7. @pytest.mark.parametrize("sort", [None, False])
  8. def test_intersection(self, sort):
  9. # intersect with Int64Index
  10. index = RangeIndex(start=0, stop=20, step=2)
  11. other = Index(np.arange(1, 6))
  12. result = index.intersection(other, sort=sort)
  13. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  14. tm.assert_index_equal(result, expected)
  15. result = other.intersection(index, sort=sort)
  16. expected = Index(
  17. np.sort(np.asarray(np.intersect1d(index.values, other.values)))
  18. )
  19. tm.assert_index_equal(result, expected)
  20. # intersect with increasing RangeIndex
  21. other = RangeIndex(1, 6)
  22. result = index.intersection(other, sort=sort)
  23. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  24. tm.assert_index_equal(result, expected)
  25. # intersect with decreasing RangeIndex
  26. other = RangeIndex(5, 0, -1)
  27. result = index.intersection(other, sort=sort)
  28. expected = Index(np.sort(np.intersect1d(index.values, other.values)))
  29. tm.assert_index_equal(result, expected)
  30. # reversed (GH 17296)
  31. result = other.intersection(index, sort=sort)
  32. tm.assert_index_equal(result, expected)
  33. # GH 17296: intersect two decreasing RangeIndexes
  34. first = RangeIndex(10, -2, -2)
  35. other = RangeIndex(5, -4, -1)
  36. expected = first.astype(int).intersection(other.astype(int), sort=sort)
  37. result = first.intersection(other, sort=sort).astype(int)
  38. tm.assert_index_equal(result, expected)
  39. # reversed
  40. result = other.intersection(first, sort=sort).astype(int)
  41. tm.assert_index_equal(result, expected)
  42. index = RangeIndex(5)
  43. # intersect of non-overlapping indices
  44. other = RangeIndex(5, 10, 1)
  45. result = index.intersection(other, sort=sort)
  46. expected = RangeIndex(0, 0, 1)
  47. tm.assert_index_equal(result, expected)
  48. other = RangeIndex(-1, -5, -1)
  49. result = index.intersection(other, sort=sort)
  50. expected = RangeIndex(0, 0, 1)
  51. tm.assert_index_equal(result, expected)
  52. # intersection of empty indices
  53. other = RangeIndex(0, 0, 1)
  54. result = index.intersection(other, sort=sort)
  55. expected = RangeIndex(0, 0, 1)
  56. tm.assert_index_equal(result, expected)
  57. result = other.intersection(index, sort=sort)
  58. tm.assert_index_equal(result, expected)
  59. # intersection of non-overlapping values based on start value and gcd
  60. index = RangeIndex(1, 10, 2)
  61. other = RangeIndex(0, 10, 4)
  62. result = index.intersection(other, sort=sort)
  63. expected = RangeIndex(0, 0, 1)
  64. tm.assert_index_equal(result, expected)
  65. @pytest.mark.parametrize("sort", [False, None])
  66. def test_union_noncomparable(self, sort):
  67. # corner case, non-Int64Index
  68. index = RangeIndex(start=0, stop=20, step=2)
  69. other = Index([datetime.now() + timedelta(i) for i in range(4)], dtype=object)
  70. result = index.union(other, sort=sort)
  71. expected = Index(np.concatenate((index, other)))
  72. tm.assert_index_equal(result, expected)
  73. result = other.union(index, sort=sort)
  74. expected = Index(np.concatenate((other, index)))
  75. tm.assert_index_equal(result, expected)
  76. @pytest.fixture(
  77. params=[
  78. (
  79. RangeIndex(0, 10, 1),
  80. RangeIndex(0, 10, 1),
  81. RangeIndex(0, 10, 1),
  82. RangeIndex(0, 10, 1),
  83. ),
  84. (
  85. RangeIndex(0, 10, 1),
  86. RangeIndex(5, 20, 1),
  87. RangeIndex(0, 20, 1),
  88. Int64Index(range(20)),
  89. ),
  90. (
  91. RangeIndex(0, 10, 1),
  92. RangeIndex(10, 20, 1),
  93. RangeIndex(0, 20, 1),
  94. Int64Index(range(20)),
  95. ),
  96. (
  97. RangeIndex(0, -10, -1),
  98. RangeIndex(0, -10, -1),
  99. RangeIndex(0, -10, -1),
  100. RangeIndex(0, -10, -1),
  101. ),
  102. (
  103. RangeIndex(0, -10, -1),
  104. RangeIndex(-10, -20, -1),
  105. RangeIndex(-19, 1, 1),
  106. Int64Index(range(0, -20, -1)),
  107. ),
  108. (
  109. RangeIndex(0, 10, 2),
  110. RangeIndex(1, 10, 2),
  111. RangeIndex(0, 10, 1),
  112. Int64Index(list(range(0, 10, 2)) + list(range(1, 10, 2))),
  113. ),
  114. (
  115. RangeIndex(0, 11, 2),
  116. RangeIndex(1, 12, 2),
  117. RangeIndex(0, 12, 1),
  118. Int64Index(list(range(0, 11, 2)) + list(range(1, 12, 2))),
  119. ),
  120. (
  121. RangeIndex(0, 21, 4),
  122. RangeIndex(-2, 24, 4),
  123. RangeIndex(-2, 24, 2),
  124. Int64Index(list(range(0, 21, 4)) + list(range(-2, 24, 4))),
  125. ),
  126. (
  127. RangeIndex(0, -20, -2),
  128. RangeIndex(-1, -21, -2),
  129. RangeIndex(-19, 1, 1),
  130. Int64Index(list(range(0, -20, -2)) + list(range(-1, -21, -2))),
  131. ),
  132. (
  133. RangeIndex(0, 100, 5),
  134. RangeIndex(0, 100, 20),
  135. RangeIndex(0, 100, 5),
  136. Int64Index(range(0, 100, 5)),
  137. ),
  138. (
  139. RangeIndex(0, -100, -5),
  140. RangeIndex(5, -100, -20),
  141. RangeIndex(-95, 10, 5),
  142. Int64Index(list(range(0, -100, -5)) + [5]),
  143. ),
  144. (
  145. RangeIndex(0, -11, -1),
  146. RangeIndex(1, -12, -4),
  147. RangeIndex(-11, 2, 1),
  148. Int64Index(list(range(0, -11, -1)) + [1, -11]),
  149. ),
  150. (RangeIndex(0), RangeIndex(0), RangeIndex(0), RangeIndex(0)),
  151. (
  152. RangeIndex(0, -10, -2),
  153. RangeIndex(0),
  154. RangeIndex(0, -10, -2),
  155. RangeIndex(0, -10, -2),
  156. ),
  157. (
  158. RangeIndex(0, 100, 2),
  159. RangeIndex(100, 150, 200),
  160. RangeIndex(0, 102, 2),
  161. Int64Index(range(0, 102, 2)),
  162. ),
  163. (
  164. RangeIndex(0, -100, -2),
  165. RangeIndex(-100, 50, 102),
  166. RangeIndex(-100, 4, 2),
  167. Int64Index(list(range(0, -100, -2)) + [-100, 2]),
  168. ),
  169. (
  170. RangeIndex(0, -100, -1),
  171. RangeIndex(0, -50, -3),
  172. RangeIndex(-99, 1, 1),
  173. Int64Index(list(range(0, -100, -1))),
  174. ),
  175. (
  176. RangeIndex(0, 1, 1),
  177. RangeIndex(5, 6, 10),
  178. RangeIndex(0, 6, 5),
  179. Int64Index([0, 5]),
  180. ),
  181. (
  182. RangeIndex(0, 10, 5),
  183. RangeIndex(-5, -6, -20),
  184. RangeIndex(-5, 10, 5),
  185. Int64Index([0, 5, -5]),
  186. ),
  187. (
  188. RangeIndex(0, 3, 1),
  189. RangeIndex(4, 5, 1),
  190. Int64Index([0, 1, 2, 4]),
  191. Int64Index([0, 1, 2, 4]),
  192. ),
  193. (
  194. RangeIndex(0, 10, 1),
  195. Int64Index([]),
  196. RangeIndex(0, 10, 1),
  197. RangeIndex(0, 10, 1),
  198. ),
  199. (
  200. RangeIndex(0),
  201. Int64Index([1, 5, 6]),
  202. Int64Index([1, 5, 6]),
  203. Int64Index([1, 5, 6]),
  204. ),
  205. ]
  206. )
  207. def unions(self, request):
  208. """Inputs and expected outputs for RangeIndex.union tests"""
  209. return request.param
  210. def test_union_sorted(self, unions):
  211. idx1, idx2, expected_sorted, expected_notsorted = unions
  212. res1 = idx1.union(idx2, sort=None)
  213. tm.assert_index_equal(res1, expected_sorted, exact=True)
  214. res1 = idx1.union(idx2, sort=False)
  215. tm.assert_index_equal(res1, expected_notsorted, exact=True)
  216. res2 = idx2.union(idx1, sort=None)
  217. res3 = idx1._int64index.union(idx2, sort=None)
  218. tm.assert_index_equal(res2, expected_sorted, exact=True)
  219. tm.assert_index_equal(res3, expected_sorted)