1"""Global configuration state and functions for management
2"""
3import os
4import threading
5from contextlib import contextmanager as contextmanager
6
7_global_config = {
8 "assume_finite": bool(os.environ.get("SKLEARN_ASSUME_FINITE", False)),
9 "working_memory": int(os.environ.get("SKLEARN_WORKING_MEMORY", 1024)),
10 "print_changed_only": True,
11 "display": "diagram",
12 "pairwise_dist_chunk_size": int(
13 os.environ.get("SKLEARN_PAIRWISE_DIST_CHUNK_SIZE", 256)
14 ),
15 "enable_cython_pairwise_dist": True,
16 "array_api_dispatch": False,
17 "transform_output": "default",
18 "enable_metadata_routing": False,
19 "skip_parameter_validation": False,
20}
21_threadlocal = threading.local()
22
23
24def _get_threadlocal_config():
25 """Get a threadlocal **mutable** configuration. If the configuration
26 does not exist, copy the default global configuration."""
27 if not hasattr(_threadlocal, "global_config"):
28 _threadlocal.global_config = _global_config.copy()
29 return _threadlocal.global_config
30
31
32def get_config():
33 """Retrieve current values for configuration set by :func:`set_config`.
34
35 Returns
36 -------
37 config : dict
38 Keys are parameter names that can be passed to :func:`set_config`.
39
40 See Also
41 --------
42 config_context : Context manager for global scikit-learn configuration.
43 set_config : Set global scikit-learn configuration.
44 """
45 # Return a copy of the threadlocal configuration so that users will
46 # not be able to modify the configuration with the returned dict.
47 return _get_threadlocal_config().copy()
48
49
50def set_config(
51 assume_finite=None,
52 working_memory=None,
53 print_changed_only=None,
54 display=None,
55 pairwise_dist_chunk_size=None,
56 enable_cython_pairwise_dist=None,
57 array_api_dispatch=None,
58 transform_output=None,
59 enable_metadata_routing=None,
60 skip_parameter_validation=None,
61):
62 """Set global scikit-learn configuration.
63
64 .. versionadded:: 0.19
65
66 Parameters
67 ----------
68 assume_finite : bool, default=None
69 If True, validation for finiteness will be skipped,
70 saving time, but leading to potential crashes. If
71 False, validation for finiteness will be performed,
72 avoiding error. Global default: False.
73
74 .. versionadded:: 0.19
75
76 working_memory : int, default=None
77 If set, scikit-learn will attempt to limit the size of temporary arrays
78 to this number of MiB (per job when parallelised), often saving both
79 computation time and memory on expensive operations that can be
80 performed in chunks. Global default: 1024.
81
82 .. versionadded:: 0.20
83
84 print_changed_only : bool, default=None
85 If True, only the parameters that were set to non-default
86 values will be printed when printing an estimator. For example,
87 ``print(SVC())`` while True will only print 'SVC()' while the default
88 behaviour would be to print 'SVC(C=1.0, cache_size=200, ...)' with
89 all the non-changed parameters.
90
91 .. versionadded:: 0.21
92
93 display : {'text', 'diagram'}, default=None
94 If 'diagram', estimators will be displayed as a diagram in a Jupyter
95 lab or notebook context. If 'text', estimators will be displayed as
96 text. Default is 'diagram'.
97
98 .. versionadded:: 0.23
99
100 pairwise_dist_chunk_size : int, default=None
101 The number of row vectors per chunk for the accelerated pairwise-
102 distances reduction backend. Default is 256 (suitable for most of
103 modern laptops' caches and architectures).
104
105 Intended for easier benchmarking and testing of scikit-learn internals.
106 End users are not expected to benefit from customizing this configuration
107 setting.
108
109 .. versionadded:: 1.1
110
111 enable_cython_pairwise_dist : bool, default=None
112 Use the accelerated pairwise-distances reduction backend when
113 possible. Global default: True.
114
115 Intended for easier benchmarking and testing of scikit-learn internals.
116 End users are not expected to benefit from customizing this configuration
117 setting.
118
119 .. versionadded:: 1.1
120
121 array_api_dispatch : bool, default=None
122 Use Array API dispatching when inputs follow the Array API standard.
123 Default is False.
124
125 See the :ref:`User Guide <array_api>` for more details.
126
127 .. versionadded:: 1.2
128
129 transform_output : str, default=None
130 Configure output of `transform` and `fit_transform`.
131
132 See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
133 for an example on how to use the API.
134
135 - `"default"`: Default output format of a transformer
136 - `"pandas"`: DataFrame output
137 - `"polars"`: Polars output
138 - `None`: Transform configuration is unchanged
139
140 .. versionadded:: 1.2
141 .. versionadded:: 1.4
142 `"polars"` option was added.
143
144 enable_metadata_routing : bool, default=None
145 Enable metadata routing. By default this feature is disabled.
146
147 Refer to :ref:`metadata routing user guide <metadata_routing>` for more
148 details.
149
150 - `True`: Metadata routing is enabled
151 - `False`: Metadata routing is disabled, use the old syntax.
152 - `None`: Configuration is unchanged
153
154 .. versionadded:: 1.3
155
156 skip_parameter_validation : bool, default=None
157 If `True`, disable the validation of the hyper-parameters' types and values in
158 the fit method of estimators and for arguments passed to public helper
159 functions. It can save time in some situations but can lead to low level
160 crashes and exceptions with confusing error messages.
161
162 Note that for data parameters, such as `X` and `y`, only type validation is
163 skipped but validation with `check_array` will continue to run.
164
165 .. versionadded:: 1.3
166
167 See Also
168 --------
169 config_context : Context manager for global scikit-learn configuration.
170 get_config : Retrieve current values of the global configuration.
171 """
172 local_config = _get_threadlocal_config()
173
174 if assume_finite is not None:
175 local_config["assume_finite"] = assume_finite
176 if working_memory is not None:
177 local_config["working_memory"] = working_memory
178 if print_changed_only is not None:
179 local_config["print_changed_only"] = print_changed_only
180 if display is not None:
181 local_config["display"] = display
182 if pairwise_dist_chunk_size is not None:
183 local_config["pairwise_dist_chunk_size"] = pairwise_dist_chunk_size
184 if enable_cython_pairwise_dist is not None:
185 local_config["enable_cython_pairwise_dist"] = enable_cython_pairwise_dist
186 if array_api_dispatch is not None:
187 from .utils._array_api import _check_array_api_dispatch
188
189 _check_array_api_dispatch(array_api_dispatch)
190 local_config["array_api_dispatch"] = array_api_dispatch
191 if transform_output is not None:
192 local_config["transform_output"] = transform_output
193 if enable_metadata_routing is not None:
194 local_config["enable_metadata_routing"] = enable_metadata_routing
195 if skip_parameter_validation is not None:
196 local_config["skip_parameter_validation"] = skip_parameter_validation
197
198
199@contextmanager
200def config_context(
201 *,
202 assume_finite=None,
203 working_memory=None,
204 print_changed_only=None,
205 display=None,
206 pairwise_dist_chunk_size=None,
207 enable_cython_pairwise_dist=None,
208 array_api_dispatch=None,
209 transform_output=None,
210 enable_metadata_routing=None,
211 skip_parameter_validation=None,
212):
213 """Context manager for global scikit-learn configuration.
214
215 Parameters
216 ----------
217 assume_finite : bool, default=None
218 If True, validation for finiteness will be skipped,
219 saving time, but leading to potential crashes. If
220 False, validation for finiteness will be performed,
221 avoiding error. If None, the existing value won't change.
222 The default value is False.
223
224 working_memory : int, default=None
225 If set, scikit-learn will attempt to limit the size of temporary arrays
226 to this number of MiB (per job when parallelised), often saving both
227 computation time and memory on expensive operations that can be
228 performed in chunks. If None, the existing value won't change.
229 The default value is 1024.
230
231 print_changed_only : bool, default=None
232 If True, only the parameters that were set to non-default
233 values will be printed when printing an estimator. For example,
234 ``print(SVC())`` while True will only print 'SVC()', but would print
235 'SVC(C=1.0, cache_size=200, ...)' with all the non-changed parameters
236 when False. If None, the existing value won't change.
237 The default value is True.
238
239 .. versionchanged:: 0.23
240 Default changed from False to True.
241
242 display : {'text', 'diagram'}, default=None
243 If 'diagram', estimators will be displayed as a diagram in a Jupyter
244 lab or notebook context. If 'text', estimators will be displayed as
245 text. If None, the existing value won't change.
246 The default value is 'diagram'.
247
248 .. versionadded:: 0.23
249
250 pairwise_dist_chunk_size : int, default=None
251 The number of row vectors per chunk for the accelerated pairwise-
252 distances reduction backend. Default is 256 (suitable for most of
253 modern laptops' caches and architectures).
254
255 Intended for easier benchmarking and testing of scikit-learn internals.
256 End users are not expected to benefit from customizing this configuration
257 setting.
258
259 .. versionadded:: 1.1
260
261 enable_cython_pairwise_dist : bool, default=None
262 Use the accelerated pairwise-distances reduction backend when
263 possible. Global default: True.
264
265 Intended for easier benchmarking and testing of scikit-learn internals.
266 End users are not expected to benefit from customizing this configuration
267 setting.
268
269 .. versionadded:: 1.1
270
271 array_api_dispatch : bool, default=None
272 Use Array API dispatching when inputs follow the Array API standard.
273 Default is False.
274
275 See the :ref:`User Guide <array_api>` for more details.
276
277 .. versionadded:: 1.2
278
279 transform_output : str, default=None
280 Configure output of `transform` and `fit_transform`.
281
282 See :ref:`sphx_glr_auto_examples_miscellaneous_plot_set_output.py`
283 for an example on how to use the API.
284
285 - `"default"`: Default output format of a transformer
286 - `"pandas"`: DataFrame output
287 - `"polars"`: Polars output
288 - `None`: Transform configuration is unchanged
289
290 .. versionadded:: 1.2
291 .. versionadded:: 1.4
292 `"polars"` option was added.
293
294 enable_metadata_routing : bool, default=None
295 Enable metadata routing. By default this feature is disabled.
296
297 Refer to :ref:`metadata routing user guide <metadata_routing>` for more
298 details.
299
300 - `True`: Metadata routing is enabled
301 - `False`: Metadata routing is disabled, use the old syntax.
302 - `None`: Configuration is unchanged
303
304 .. versionadded:: 1.3
305
306 skip_parameter_validation : bool, default=None
307 If `True`, disable the validation of the hyper-parameters' types and values in
308 the fit method of estimators and for arguments passed to public helper
309 functions. It can save time in some situations but can lead to low level
310 crashes and exceptions with confusing error messages.
311
312 Note that for data parameters, such as `X` and `y`, only type validation is
313 skipped but validation with `check_array` will continue to run.
314
315 .. versionadded:: 1.3
316
317 Yields
318 ------
319 None.
320
321 See Also
322 --------
323 set_config : Set global scikit-learn configuration.
324 get_config : Retrieve current values of the global configuration.
325
326 Notes
327 -----
328 All settings, not just those presently modified, will be returned to
329 their previous values when the context manager is exited.
330
331 Examples
332 --------
333 >>> import sklearn
334 >>> from sklearn.utils.validation import assert_all_finite
335 >>> with sklearn.config_context(assume_finite=True):
336 ... assert_all_finite([float('nan')])
337 >>> with sklearn.config_context(assume_finite=True):
338 ... with sklearn.config_context(assume_finite=False):
339 ... assert_all_finite([float('nan')])
340 Traceback (most recent call last):
341 ...
342 ValueError: Input contains NaN...
343 """
344 old_config = get_config()
345 set_config(
346 assume_finite=assume_finite,
347 working_memory=working_memory,
348 print_changed_only=print_changed_only,
349 display=display,
350 pairwise_dist_chunk_size=pairwise_dist_chunk_size,
351 enable_cython_pairwise_dist=enable_cython_pairwise_dist,
352 array_api_dispatch=array_api_dispatch,
353 transform_output=transform_output,
354 enable_metadata_routing=enable_metadata_routing,
355 skip_parameter_validation=skip_parameter_validation,
356 )
357
358 try:
359 yield
360 finally:
361 set_config(**old_config)