1"""
2Functions to generate methods and pin them to the appropriate classes.
3"""
4from __future__ import annotations
5
6import operator
7
8from pandas.core.dtypes.generic import (
9 ABCDataFrame,
10 ABCSeries,
11)
12
13from pandas.core.ops import roperator
14
15
16def _get_method_wrappers(cls):
17 """
18 Find the appropriate operation-wrappers to use when defining flex/special
19 arithmetic, boolean, and comparison operations with the given class.
20
21 Parameters
22 ----------
23 cls : class
24
25 Returns
26 -------
27 arith_flex : function or None
28 comp_flex : function or None
29 """
30 # TODO: make these non-runtime imports once the relevant functions
31 # are no longer in __init__
32 from pandas.core.ops import (
33 flex_arith_method_FRAME,
34 flex_comp_method_FRAME,
35 flex_method_SERIES,
36 )
37
38 if issubclass(cls, ABCSeries):
39 # Just Series
40 arith_flex = flex_method_SERIES
41 comp_flex = flex_method_SERIES
42 elif issubclass(cls, ABCDataFrame):
43 arith_flex = flex_arith_method_FRAME
44 comp_flex = flex_comp_method_FRAME
45 return arith_flex, comp_flex
46
47
48def add_flex_arithmetic_methods(cls) -> None:
49 """
50 Adds the full suite of flex arithmetic methods (``pow``, ``mul``, ``add``)
51 to the class.
52
53 Parameters
54 ----------
55 cls : class
56 flex methods will be defined and pinned to this class
57 """
58 flex_arith_method, flex_comp_method = _get_method_wrappers(cls)
59 new_methods = _create_methods(cls, flex_arith_method, flex_comp_method)
60 new_methods.update(
61 {
62 "multiply": new_methods["mul"],
63 "subtract": new_methods["sub"],
64 "divide": new_methods["div"],
65 }
66 )
67 # opt out of bool flex methods for now
68 assert not any(kname in new_methods for kname in ("ror_", "rxor", "rand_"))
69
70 _add_methods(cls, new_methods=new_methods)
71
72
73def _create_methods(cls, arith_method, comp_method):
74 # creates actual flex methods based upon arithmetic, and comp method
75 # constructors.
76
77 have_divmod = issubclass(cls, ABCSeries)
78 # divmod is available for Series
79
80 new_methods = {}
81
82 new_methods.update(
83 {
84 "add": arith_method(operator.add),
85 "radd": arith_method(roperator.radd),
86 "sub": arith_method(operator.sub),
87 "mul": arith_method(operator.mul),
88 "truediv": arith_method(operator.truediv),
89 "floordiv": arith_method(operator.floordiv),
90 "mod": arith_method(operator.mod),
91 "pow": arith_method(operator.pow),
92 "rmul": arith_method(roperator.rmul),
93 "rsub": arith_method(roperator.rsub),
94 "rtruediv": arith_method(roperator.rtruediv),
95 "rfloordiv": arith_method(roperator.rfloordiv),
96 "rpow": arith_method(roperator.rpow),
97 "rmod": arith_method(roperator.rmod),
98 }
99 )
100 new_methods["div"] = new_methods["truediv"]
101 new_methods["rdiv"] = new_methods["rtruediv"]
102 if have_divmod:
103 # divmod doesn't have an op that is supported by numexpr
104 new_methods["divmod"] = arith_method(divmod)
105 new_methods["rdivmod"] = arith_method(roperator.rdivmod)
106
107 new_methods.update(
108 {
109 "eq": comp_method(operator.eq),
110 "ne": comp_method(operator.ne),
111 "lt": comp_method(operator.lt),
112 "gt": comp_method(operator.gt),
113 "le": comp_method(operator.le),
114 "ge": comp_method(operator.ge),
115 }
116 )
117
118 new_methods = {k.strip("_"): v for k, v in new_methods.items()}
119 return new_methods
120
121
122def _add_methods(cls, new_methods) -> None:
123 for name, method in new_methods.items():
124 setattr(cls, name, method)