function.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. """
  2. For compatibility with numpy libraries, pandas functions or
  3. methods have to accept '*args' and '**kwargs' parameters to
  4. accommodate numpy arguments that are not actually used or
  5. respected in the pandas implementation.
  6. To ensure that users do not abuse these parameters, validation
  7. is performed in 'validators.py' to make sure that any extra
  8. parameters passed correspond ONLY to those in the numpy signature.
  9. Part of that validation includes whether or not the user attempted
  10. to pass in non-default values for these extraneous parameters. As we
  11. want to discourage users from relying on these parameters when calling
  12. the pandas implementation, we want them only to pass in the default values
  13. for these parameters.
  14. This module provides a set of commonly used default arguments for functions
  15. and methods that are spread throughout the codebase. This module will make it
  16. easier to adjust to future upstream changes in the analogous numpy signatures.
  17. """
  18. from collections import OrderedDict
  19. from distutils.version import LooseVersion
  20. from typing import Any, Dict, Optional, Union
  21. from numpy import __version__ as _np_version, ndarray
  22. from pandas._libs.lib import is_bool, is_integer
  23. from pandas.errors import UnsupportedFunctionCall
  24. from pandas.util._validators import (
  25. validate_args,
  26. validate_args_and_kwargs,
  27. validate_kwargs,
  28. )
  29. class CompatValidator:
  30. def __init__(self, defaults, fname=None, method=None, max_fname_arg_count=None):
  31. self.fname = fname
  32. self.method = method
  33. self.defaults = defaults
  34. self.max_fname_arg_count = max_fname_arg_count
  35. def __call__(self, args, kwargs, fname=None, max_fname_arg_count=None, method=None):
  36. if args or kwargs:
  37. fname = self.fname if fname is None else fname
  38. max_fname_arg_count = (
  39. self.max_fname_arg_count
  40. if max_fname_arg_count is None
  41. else max_fname_arg_count
  42. )
  43. method = self.method if method is None else method
  44. if method == "args":
  45. validate_args(fname, args, max_fname_arg_count, self.defaults)
  46. elif method == "kwargs":
  47. validate_kwargs(fname, kwargs, self.defaults)
  48. elif method == "both":
  49. validate_args_and_kwargs(
  50. fname, args, kwargs, max_fname_arg_count, self.defaults
  51. )
  52. else:
  53. raise ValueError(f"invalid validation method '{method}'")
  54. ARGMINMAX_DEFAULTS = dict(out=None)
  55. validate_argmin = CompatValidator(
  56. ARGMINMAX_DEFAULTS, fname="argmin", method="both", max_fname_arg_count=1
  57. )
  58. validate_argmax = CompatValidator(
  59. ARGMINMAX_DEFAULTS, fname="argmax", method="both", max_fname_arg_count=1
  60. )
  61. def process_skipna(skipna, args):
  62. if isinstance(skipna, ndarray) or skipna is None:
  63. args = (skipna,) + args
  64. skipna = True
  65. return skipna, args
  66. def validate_argmin_with_skipna(skipna, args, kwargs):
  67. """
  68. If 'Series.argmin' is called via the 'numpy' library,
  69. the third parameter in its signature is 'out', which
  70. takes either an ndarray or 'None', so check if the
  71. 'skipna' parameter is either an instance of ndarray or
  72. is None, since 'skipna' itself should be a boolean
  73. """
  74. skipna, args = process_skipna(skipna, args)
  75. validate_argmin(args, kwargs)
  76. return skipna
  77. def validate_argmax_with_skipna(skipna, args, kwargs):
  78. """
  79. If 'Series.argmax' is called via the 'numpy' library,
  80. the third parameter in its signature is 'out', which
  81. takes either an ndarray or 'None', so check if the
  82. 'skipna' parameter is either an instance of ndarray or
  83. is None, since 'skipna' itself should be a boolean
  84. """
  85. skipna, args = process_skipna(skipna, args)
  86. validate_argmax(args, kwargs)
  87. return skipna
  88. ARGSORT_DEFAULTS: "OrderedDict[str, Optional[Union[int, str]]]" = OrderedDict()
  89. ARGSORT_DEFAULTS["axis"] = -1
  90. ARGSORT_DEFAULTS["kind"] = "quicksort"
  91. ARGSORT_DEFAULTS["order"] = None
  92. if LooseVersion(_np_version) >= LooseVersion("1.17.0"):
  93. # GH-26361. NumPy added radix sort and changed default to None.
  94. ARGSORT_DEFAULTS["kind"] = None
  95. validate_argsort = CompatValidator(
  96. ARGSORT_DEFAULTS, fname="argsort", max_fname_arg_count=0, method="both"
  97. )
  98. # two different signatures of argsort, this second validation
  99. # for when the `kind` param is supported
  100. ARGSORT_DEFAULTS_KIND: "OrderedDict[str, Optional[int]]" = OrderedDict()
  101. ARGSORT_DEFAULTS_KIND["axis"] = -1
  102. ARGSORT_DEFAULTS_KIND["order"] = None
  103. validate_argsort_kind = CompatValidator(
  104. ARGSORT_DEFAULTS_KIND, fname="argsort", max_fname_arg_count=0, method="both"
  105. )
  106. def validate_argsort_with_ascending(ascending, args, kwargs):
  107. """
  108. If 'Categorical.argsort' is called via the 'numpy' library, the
  109. first parameter in its signature is 'axis', which takes either
  110. an integer or 'None', so check if the 'ascending' parameter has
  111. either integer type or is None, since 'ascending' itself should
  112. be a boolean
  113. """
  114. if is_integer(ascending) or ascending is None:
  115. args = (ascending,) + args
  116. ascending = True
  117. validate_argsort_kind(args, kwargs, max_fname_arg_count=3)
  118. return ascending
  119. CLIP_DEFAULTS = dict(out=None) # type Dict[str, Any]
  120. validate_clip = CompatValidator(
  121. CLIP_DEFAULTS, fname="clip", method="both", max_fname_arg_count=3
  122. )
  123. def validate_clip_with_axis(axis, args, kwargs):
  124. """
  125. If 'NDFrame.clip' is called via the numpy library, the third
  126. parameter in its signature is 'out', which can takes an ndarray,
  127. so check if the 'axis' parameter is an instance of ndarray, since
  128. 'axis' itself should either be an integer or None
  129. """
  130. if isinstance(axis, ndarray):
  131. args = (axis,) + args
  132. axis = None
  133. validate_clip(args, kwargs)
  134. return axis
  135. CUM_FUNC_DEFAULTS: "OrderedDict[str, Any]" = OrderedDict()
  136. CUM_FUNC_DEFAULTS["dtype"] = None
  137. CUM_FUNC_DEFAULTS["out"] = None
  138. validate_cum_func = CompatValidator(
  139. CUM_FUNC_DEFAULTS, method="both", max_fname_arg_count=1
  140. )
  141. validate_cumsum = CompatValidator(
  142. CUM_FUNC_DEFAULTS, fname="cumsum", method="both", max_fname_arg_count=1
  143. )
  144. def validate_cum_func_with_skipna(skipna, args, kwargs, name):
  145. """
  146. If this function is called via the 'numpy' library, the third
  147. parameter in its signature is 'dtype', which takes either a
  148. 'numpy' dtype or 'None', so check if the 'skipna' parameter is
  149. a boolean or not
  150. """
  151. if not is_bool(skipna):
  152. args = (skipna,) + args
  153. skipna = True
  154. validate_cum_func(args, kwargs, fname=name)
  155. return skipna
  156. ALLANY_DEFAULTS: "OrderedDict[str, Optional[bool]]" = OrderedDict()
  157. ALLANY_DEFAULTS["dtype"] = None
  158. ALLANY_DEFAULTS["out"] = None
  159. ALLANY_DEFAULTS["keepdims"] = False
  160. validate_all = CompatValidator(
  161. ALLANY_DEFAULTS, fname="all", method="both", max_fname_arg_count=1
  162. )
  163. validate_any = CompatValidator(
  164. ALLANY_DEFAULTS, fname="any", method="both", max_fname_arg_count=1
  165. )
  166. LOGICAL_FUNC_DEFAULTS = dict(out=None, keepdims=False)
  167. validate_logical_func = CompatValidator(LOGICAL_FUNC_DEFAULTS, method="kwargs")
  168. MINMAX_DEFAULTS = dict(out=None, keepdims=False)
  169. validate_min = CompatValidator(
  170. MINMAX_DEFAULTS, fname="min", method="both", max_fname_arg_count=1
  171. )
  172. validate_max = CompatValidator(
  173. MINMAX_DEFAULTS, fname="max", method="both", max_fname_arg_count=1
  174. )
  175. RESHAPE_DEFAULTS: Dict[str, str] = dict(order="C")
  176. validate_reshape = CompatValidator(
  177. RESHAPE_DEFAULTS, fname="reshape", method="both", max_fname_arg_count=1
  178. )
  179. REPEAT_DEFAULTS: Dict[str, Any] = dict(axis=None)
  180. validate_repeat = CompatValidator(
  181. REPEAT_DEFAULTS, fname="repeat", method="both", max_fname_arg_count=1
  182. )
  183. ROUND_DEFAULTS: Dict[str, Any] = dict(out=None)
  184. validate_round = CompatValidator(
  185. ROUND_DEFAULTS, fname="round", method="both", max_fname_arg_count=1
  186. )
  187. SORT_DEFAULTS: "OrderedDict[str, Optional[Union[int, str]]]" = OrderedDict()
  188. SORT_DEFAULTS["axis"] = -1
  189. SORT_DEFAULTS["kind"] = "quicksort"
  190. SORT_DEFAULTS["order"] = None
  191. validate_sort = CompatValidator(SORT_DEFAULTS, fname="sort", method="kwargs")
  192. STAT_FUNC_DEFAULTS: "OrderedDict[str, Optional[Any]]" = OrderedDict()
  193. STAT_FUNC_DEFAULTS["dtype"] = None
  194. STAT_FUNC_DEFAULTS["out"] = None
  195. PROD_DEFAULTS = SUM_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  196. SUM_DEFAULTS["keepdims"] = False
  197. SUM_DEFAULTS["initial"] = None
  198. MEDIAN_DEFAULTS = STAT_FUNC_DEFAULTS.copy()
  199. MEDIAN_DEFAULTS["overwrite_input"] = False
  200. MEDIAN_DEFAULTS["keepdims"] = False
  201. STAT_FUNC_DEFAULTS["keepdims"] = False
  202. validate_stat_func = CompatValidator(STAT_FUNC_DEFAULTS, method="kwargs")
  203. validate_sum = CompatValidator(
  204. SUM_DEFAULTS, fname="sum", method="both", max_fname_arg_count=1
  205. )
  206. validate_prod = CompatValidator(
  207. PROD_DEFAULTS, fname="prod", method="both", max_fname_arg_count=1
  208. )
  209. validate_mean = CompatValidator(
  210. STAT_FUNC_DEFAULTS, fname="mean", method="both", max_fname_arg_count=1
  211. )
  212. validate_median = CompatValidator(
  213. MEDIAN_DEFAULTS, fname="median", method="both", max_fname_arg_count=1
  214. )
  215. STAT_DDOF_FUNC_DEFAULTS: "OrderedDict[str, Optional[bool]]" = OrderedDict()
  216. STAT_DDOF_FUNC_DEFAULTS["dtype"] = None
  217. STAT_DDOF_FUNC_DEFAULTS["out"] = None
  218. STAT_DDOF_FUNC_DEFAULTS["keepdims"] = False
  219. validate_stat_ddof_func = CompatValidator(STAT_DDOF_FUNC_DEFAULTS, method="kwargs")
  220. TAKE_DEFAULTS: "OrderedDict[str, Optional[str]]" = OrderedDict()
  221. TAKE_DEFAULTS["out"] = None
  222. TAKE_DEFAULTS["mode"] = "raise"
  223. validate_take = CompatValidator(TAKE_DEFAULTS, fname="take", method="kwargs")
  224. def validate_take_with_convert(convert, args, kwargs):
  225. """
  226. If this function is called via the 'numpy' library, the third
  227. parameter in its signature is 'axis', which takes either an
  228. ndarray or 'None', so check if the 'convert' parameter is either
  229. an instance of ndarray or is None
  230. """
  231. if isinstance(convert, ndarray) or convert is None:
  232. args = (convert,) + args
  233. convert = True
  234. validate_take(args, kwargs, max_fname_arg_count=3, method="both")
  235. return convert
  236. TRANSPOSE_DEFAULTS = dict(axes=None)
  237. validate_transpose = CompatValidator(
  238. TRANSPOSE_DEFAULTS, fname="transpose", method="both", max_fname_arg_count=0
  239. )
  240. def validate_window_func(name, args, kwargs):
  241. numpy_args = ("axis", "dtype", "out")
  242. msg = (
  243. f"numpy operations are not valid with window objects. "
  244. f"Use .{name}() directly instead "
  245. )
  246. if len(args) > 0:
  247. raise UnsupportedFunctionCall(msg)
  248. for arg in numpy_args:
  249. if arg in kwargs:
  250. raise UnsupportedFunctionCall(msg)
  251. def validate_rolling_func(name, args, kwargs):
  252. numpy_args = ("axis", "dtype", "out")
  253. msg = (
  254. f"numpy operations are not valid with window objects. "
  255. f"Use .rolling(...).{name}() instead "
  256. )
  257. if len(args) > 0:
  258. raise UnsupportedFunctionCall(msg)
  259. for arg in numpy_args:
  260. if arg in kwargs:
  261. raise UnsupportedFunctionCall(msg)
  262. def validate_expanding_func(name, args, kwargs):
  263. numpy_args = ("axis", "dtype", "out")
  264. msg = (
  265. f"numpy operations are not valid with window objects. "
  266. f"Use .expanding(...).{name}() instead "
  267. )
  268. if len(args) > 0:
  269. raise UnsupportedFunctionCall(msg)
  270. for arg in numpy_args:
  271. if arg in kwargs:
  272. raise UnsupportedFunctionCall(msg)
  273. def validate_groupby_func(name, args, kwargs, allowed=None):
  274. """
  275. 'args' and 'kwargs' should be empty, except for allowed
  276. kwargs because all of
  277. their necessary parameters are explicitly listed in
  278. the function signature
  279. """
  280. if allowed is None:
  281. allowed = []
  282. kwargs = set(kwargs) - set(allowed)
  283. if len(args) + len(kwargs) > 0:
  284. raise UnsupportedFunctionCall(
  285. f"numpy operations are not valid with "
  286. f"groupby. Use .groupby(...).{name}() "
  287. f"instead"
  288. )
  289. RESAMPLER_NUMPY_OPS = ("min", "max", "sum", "prod", "mean", "std", "var")
  290. def validate_resampler_func(method, args, kwargs):
  291. """
  292. 'args' and 'kwargs' should be empty because all of
  293. their necessary parameters are explicitly listed in
  294. the function signature
  295. """
  296. if len(args) + len(kwargs) > 0:
  297. if method in RESAMPLER_NUMPY_OPS:
  298. raise UnsupportedFunctionCall(
  299. f"numpy operations are not "
  300. f"valid with resample. Use "
  301. f".resample(...).{method}() instead"
  302. )
  303. else:
  304. raise TypeError("too many arguments passed in")
  305. def validate_minmax_axis(axis):
  306. """
  307. Ensure that the axis argument passed to min, max, argmin, or argmax is
  308. zero or None, as otherwise it will be incorrectly ignored.
  309. Parameters
  310. ----------
  311. axis : int or None
  312. Raises
  313. ------
  314. ValueError
  315. """
  316. ndim = 1 # hard-coded for Index
  317. if axis is None:
  318. return
  319. if axis >= ndim or (axis < 0 and ndim + axis < 0):
  320. raise ValueError(f"`axis` must be fewer than the number of dimensions ({ndim})")