test_array.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. import datetime
  2. import decimal
  3. import numpy as np
  4. import pytest
  5. import pytz
  6. from pandas.core.dtypes.dtypes import registry
  7. import pandas as pd
  8. import pandas._testing as tm
  9. from pandas.api.extensions import register_extension_dtype
  10. from pandas.api.types import is_scalar
  11. from pandas.arrays import (
  12. BooleanArray,
  13. DatetimeArray,
  14. IntegerArray,
  15. IntervalArray,
  16. SparseArray,
  17. StringArray,
  18. TimedeltaArray,
  19. )
  20. from pandas.core.arrays import PandasArray, integer_array, period_array
  21. from pandas.tests.extension.decimal import DecimalArray, DecimalDtype, to_decimal
  22. @pytest.mark.parametrize(
  23. "data, dtype, expected",
  24. [
  25. # Basic NumPy defaults.
  26. ([1, 2], None, IntegerArray._from_sequence([1, 2])),
  27. ([1, 2], object, PandasArray(np.array([1, 2], dtype=object))),
  28. (
  29. [1, 2],
  30. np.dtype("float32"),
  31. PandasArray(np.array([1.0, 2.0], dtype=np.dtype("float32"))),
  32. ),
  33. (np.array([1, 2], dtype="int64"), None, IntegerArray._from_sequence([1, 2]),),
  34. # String alias passes through to NumPy
  35. ([1, 2], "float32", PandasArray(np.array([1, 2], dtype="float32"))),
  36. # Period alias
  37. (
  38. [pd.Period("2000", "D"), pd.Period("2001", "D")],
  39. "Period[D]",
  40. period_array(["2000", "2001"], freq="D"),
  41. ),
  42. # Period dtype
  43. (
  44. [pd.Period("2000", "D")],
  45. pd.PeriodDtype("D"),
  46. period_array(["2000"], freq="D"),
  47. ),
  48. # Datetime (naive)
  49. (
  50. [1, 2],
  51. np.dtype("datetime64[ns]"),
  52. DatetimeArray._from_sequence(np.array([1, 2], dtype="datetime64[ns]")),
  53. ),
  54. (
  55. np.array([1, 2], dtype="datetime64[ns]"),
  56. None,
  57. DatetimeArray._from_sequence(np.array([1, 2], dtype="datetime64[ns]")),
  58. ),
  59. (
  60. pd.DatetimeIndex(["2000", "2001"]),
  61. np.dtype("datetime64[ns]"),
  62. DatetimeArray._from_sequence(["2000", "2001"]),
  63. ),
  64. (
  65. pd.DatetimeIndex(["2000", "2001"]),
  66. None,
  67. DatetimeArray._from_sequence(["2000", "2001"]),
  68. ),
  69. (
  70. ["2000", "2001"],
  71. np.dtype("datetime64[ns]"),
  72. DatetimeArray._from_sequence(["2000", "2001"]),
  73. ),
  74. # Datetime (tz-aware)
  75. (
  76. ["2000", "2001"],
  77. pd.DatetimeTZDtype(tz="CET"),
  78. DatetimeArray._from_sequence(
  79. ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET")
  80. ),
  81. ),
  82. # Timedelta
  83. (
  84. ["1H", "2H"],
  85. np.dtype("timedelta64[ns]"),
  86. TimedeltaArray._from_sequence(["1H", "2H"]),
  87. ),
  88. (
  89. pd.TimedeltaIndex(["1H", "2H"]),
  90. np.dtype("timedelta64[ns]"),
  91. TimedeltaArray._from_sequence(["1H", "2H"]),
  92. ),
  93. (
  94. pd.TimedeltaIndex(["1H", "2H"]),
  95. None,
  96. TimedeltaArray._from_sequence(["1H", "2H"]),
  97. ),
  98. # Category
  99. (["a", "b"], "category", pd.Categorical(["a", "b"])),
  100. (
  101. ["a", "b"],
  102. pd.CategoricalDtype(None, ordered=True),
  103. pd.Categorical(["a", "b"], ordered=True),
  104. ),
  105. # Interval
  106. (
  107. [pd.Interval(1, 2), pd.Interval(3, 4)],
  108. "interval",
  109. IntervalArray.from_tuples([(1, 2), (3, 4)]),
  110. ),
  111. # Sparse
  112. ([0, 1], "Sparse[int64]", SparseArray([0, 1], dtype="int64")),
  113. # IntegerNA
  114. ([1, None], "Int16", integer_array([1, None], dtype="Int16")),
  115. (pd.Series([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  116. # String
  117. (["a", None], "string", StringArray._from_sequence(["a", None])),
  118. (["a", None], pd.StringDtype(), StringArray._from_sequence(["a", None]),),
  119. # Boolean
  120. ([True, None], "boolean", BooleanArray._from_sequence([True, None])),
  121. ([True, None], pd.BooleanDtype(), BooleanArray._from_sequence([True, None]),),
  122. # Index
  123. (pd.Index([1, 2]), None, PandasArray(np.array([1, 2], dtype=np.int64))),
  124. # Series[EA] returns the EA
  125. (
  126. pd.Series(pd.Categorical(["a", "b"], categories=["a", "b", "c"])),
  127. None,
  128. pd.Categorical(["a", "b"], categories=["a", "b", "c"]),
  129. ),
  130. # "3rd party" EAs work
  131. ([decimal.Decimal(0), decimal.Decimal(1)], "decimal", to_decimal([0, 1])),
  132. # pass an ExtensionArray, but a different dtype
  133. (
  134. period_array(["2000", "2001"], freq="D"),
  135. "category",
  136. pd.Categorical([pd.Period("2000", "D"), pd.Period("2001", "D")]),
  137. ),
  138. ],
  139. )
  140. def test_array(data, dtype, expected):
  141. result = pd.array(data, dtype=dtype)
  142. tm.assert_equal(result, expected)
  143. def test_array_copy():
  144. a = np.array([1, 2])
  145. # default is to copy
  146. b = pd.array(a, dtype=a.dtype)
  147. assert np.shares_memory(a, b._ndarray) is False
  148. # copy=True
  149. b = pd.array(a, dtype=a.dtype, copy=True)
  150. assert np.shares_memory(a, b._ndarray) is False
  151. # copy=False
  152. b = pd.array(a, dtype=a.dtype, copy=False)
  153. assert np.shares_memory(a, b._ndarray) is True
  154. cet = pytz.timezone("CET")
  155. @pytest.mark.parametrize(
  156. "data, expected",
  157. [
  158. # period
  159. (
  160. [pd.Period("2000", "D"), pd.Period("2001", "D")],
  161. period_array(["2000", "2001"], freq="D"),
  162. ),
  163. # interval
  164. ([pd.Interval(0, 1), pd.Interval(1, 2)], IntervalArray.from_breaks([0, 1, 2]),),
  165. # datetime
  166. (
  167. [pd.Timestamp("2000"), pd.Timestamp("2001")],
  168. DatetimeArray._from_sequence(["2000", "2001"]),
  169. ),
  170. (
  171. [datetime.datetime(2000, 1, 1), datetime.datetime(2001, 1, 1)],
  172. DatetimeArray._from_sequence(["2000", "2001"]),
  173. ),
  174. (
  175. np.array([1, 2], dtype="M8[ns]"),
  176. DatetimeArray(np.array([1, 2], dtype="M8[ns]")),
  177. ),
  178. (
  179. np.array([1, 2], dtype="M8[us]"),
  180. DatetimeArray(np.array([1000, 2000], dtype="M8[ns]")),
  181. ),
  182. # datetimetz
  183. (
  184. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2001", tz="CET")],
  185. DatetimeArray._from_sequence(
  186. ["2000", "2001"], dtype=pd.DatetimeTZDtype(tz="CET")
  187. ),
  188. ),
  189. (
  190. [
  191. datetime.datetime(2000, 1, 1, tzinfo=cet),
  192. datetime.datetime(2001, 1, 1, tzinfo=cet),
  193. ],
  194. DatetimeArray._from_sequence(["2000", "2001"], tz=cet),
  195. ),
  196. # timedelta
  197. (
  198. [pd.Timedelta("1H"), pd.Timedelta("2H")],
  199. TimedeltaArray._from_sequence(["1H", "2H"]),
  200. ),
  201. (
  202. np.array([1, 2], dtype="m8[ns]"),
  203. TimedeltaArray(np.array([1, 2], dtype="m8[ns]")),
  204. ),
  205. (
  206. np.array([1, 2], dtype="m8[us]"),
  207. TimedeltaArray(np.array([1000, 2000], dtype="m8[ns]")),
  208. ),
  209. # integer
  210. ([1, 2], IntegerArray._from_sequence([1, 2])),
  211. ([1, None], IntegerArray._from_sequence([1, None])),
  212. # string
  213. (["a", "b"], StringArray._from_sequence(["a", "b"])),
  214. (["a", None], StringArray._from_sequence(["a", None])),
  215. # Boolean
  216. ([True, False], BooleanArray._from_sequence([True, False])),
  217. ([True, None], BooleanArray._from_sequence([True, None])),
  218. ],
  219. )
  220. def test_array_inference(data, expected):
  221. result = pd.array(data)
  222. tm.assert_equal(result, expected)
  223. @pytest.mark.parametrize(
  224. "data",
  225. [
  226. # mix of frequencies
  227. [pd.Period("2000", "D"), pd.Period("2001", "A")],
  228. # mix of closed
  229. [pd.Interval(0, 1, closed="left"), pd.Interval(1, 2, closed="right")],
  230. # Mix of timezones
  231. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000", tz="UTC")],
  232. # Mix of tz-aware and tz-naive
  233. [pd.Timestamp("2000", tz="CET"), pd.Timestamp("2000")],
  234. np.array([pd.Timestamp("2000"), pd.Timestamp("2000", tz="CET")]),
  235. ],
  236. )
  237. def test_array_inference_fails(data):
  238. result = pd.array(data)
  239. expected = PandasArray(np.array(data, dtype=object))
  240. tm.assert_extension_array_equal(result, expected)
  241. @pytest.mark.parametrize("data", [np.array([[1, 2], [3, 4]]), [[1, 2], [3, 4]]])
  242. def test_nd_raises(data):
  243. with pytest.raises(ValueError, match="PandasArray must be 1-dimensional"):
  244. pd.array(data, dtype="int64")
  245. def test_scalar_raises():
  246. with pytest.raises(ValueError, match="Cannot pass scalar '1'"):
  247. pd.array(1)
  248. # ---------------------------------------------------------------------------
  249. # A couple dummy classes to ensure that Series and Indexes are unboxed before
  250. # getting to the EA classes.
  251. @register_extension_dtype
  252. class DecimalDtype2(DecimalDtype):
  253. name = "decimal2"
  254. @classmethod
  255. def construct_array_type(cls):
  256. """
  257. Return the array type associated with this dtype.
  258. Returns
  259. -------
  260. type
  261. """
  262. return DecimalArray2
  263. class DecimalArray2(DecimalArray):
  264. @classmethod
  265. def _from_sequence(cls, scalars, dtype=None, copy=False):
  266. if isinstance(scalars, (pd.Series, pd.Index)):
  267. raise TypeError
  268. return super()._from_sequence(scalars, dtype=dtype, copy=copy)
  269. def test_array_unboxes(index_or_series):
  270. box = index_or_series
  271. data = box([decimal.Decimal("1"), decimal.Decimal("2")])
  272. # make sure it works
  273. with pytest.raises(TypeError):
  274. DecimalArray2._from_sequence(data)
  275. result = pd.array(data, dtype="decimal2")
  276. expected = DecimalArray2._from_sequence(data.values)
  277. tm.assert_equal(result, expected)
  278. @pytest.fixture
  279. def registry_without_decimal():
  280. idx = registry.dtypes.index(DecimalDtype)
  281. registry.dtypes.pop(idx)
  282. yield
  283. registry.dtypes.append(DecimalDtype)
  284. def test_array_not_registered(registry_without_decimal):
  285. # check we aren't on it
  286. assert registry.find("decimal") is None
  287. data = [decimal.Decimal("1"), decimal.Decimal("2")]
  288. result = pd.array(data, dtype=DecimalDtype)
  289. expected = DecimalArray._from_sequence(data)
  290. tm.assert_equal(result, expected)
  291. class TestArrayAnalytics:
  292. def test_searchsorted(self, string_dtype):
  293. arr = pd.array(["a", "b", "c"], dtype=string_dtype)
  294. result = arr.searchsorted("a", side="left")
  295. assert is_scalar(result)
  296. assert result == 0
  297. result = arr.searchsorted("a", side="right")
  298. assert is_scalar(result)
  299. assert result == 1
  300. def test_searchsorted_numeric_dtypes_scalar(self, any_real_dtype):
  301. arr = pd.array([1, 3, 90], dtype=any_real_dtype)
  302. result = arr.searchsorted(30)
  303. assert is_scalar(result)
  304. assert result == 2
  305. result = arr.searchsorted([30])
  306. expected = np.array([2], dtype=np.intp)
  307. tm.assert_numpy_array_equal(result, expected)
  308. def test_searchsorted_numeric_dtypes_vector(self, any_real_dtype):
  309. arr = pd.array([1, 3, 90], dtype=any_real_dtype)
  310. result = arr.searchsorted([2, 30])
  311. expected = np.array([1, 2], dtype=np.intp)
  312. tm.assert_numpy_array_equal(result, expected)
  313. @pytest.mark.parametrize(
  314. "arr, val",
  315. [
  316. [
  317. pd.date_range("20120101", periods=10, freq="2D"),
  318. pd.Timestamp("20120102"),
  319. ],
  320. [
  321. pd.date_range("20120101", periods=10, freq="2D", tz="Asia/Hong_Kong"),
  322. pd.Timestamp("20120102", tz="Asia/Hong_Kong"),
  323. ],
  324. [
  325. pd.timedelta_range(start="1 day", end="10 days", periods=10),
  326. pd.Timedelta("2 days"),
  327. ],
  328. ],
  329. )
  330. def test_search_sorted_datetime64_scalar(self, arr, val):
  331. arr = pd.array(arr)
  332. result = arr.searchsorted(val)
  333. assert is_scalar(result)
  334. assert result == 1
  335. def test_searchsorted_sorter(self, any_real_dtype):
  336. arr = pd.array([3, 1, 2], dtype=any_real_dtype)
  337. result = arr.searchsorted([0, 3], sorter=np.argsort(arr))
  338. expected = np.array([0, 2], dtype=np.intp)
  339. tm.assert_numpy_array_equal(result, expected)