1"""
2Module containing utilities for NDFrame.sample() and .GroupBy.sample()
3"""
4from __future__ import annotations
5
6from typing import TYPE_CHECKING
7
8import numpy as np
9
10from pandas._libs import lib
11from pandas._typing import AxisInt
12
13from pandas.core.dtypes.generic import (
14 ABCDataFrame,
15 ABCSeries,
16)
17
18if TYPE_CHECKING:
19 from pandas.core.generic import NDFrame
20
21
22def preprocess_weights(obj: NDFrame, weights, axis: AxisInt) -> np.ndarray:
23 """
24 Process and validate the `weights` argument to `NDFrame.sample` and
25 `.GroupBy.sample`.
26
27 Returns `weights` as an ndarray[np.float64], validated except for normalizing
28 weights (because that must be done groupwise in groupby sampling).
29 """
30 # If a series, align with frame
31 if isinstance(weights, ABCSeries):
32 weights = weights.reindex(obj.axes[axis])
33
34 # Strings acceptable if a dataframe and axis = 0
35 if isinstance(weights, str):
36 if isinstance(obj, ABCDataFrame):
37 if axis == 0:
38 try:
39 weights = obj[weights]
40 except KeyError as err:
41 raise KeyError(
42 "String passed to weights not a valid column"
43 ) from err
44 else:
45 raise ValueError(
46 "Strings can only be passed to "
47 "weights when sampling from rows on "
48 "a DataFrame"
49 )
50 else:
51 raise ValueError(
52 "Strings cannot be passed as weights when sampling from a Series."
53 )
54
55 if isinstance(obj, ABCSeries):
56 func = obj._constructor
57 else:
58 func = obj._constructor_sliced
59
60 weights = func(weights, dtype="float64")._values
61
62 if len(weights) != obj.shape[axis]:
63 raise ValueError("Weights and axis to be sampled must be of same length")
64
65 if lib.has_infs(weights):
66 raise ValueError("weight vector may not include `inf` values")
67
68 if (weights < 0).any():
69 raise ValueError("weight vector many not include negative values")
70
71 missing = np.isnan(weights)
72 if missing.any():
73 # Don't modify weights in place
74 weights = weights.copy()
75 weights[missing] = 0
76 return weights
77
78
79def process_sampling_size(
80 n: int | None, frac: float | None, replace: bool
81) -> int | None:
82 """
83 Process and validate the `n` and `frac` arguments to `NDFrame.sample` and
84 `.GroupBy.sample`.
85
86 Returns None if `frac` should be used (variable sampling sizes), otherwise returns
87 the constant sampling size.
88 """
89 # If no frac or n, default to n=1.
90 if n is None and frac is None:
91 n = 1
92 elif n is not None and frac is not None:
93 raise ValueError("Please enter a value for `frac` OR `n`, not both")
94 elif n is not None:
95 if n < 0:
96 raise ValueError(
97 "A negative number of rows requested. Please provide `n` >= 0."
98 )
99 if n % 1 != 0:
100 raise ValueError("Only integers accepted as `n` values")
101 else:
102 assert frac is not None # for mypy
103 if frac > 1 and not replace:
104 raise ValueError(
105 "Replace has to be set to `True` when "
106 "upsampling the population `frac` > 1."
107 )
108 if frac < 0:
109 raise ValueError(
110 "A negative number of rows requested. Please provide `frac` >= 0."
111 )
112
113 return n
114
115
116def sample(
117 obj_len: int,
118 size: int,
119 replace: bool,
120 weights: np.ndarray | None,
121 random_state: np.random.RandomState | np.random.Generator,
122) -> np.ndarray:
123 """
124 Randomly sample `size` indices in `np.arange(obj_len)`
125
126 Parameters
127 ----------
128 obj_len : int
129 The length of the indices being considered
130 size : int
131 The number of values to choose
132 replace : bool
133 Allow or disallow sampling of the same row more than once.
134 weights : np.ndarray[np.float64] or None
135 If None, equal probability weighting, otherwise weights according
136 to the vector normalized
137 random_state: np.random.RandomState or np.random.Generator
138 State used for the random sampling
139
140 Returns
141 -------
142 np.ndarray[np.intp]
143 """
144 if weights is not None:
145 weight_sum = weights.sum()
146 if weight_sum != 0:
147 weights = weights / weight_sum
148 else:
149 raise ValueError("Invalid weights: weights sum to zero")
150
151 return random_state.choice(obj_len, size=size, replace=replace, p=weights).astype(
152 np.intp, copy=False
153 )