1"""
2Functions for changing global ufunc configuration
3
4This provides helpers which wrap `_get_extobj_dict` and `_make_extobj`, and
5`_extobj_contextvar` from umath.
6"""
7import contextlib
8import contextvars
9import functools
10
11from .._utils import set_module
12from .umath import _make_extobj, _get_extobj_dict, _extobj_contextvar
13
14__all__ = [
15 "seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
16 "errstate"
17]
18
19
20@set_module('numpy')
21def seterr(all=None, divide=None, over=None, under=None, invalid=None):
22 """
23 Set how floating-point errors are handled.
24
25 Note that operations on integer scalar types (such as `int16`) are
26 handled like floating point, and are affected by these settings.
27
28 Parameters
29 ----------
30 all : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
31 Set treatment for all types of floating-point errors at once:
32
33 - ignore: Take no action when the exception occurs.
34 - warn: Print a :exc:`RuntimeWarning` (via the Python `warnings`
35 module).
36 - raise: Raise a :exc:`FloatingPointError`.
37 - call: Call a function specified using the `seterrcall` function.
38 - print: Print a warning directly to ``stdout``.
39 - log: Record error in a Log object specified by `seterrcall`.
40
41 The default is not to change the current behavior.
42 divide : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
43 Treatment for division by zero.
44 over : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
45 Treatment for floating-point overflow.
46 under : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
47 Treatment for floating-point underflow.
48 invalid : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
49 Treatment for invalid floating-point operation.
50
51 Returns
52 -------
53 old_settings : dict
54 Dictionary containing the old settings.
55
56 See also
57 --------
58 seterrcall : Set a callback function for the 'call' mode.
59 geterr, geterrcall, errstate
60
61 Notes
62 -----
63 The floating-point exceptions are defined in the IEEE 754 standard [1]_:
64
65 - Division by zero: infinite result obtained from finite numbers.
66 - Overflow: result too large to be expressed.
67 - Underflow: result so close to zero that some precision
68 was lost.
69 - Invalid operation: result is not an expressible number, typically
70 indicates that a NaN was produced.
71
72 .. [1] https://en.wikipedia.org/wiki/IEEE_754
73
74 Examples
75 --------
76 >>> import numpy as np
77 >>> orig_settings = np.seterr(all='ignore') # seterr to known value
78 >>> np.int16(32000) * np.int16(3)
79 np.int16(30464)
80 >>> np.seterr(over='raise')
81 {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
82 >>> old_settings = np.seterr(all='warn', over='raise')
83 >>> np.int16(32000) * np.int16(3)
84 Traceback (most recent call last):
85 File "<stdin>", line 1, in <module>
86 FloatingPointError: overflow encountered in scalar multiply
87
88 >>> old_settings = np.seterr(all='print')
89 >>> np.geterr()
90 {'divide': 'print', 'over': 'print', 'under': 'print', 'invalid': 'print'}
91 >>> np.int16(32000) * np.int16(3)
92 np.int16(30464)
93 >>> np.seterr(**orig_settings) # restore original
94 {'divide': 'print', 'over': 'print', 'under': 'print', 'invalid': 'print'}
95
96 """
97
98 old = _get_extobj_dict()
99 # The errstate doesn't include call and bufsize, so pop them:
100 old.pop("call", None)
101 old.pop("bufsize", None)
102
103 extobj = _make_extobj(
104 all=all, divide=divide, over=over, under=under, invalid=invalid)
105 _extobj_contextvar.set(extobj)
106 return old
107
108
109@set_module('numpy')
110def geterr():
111 """
112 Get the current way of handling floating-point errors.
113
114 Returns
115 -------
116 res : dict
117 A dictionary with keys "divide", "over", "under", and "invalid",
118 whose values are from the strings "ignore", "print", "log", "warn",
119 "raise", and "call". The keys represent possible floating-point
120 exceptions, and the values define how these exceptions are handled.
121
122 See Also
123 --------
124 geterrcall, seterr, seterrcall
125
126 Notes
127 -----
128 For complete documentation of the types of floating-point exceptions and
129 treatment options, see `seterr`.
130
131 Examples
132 --------
133 >>> import numpy as np
134 >>> np.geterr()
135 {'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}
136 >>> np.arange(3.) / np.arange(3.) # doctest: +SKIP
137 array([nan, 1., 1.])
138 RuntimeWarning: invalid value encountered in divide
139
140 >>> oldsettings = np.seterr(all='warn', invalid='raise')
141 >>> np.geterr()
142 {'divide': 'warn', 'over': 'warn', 'under': 'warn', 'invalid': 'raise'}
143 >>> np.arange(3.) / np.arange(3.)
144 Traceback (most recent call last):
145 ...
146 FloatingPointError: invalid value encountered in divide
147 >>> oldsettings = np.seterr(**oldsettings) # restore original
148
149 """
150 res = _get_extobj_dict()
151 # The "geterr" doesn't include call and bufsize,:
152 res.pop("call", None)
153 res.pop("bufsize", None)
154 return res
155
156
157@set_module('numpy')
158def setbufsize(size):
159 """
160 Set the size of the buffer used in ufuncs.
161
162 .. versionchanged:: 2.0
163 The scope of setting the buffer is tied to the `numpy.errstate`
164 context. Exiting a ``with errstate():`` will also restore the bufsize.
165
166 Parameters
167 ----------
168 size : int
169 Size of buffer.
170
171 Returns
172 -------
173 bufsize : int
174 Previous size of ufunc buffer in bytes.
175
176 Examples
177 --------
178 When exiting a `numpy.errstate` context manager the bufsize is restored:
179
180 >>> import numpy as np
181 >>> with np.errstate():
182 ... np.setbufsize(4096)
183 ... print(np.getbufsize())
184 ...
185 8192
186 4096
187 >>> np.getbufsize()
188 8192
189
190 """
191 old = _get_extobj_dict()["bufsize"]
192 extobj = _make_extobj(bufsize=size)
193 _extobj_contextvar.set(extobj)
194 return old
195
196
197@set_module('numpy')
198def getbufsize():
199 """
200 Return the size of the buffer used in ufuncs.
201
202 Returns
203 -------
204 getbufsize : int
205 Size of ufunc buffer in bytes.
206
207 Examples
208 --------
209 >>> import numpy as np
210 >>> np.getbufsize()
211 8192
212
213 """
214 return _get_extobj_dict()["bufsize"]
215
216
217@set_module('numpy')
218def seterrcall(func):
219 """
220 Set the floating-point error callback function or log object.
221
222 There are two ways to capture floating-point error messages. The first
223 is to set the error-handler to 'call', using `seterr`. Then, set
224 the function to call using this function.
225
226 The second is to set the error-handler to 'log', using `seterr`.
227 Floating-point errors then trigger a call to the 'write' method of
228 the provided object.
229
230 Parameters
231 ----------
232 func : callable f(err, flag) or object with write method
233 Function to call upon floating-point errors ('call'-mode) or
234 object whose 'write' method is used to log such message ('log'-mode).
235
236 The call function takes two arguments. The first is a string describing
237 the type of error (such as "divide by zero", "overflow", "underflow",
238 or "invalid value"), and the second is the status flag. The flag is a
239 byte, whose four least-significant bits indicate the type of error, one
240 of "divide", "over", "under", "invalid"::
241
242 [0 0 0 0 divide over under invalid]
243
244 In other words, ``flags = divide + 2*over + 4*under + 8*invalid``.
245
246 If an object is provided, its write method should take one argument,
247 a string.
248
249 Returns
250 -------
251 h : callable, log instance or None
252 The old error handler.
253
254 See Also
255 --------
256 seterr, geterr, geterrcall
257
258 Examples
259 --------
260 Callback upon error:
261
262 >>> def err_handler(type, flag):
263 ... print("Floating point error (%s), with flag %s" % (type, flag))
264 ...
265
266 >>> import numpy as np
267
268 >>> orig_handler = np.seterrcall(err_handler)
269 >>> orig_err = np.seterr(all='call')
270
271 >>> np.array([1, 2, 3]) / 0.0
272 Floating point error (divide by zero), with flag 1
273 array([inf, inf, inf])
274
275 >>> np.seterrcall(orig_handler)
276 <function err_handler at 0x...>
277 >>> np.seterr(**orig_err)
278 {'divide': 'call', 'over': 'call', 'under': 'call', 'invalid': 'call'}
279
280 Log error message:
281
282 >>> class Log:
283 ... def write(self, msg):
284 ... print("LOG: %s" % msg)
285 ...
286
287 >>> log = Log()
288 >>> saved_handler = np.seterrcall(log)
289 >>> save_err = np.seterr(all='log')
290
291 >>> np.array([1, 2, 3]) / 0.0
292 LOG: Warning: divide by zero encountered in divide
293 array([inf, inf, inf])
294
295 >>> np.seterrcall(orig_handler)
296 <numpy.Log object at 0x...>
297 >>> np.seterr(**orig_err)
298 {'divide': 'log', 'over': 'log', 'under': 'log', 'invalid': 'log'}
299
300 """
301 old = _get_extobj_dict()["call"]
302 extobj = _make_extobj(call=func)
303 _extobj_contextvar.set(extobj)
304 return old
305
306
307@set_module('numpy')
308def geterrcall():
309 """
310 Return the current callback function used on floating-point errors.
311
312 When the error handling for a floating-point error (one of "divide",
313 "over", "under", or "invalid") is set to 'call' or 'log', the function
314 that is called or the log instance that is written to is returned by
315 `geterrcall`. This function or log instance has been set with
316 `seterrcall`.
317
318 Returns
319 -------
320 errobj : callable, log instance or None
321 The current error handler. If no handler was set through `seterrcall`,
322 ``None`` is returned.
323
324 See Also
325 --------
326 seterrcall, seterr, geterr
327
328 Notes
329 -----
330 For complete documentation of the types of floating-point exceptions and
331 treatment options, see `seterr`.
332
333 Examples
334 --------
335 >>> import numpy as np
336 >>> np.geterrcall() # we did not yet set a handler, returns None
337
338 >>> orig_settings = np.seterr(all='call')
339 >>> def err_handler(type, flag):
340 ... print("Floating point error (%s), with flag %s" % (type, flag))
341 >>> old_handler = np.seterrcall(err_handler)
342 >>> np.array([1, 2, 3]) / 0.0
343 Floating point error (divide by zero), with flag 1
344 array([inf, inf, inf])
345
346 >>> cur_handler = np.geterrcall()
347 >>> cur_handler is err_handler
348 True
349 >>> old_settings = np.seterr(**orig_settings) # restore original
350 >>> old_handler = np.seterrcall(None) # restore original
351
352 """
353 return _get_extobj_dict()["call"]
354
355
356class _unspecified:
357 pass
358
359
360_Unspecified = _unspecified()
361
362
363@set_module('numpy')
364class errstate:
365 """
366 errstate(**kwargs)
367
368 Context manager for floating-point error handling.
369
370 Using an instance of `errstate` as a context manager allows statements in
371 that context to execute with a known error handling behavior. Upon entering
372 the context the error handling is set with `seterr` and `seterrcall`, and
373 upon exiting it is reset to what it was before.
374
375 .. versionchanged:: 1.17.0
376 `errstate` is also usable as a function decorator, saving
377 a level of indentation if an entire function is wrapped.
378
379 .. versionchanged:: 2.0
380 `errstate` is now fully thread and asyncio safe, but may not be
381 entered more than once.
382 It is not safe to decorate async functions using ``errstate``.
383
384 Parameters
385 ----------
386 kwargs : {divide, over, under, invalid}
387 Keyword arguments. The valid keywords are the possible floating-point
388 exceptions. Each keyword should have a string value that defines the
389 treatment for the particular error. Possible values are
390 {'ignore', 'warn', 'raise', 'call', 'print', 'log'}.
391
392 See Also
393 --------
394 seterr, geterr, seterrcall, geterrcall
395
396 Notes
397 -----
398 For complete documentation of the types of floating-point exceptions and
399 treatment options, see `seterr`.
400
401 Examples
402 --------
403 >>> import numpy as np
404 >>> olderr = np.seterr(all='ignore') # Set error handling to known state.
405
406 >>> np.arange(3) / 0.
407 array([nan, inf, inf])
408 >>> with np.errstate(divide='ignore'):
409 ... np.arange(3) / 0.
410 array([nan, inf, inf])
411
412 >>> np.sqrt(-1)
413 np.float64(nan)
414 >>> with np.errstate(invalid='raise'):
415 ... np.sqrt(-1)
416 Traceback (most recent call last):
417 File "<stdin>", line 2, in <module>
418 FloatingPointError: invalid value encountered in sqrt
419
420 Outside the context the error handling behavior has not changed:
421
422 >>> np.geterr()
423 {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
424 >>> olderr = np.seterr(**olderr) # restore original state
425
426 """
427 __slots__ = (
428 "_call", "_all", "_divide", "_over", "_under", "_invalid", "_token")
429
430 def __init__(self, *, call=_Unspecified,
431 all=None, divide=None, over=None, under=None, invalid=None):
432 self._token = None
433 self._call = call
434 self._all = all
435 self._divide = divide
436 self._over = over
437 self._under = under
438 self._invalid = invalid
439
440 def __enter__(self):
441 # Note that __call__ duplicates much of this logic
442 if self._token is not None:
443 raise TypeError("Cannot enter `np.errstate` twice.")
444 if self._call is _Unspecified:
445 extobj = _make_extobj(
446 all=self._all, divide=self._divide, over=self._over,
447 under=self._under, invalid=self._invalid)
448 else:
449 extobj = _make_extobj(
450 call=self._call,
451 all=self._all, divide=self._divide, over=self._over,
452 under=self._under, invalid=self._invalid)
453
454 self._token = _extobj_contextvar.set(extobj)
455
456 def __exit__(self, *exc_info):
457 _extobj_contextvar.reset(self._token)
458
459 def __call__(self, func):
460 # We need to customize `__call__` compared to `ContextDecorator`
461 # because we must store the token per-thread so cannot store it on
462 # the instance (we could create a new instance for this).
463 # This duplicates the code from `__enter__`.
464 @functools.wraps(func)
465 def inner(*args, **kwargs):
466 if self._call is _Unspecified:
467 extobj = _make_extobj(
468 all=self._all, divide=self._divide, over=self._over,
469 under=self._under, invalid=self._invalid)
470 else:
471 extobj = _make_extobj(
472 call=self._call,
473 all=self._all, divide=self._divide, over=self._over,
474 under=self._under, invalid=self._invalid)
475
476 _token = _extobj_contextvar.set(extobj)
477 try:
478 # Call the original, decorated, function:
479 return func(*args, **kwargs)
480 finally:
481 _extobj_contextvar.reset(_token)
482
483 return inner