test_take.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. from datetime import datetime
  2. import re
  3. import numpy as np
  4. import pytest
  5. from pandas._libs.tslib import iNaT
  6. import pandas._testing as tm
  7. import pandas.core.algorithms as algos
  8. @pytest.fixture(params=[True, False])
  9. def writeable(request):
  10. return request.param
  11. # Check that take_nd works both with writeable arrays
  12. # (in which case fast typed memory-views implementation)
  13. # and read-only arrays alike.
  14. @pytest.fixture(
  15. params=[
  16. (np.float64, True),
  17. (np.float32, True),
  18. (np.uint64, False),
  19. (np.uint32, False),
  20. (np.uint16, False),
  21. (np.uint8, False),
  22. (np.int64, False),
  23. (np.int32, False),
  24. (np.int16, False),
  25. (np.int8, False),
  26. (np.object_, True),
  27. (np.bool, False),
  28. ]
  29. )
  30. def dtype_can_hold_na(request):
  31. return request.param
  32. @pytest.fixture(
  33. params=[
  34. (np.int8, np.int16(127), np.int8),
  35. (np.int8, np.int16(128), np.int16),
  36. (np.int32, 1, np.int32),
  37. (np.int32, 2.0, np.float64),
  38. (np.int32, 3.0 + 4.0j, np.complex128),
  39. (np.int32, True, np.object_),
  40. (np.int32, "", np.object_),
  41. (np.float64, 1, np.float64),
  42. (np.float64, 2.0, np.float64),
  43. (np.float64, 3.0 + 4.0j, np.complex128),
  44. (np.float64, True, np.object_),
  45. (np.float64, "", np.object_),
  46. (np.complex128, 1, np.complex128),
  47. (np.complex128, 2.0, np.complex128),
  48. (np.complex128, 3.0 + 4.0j, np.complex128),
  49. (np.complex128, True, np.object_),
  50. (np.complex128, "", np.object_),
  51. (np.bool_, 1, np.object_),
  52. (np.bool_, 2.0, np.object_),
  53. (np.bool_, 3.0 + 4.0j, np.object_),
  54. (np.bool_, True, np.bool_),
  55. (np.bool_, "", np.object_),
  56. ]
  57. )
  58. def dtype_fill_out_dtype(request):
  59. return request.param
  60. class TestTake:
  61. # Standard incompatible fill error.
  62. fill_error = re.compile("Incompatible type for fill_value")
  63. def test_1d_with_out(self, dtype_can_hold_na, writeable):
  64. dtype, can_hold_na = dtype_can_hold_na
  65. data = np.random.randint(0, 2, 4).astype(dtype)
  66. data.flags.writeable = writeable
  67. indexer = [2, 1, 0, 1]
  68. out = np.empty(4, dtype=dtype)
  69. algos.take_1d(data, indexer, out=out)
  70. expected = data.take(indexer)
  71. tm.assert_almost_equal(out, expected)
  72. indexer = [2, 1, 0, -1]
  73. out = np.empty(4, dtype=dtype)
  74. if can_hold_na:
  75. algos.take_1d(data, indexer, out=out)
  76. expected = data.take(indexer)
  77. expected[3] = np.nan
  78. tm.assert_almost_equal(out, expected)
  79. else:
  80. with pytest.raises(TypeError, match=self.fill_error):
  81. algos.take_1d(data, indexer, out=out)
  82. # No Exception otherwise.
  83. data.take(indexer, out=out)
  84. def test_1d_fill_nonna(self, dtype_fill_out_dtype):
  85. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  86. data = np.random.randint(0, 2, 4).astype(dtype)
  87. indexer = [2, 1, 0, -1]
  88. result = algos.take_1d(data, indexer, fill_value=fill_value)
  89. assert (result[[0, 1, 2]] == data[[2, 1, 0]]).all()
  90. assert result[3] == fill_value
  91. assert result.dtype == out_dtype
  92. indexer = [2, 1, 0, 1]
  93. result = algos.take_1d(data, indexer, fill_value=fill_value)
  94. assert (result[[0, 1, 2, 3]] == data[indexer]).all()
  95. assert result.dtype == dtype
  96. def test_2d_with_out(self, dtype_can_hold_na, writeable):
  97. dtype, can_hold_na = dtype_can_hold_na
  98. data = np.random.randint(0, 2, (5, 3)).astype(dtype)
  99. data.flags.writeable = writeable
  100. indexer = [2, 1, 0, 1]
  101. out0 = np.empty((4, 3), dtype=dtype)
  102. out1 = np.empty((5, 4), dtype=dtype)
  103. algos.take_nd(data, indexer, out=out0, axis=0)
  104. algos.take_nd(data, indexer, out=out1, axis=1)
  105. expected0 = data.take(indexer, axis=0)
  106. expected1 = data.take(indexer, axis=1)
  107. tm.assert_almost_equal(out0, expected0)
  108. tm.assert_almost_equal(out1, expected1)
  109. indexer = [2, 1, 0, -1]
  110. out0 = np.empty((4, 3), dtype=dtype)
  111. out1 = np.empty((5, 4), dtype=dtype)
  112. if can_hold_na:
  113. algos.take_nd(data, indexer, out=out0, axis=0)
  114. algos.take_nd(data, indexer, out=out1, axis=1)
  115. expected0 = data.take(indexer, axis=0)
  116. expected1 = data.take(indexer, axis=1)
  117. expected0[3, :] = np.nan
  118. expected1[:, 3] = np.nan
  119. tm.assert_almost_equal(out0, expected0)
  120. tm.assert_almost_equal(out1, expected1)
  121. else:
  122. for i, out in enumerate([out0, out1]):
  123. with pytest.raises(TypeError, match=self.fill_error):
  124. algos.take_nd(data, indexer, out=out, axis=i)
  125. # No Exception otherwise.
  126. data.take(indexer, out=out, axis=i)
  127. def test_2d_fill_nonna(self, dtype_fill_out_dtype):
  128. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  129. data = np.random.randint(0, 2, (5, 3)).astype(dtype)
  130. indexer = [2, 1, 0, -1]
  131. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  132. assert (result[[0, 1, 2], :] == data[[2, 1, 0], :]).all()
  133. assert (result[3, :] == fill_value).all()
  134. assert result.dtype == out_dtype
  135. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  136. assert (result[:, [0, 1, 2]] == data[:, [2, 1, 0]]).all()
  137. assert (result[:, 3] == fill_value).all()
  138. assert result.dtype == out_dtype
  139. indexer = [2, 1, 0, 1]
  140. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  141. assert (result[[0, 1, 2, 3], :] == data[indexer, :]).all()
  142. assert result.dtype == dtype
  143. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  144. assert (result[:, [0, 1, 2, 3]] == data[:, indexer]).all()
  145. assert result.dtype == dtype
  146. def test_3d_with_out(self, dtype_can_hold_na):
  147. dtype, can_hold_na = dtype_can_hold_na
  148. data = np.random.randint(0, 2, (5, 4, 3)).astype(dtype)
  149. indexer = [2, 1, 0, 1]
  150. out0 = np.empty((4, 4, 3), dtype=dtype)
  151. out1 = np.empty((5, 4, 3), dtype=dtype)
  152. out2 = np.empty((5, 4, 4), dtype=dtype)
  153. algos.take_nd(data, indexer, out=out0, axis=0)
  154. algos.take_nd(data, indexer, out=out1, axis=1)
  155. algos.take_nd(data, indexer, out=out2, axis=2)
  156. expected0 = data.take(indexer, axis=0)
  157. expected1 = data.take(indexer, axis=1)
  158. expected2 = data.take(indexer, axis=2)
  159. tm.assert_almost_equal(out0, expected0)
  160. tm.assert_almost_equal(out1, expected1)
  161. tm.assert_almost_equal(out2, expected2)
  162. indexer = [2, 1, 0, -1]
  163. out0 = np.empty((4, 4, 3), dtype=dtype)
  164. out1 = np.empty((5, 4, 3), dtype=dtype)
  165. out2 = np.empty((5, 4, 4), dtype=dtype)
  166. if can_hold_na:
  167. algos.take_nd(data, indexer, out=out0, axis=0)
  168. algos.take_nd(data, indexer, out=out1, axis=1)
  169. algos.take_nd(data, indexer, out=out2, axis=2)
  170. expected0 = data.take(indexer, axis=0)
  171. expected1 = data.take(indexer, axis=1)
  172. expected2 = data.take(indexer, axis=2)
  173. expected0[3, :, :] = np.nan
  174. expected1[:, 3, :] = np.nan
  175. expected2[:, :, 3] = np.nan
  176. tm.assert_almost_equal(out0, expected0)
  177. tm.assert_almost_equal(out1, expected1)
  178. tm.assert_almost_equal(out2, expected2)
  179. else:
  180. for i, out in enumerate([out0, out1, out2]):
  181. with pytest.raises(TypeError, match=self.fill_error):
  182. algos.take_nd(data, indexer, out=out, axis=i)
  183. # No Exception otherwise.
  184. data.take(indexer, out=out, axis=i)
  185. def test_3d_fill_nonna(self, dtype_fill_out_dtype):
  186. dtype, fill_value, out_dtype = dtype_fill_out_dtype
  187. data = np.random.randint(0, 2, (5, 4, 3)).astype(dtype)
  188. indexer = [2, 1, 0, -1]
  189. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  190. assert (result[[0, 1, 2], :, :] == data[[2, 1, 0], :, :]).all()
  191. assert (result[3, :, :] == fill_value).all()
  192. assert result.dtype == out_dtype
  193. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  194. assert (result[:, [0, 1, 2], :] == data[:, [2, 1, 0], :]).all()
  195. assert (result[:, 3, :] == fill_value).all()
  196. assert result.dtype == out_dtype
  197. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  198. assert (result[:, :, [0, 1, 2]] == data[:, :, [2, 1, 0]]).all()
  199. assert (result[:, :, 3] == fill_value).all()
  200. assert result.dtype == out_dtype
  201. indexer = [2, 1, 0, 1]
  202. result = algos.take_nd(data, indexer, axis=0, fill_value=fill_value)
  203. assert (result[[0, 1, 2, 3], :, :] == data[indexer, :, :]).all()
  204. assert result.dtype == dtype
  205. result = algos.take_nd(data, indexer, axis=1, fill_value=fill_value)
  206. assert (result[:, [0, 1, 2, 3], :] == data[:, indexer, :]).all()
  207. assert result.dtype == dtype
  208. result = algos.take_nd(data, indexer, axis=2, fill_value=fill_value)
  209. assert (result[:, :, [0, 1, 2, 3]] == data[:, :, indexer]).all()
  210. assert result.dtype == dtype
  211. def test_1d_other_dtypes(self):
  212. arr = np.random.randn(10).astype(np.float32)
  213. indexer = [1, 2, 3, -1]
  214. result = algos.take_1d(arr, indexer)
  215. expected = arr.take(indexer)
  216. expected[-1] = np.nan
  217. tm.assert_almost_equal(result, expected)
  218. def test_2d_other_dtypes(self):
  219. arr = np.random.randn(10, 5).astype(np.float32)
  220. indexer = [1, 2, 3, -1]
  221. # axis=0
  222. result = algos.take_nd(arr, indexer, axis=0)
  223. expected = arr.take(indexer, axis=0)
  224. expected[-1] = np.nan
  225. tm.assert_almost_equal(result, expected)
  226. # axis=1
  227. result = algos.take_nd(arr, indexer, axis=1)
  228. expected = arr.take(indexer, axis=1)
  229. expected[:, -1] = np.nan
  230. tm.assert_almost_equal(result, expected)
  231. def test_1d_bool(self):
  232. arr = np.array([0, 1, 0], dtype=bool)
  233. result = algos.take_1d(arr, [0, 2, 2, 1])
  234. expected = arr.take([0, 2, 2, 1])
  235. tm.assert_numpy_array_equal(result, expected)
  236. result = algos.take_1d(arr, [0, 2, -1])
  237. assert result.dtype == np.object_
  238. def test_2d_bool(self):
  239. arr = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]], dtype=bool)
  240. result = algos.take_nd(arr, [0, 2, 2, 1])
  241. expected = arr.take([0, 2, 2, 1], axis=0)
  242. tm.assert_numpy_array_equal(result, expected)
  243. result = algos.take_nd(arr, [0, 2, 2, 1], axis=1)
  244. expected = arr.take([0, 2, 2, 1], axis=1)
  245. tm.assert_numpy_array_equal(result, expected)
  246. result = algos.take_nd(arr, [0, 2, -1])
  247. assert result.dtype == np.object_
  248. def test_2d_float32(self):
  249. arr = np.random.randn(4, 3).astype(np.float32)
  250. indexer = [0, 2, -1, 1, -1]
  251. # axis=0
  252. result = algos.take_nd(arr, indexer, axis=0)
  253. result2 = np.empty_like(result)
  254. algos.take_nd(arr, indexer, axis=0, out=result2)
  255. tm.assert_almost_equal(result, result2)
  256. expected = arr.take(indexer, axis=0)
  257. expected[[2, 4], :] = np.nan
  258. tm.assert_almost_equal(result, expected)
  259. # this now accepts a float32! # test with float64 out buffer
  260. out = np.empty((len(indexer), arr.shape[1]), dtype="float32")
  261. algos.take_nd(arr, indexer, out=out) # it works!
  262. # axis=1
  263. result = algos.take_nd(arr, indexer, axis=1)
  264. result2 = np.empty_like(result)
  265. algos.take_nd(arr, indexer, axis=1, out=result2)
  266. tm.assert_almost_equal(result, result2)
  267. expected = arr.take(indexer, axis=1)
  268. expected[:, [2, 4]] = np.nan
  269. tm.assert_almost_equal(result, expected)
  270. def test_2d_datetime64(self):
  271. # 2005/01/01 - 2006/01/01
  272. arr = np.random.randint(11045376, 11360736, (5, 3)) * 100000000000
  273. arr = arr.view(dtype="datetime64[ns]")
  274. indexer = [0, 2, -1, 1, -1]
  275. # axis=0
  276. result = algos.take_nd(arr, indexer, axis=0)
  277. result2 = np.empty_like(result)
  278. algos.take_nd(arr, indexer, axis=0, out=result2)
  279. tm.assert_almost_equal(result, result2)
  280. expected = arr.take(indexer, axis=0)
  281. expected.view(np.int64)[[2, 4], :] = iNaT
  282. tm.assert_almost_equal(result, expected)
  283. result = algos.take_nd(arr, indexer, axis=0, fill_value=datetime(2007, 1, 1))
  284. result2 = np.empty_like(result)
  285. algos.take_nd(
  286. arr, indexer, out=result2, axis=0, fill_value=datetime(2007, 1, 1)
  287. )
  288. tm.assert_almost_equal(result, result2)
  289. expected = arr.take(indexer, axis=0)
  290. expected[[2, 4], :] = datetime(2007, 1, 1)
  291. tm.assert_almost_equal(result, expected)
  292. # axis=1
  293. result = algos.take_nd(arr, indexer, axis=1)
  294. result2 = np.empty_like(result)
  295. algos.take_nd(arr, indexer, axis=1, out=result2)
  296. tm.assert_almost_equal(result, result2)
  297. expected = arr.take(indexer, axis=1)
  298. expected.view(np.int64)[:, [2, 4]] = iNaT
  299. tm.assert_almost_equal(result, expected)
  300. result = algos.take_nd(arr, indexer, axis=1, fill_value=datetime(2007, 1, 1))
  301. result2 = np.empty_like(result)
  302. algos.take_nd(
  303. arr, indexer, out=result2, axis=1, fill_value=datetime(2007, 1, 1)
  304. )
  305. tm.assert_almost_equal(result, result2)
  306. expected = arr.take(indexer, axis=1)
  307. expected[:, [2, 4]] = datetime(2007, 1, 1)
  308. tm.assert_almost_equal(result, expected)
  309. def test_take_axis_0(self):
  310. arr = np.arange(12).reshape(4, 3)
  311. result = algos.take(arr, [0, -1])
  312. expected = np.array([[0, 1, 2], [9, 10, 11]])
  313. tm.assert_numpy_array_equal(result, expected)
  314. # allow_fill=True
  315. result = algos.take(arr, [0, -1], allow_fill=True, fill_value=0)
  316. expected = np.array([[0, 1, 2], [0, 0, 0]])
  317. tm.assert_numpy_array_equal(result, expected)
  318. def test_take_axis_1(self):
  319. arr = np.arange(12).reshape(4, 3)
  320. result = algos.take(arr, [0, -1], axis=1)
  321. expected = np.array([[0, 2], [3, 5], [6, 8], [9, 11]])
  322. tm.assert_numpy_array_equal(result, expected)
  323. # allow_fill=True
  324. result = algos.take(arr, [0, -1], axis=1, allow_fill=True, fill_value=0)
  325. expected = np.array([[0, 0], [3, 0], [6, 0], [9, 0]])
  326. tm.assert_numpy_array_equal(result, expected)
  327. # GH#26976 make sure we validate along the correct axis
  328. with pytest.raises(IndexError, match="indices are out-of-bounds"):
  329. algos.take(arr, [0, 3], axis=1, allow_fill=True, fill_value=0)
  330. class TestExtensionTake:
  331. # The take method found in pd.api.extensions
  332. def test_bounds_check_large(self):
  333. arr = np.array([1, 2])
  334. with pytest.raises(IndexError):
  335. algos.take(arr, [2, 3], allow_fill=True)
  336. with pytest.raises(IndexError):
  337. algos.take(arr, [2, 3], allow_fill=False)
  338. def test_bounds_check_small(self):
  339. arr = np.array([1, 2, 3], dtype=np.int64)
  340. indexer = [0, -1, -2]
  341. with pytest.raises(ValueError):
  342. algos.take(arr, indexer, allow_fill=True)
  343. result = algos.take(arr, indexer)
  344. expected = np.array([1, 3, 2], dtype=np.int64)
  345. tm.assert_numpy_array_equal(result, expected)
  346. @pytest.mark.parametrize("allow_fill", [True, False])
  347. def test_take_empty(self, allow_fill):
  348. arr = np.array([], dtype=np.int64)
  349. # empty take is ok
  350. result = algos.take(arr, [], allow_fill=allow_fill)
  351. tm.assert_numpy_array_equal(arr, result)
  352. with pytest.raises(IndexError):
  353. algos.take(arr, [0], allow_fill=allow_fill)
  354. def test_take_na_empty(self):
  355. result = algos.take(np.array([]), [-1, -1], allow_fill=True, fill_value=0.0)
  356. expected = np.array([0.0, 0.0])
  357. tm.assert_numpy_array_equal(result, expected)
  358. def test_take_coerces_list(self):
  359. arr = [1, 2, 3]
  360. result = algos.take(arr, [0, 0])
  361. expected = np.array([1, 1])
  362. tm.assert_numpy_array_equal(result, expected)