1"""
2Module containing utilities for NDFrame.sample() and .GroupBy.sample()
3"""
4from __future__ import annotations
5
6from typing import TYPE_CHECKING
7
8import numpy as np
9
10from pandas._libs import lib
11
12from pandas.core.dtypes.generic import (
13 ABCDataFrame,
14 ABCSeries,
15)
16
17if TYPE_CHECKING:
18 from pandas._typing import AxisInt
19
20 from pandas.core.generic import NDFrame
21
22
23def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray:
24 """
25 Process and validate the `weights` argument to `NDFrame.sample` and
26 `.GroupBy.sample`.
27
28 Returns `weights` as an ndarray[np.float64], validated except for normalizing
29 weights (because that must be done groupwise in groupby sampling).
30 """
31 # If a series, align with frame
32 if isinstance(weights, ABCSeries):
33 weights = weights.reindex(obj.axes[axis])
34
35 # Strings acceptable if a dataframe and axis = 0
36 if isinstance(weights, str):
37 if isinstance(obj, ABCDataFrame):
38 if axis == 0:
39 try:
40 weights = obj[weights]
41 except KeyError as err:
42 raise KeyError(
43 "String passed to weights not a valid column"
44 ) from err
45 else:
46 raise ValueError(
47 "Strings can only be passed to "
48 "weights when sampling from rows on "
49 "a DataFrame"
50 )
51 else:
52 raise ValueError(
53 "Strings cannot be passed as weights when sampling from a Series."
54 )
55
56 if isinstance(obj, ABCSeries):
57 func = obj._constructor
58 else:
59 func = obj._constructor_sliced
60
61 weights = func(weights, dtype="float64")._values
62
63 if len(weights) != obj.shape[axis]:
64 raise ValueError("Weights and axis to be sampled must be of same length")
65
66 if lib.has_infs(weights):
67 raise ValueError("weight vector may not include `inf` values")
68
69 if (weights < 0).any():
70 raise ValueError("weight vector many not include negative values")
71
72 missing = np.isnan(weights)
73 if missing.any():
74 # Don't modify weights in place
75 weights = weights.copy()
76 weights[missing] = 0
77 return weights
78
79
80def process_sampling_size(
81 n: int | None, frac: float | None, replace: bool
82) -> int | None:
83 """
84 Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
85 `.GroupBy.sample`.
86
87 Returns None if `frac` should be used (variable sampling sizes), otherwise returns
88 the constant sampling size.
89 """
90 # If no frac or n, default to n=1.
91 if n is None and frac is None:
92 n = 1
93 elif n is not None and frac is not None:
94 raise ValueError("Please enter a value for `frac` OR `n`, not both")
95 elif n is not None:
96 if n < 0:
97 raise ValueError(
98 "A negative number of rows requested. Please provide `n` >= 0."
99 )
100 if n % 1 != 0:
101 raise ValueError("Only integers accepted as `n` values")
102 else:
103 assert frac is not None # for mypy
104 if frac > 1 and not replace:
105 raise ValueError(
106 "Replace has to be set to `True` when "
107 "upsampling the population `frac` > 1."
108 )
109 if frac < 0:
110 raise ValueError(
111 "A negative number of rows requested. Please provide `frac` >= 0."
112 )
113
114 return n
115
116
117def sample(
118 obj_len: int,
119 size: int,
120 replace: bool,
121 weights: np.ndarray | None,
122 random_state: np.random.RandomState | np.random.Generator,
123) -> np.ndarray:
124 """
125 Randomly sample `size` indices in `np.arange(obj_len)`
126
127 Parameters
128 ----------
129 obj_len : int
130 The length of the indices being considered
131 size : int
132 The number of values to choose
133 replace : bool
134 Allow or disallow sampling of the same row more than once.
135 weights : np.ndarray[np.float64] or None
136 If None, equal probability weighting, otherwise weights according
137 to the vector normalized
138 random_state: np.random.RandomState or np.random.Generator
139 State used for the random sampling
140
141 Returns
142 -------
143 np.ndarray[np.intp]
144 """
145 if weights is not None:
146 weight_sum = weights.sum()
147 if weight_sum != 0:
148 weights = weights / weight_sum
149 else:
150 raise ValueError("Invalid weights: weights sum to zero")
151
152 return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
153 np.intp, copy=False
154 )