1
0

test_categorical.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286
  1. """
  2. This file contains a minimal set of tests for compliance with the extension
  3. array interface test suite, and should contain no other tests.
  4. The test suite for the full functionality of the array is located in
  5. `pandas/tests/arrays/`.
  6. The tests in this file are inherited from the BaseExtensionTests, and only
  7. minimal tweaks should be applied to get the tests passing (by overwriting a
  8. parent method).
  9. Additional tests should either be added to one of the BaseExtensionTests
  10. classes (if they are relevant for the extension interface for all dtypes), or
  11. be added to the array-specific tests in `pandas/tests/arrays/`.
  12. """
  13. import string
  14. import numpy as np
  15. import pytest
  16. import pandas as pd
  17. from pandas import Categorical, CategoricalIndex, Timestamp
  18. import pandas._testing as tm
  19. from pandas.api.types import CategoricalDtype
  20. from pandas.tests.extension import base
  21. def make_data():
  22. while True:
  23. values = np.random.choice(list(string.ascii_letters), size=100)
  24. # ensure we meet the requirements
  25. # 1. first two not null
  26. # 2. first and second are different
  27. if values[0] != values[1]:
  28. break
  29. return values
  30. @pytest.fixture
  31. def dtype():
  32. return CategoricalDtype()
  33. @pytest.fixture
  34. def data():
  35. """Length-100 array for this type.
  36. * data[0] and data[1] should both be non missing
  37. * data[0] and data[1] should not gbe equal
  38. """
  39. return Categorical(make_data())
  40. @pytest.fixture
  41. def data_missing():
  42. """Length 2 array with [NA, Valid]"""
  43. return Categorical([np.nan, "A"])
  44. @pytest.fixture
  45. def data_for_sorting():
  46. return Categorical(["A", "B", "C"], categories=["C", "A", "B"], ordered=True)
  47. @pytest.fixture
  48. def data_missing_for_sorting():
  49. return Categorical(["A", None, "B"], categories=["B", "A"], ordered=True)
  50. @pytest.fixture
  51. def na_value():
  52. return np.nan
  53. @pytest.fixture
  54. def data_for_grouping():
  55. return Categorical(["a", "a", None, None, "b", "b", "a", "c"])
  56. class TestDtype(base.BaseDtypeTests):
  57. pass
  58. class TestInterface(base.BaseInterfaceTests):
  59. @pytest.mark.skip(reason="Memory usage doesn't match")
  60. def test_memory_usage(self, data):
  61. # Is this deliberate?
  62. super().test_memory_usage(data)
  63. class TestConstructors(base.BaseConstructorsTests):
  64. pass
  65. class TestReshaping(base.BaseReshapingTests):
  66. pass
  67. class TestGetitem(base.BaseGetitemTests):
  68. skip_take = pytest.mark.skip(reason="GH-20664.")
  69. @pytest.mark.skip(reason="Backwards compatibility")
  70. def test_getitem_scalar(self, data):
  71. # CategoricalDtype.type isn't "correct" since it should
  72. # be a parent of the elements (object). But don't want
  73. # to break things by changing.
  74. super().test_getitem_scalar(data)
  75. @skip_take
  76. def test_take(self, data, na_value, na_cmp):
  77. # TODO remove this once Categorical.take is fixed
  78. super().test_take(data, na_value, na_cmp)
  79. @skip_take
  80. def test_take_negative(self, data):
  81. super().test_take_negative(data)
  82. @skip_take
  83. def test_take_pandas_style_negative_raises(self, data, na_value):
  84. super().test_take_pandas_style_negative_raises(data, na_value)
  85. @skip_take
  86. def test_take_non_na_fill_value(self, data_missing):
  87. super().test_take_non_na_fill_value(data_missing)
  88. @skip_take
  89. def test_take_out_of_bounds_raises(self, data, allow_fill):
  90. return super().test_take_out_of_bounds_raises(data, allow_fill)
  91. @pytest.mark.skip(reason="GH-20747. Unobserved categories.")
  92. def test_take_series(self, data):
  93. super().test_take_series(data)
  94. @skip_take
  95. def test_reindex_non_na_fill_value(self, data_missing):
  96. super().test_reindex_non_na_fill_value(data_missing)
  97. @pytest.mark.skip(reason="Categorical.take buggy")
  98. def test_take_empty(self, data, na_value, na_cmp):
  99. super().test_take_empty(data, na_value, na_cmp)
  100. @pytest.mark.skip(reason="test not written correctly for categorical")
  101. def test_reindex(self, data, na_value):
  102. super().test_reindex(data, na_value)
  103. class TestSetitem(base.BaseSetitemTests):
  104. pass
  105. class TestMissing(base.BaseMissingTests):
  106. @pytest.mark.skip(reason="Not implemented")
  107. def test_fillna_limit_pad(self, data_missing):
  108. super().test_fillna_limit_pad(data_missing)
  109. @pytest.mark.skip(reason="Not implemented")
  110. def test_fillna_limit_backfill(self, data_missing):
  111. super().test_fillna_limit_backfill(data_missing)
  112. class TestReduce(base.BaseNoReduceTests):
  113. pass
  114. class TestMethods(base.BaseMethodsTests):
  115. @pytest.mark.skip(reason="Unobserved categories included")
  116. def test_value_counts(self, all_data, dropna):
  117. return super().test_value_counts(all_data, dropna)
  118. def test_combine_add(self, data_repeated):
  119. # GH 20825
  120. # When adding categoricals in combine, result is a string
  121. orig_data1, orig_data2 = data_repeated(2)
  122. s1 = pd.Series(orig_data1)
  123. s2 = pd.Series(orig_data2)
  124. result = s1.combine(s2, lambda x1, x2: x1 + x2)
  125. expected = pd.Series(
  126. ([a + b for (a, b) in zip(list(orig_data1), list(orig_data2))])
  127. )
  128. self.assert_series_equal(result, expected)
  129. val = s1.iloc[0]
  130. result = s1.combine(val, lambda x1, x2: x1 + x2)
  131. expected = pd.Series([a + val for a in list(orig_data1)])
  132. self.assert_series_equal(result, expected)
  133. @pytest.mark.skip(reason="Not Applicable")
  134. def test_fillna_length_mismatch(self, data_missing):
  135. super().test_fillna_length_mismatch(data_missing)
  136. def test_searchsorted(self, data_for_sorting):
  137. if not data_for_sorting.ordered:
  138. raise pytest.skip(reason="searchsorted requires ordered data.")
  139. class TestCasting(base.BaseCastingTests):
  140. @pytest.mark.parametrize("cls", [Categorical, CategoricalIndex])
  141. @pytest.mark.parametrize("values", [[1, np.nan], [Timestamp("2000"), pd.NaT]])
  142. def test_cast_nan_to_int(self, cls, values):
  143. # GH 28406
  144. s = cls(values)
  145. msg = "Cannot (cast|convert)"
  146. with pytest.raises((ValueError, TypeError), match=msg):
  147. s.astype(int)
  148. @pytest.mark.parametrize(
  149. "expected",
  150. [
  151. pd.Series(["2019", "2020"], dtype="datetime64[ns, UTC]"),
  152. pd.Series([0, 0], dtype="timedelta64[ns]"),
  153. pd.Series([pd.Period("2019"), pd.Period("2020")], dtype="period[A-DEC]"),
  154. pd.Series([pd.Interval(0, 1), pd.Interval(1, 2)], dtype="interval"),
  155. pd.Series([1, np.nan], dtype="Int64"),
  156. ],
  157. )
  158. def test_cast_category_to_extension_dtype(self, expected):
  159. # GH 28668
  160. result = expected.astype("category").astype(expected.dtype)
  161. tm.assert_series_equal(result, expected)
  162. @pytest.mark.parametrize(
  163. "dtype, expected",
  164. [
  165. (
  166. "datetime64[ns]",
  167. np.array(["2015-01-01T00:00:00.000000000"], dtype="datetime64[ns]"),
  168. ),
  169. (
  170. "datetime64[ns, MET]",
  171. pd.DatetimeIndex(
  172. [pd.Timestamp("2015-01-01 00:00:00+0100", tz="MET")]
  173. ).array,
  174. ),
  175. ],
  176. )
  177. def test_consistent_casting(self, dtype, expected):
  178. # GH 28448
  179. result = pd.Categorical("2015-01-01").astype(dtype)
  180. assert result == expected
  181. class TestArithmeticOps(base.BaseArithmeticOpsTests):
  182. def test_arith_series_with_scalar(self, data, all_arithmetic_operators):
  183. op_name = all_arithmetic_operators
  184. if op_name != "__rmod__":
  185. super().test_arith_series_with_scalar(data, op_name)
  186. else:
  187. pytest.skip("rmod never called when string is first argument")
  188. def test_add_series_with_extension_array(self, data):
  189. ser = pd.Series(data)
  190. with pytest.raises(TypeError, match="cannot perform|unsupported operand"):
  191. ser + data
  192. def test_divmod_series_array(self):
  193. # GH 23287
  194. # skipping because it is not implemented
  195. pass
  196. def _check_divmod_op(self, s, op, other, exc=NotImplementedError):
  197. return super()._check_divmod_op(s, op, other, exc=TypeError)
  198. class TestComparisonOps(base.BaseComparisonOpsTests):
  199. def _compare_other(self, s, data, op_name, other):
  200. op = self.get_op_from_name(op_name)
  201. if op_name == "__eq__":
  202. result = op(s, other)
  203. expected = s.combine(other, lambda x, y: x == y)
  204. assert (result == expected).all()
  205. elif op_name == "__ne__":
  206. result = op(s, other)
  207. expected = s.combine(other, lambda x, y: x != y)
  208. assert (result == expected).all()
  209. else:
  210. with pytest.raises(TypeError):
  211. op(data, other)
  212. class TestParsing(base.BaseParsingTests):
  213. pass