1"""
2Functions for changing global ufunc configuration
3
4This provides helpers which wrap `umath.geterrobj` and `umath.seterrobj`
5"""
6import collections.abc
7import contextlib
8import contextvars
9
10from .overrides import set_module
11from .umath import (
12 UFUNC_BUFSIZE_DEFAULT,
13 ERR_IGNORE, ERR_WARN, ERR_RAISE, ERR_CALL, ERR_PRINT, ERR_LOG, ERR_DEFAULT,
14 SHIFT_DIVIDEBYZERO, SHIFT_OVERFLOW, SHIFT_UNDERFLOW, SHIFT_INVALID,
15)
16from . import umath
17
18__all__ = [
19 "seterr", "geterr", "setbufsize", "getbufsize", "seterrcall", "geterrcall",
20 "errstate", '_no_nep50_warning'
21]
22
23_errdict = {"ignore": ERR_IGNORE,
24 "warn": ERR_WARN,
25 "raise": ERR_RAISE,
26 "call": ERR_CALL,
27 "print": ERR_PRINT,
28 "log": ERR_LOG}
29
30_errdict_rev = {value: key for key, value in _errdict.items()}
31
32
33@set_module('numpy')
34def seterr(all=None, divide=None, over=None, under=None, invalid=None):
35 """
36 Set how floating-point errors are handled.
37
38 Note that operations on integer scalar types (such as `int16`) are
39 handled like floating point, and are affected by these settings.
40
41 Parameters
42 ----------
43 all : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
44 Set treatment for all types of floating-point errors at once:
45
46 - ignore: Take no action when the exception occurs.
47 - warn: Print a `RuntimeWarning` (via the Python `warnings` module).
48 - raise: Raise a `FloatingPointError`.
49 - call: Call a function specified using the `seterrcall` function.
50 - print: Print a warning directly to ``stdout``.
51 - log: Record error in a Log object specified by `seterrcall`.
52
53 The default is not to change the current behavior.
54 divide : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
55 Treatment for division by zero.
56 over : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
57 Treatment for floating-point overflow.
58 under : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
59 Treatment for floating-point underflow.
60 invalid : {'ignore', 'warn', 'raise', 'call', 'print', 'log'}, optional
61 Treatment for invalid floating-point operation.
62
63 Returns
64 -------
65 old_settings : dict
66 Dictionary containing the old settings.
67
68 See also
69 --------
70 seterrcall : Set a callback function for the 'call' mode.
71 geterr, geterrcall, errstate
72
73 Notes
74 -----
75 The floating-point exceptions are defined in the IEEE 754 standard [1]_:
76
77 - Division by zero: infinite result obtained from finite numbers.
78 - Overflow: result too large to be expressed.
79 - Underflow: result so close to zero that some precision
80 was lost.
81 - Invalid operation: result is not an expressible number, typically
82 indicates that a NaN was produced.
83
84 .. [1] https://en.wikipedia.org/wiki/IEEE_754
85
86 Examples
87 --------
88 >>> old_settings = np.seterr(all='ignore') #seterr to known value
89 >>> np.seterr(over='raise')
90 {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
91 >>> np.seterr(**old_settings) # reset to default
92 {'divide': 'ignore', 'over': 'raise', 'under': 'ignore', 'invalid': 'ignore'}
93
94 >>> np.int16(32000) * np.int16(3)
95 30464
96 >>> old_settings = np.seterr(all='warn', over='raise')
97 >>> np.int16(32000) * np.int16(3)
98 Traceback (most recent call last):
99 File "<stdin>", line 1, in <module>
100 FloatingPointError: overflow encountered in scalar multiply
101
102 >>> old_settings = np.seterr(all='print')
103 >>> np.geterr()
104 {'divide': 'print', 'over': 'print', 'under': 'print', 'invalid': 'print'}
105 >>> np.int16(32000) * np.int16(3)
106 30464
107
108 """
109
110 pyvals = umath.geterrobj()
111 old = geterr()
112
113 if divide is None:
114 divide = all or old['divide']
115 if over is None:
116 over = all or old['over']
117 if under is None:
118 under = all or old['under']
119 if invalid is None:
120 invalid = all or old['invalid']
121
122 maskvalue = ((_errdict[divide] << SHIFT_DIVIDEBYZERO) +
123 (_errdict[over] << SHIFT_OVERFLOW) +
124 (_errdict[under] << SHIFT_UNDERFLOW) +
125 (_errdict[invalid] << SHIFT_INVALID))
126
127 pyvals[1] = maskvalue
128 umath.seterrobj(pyvals)
129 return old
130
131
132@set_module('numpy')
133def geterr():
134 """
135 Get the current way of handling floating-point errors.
136
137 Returns
138 -------
139 res : dict
140 A dictionary with keys "divide", "over", "under", and "invalid",
141 whose values are from the strings "ignore", "print", "log", "warn",
142 "raise", and "call". The keys represent possible floating-point
143 exceptions, and the values define how these exceptions are handled.
144
145 See Also
146 --------
147 geterrcall, seterr, seterrcall
148
149 Notes
150 -----
151 For complete documentation of the types of floating-point exceptions and
152 treatment options, see `seterr`.
153
154 Examples
155 --------
156 >>> np.geterr()
157 {'divide': 'warn', 'over': 'warn', 'under': 'ignore', 'invalid': 'warn'}
158 >>> np.arange(3.) / np.arange(3.)
159 array([nan, 1., 1.])
160
161 >>> oldsettings = np.seterr(all='warn', over='raise')
162 >>> np.geterr()
163 {'divide': 'warn', 'over': 'raise', 'under': 'warn', 'invalid': 'warn'}
164 >>> np.arange(3.) / np.arange(3.)
165 array([nan, 1., 1.])
166
167 """
168 maskvalue = umath.geterrobj()[1]
169 mask = 7
170 res = {}
171 val = (maskvalue >> SHIFT_DIVIDEBYZERO) & mask
172 res['divide'] = _errdict_rev[val]
173 val = (maskvalue >> SHIFT_OVERFLOW) & mask
174 res['over'] = _errdict_rev[val]
175 val = (maskvalue >> SHIFT_UNDERFLOW) & mask
176 res['under'] = _errdict_rev[val]
177 val = (maskvalue >> SHIFT_INVALID) & mask
178 res['invalid'] = _errdict_rev[val]
179 return res
180
181
182@set_module('numpy')
183def setbufsize(size):
184 """
185 Set the size of the buffer used in ufuncs.
186
187 Parameters
188 ----------
189 size : int
190 Size of buffer.
191
192 """
193 if size > 10e6:
194 raise ValueError("Buffer size, %s, is too big." % size)
195 if size < 5:
196 raise ValueError("Buffer size, %s, is too small." % size)
197 if size % 16 != 0:
198 raise ValueError("Buffer size, %s, is not a multiple of 16." % size)
199
200 pyvals = umath.geterrobj()
201 old = getbufsize()
202 pyvals[0] = size
203 umath.seterrobj(pyvals)
204 return old
205
206
207@set_module('numpy')
208def getbufsize():
209 """
210 Return the size of the buffer used in ufuncs.
211
212 Returns
213 -------
214 getbufsize : int
215 Size of ufunc buffer in bytes.
216
217 """
218 return umath.geterrobj()[0]
219
220
221@set_module('numpy')
222def seterrcall(func):
223 """
224 Set the floating-point error callback function or log object.
225
226 There are two ways to capture floating-point error messages. The first
227 is to set the error-handler to 'call', using `seterr`. Then, set
228 the function to call using this function.
229
230 The second is to set the error-handler to 'log', using `seterr`.
231 Floating-point errors then trigger a call to the 'write' method of
232 the provided object.
233
234 Parameters
235 ----------
236 func : callable f(err, flag) or object with write method
237 Function to call upon floating-point errors ('call'-mode) or
238 object whose 'write' method is used to log such message ('log'-mode).
239
240 The call function takes two arguments. The first is a string describing
241 the type of error (such as "divide by zero", "overflow", "underflow",
242 or "invalid value"), and the second is the status flag. The flag is a
243 byte, whose four least-significant bits indicate the type of error, one
244 of "divide", "over", "under", "invalid"::
245
246 [0 0 0 0 divide over under invalid]
247
248 In other words, ``flags = divide + 2*over + 4*under + 8*invalid``.
249
250 If an object is provided, its write method should take one argument,
251 a string.
252
253 Returns
254 -------
255 h : callable, log instance or None
256 The old error handler.
257
258 See Also
259 --------
260 seterr, geterr, geterrcall
261
262 Examples
263 --------
264 Callback upon error:
265
266 >>> def err_handler(type, flag):
267 ... print("Floating point error (%s), with flag %s" % (type, flag))
268 ...
269
270 >>> saved_handler = np.seterrcall(err_handler)
271 >>> save_err = np.seterr(all='call')
272
273 >>> np.array([1, 2, 3]) / 0.0
274 Floating point error (divide by zero), with flag 1
275 array([inf, inf, inf])
276
277 >>> np.seterrcall(saved_handler)
278 <function err_handler at 0x...>
279 >>> np.seterr(**save_err)
280 {'divide': 'call', 'over': 'call', 'under': 'call', 'invalid': 'call'}
281
282 Log error message:
283
284 >>> class Log:
285 ... def write(self, msg):
286 ... print("LOG: %s" % msg)
287 ...
288
289 >>> log = Log()
290 >>> saved_handler = np.seterrcall(log)
291 >>> save_err = np.seterr(all='log')
292
293 >>> np.array([1, 2, 3]) / 0.0
294 LOG: Warning: divide by zero encountered in divide
295 array([inf, inf, inf])
296
297 >>> np.seterrcall(saved_handler)
298 <numpy.core.numeric.Log object at 0x...>
299 >>> np.seterr(**save_err)
300 {'divide': 'log', 'over': 'log', 'under': 'log', 'invalid': 'log'}
301
302 """
303 if func is not None and not isinstance(func, collections.abc.Callable):
304 if (not hasattr(func, 'write') or
305 not isinstance(func.write, collections.abc.Callable)):
306 raise ValueError("Only callable can be used as callback")
307 pyvals = umath.geterrobj()
308 old = geterrcall()
309 pyvals[2] = func
310 umath.seterrobj(pyvals)
311 return old
312
313
314@set_module('numpy')
315def geterrcall():
316 """
317 Return the current callback function used on floating-point errors.
318
319 When the error handling for a floating-point error (one of "divide",
320 "over", "under", or "invalid") is set to 'call' or 'log', the function
321 that is called or the log instance that is written to is returned by
322 `geterrcall`. This function or log instance has been set with
323 `seterrcall`.
324
325 Returns
326 -------
327 errobj : callable, log instance or None
328 The current error handler. If no handler was set through `seterrcall`,
329 ``None`` is returned.
330
331 See Also
332 --------
333 seterrcall, seterr, geterr
334
335 Notes
336 -----
337 For complete documentation of the types of floating-point exceptions and
338 treatment options, see `seterr`.
339
340 Examples
341 --------
342 >>> np.geterrcall() # we did not yet set a handler, returns None
343
344 >>> oldsettings = np.seterr(all='call')
345 >>> def err_handler(type, flag):
346 ... print("Floating point error (%s), with flag %s" % (type, flag))
347 >>> oldhandler = np.seterrcall(err_handler)
348 >>> np.array([1, 2, 3]) / 0.0
349 Floating point error (divide by zero), with flag 1
350 array([inf, inf, inf])
351
352 >>> cur_handler = np.geterrcall()
353 >>> cur_handler is err_handler
354 True
355
356 """
357 return umath.geterrobj()[2]
358
359
360class _unspecified:
361 pass
362
363
364_Unspecified = _unspecified()
365
366
367@set_module('numpy')
368class errstate(contextlib.ContextDecorator):
369 """
370 errstate(**kwargs)
371
372 Context manager for floating-point error handling.
373
374 Using an instance of `errstate` as a context manager allows statements in
375 that context to execute with a known error handling behavior. Upon entering
376 the context the error handling is set with `seterr` and `seterrcall`, and
377 upon exiting it is reset to what it was before.
378
379 .. versionchanged:: 1.17.0
380 `errstate` is also usable as a function decorator, saving
381 a level of indentation if an entire function is wrapped.
382 See :py:class:`contextlib.ContextDecorator` for more information.
383
384 Parameters
385 ----------
386 kwargs : {divide, over, under, invalid}
387 Keyword arguments. The valid keywords are the possible floating-point
388 exceptions. Each keyword should have a string value that defines the
389 treatment for the particular error. Possible values are
390 {'ignore', 'warn', 'raise', 'call', 'print', 'log'}.
391
392 See Also
393 --------
394 seterr, geterr, seterrcall, geterrcall
395
396 Notes
397 -----
398 For complete documentation of the types of floating-point exceptions and
399 treatment options, see `seterr`.
400
401 Examples
402 --------
403 >>> olderr = np.seterr(all='ignore') # Set error handling to known state.
404
405 >>> np.arange(3) / 0.
406 array([nan, inf, inf])
407 >>> with np.errstate(divide='warn'):
408 ... np.arange(3) / 0.
409 array([nan, inf, inf])
410
411 >>> np.sqrt(-1)
412 nan
413 >>> with np.errstate(invalid='raise'):
414 ... np.sqrt(-1)
415 Traceback (most recent call last):
416 File "<stdin>", line 2, in <module>
417 FloatingPointError: invalid value encountered in sqrt
418
419 Outside the context the error handling behavior has not changed:
420
421 >>> np.geterr()
422 {'divide': 'ignore', 'over': 'ignore', 'under': 'ignore', 'invalid': 'ignore'}
423
424 """
425
426 def __init__(self, *, call=_Unspecified, **kwargs):
427 self.call = call
428 self.kwargs = kwargs
429
430 def __enter__(self):
431 self.oldstate = seterr(**self.kwargs)
432 if self.call is not _Unspecified:
433 self.oldcall = seterrcall(self.call)
434
435 def __exit__(self, *exc_info):
436 seterr(**self.oldstate)
437 if self.call is not _Unspecified:
438 seterrcall(self.oldcall)
439
440
441def _setdef():
442 defval = [UFUNC_BUFSIZE_DEFAULT, ERR_DEFAULT, None]
443 umath.seterrobj(defval)
444
445
446# set the default values
447_setdef()
448
449
450NO_NEP50_WARNING = contextvars.ContextVar("_no_nep50_warning", default=False)
451
452@set_module('numpy')
453@contextlib.contextmanager
454def _no_nep50_warning():
455 """
456 Context manager to disable NEP 50 warnings. This context manager is
457 only relevant if the NEP 50 warnings are enabled globally (which is not
458 thread/context safe).
459
460 This warning context manager itself is fully safe, however.
461 """
462 token = NO_NEP50_WARNING.set(True)
463 try:
464 yield
465 finally:
466 NO_NEP50_WARNING.reset(token)