1import functools
2import sys
3import math
4import warnings
5
6import numpy.core.numeric as _nx
7from numpy.core.numeric import (
8 asarray, ScalarType, array, alltrue, cumprod, arange, ndim
9)
10from numpy.core.numerictypes import find_common_type, issubdtype
11
12import numpy.matrixlib as matrixlib
13from .function_base import diff
14from numpy.core.multiarray import ravel_multi_index, unravel_index
15from numpy.core.overrides import set_module
16from numpy.core import overrides, linspace
17from numpy.lib.stride_tricks import as_strided
18
19
20array_function_dispatch = functools.partial(
21 overrides.array_function_dispatch, module='numpy')
22
23
24__all__ = [
25 'ravel_multi_index', 'unravel_index', 'mgrid', 'ogrid', 'r_', 'c_',
26 's_', 'index_exp', 'ix_', 'ndenumerate', 'ndindex', 'fill_diagonal',
27 'diag_indices', 'diag_indices_from'
28]
29
30
31def _ix__dispatcher(*args):
32 return args
33
34
35@array_function_dispatch(_ix__dispatcher)
36def ix_(*args):
37 """
38 Construct an open mesh from multiple sequences.
39
40 This function takes N 1-D sequences and returns N outputs with N
41 dimensions each, such that the shape is 1 in all but one dimension
42 and the dimension with the non-unit shape value cycles through all
43 N dimensions.
44
45 Using `ix_` one can quickly construct index arrays that will index
46 the cross product. ``a[np.ix_([1,3],[2,5])]`` returns the array
47 ``[[a[1,2] a[1,5]], [a[3,2] a[3,5]]]``.
48
49 Parameters
50 ----------
51 args : 1-D sequences
52 Each sequence should be of integer or boolean type.
53 Boolean sequences will be interpreted as boolean masks for the
54 corresponding dimension (equivalent to passing in
55 ``np.nonzero(boolean_sequence)``).
56
57 Returns
58 -------
59 out : tuple of ndarrays
60 N arrays with N dimensions each, with N the number of input
61 sequences. Together these arrays form an open mesh.
62
63 See Also
64 --------
65 ogrid, mgrid, meshgrid
66
67 Examples
68 --------
69 >>> a = np.arange(10).reshape(2, 5)
70 >>> a
71 array([[0, 1, 2, 3, 4],
72 [5, 6, 7, 8, 9]])
73 >>> ixgrid = np.ix_([0, 1], [2, 4])
74 >>> ixgrid
75 (array([[0],
76 [1]]), array([[2, 4]]))
77 >>> ixgrid[0].shape, ixgrid[1].shape
78 ((2, 1), (1, 2))
79 >>> a[ixgrid]
80 array([[2, 4],
81 [7, 9]])
82
83 >>> ixgrid = np.ix_([True, True], [2, 4])
84 >>> a[ixgrid]
85 array([[2, 4],
86 [7, 9]])
87 >>> ixgrid = np.ix_([True, True], [False, False, True, False, True])
88 >>> a[ixgrid]
89 array([[2, 4],
90 [7, 9]])
91
92 """
93 out = []
94 nd = len(args)
95 for k, new in enumerate(args):
96 if not isinstance(new, _nx.ndarray):
97 new = asarray(new)
98 if new.size == 0:
99 # Explicitly type empty arrays to avoid float default
100 new = new.astype(_nx.intp)
101 if new.ndim != 1:
102 raise ValueError("Cross index must be 1 dimensional")
103 if issubdtype(new.dtype, _nx.bool_):
104 new, = new.nonzero()
105 new = new.reshape((1,)*k + (new.size,) + (1,)*(nd-k-1))
106 out.append(new)
107 return tuple(out)
108
109
110class nd_grid:
111 """
112 Construct a multi-dimensional "meshgrid".
113
114 ``grid = nd_grid()`` creates an instance which will return a mesh-grid
115 when indexed. The dimension and number of the output arrays are equal
116 to the number of indexing dimensions. If the step length is not a
117 complex number, then the stop is not inclusive.
118
119 However, if the step length is a **complex number** (e.g. 5j), then the
120 integer part of its magnitude is interpreted as specifying the
121 number of points to create between the start and stop values, where
122 the stop value **is inclusive**.
123
124 If instantiated with an argument of ``sparse=True``, the mesh-grid is
125 open (or not fleshed out) so that only one-dimension of each returned
126 argument is greater than 1.
127
128 Parameters
129 ----------
130 sparse : bool, optional
131 Whether the grid is sparse or not. Default is False.
132
133 Notes
134 -----
135 Two instances of `nd_grid` are made available in the NumPy namespace,
136 `mgrid` and `ogrid`, approximately defined as::
137
138 mgrid = nd_grid(sparse=False)
139 ogrid = nd_grid(sparse=True)
140
141 Users should use these pre-defined instances instead of using `nd_grid`
142 directly.
143 """
144
145 def __init__(self, sparse=False):
146 self.sparse = sparse
147
148 def __getitem__(self, key):
149 try:
150 size = []
151 # Mimic the behavior of `np.arange` and use a data type
152 # which is at least as large as `np.int_`
153 num_list = [0]
154 for k in range(len(key)):
155 step = key[k].step
156 start = key[k].start
157 stop = key[k].stop
158 if start is None:
159 start = 0
160 if step is None:
161 step = 1
162 if isinstance(step, (_nx.complexfloating, complex)):
163 step = abs(step)
164 size.append(int(step))
165 else:
166 size.append(
167 int(math.ceil((stop - start) / (step*1.0))))
168 num_list += [start, stop, step]
169 typ = _nx.result_type(*num_list)
170 if self.sparse:
171 nn = [_nx.arange(_x, dtype=_t)
172 for _x, _t in zip(size, (typ,)*len(size))]
173 else:
174 nn = _nx.indices(size, typ)
175 for k, kk in enumerate(key):
176 step = kk.step
177 start = kk.start
178 if start is None:
179 start = 0
180 if step is None:
181 step = 1
182 if isinstance(step, (_nx.complexfloating, complex)):
183 step = int(abs(step))
184 if step != 1:
185 step = (kk.stop - start) / float(step - 1)
186 nn[k] = (nn[k]*step+start)
187 if self.sparse:
188 slobj = [_nx.newaxis]*len(size)
189 for k in range(len(size)):
190 slobj[k] = slice(None, None)
191 nn[k] = nn[k][tuple(slobj)]
192 slobj[k] = _nx.newaxis
193 return nn
194 except (IndexError, TypeError):
195 step = key.step
196 stop = key.stop
197 start = key.start
198 if start is None:
199 start = 0
200 if isinstance(step, (_nx.complexfloating, complex)):
201 # Prevent the (potential) creation of integer arrays
202 step_float = abs(step)
203 step = length = int(step_float)
204 if step != 1:
205 step = (key.stop-start)/float(step-1)
206 typ = _nx.result_type(start, stop, step_float)
207 return _nx.arange(0, length, 1, dtype=typ)*step + start
208 else:
209 return _nx.arange(start, stop, step)
210
211
212class MGridClass(nd_grid):
213 """
214 `nd_grid` instance which returns a dense multi-dimensional "meshgrid".
215
216 An instance of `numpy.lib.index_tricks.nd_grid` which returns an dense
217 (or fleshed out) mesh-grid when indexed, so that each returned argument
218 has the same shape. The dimensions and number of the output arrays are
219 equal to the number of indexing dimensions. If the step length is not a
220 complex number, then the stop is not inclusive.
221
222 However, if the step length is a **complex number** (e.g. 5j), then
223 the integer part of its magnitude is interpreted as specifying the
224 number of points to create between the start and stop values, where
225 the stop value **is inclusive**.
226
227 Returns
228 -------
229 mesh-grid `ndarrays` all of the same dimensions
230
231 See Also
232 --------
233 lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
234 ogrid : like mgrid but returns open (not fleshed out) mesh grids
235 meshgrid: return coordinate matrices from coordinate vectors
236 r_ : array concatenator
237 :ref:`how-to-partition`
238
239 Examples
240 --------
241 >>> np.mgrid[0:5, 0:5]
242 array([[[0, 0, 0, 0, 0],
243 [1, 1, 1, 1, 1],
244 [2, 2, 2, 2, 2],
245 [3, 3, 3, 3, 3],
246 [4, 4, 4, 4, 4]],
247 [[0, 1, 2, 3, 4],
248 [0, 1, 2, 3, 4],
249 [0, 1, 2, 3, 4],
250 [0, 1, 2, 3, 4],
251 [0, 1, 2, 3, 4]]])
252 >>> np.mgrid[-1:1:5j]
253 array([-1. , -0.5, 0. , 0.5, 1. ])
254
255 """
256
257 def __init__(self):
258 super().__init__(sparse=False)
259
260
261mgrid = MGridClass()
262
263
264class OGridClass(nd_grid):
265 """
266 `nd_grid` instance which returns an open multi-dimensional "meshgrid".
267
268 An instance of `numpy.lib.index_tricks.nd_grid` which returns an open
269 (i.e. not fleshed out) mesh-grid when indexed, so that only one dimension
270 of each returned array is greater than 1. The dimension and number of the
271 output arrays are equal to the number of indexing dimensions. If the step
272 length is not a complex number, then the stop is not inclusive.
273
274 However, if the step length is a **complex number** (e.g. 5j), then
275 the integer part of its magnitude is interpreted as specifying the
276 number of points to create between the start and stop values, where
277 the stop value **is inclusive**.
278
279 Returns
280 -------
281 mesh-grid
282 `ndarrays` with only one dimension not equal to 1
283
284 See Also
285 --------
286 np.lib.index_tricks.nd_grid : class of `ogrid` and `mgrid` objects
287 mgrid : like `ogrid` but returns dense (or fleshed out) mesh grids
288 meshgrid: return coordinate matrices from coordinate vectors
289 r_ : array concatenator
290 :ref:`how-to-partition`
291
292 Examples
293 --------
294 >>> from numpy import ogrid
295 >>> ogrid[-1:1:5j]
296 array([-1. , -0.5, 0. , 0.5, 1. ])
297 >>> ogrid[0:5,0:5]
298 [array([[0],
299 [1],
300 [2],
301 [3],
302 [4]]), array([[0, 1, 2, 3, 4]])]
303
304 """
305
306 def __init__(self):
307 super().__init__(sparse=True)
308
309
310ogrid = OGridClass()
311
312
313class AxisConcatenator:
314 """
315 Translates slice objects to concatenation along an axis.
316
317 For detailed documentation on usage, see `r_`.
318 """
319 # allow ma.mr_ to override this
320 concatenate = staticmethod(_nx.concatenate)
321 makemat = staticmethod(matrixlib.matrix)
322
323 def __init__(self, axis=0, matrix=False, ndmin=1, trans1d=-1):
324 self.axis = axis
325 self.matrix = matrix
326 self.trans1d = trans1d
327 self.ndmin = ndmin
328
329 def __getitem__(self, key):
330 # handle matrix builder syntax
331 if isinstance(key, str):
332 frame = sys._getframe().f_back
333 mymat = matrixlib.bmat(key, frame.f_globals, frame.f_locals)
334 return mymat
335
336 if not isinstance(key, tuple):
337 key = (key,)
338
339 # copy attributes, since they can be overridden in the first argument
340 trans1d = self.trans1d
341 ndmin = self.ndmin
342 matrix = self.matrix
343 axis = self.axis
344
345 objs = []
346 scalars = []
347 arraytypes = []
348 scalartypes = []
349
350 for k, item in enumerate(key):
351 scalar = False
352 if isinstance(item, slice):
353 step = item.step
354 start = item.start
355 stop = item.stop
356 if start is None:
357 start = 0
358 if step is None:
359 step = 1
360 if isinstance(step, (_nx.complexfloating, complex)):
361 size = int(abs(step))
362 newobj = linspace(start, stop, num=size)
363 else:
364 newobj = _nx.arange(start, stop, step)
365 if ndmin > 1:
366 newobj = array(newobj, copy=False, ndmin=ndmin)
367 if trans1d != -1:
368 newobj = newobj.swapaxes(-1, trans1d)
369 elif isinstance(item, str):
370 if k != 0:
371 raise ValueError("special directives must be the "
372 "first entry.")
373 if item in ('r', 'c'):
374 matrix = True
375 col = (item == 'c')
376 continue
377 if ',' in item:
378 vec = item.split(',')
379 try:
380 axis, ndmin = [int(x) for x in vec[:2]]
381 if len(vec) == 3:
382 trans1d = int(vec[2])
383 continue
384 except Exception as e:
385 raise ValueError(
386 "unknown special directive {!r}".format(item)
387 ) from e
388 try:
389 axis = int(item)
390 continue
391 except (ValueError, TypeError) as e:
392 raise ValueError("unknown special directive") from e
393 elif type(item) in ScalarType:
394 newobj = array(item, ndmin=ndmin)
395 scalars.append(len(objs))
396 scalar = True
397 scalartypes.append(newobj.dtype)
398 else:
399 item_ndim = ndim(item)
400 newobj = array(item, copy=False, subok=True, ndmin=ndmin)
401 if trans1d != -1 and item_ndim < ndmin:
402 k2 = ndmin - item_ndim
403 k1 = trans1d
404 if k1 < 0:
405 k1 += k2 + 1
406 defaxes = list(range(ndmin))
407 axes = defaxes[:k1] + defaxes[k2:] + defaxes[k1:k2]
408 newobj = newobj.transpose(axes)
409 objs.append(newobj)
410 if not scalar and isinstance(newobj, _nx.ndarray):
411 arraytypes.append(newobj.dtype)
412
413 # Ensure that scalars won't up-cast unless warranted
414 final_dtype = find_common_type(arraytypes, scalartypes)
415 if final_dtype is not None:
416 for k in scalars:
417 objs[k] = objs[k].astype(final_dtype)
418
419 res = self.concatenate(tuple(objs), axis=axis)
420
421 if matrix:
422 oldndim = res.ndim
423 res = self.makemat(res)
424 if oldndim == 1 and col:
425 res = res.T
426 return res
427
428 def __len__(self):
429 return 0
430
431# separate classes are used here instead of just making r_ = concatentor(0),
432# etc. because otherwise we couldn't get the doc string to come out right
433# in help(r_)
434
435
436class RClass(AxisConcatenator):
437 """
438 Translates slice objects to concatenation along the first axis.
439
440 This is a simple way to build up arrays quickly. There are two use cases.
441
442 1. If the index expression contains comma separated arrays, then stack
443 them along their first axis.
444 2. If the index expression contains slice notation or scalars then create
445 a 1-D array with a range indicated by the slice notation.
446
447 If slice notation is used, the syntax ``start:stop:step`` is equivalent
448 to ``np.arange(start, stop, step)`` inside of the brackets. However, if
449 ``step`` is an imaginary number (i.e. 100j) then its integer portion is
450 interpreted as a number-of-points desired and the start and stop are
451 inclusive. In other words ``start:stop:stepj`` is interpreted as
452 ``np.linspace(start, stop, step, endpoint=1)`` inside of the brackets.
453 After expansion of slice notation, all comma separated sequences are
454 concatenated together.
455
456 Optional character strings placed as the first element of the index
457 expression can be used to change the output. The strings 'r' or 'c' result
458 in matrix output. If the result is 1-D and 'r' is specified a 1 x N (row)
459 matrix is produced. If the result is 1-D and 'c' is specified, then a N x 1
460 (column) matrix is produced. If the result is 2-D then both provide the
461 same matrix result.
462
463 A string integer specifies which axis to stack multiple comma separated
464 arrays along. A string of two comma-separated integers allows indication
465 of the minimum number of dimensions to force each entry into as the
466 second integer (the axis to concatenate along is still the first integer).
467
468 A string with three comma-separated integers allows specification of the
469 axis to concatenate along, the minimum number of dimensions to force the
470 entries to, and which axis should contain the start of the arrays which
471 are less than the specified number of dimensions. In other words the third
472 integer allows you to specify where the 1's should be placed in the shape
473 of the arrays that have their shapes upgraded. By default, they are placed
474 in the front of the shape tuple. The third argument allows you to specify
475 where the start of the array should be instead. Thus, a third argument of
476 '0' would place the 1's at the end of the array shape. Negative integers
477 specify where in the new shape tuple the last dimension of upgraded arrays
478 should be placed, so the default is '-1'.
479
480 Parameters
481 ----------
482 Not a function, so takes no parameters
483
484
485 Returns
486 -------
487 A concatenated ndarray or matrix.
488
489 See Also
490 --------
491 concatenate : Join a sequence of arrays along an existing axis.
492 c_ : Translates slice objects to concatenation along the second axis.
493
494 Examples
495 --------
496 >>> np.r_[np.array([1,2,3]), 0, 0, np.array([4,5,6])]
497 array([1, 2, 3, ..., 4, 5, 6])
498 >>> np.r_[-1:1:6j, [0]*3, 5, 6]
499 array([-1. , -0.6, -0.2, 0.2, 0.6, 1. , 0. , 0. , 0. , 5. , 6. ])
500
501 String integers specify the axis to concatenate along or the minimum
502 number of dimensions to force entries into.
503
504 >>> a = np.array([[0, 1, 2], [3, 4, 5]])
505 >>> np.r_['-1', a, a] # concatenate along last axis
506 array([[0, 1, 2, 0, 1, 2],
507 [3, 4, 5, 3, 4, 5]])
508 >>> np.r_['0,2', [1,2,3], [4,5,6]] # concatenate along first axis, dim>=2
509 array([[1, 2, 3],
510 [4, 5, 6]])
511
512 >>> np.r_['0,2,0', [1,2,3], [4,5,6]]
513 array([[1],
514 [2],
515 [3],
516 [4],
517 [5],
518 [6]])
519 >>> np.r_['1,2,0', [1,2,3], [4,5,6]]
520 array([[1, 4],
521 [2, 5],
522 [3, 6]])
523
524 Using 'r' or 'c' as a first string argument creates a matrix.
525
526 >>> np.r_['r',[1,2,3], [4,5,6]]
527 matrix([[1, 2, 3, 4, 5, 6]])
528
529 """
530
531 def __init__(self):
532 AxisConcatenator.__init__(self, 0)
533
534
535r_ = RClass()
536
537
538class CClass(AxisConcatenator):
539 """
540 Translates slice objects to concatenation along the second axis.
541
542 This is short-hand for ``np.r_['-1,2,0', index expression]``, which is
543 useful because of its common occurrence. In particular, arrays will be
544 stacked along their last axis after being upgraded to at least 2-D with
545 1's post-pended to the shape (column vectors made out of 1-D arrays).
546
547 See Also
548 --------
549 column_stack : Stack 1-D arrays as columns into a 2-D array.
550 r_ : For more detailed documentation.
551
552 Examples
553 --------
554 >>> np.c_[np.array([1,2,3]), np.array([4,5,6])]
555 array([[1, 4],
556 [2, 5],
557 [3, 6]])
558 >>> np.c_[np.array([[1,2,3]]), 0, 0, np.array([[4,5,6]])]
559 array([[1, 2, 3, ..., 4, 5, 6]])
560
561 """
562
563 def __init__(self):
564 AxisConcatenator.__init__(self, -1, ndmin=2, trans1d=0)
565
566
567c_ = CClass()
568
569
570@set_module('numpy')
571class ndenumerate:
572 """
573 Multidimensional index iterator.
574
575 Return an iterator yielding pairs of array coordinates and values.
576
577 Parameters
578 ----------
579 arr : ndarray
580 Input array.
581
582 See Also
583 --------
584 ndindex, flatiter
585
586 Examples
587 --------
588 >>> a = np.array([[1, 2], [3, 4]])
589 >>> for index, x in np.ndenumerate(a):
590 ... print(index, x)
591 (0, 0) 1
592 (0, 1) 2
593 (1, 0) 3
594 (1, 1) 4
595
596 """
597
598 def __init__(self, arr):
599 self.iter = asarray(arr).flat
600
601 def __next__(self):
602 """
603 Standard iterator method, returns the index tuple and array value.
604
605 Returns
606 -------
607 coords : tuple of ints
608 The indices of the current iteration.
609 val : scalar
610 The array element of the current iteration.
611
612 """
613 return self.iter.coords, next(self.iter)
614
615 def __iter__(self):
616 return self
617
618
619@set_module('numpy')
620class ndindex:
621 """
622 An N-dimensional iterator object to index arrays.
623
624 Given the shape of an array, an `ndindex` instance iterates over
625 the N-dimensional index of the array. At each iteration a tuple
626 of indices is returned, the last dimension is iterated over first.
627
628 Parameters
629 ----------
630 shape : ints, or a single tuple of ints
631 The size of each dimension of the array can be passed as
632 individual parameters or as the elements of a tuple.
633
634 See Also
635 --------
636 ndenumerate, flatiter
637
638 Examples
639 --------
640 Dimensions as individual arguments
641
642 >>> for index in np.ndindex(3, 2, 1):
643 ... print(index)
644 (0, 0, 0)
645 (0, 1, 0)
646 (1, 0, 0)
647 (1, 1, 0)
648 (2, 0, 0)
649 (2, 1, 0)
650
651 Same dimensions - but in a tuple ``(3, 2, 1)``
652
653 >>> for index in np.ndindex((3, 2, 1)):
654 ... print(index)
655 (0, 0, 0)
656 (0, 1, 0)
657 (1, 0, 0)
658 (1, 1, 0)
659 (2, 0, 0)
660 (2, 1, 0)
661
662 """
663
664 def __init__(self, *shape):
665 if len(shape) == 1 and isinstance(shape[0], tuple):
666 shape = shape[0]
667 x = as_strided(_nx.zeros(1), shape=shape,
668 strides=_nx.zeros_like(shape))
669 self._it = _nx.nditer(x, flags=['multi_index', 'zerosize_ok'],
670 order='C')
671
672 def __iter__(self):
673 return self
674
675 def ndincr(self):
676 """
677 Increment the multi-dimensional index by one.
678
679 This method is for backward compatibility only: do not use.
680
681 .. deprecated:: 1.20.0
682 This method has been advised against since numpy 1.8.0, but only
683 started emitting DeprecationWarning as of this version.
684 """
685 # NumPy 1.20.0, 2020-09-08
686 warnings.warn(
687 "`ndindex.ndincr()` is deprecated, use `next(ndindex)` instead",
688 DeprecationWarning, stacklevel=2)
689 next(self)
690
691 def __next__(self):
692 """
693 Standard iterator method, updates the index and returns the index
694 tuple.
695
696 Returns
697 -------
698 val : tuple of ints
699 Returns a tuple containing the indices of the current
700 iteration.
701
702 """
703 next(self._it)
704 return self._it.multi_index
705
706
707# You can do all this with slice() plus a few special objects,
708# but there's a lot to remember. This version is simpler because
709# it uses the standard array indexing syntax.
710#
711# Written by Konrad Hinsen <hinsen@cnrs-orleans.fr>
712# last revision: 1999-7-23
713#
714# Cosmetic changes by T. Oliphant 2001
715#
716#
717
718class IndexExpression:
719 """
720 A nicer way to build up index tuples for arrays.
721
722 .. note::
723 Use one of the two predefined instances `index_exp` or `s_`
724 rather than directly using `IndexExpression`.
725
726 For any index combination, including slicing and axis insertion,
727 ``a[indices]`` is the same as ``a[np.index_exp[indices]]`` for any
728 array `a`. However, ``np.index_exp[indices]`` can be used anywhere
729 in Python code and returns a tuple of slice objects that can be
730 used in the construction of complex index expressions.
731
732 Parameters
733 ----------
734 maketuple : bool
735 If True, always returns a tuple.
736
737 See Also
738 --------
739 index_exp : Predefined instance that always returns a tuple:
740 `index_exp = IndexExpression(maketuple=True)`.
741 s_ : Predefined instance without tuple conversion:
742 `s_ = IndexExpression(maketuple=False)`.
743
744 Notes
745 -----
746 You can do all this with `slice()` plus a few special objects,
747 but there's a lot to remember and this version is simpler because
748 it uses the standard array indexing syntax.
749
750 Examples
751 --------
752 >>> np.s_[2::2]
753 slice(2, None, 2)
754 >>> np.index_exp[2::2]
755 (slice(2, None, 2),)
756
757 >>> np.array([0, 1, 2, 3, 4])[np.s_[2::2]]
758 array([2, 4])
759
760 """
761
762 def __init__(self, maketuple):
763 self.maketuple = maketuple
764
765 def __getitem__(self, item):
766 if self.maketuple and not isinstance(item, tuple):
767 return (item,)
768 else:
769 return item
770
771
772index_exp = IndexExpression(maketuple=True)
773s_ = IndexExpression(maketuple=False)
774
775# End contribution from Konrad.
776
777
778# The following functions complement those in twodim_base, but are
779# applicable to N-dimensions.
780
781
782def _fill_diagonal_dispatcher(a, val, wrap=None):
783 return (a,)
784
785
786@array_function_dispatch(_fill_diagonal_dispatcher)
787def fill_diagonal(a, val, wrap=False):
788 """Fill the main diagonal of the given array of any dimensionality.
789
790 For an array `a` with ``a.ndim >= 2``, the diagonal is the list of
791 locations with indices ``a[i, ..., i]`` all identical. This function
792 modifies the input array in-place, it does not return a value.
793
794 Parameters
795 ----------
796 a : array, at least 2-D.
797 Array whose diagonal is to be filled, it gets modified in-place.
798
799 val : scalar or array_like
800 Value(s) to write on the diagonal. If `val` is scalar, the value is
801 written along the diagonal. If array-like, the flattened `val` is
802 written along the diagonal, repeating if necessary to fill all
803 diagonal entries.
804
805 wrap : bool
806 For tall matrices in NumPy version up to 1.6.2, the
807 diagonal "wrapped" after N columns. You can have this behavior
808 with this option. This affects only tall matrices.
809
810 See also
811 --------
812 diag_indices, diag_indices_from
813
814 Notes
815 -----
816 .. versionadded:: 1.4.0
817
818 This functionality can be obtained via `diag_indices`, but internally
819 this version uses a much faster implementation that never constructs the
820 indices and uses simple slicing.
821
822 Examples
823 --------
824 >>> a = np.zeros((3, 3), int)
825 >>> np.fill_diagonal(a, 5)
826 >>> a
827 array([[5, 0, 0],
828 [0, 5, 0],
829 [0, 0, 5]])
830
831 The same function can operate on a 4-D array:
832
833 >>> a = np.zeros((3, 3, 3, 3), int)
834 >>> np.fill_diagonal(a, 4)
835
836 We only show a few blocks for clarity:
837
838 >>> a[0, 0]
839 array([[4, 0, 0],
840 [0, 0, 0],
841 [0, 0, 0]])
842 >>> a[1, 1]
843 array([[0, 0, 0],
844 [0, 4, 0],
845 [0, 0, 0]])
846 >>> a[2, 2]
847 array([[0, 0, 0],
848 [0, 0, 0],
849 [0, 0, 4]])
850
851 The wrap option affects only tall matrices:
852
853 >>> # tall matrices no wrap
854 >>> a = np.zeros((5, 3), int)
855 >>> np.fill_diagonal(a, 4)
856 >>> a
857 array([[4, 0, 0],
858 [0, 4, 0],
859 [0, 0, 4],
860 [0, 0, 0],
861 [0, 0, 0]])
862
863 >>> # tall matrices wrap
864 >>> a = np.zeros((5, 3), int)
865 >>> np.fill_diagonal(a, 4, wrap=True)
866 >>> a
867 array([[4, 0, 0],
868 [0, 4, 0],
869 [0, 0, 4],
870 [0, 0, 0],
871 [4, 0, 0]])
872
873 >>> # wide matrices
874 >>> a = np.zeros((3, 5), int)
875 >>> np.fill_diagonal(a, 4, wrap=True)
876 >>> a
877 array([[4, 0, 0, 0, 0],
878 [0, 4, 0, 0, 0],
879 [0, 0, 4, 0, 0]])
880
881 The anti-diagonal can be filled by reversing the order of elements
882 using either `numpy.flipud` or `numpy.fliplr`.
883
884 >>> a = np.zeros((3, 3), int);
885 >>> np.fill_diagonal(np.fliplr(a), [1,2,3]) # Horizontal flip
886 >>> a
887 array([[0, 0, 1],
888 [0, 2, 0],
889 [3, 0, 0]])
890 >>> np.fill_diagonal(np.flipud(a), [1,2,3]) # Vertical flip
891 >>> a
892 array([[0, 0, 3],
893 [0, 2, 0],
894 [1, 0, 0]])
895
896 Note that the order in which the diagonal is filled varies depending
897 on the flip function.
898 """
899 if a.ndim < 2:
900 raise ValueError("array must be at least 2-d")
901 end = None
902 if a.ndim == 2:
903 # Explicit, fast formula for the common case. For 2-d arrays, we
904 # accept rectangular ones.
905 step = a.shape[1] + 1
906 # This is needed to don't have tall matrix have the diagonal wrap.
907 if not wrap:
908 end = a.shape[1] * a.shape[1]
909 else:
910 # For more than d=2, the strided formula is only valid for arrays with
911 # all dimensions equal, so we check first.
912 if not alltrue(diff(a.shape) == 0):
913 raise ValueError("All dimensions of input must be of equal length")
914 step = 1 + (cumprod(a.shape[:-1])).sum()
915
916 # Write the value out into the diagonal.
917 a.flat[:end:step] = val
918
919
920@set_module('numpy')
921def diag_indices(n, ndim=2):
922 """
923 Return the indices to access the main diagonal of an array.
924
925 This returns a tuple of indices that can be used to access the main
926 diagonal of an array `a` with ``a.ndim >= 2`` dimensions and shape
927 (n, n, ..., n). For ``a.ndim = 2`` this is the usual diagonal, for
928 ``a.ndim > 2`` this is the set of indices to access ``a[i, i, ..., i]``
929 for ``i = [0..n-1]``.
930
931 Parameters
932 ----------
933 n : int
934 The size, along each dimension, of the arrays for which the returned
935 indices can be used.
936
937 ndim : int, optional
938 The number of dimensions.
939
940 See Also
941 --------
942 diag_indices_from
943
944 Notes
945 -----
946 .. versionadded:: 1.4.0
947
948 Examples
949 --------
950 Create a set of indices to access the diagonal of a (4, 4) array:
951
952 >>> di = np.diag_indices(4)
953 >>> di
954 (array([0, 1, 2, 3]), array([0, 1, 2, 3]))
955 >>> a = np.arange(16).reshape(4, 4)
956 >>> a
957 array([[ 0, 1, 2, 3],
958 [ 4, 5, 6, 7],
959 [ 8, 9, 10, 11],
960 [12, 13, 14, 15]])
961 >>> a[di] = 100
962 >>> a
963 array([[100, 1, 2, 3],
964 [ 4, 100, 6, 7],
965 [ 8, 9, 100, 11],
966 [ 12, 13, 14, 100]])
967
968 Now, we create indices to manipulate a 3-D array:
969
970 >>> d3 = np.diag_indices(2, 3)
971 >>> d3
972 (array([0, 1]), array([0, 1]), array([0, 1]))
973
974 And use it to set the diagonal of an array of zeros to 1:
975
976 >>> a = np.zeros((2, 2, 2), dtype=int)
977 >>> a[d3] = 1
978 >>> a
979 array([[[1, 0],
980 [0, 0]],
981 [[0, 0],
982 [0, 1]]])
983
984 """
985 idx = arange(n)
986 return (idx,) * ndim
987
988
989def _diag_indices_from(arr):
990 return (arr,)
991
992
993@array_function_dispatch(_diag_indices_from)
994def diag_indices_from(arr):
995 """
996 Return the indices to access the main diagonal of an n-dimensional array.
997
998 See `diag_indices` for full details.
999
1000 Parameters
1001 ----------
1002 arr : array, at least 2-D
1003
1004 See Also
1005 --------
1006 diag_indices
1007
1008 Notes
1009 -----
1010 .. versionadded:: 1.4.0
1011
1012 """
1013
1014 if not arr.ndim >= 2:
1015 raise ValueError("input array must be at least 2-d")
1016 # For more than d=2, the strided formula is only valid for arrays with
1017 # all dimensions equal, so we check first.
1018 if not alltrue(diff(arr.shape) == 0):
1019 raise ValueError("All dimensions of input must be of equal length")
1020
1021 return diag_indices(arr.shape[0], arr.ndim)