Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/sparse/linalg/_expm_multiply.py: 12%
281 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
1"""Compute the action of the matrix exponential."""
2from warnings import warn
4import numpy as np
6import scipy.linalg
7import scipy.sparse.linalg
8from scipy.linalg._decomp_qr import qr
9from scipy.sparse._sputils import is_pydata_spmatrix
10from scipy.sparse.linalg import aslinearoperator
11from scipy.sparse.linalg._interface import IdentityOperator
12from scipy.sparse.linalg._onenormest import onenormest
14__all__ = ['expm_multiply']
17def _exact_inf_norm(A):
18 # A compatibility function which should eventually disappear.
19 if scipy.sparse.isspmatrix(A):
20 return max(abs(A).sum(axis=1).flat)
21 elif is_pydata_spmatrix(A):
22 return max(abs(A).sum(axis=1))
23 else:
24 return np.linalg.norm(A, np.inf)
27def _exact_1_norm(A):
28 # A compatibility function which should eventually disappear.
29 if scipy.sparse.isspmatrix(A):
30 return max(abs(A).sum(axis=0).flat)
31 elif is_pydata_spmatrix(A):
32 return max(abs(A).sum(axis=0))
33 else:
34 return np.linalg.norm(A, 1)
37def _trace(A):
38 # A compatibility function which should eventually disappear.
39 if is_pydata_spmatrix(A):
40 return A.to_scipy_sparse().trace()
41 else:
42 return A.trace()
45def traceest(A, m3, seed=None):
46 """Estimate `np.trace(A)` using `3*m3` matrix-vector products.
48 The result is not deterministic.
50 Parameters
51 ----------
52 A : LinearOperator
53 Linear operator whose trace will be estimated. Has to be square.
54 m3 : int
55 Number of matrix-vector products divided by 3 used to estimate the
56 trace.
57 seed : optional
58 Seed for `numpy.random.default_rng`.
59 Can be provided to obtain deterministic results.
61 Returns
62 -------
63 trace : LinearOperator.dtype
64 Estimate of the trace
66 Notes
67 -----
68 This is the Hutch++ algorithm given in [1]_.
70 References
71 ----------
72 .. [1] Meyer, Raphael A., Cameron Musco, Christopher Musco, and David P.
73 Woodruff. "Hutch++: Optimal Stochastic Trace Estimation." In Symposium
74 on Simplicity in Algorithms (SOSA), pp. 142-155. Society for Industrial
75 and Applied Mathematics, 2021
76 https://doi.org/10.1137/1.9781611976496.16
78 """
79 rng = np.random.default_rng(seed)
80 if len(A.shape) != 2 or A.shape[-1] != A.shape[-2]:
81 raise ValueError("Expected A to be like a square matrix.")
82 n = A.shape[-1]
83 S = rng.choice([-1.0, +1.0], [n, m3])
84 Q, _ = qr(A.matmat(S), overwrite_a=True, mode='economic')
85 trQAQ = np.trace(Q.conj().T @ A.matmat(Q))
86 G = rng.choice([-1, +1], [n, m3])
87 right = G - Q@(Q.conj().T @ G)
88 trGAG = np.trace(right.conj().T @ A.matmat(right))
89 return trQAQ + trGAG/m3
92def _ident_like(A):
93 # A compatibility function which should eventually disappear.
94 if scipy.sparse.isspmatrix(A):
95 return scipy.sparse._construct.eye(A.shape[0], A.shape[1],
96 dtype=A.dtype, format=A.format)
97 elif is_pydata_spmatrix(A):
98 import sparse
99 return sparse.eye(A.shape[0], A.shape[1], dtype=A.dtype)
100 elif isinstance(A, scipy.sparse.linalg.LinearOperator):
101 return IdentityOperator(A.shape, dtype=A.dtype)
102 else:
103 return np.eye(A.shape[0], A.shape[1], dtype=A.dtype)
106def expm_multiply(A, B, start=None, stop=None, num=None,
107 endpoint=None, traceA=None):
108 """
109 Compute the action of the matrix exponential of A on B.
111 Parameters
112 ----------
113 A : transposable linear operator
114 The operator whose exponential is of interest.
115 B : ndarray
116 The matrix or vector to be multiplied by the matrix exponential of A.
117 start : scalar, optional
118 The starting time point of the sequence.
119 stop : scalar, optional
120 The end time point of the sequence, unless `endpoint` is set to False.
121 In that case, the sequence consists of all but the last of ``num + 1``
122 evenly spaced time points, so that `stop` is excluded.
123 Note that the step size changes when `endpoint` is False.
124 num : int, optional
125 Number of time points to use.
126 endpoint : bool, optional
127 If True, `stop` is the last time point. Otherwise, it is not included.
128 traceA : scalar, optional
129 Trace of `A`. If not given the trace is estimated for linear operators,
130 or calculated exactly for sparse matrices. It is used to precondition
131 `A`, thus an approximate trace is acceptable.
132 For linear operators, `traceA` should be provided to ensure performance
133 as the estimation is not guaranteed to be reliable for all cases.
135 .. versionadded: 1.9.0
137 Returns
138 -------
139 expm_A_B : ndarray
140 The result of the action :math:`e^{t_k A} B`.
142 Warns
143 -----
144 UserWarning
145 If `A` is a linear operator and ``traceA=None`` (default).
147 Notes
148 -----
149 The optional arguments defining the sequence of evenly spaced time points
150 are compatible with the arguments of `numpy.linspace`.
152 The output ndarray shape is somewhat complicated so I explain it here.
153 The ndim of the output could be either 1, 2, or 3.
154 It would be 1 if you are computing the expm action on a single vector
155 at a single time point.
156 It would be 2 if you are computing the expm action on a vector
157 at multiple time points, or if you are computing the expm action
158 on a matrix at a single time point.
159 It would be 3 if you want the action on a matrix with multiple
160 columns at multiple time points.
161 If multiple time points are requested, expm_A_B[0] will always
162 be the action of the expm at the first time point,
163 regardless of whether the action is on a vector or a matrix.
165 References
166 ----------
167 .. [1] Awad H. Al-Mohy and Nicholas J. Higham (2011)
168 "Computing the Action of the Matrix Exponential,
169 with an Application to Exponential Integrators."
170 SIAM Journal on Scientific Computing,
171 33 (2). pp. 488-511. ISSN 1064-8275
172 http://eprints.ma.man.ac.uk/1591/
174 .. [2] Nicholas J. Higham and Awad H. Al-Mohy (2010)
175 "Computing Matrix Functions."
176 Acta Numerica,
177 19. 159-208. ISSN 0962-4929
178 http://eprints.ma.man.ac.uk/1451/
180 Examples
181 --------
182 >>> import numpy as np
183 >>> from scipy.sparse import csc_matrix
184 >>> from scipy.sparse.linalg import expm, expm_multiply
185 >>> A = csc_matrix([[1, 0], [0, 1]])
186 >>> A.toarray()
187 array([[1, 0],
188 [0, 1]], dtype=int64)
189 >>> B = np.array([np.exp(-1.), np.exp(-2.)])
190 >>> B
191 array([ 0.36787944, 0.13533528])
192 >>> expm_multiply(A, B, start=1, stop=2, num=3, endpoint=True)
193 array([[ 1. , 0.36787944],
194 [ 1.64872127, 0.60653066],
195 [ 2.71828183, 1. ]])
196 >>> expm(A).dot(B) # Verify 1st timestep
197 array([ 1. , 0.36787944])
198 >>> expm(1.5*A).dot(B) # Verify 2nd timestep
199 array([ 1.64872127, 0.60653066])
200 >>> expm(2*A).dot(B) # Verify 3rd timestep
201 array([ 2.71828183, 1. ])
202 """
203 if all(arg is None for arg in (start, stop, num, endpoint)):
204 X = _expm_multiply_simple(A, B, traceA=traceA)
205 else:
206 X, status = _expm_multiply_interval(A, B, start, stop, num,
207 endpoint, traceA=traceA)
208 return X
211def _expm_multiply_simple(A, B, t=1.0, traceA=None, balance=False):
212 """
213 Compute the action of the matrix exponential at a single time point.
215 Parameters
216 ----------
217 A : transposable linear operator
218 The operator whose exponential is of interest.
219 B : ndarray
220 The matrix to be multiplied by the matrix exponential of A.
221 t : float
222 A time point.
223 traceA : scalar, optional
224 Trace of `A`. If not given the trace is estimated for linear operators,
225 or calculated exactly for sparse matrices. It is used to precondition
226 `A`, thus an approximate trace is acceptable
227 balance : bool
228 Indicates whether or not to apply balancing.
230 Returns
231 -------
232 F : ndarray
233 :math:`e^{t A} B`
235 Notes
236 -----
237 This is algorithm (3.2) in Al-Mohy and Higham (2011).
239 """
240 if balance:
241 raise NotImplementedError
242 if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
243 raise ValueError('expected A to be like a square matrix')
244 if A.shape[1] != B.shape[0]:
245 raise ValueError('shapes of matrices A {} and B {} are incompatible'
246 .format(A.shape, B.shape))
247 ident = _ident_like(A)
248 is_linear_operator = isinstance(A, scipy.sparse.linalg.LinearOperator)
249 n = A.shape[0]
250 if len(B.shape) == 1:
251 n0 = 1
252 elif len(B.shape) == 2:
253 n0 = B.shape[1]
254 else:
255 raise ValueError('expected B to be like a matrix or a vector')
256 u_d = 2**-53
257 tol = u_d
258 if traceA is None:
259 if is_linear_operator:
260 warn("Trace of LinearOperator not available, it will be estimated."
261 " Provide `traceA` to ensure performance.", stacklevel=3)
262 # m3=1 is bit arbitrary choice, a more accurate trace (larger m3) might
263 # speed up exponential calculation, but trace estimation is more costly
264 traceA = traceest(A, m3=1) if is_linear_operator else _trace(A)
265 mu = traceA / float(n)
266 A = A - mu * ident
267 A_1_norm = onenormest(A) if is_linear_operator else _exact_1_norm(A)
268 if t*A_1_norm == 0:
269 m_star, s = 0, 1
270 else:
271 ell = 2
272 norm_info = LazyOperatorNormInfo(t*A, A_1_norm=t*A_1_norm, ell=ell)
273 m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
274 return _expm_multiply_simple_core(A, B, t, mu, m_star, s, tol, balance)
277def _expm_multiply_simple_core(A, B, t, mu, m_star, s, tol=None, balance=False):
278 """
279 A helper function.
280 """
281 if balance:
282 raise NotImplementedError
283 if tol is None:
284 u_d = 2 ** -53
285 tol = u_d
286 F = B
287 eta = np.exp(t*mu / float(s))
288 for i in range(s):
289 c1 = _exact_inf_norm(B)
290 for j in range(m_star):
291 coeff = t / float(s*(j+1))
292 B = coeff * A.dot(B)
293 c2 = _exact_inf_norm(B)
294 F = F + B
295 if c1 + c2 <= tol * _exact_inf_norm(F):
296 break
297 c1 = c2
298 F = eta * F
299 B = F
300 return F
303# This table helps to compute bounds.
304# They seem to have been difficult to calculate, involving symbolic
305# manipulation of equations, followed by numerical root finding.
306_theta = {
307 # The first 30 values are from table A.3 of Computing Matrix Functions.
308 1: 2.29e-16,
309 2: 2.58e-8,
310 3: 1.39e-5,
311 4: 3.40e-4,
312 5: 2.40e-3,
313 6: 9.07e-3,
314 7: 2.38e-2,
315 8: 5.00e-2,
316 9: 8.96e-2,
317 10: 1.44e-1,
318 # 11
319 11: 2.14e-1,
320 12: 3.00e-1,
321 13: 4.00e-1,
322 14: 5.14e-1,
323 15: 6.41e-1,
324 16: 7.81e-1,
325 17: 9.31e-1,
326 18: 1.09,
327 19: 1.26,
328 20: 1.44,
329 # 21
330 21: 1.62,
331 22: 1.82,
332 23: 2.01,
333 24: 2.22,
334 25: 2.43,
335 26: 2.64,
336 27: 2.86,
337 28: 3.08,
338 29: 3.31,
339 30: 3.54,
340 # The rest are from table 3.1 of
341 # Computing the Action of the Matrix Exponential.
342 35: 4.7,
343 40: 6.0,
344 45: 7.2,
345 50: 8.5,
346 55: 9.9,
347 }
350def _onenormest_matrix_power(A, p,
351 t=2, itmax=5, compute_v=False, compute_w=False):
352 """
353 Efficiently estimate the 1-norm of A^p.
355 Parameters
356 ----------
357 A : ndarray
358 Matrix whose 1-norm of a power is to be computed.
359 p : int
360 Non-negative integer power.
361 t : int, optional
362 A positive parameter controlling the tradeoff between
363 accuracy versus time and memory usage.
364 Larger values take longer and use more memory
365 but give more accurate output.
366 itmax : int, optional
367 Use at most this many iterations.
368 compute_v : bool, optional
369 Request a norm-maximizing linear operator input vector if True.
370 compute_w : bool, optional
371 Request a norm-maximizing linear operator output vector if True.
373 Returns
374 -------
375 est : float
376 An underestimate of the 1-norm of the sparse matrix.
377 v : ndarray, optional
378 The vector such that ||Av||_1 == est*||v||_1.
379 It can be thought of as an input to the linear operator
380 that gives an output with particularly large norm.
381 w : ndarray, optional
382 The vector Av which has relatively large 1-norm.
383 It can be thought of as an output of the linear operator
384 that is relatively large in norm compared to the input.
386 """
387 #XXX Eventually turn this into an API function in the _onenormest module,
388 #XXX and remove its underscore,
389 #XXX but wait until expm_multiply goes into scipy.
390 from scipy.sparse.linalg._onenormest import onenormest
391 return onenormest(aslinearoperator(A) ** p)
393class LazyOperatorNormInfo:
394 """
395 Information about an operator is lazily computed.
397 The information includes the exact 1-norm of the operator,
398 in addition to estimates of 1-norms of powers of the operator.
399 This uses the notation of Computing the Action (2011).
400 This class is specialized enough to probably not be of general interest
401 outside of this module.
403 """
405 def __init__(self, A, A_1_norm=None, ell=2, scale=1):
406 """
407 Provide the operator and some norm-related information.
409 Parameters
410 ----------
411 A : linear operator
412 The operator of interest.
413 A_1_norm : float, optional
414 The exact 1-norm of A.
415 ell : int, optional
416 A technical parameter controlling norm estimation quality.
417 scale : int, optional
418 If specified, return the norms of scale*A instead of A.
420 """
421 self._A = A
422 self._A_1_norm = A_1_norm
423 self._ell = ell
424 self._d = {}
425 self._scale = scale
427 def set_scale(self,scale):
428 """
429 Set the scale parameter.
430 """
431 self._scale = scale
433 def onenorm(self):
434 """
435 Compute the exact 1-norm.
436 """
437 if self._A_1_norm is None:
438 self._A_1_norm = _exact_1_norm(self._A)
439 return self._scale*self._A_1_norm
441 def d(self, p):
442 """
443 Lazily estimate d_p(A) ~= || A^p ||^(1/p) where ||.|| is the 1-norm.
444 """
445 if p not in self._d:
446 est = _onenormest_matrix_power(self._A, p, self._ell)
447 self._d[p] = est ** (1.0 / p)
448 return self._scale*self._d[p]
450 def alpha(self, p):
451 """
452 Lazily compute max(d(p), d(p+1)).
453 """
454 return max(self.d(p), self.d(p+1))
456def _compute_cost_div_m(m, p, norm_info):
457 """
458 A helper function for computing bounds.
460 This is equation (3.10).
461 It measures cost in terms of the number of required matrix products.
463 Parameters
464 ----------
465 m : int
466 A valid key of _theta.
467 p : int
468 A matrix power.
469 norm_info : LazyOperatorNormInfo
470 Information about 1-norms of related operators.
472 Returns
473 -------
474 cost_div_m : int
475 Required number of matrix products divided by m.
477 """
478 return int(np.ceil(norm_info.alpha(p) / _theta[m]))
481def _compute_p_max(m_max):
482 """
483 Compute the largest positive integer p such that p*(p-1) <= m_max + 1.
485 Do this in a slightly dumb way, but safe and not too slow.
487 Parameters
488 ----------
489 m_max : int
490 A count related to bounds.
492 """
493 sqrt_m_max = np.sqrt(m_max)
494 p_low = int(np.floor(sqrt_m_max))
495 p_high = int(np.ceil(sqrt_m_max + 1))
496 return max(p for p in range(p_low, p_high+1) if p*(p-1) <= m_max + 1)
499def _fragment_3_1(norm_info, n0, tol, m_max=55, ell=2):
500 """
501 A helper function for the _expm_multiply_* functions.
503 Parameters
504 ----------
505 norm_info : LazyOperatorNormInfo
506 Information about norms of certain linear operators of interest.
507 n0 : int
508 Number of columns in the _expm_multiply_* B matrix.
509 tol : float
510 Expected to be
511 :math:`2^{-24}` for single precision or
512 :math:`2^{-53}` for double precision.
513 m_max : int
514 A value related to a bound.
515 ell : int
516 The number of columns used in the 1-norm approximation.
517 This is usually taken to be small, maybe between 1 and 5.
519 Returns
520 -------
521 best_m : int
522 Related to bounds for error control.
523 best_s : int
524 Amount of scaling.
526 Notes
527 -----
528 This is code fragment (3.1) in Al-Mohy and Higham (2011).
529 The discussion of default values for m_max and ell
530 is given between the definitions of equation (3.11)
531 and the definition of equation (3.12).
533 """
534 if ell < 1:
535 raise ValueError('expected ell to be a positive integer')
536 best_m = None
537 best_s = None
538 if _condition_3_13(norm_info.onenorm(), n0, m_max, ell):
539 for m, theta in _theta.items():
540 s = int(np.ceil(norm_info.onenorm() / theta))
541 if best_m is None or m * s < best_m * best_s:
542 best_m = m
543 best_s = s
544 else:
545 # Equation (3.11).
546 for p in range(2, _compute_p_max(m_max) + 1):
547 for m in range(p*(p-1)-1, m_max+1):
548 if m in _theta:
549 s = _compute_cost_div_m(m, p, norm_info)
550 if best_m is None or m * s < best_m * best_s:
551 best_m = m
552 best_s = s
553 best_s = max(best_s, 1)
554 return best_m, best_s
557def _condition_3_13(A_1_norm, n0, m_max, ell):
558 """
559 A helper function for the _expm_multiply_* functions.
561 Parameters
562 ----------
563 A_1_norm : float
564 The precomputed 1-norm of A.
565 n0 : int
566 Number of columns in the _expm_multiply_* B matrix.
567 m_max : int
568 A value related to a bound.
569 ell : int
570 The number of columns used in the 1-norm approximation.
571 This is usually taken to be small, maybe between 1 and 5.
573 Returns
574 -------
575 value : bool
576 Indicates whether or not the condition has been met.
578 Notes
579 -----
580 This is condition (3.13) in Al-Mohy and Higham (2011).
582 """
584 # This is the rhs of equation (3.12).
585 p_max = _compute_p_max(m_max)
586 a = 2 * ell * p_max * (p_max + 3)
588 # Evaluate the condition (3.13).
589 b = _theta[m_max] / float(n0 * m_max)
590 return A_1_norm <= a * b
593def _expm_multiply_interval(A, B, start=None, stop=None, num=None,
594 endpoint=None, traceA=None, balance=False,
595 status_only=False):
596 """
597 Compute the action of the matrix exponential at multiple time points.
599 Parameters
600 ----------
601 A : transposable linear operator
602 The operator whose exponential is of interest.
603 B : ndarray
604 The matrix to be multiplied by the matrix exponential of A.
605 start : scalar, optional
606 The starting time point of the sequence.
607 stop : scalar, optional
608 The end time point of the sequence, unless `endpoint` is set to False.
609 In that case, the sequence consists of all but the last of ``num + 1``
610 evenly spaced time points, so that `stop` is excluded.
611 Note that the step size changes when `endpoint` is False.
612 num : int, optional
613 Number of time points to use.
614 traceA : scalar, optional
615 Trace of `A`. If not given the trace is estimated for linear operators,
616 or calculated exactly for sparse matrices. It is used to precondition
617 `A`, thus an approximate trace is acceptable
618 endpoint : bool, optional
619 If True, `stop` is the last time point. Otherwise, it is not included.
620 balance : bool
621 Indicates whether or not to apply balancing.
622 status_only : bool
623 A flag that is set to True for some debugging and testing operations.
625 Returns
626 -------
627 F : ndarray
628 :math:`e^{t_k A} B`
629 status : int
630 An integer status for testing and debugging.
632 Notes
633 -----
634 This is algorithm (5.2) in Al-Mohy and Higham (2011).
636 There seems to be a typo, where line 15 of the algorithm should be
637 moved to line 6.5 (between lines 6 and 7).
639 """
640 if balance:
641 raise NotImplementedError
642 if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
643 raise ValueError('expected A to be like a square matrix')
644 if A.shape[1] != B.shape[0]:
645 raise ValueError('shapes of matrices A {} and B {} are incompatible'
646 .format(A.shape, B.shape))
647 ident = _ident_like(A)
648 is_linear_operator = isinstance(A, scipy.sparse.linalg.LinearOperator)
649 n = A.shape[0]
650 if len(B.shape) == 1:
651 n0 = 1
652 elif len(B.shape) == 2:
653 n0 = B.shape[1]
654 else:
655 raise ValueError('expected B to be like a matrix or a vector')
656 u_d = 2**-53
657 tol = u_d
658 if traceA is None:
659 if is_linear_operator:
660 warn("Trace of LinearOperator not available, it will be estimated."
661 " Provide `traceA` to ensure performance.", stacklevel=3)
662 # m3=5 is bit arbitrary choice, a more accurate trace (larger m3) might
663 # speed up exponential calculation, but trace estimation is also costly
664 # an educated guess would need to consider the number of time points
665 traceA = traceest(A, m3=5) if is_linear_operator else _trace(A)
666 mu = traceA / float(n)
668 # Get the linspace samples, attempting to preserve the linspace defaults.
669 linspace_kwargs = {'retstep': True}
670 if num is not None:
671 linspace_kwargs['num'] = num
672 if endpoint is not None:
673 linspace_kwargs['endpoint'] = endpoint
674 samples, step = np.linspace(start, stop, **linspace_kwargs)
676 # Convert the linspace output to the notation used by the publication.
677 nsamples = len(samples)
678 if nsamples < 2:
679 raise ValueError('at least two time points are required')
680 q = nsamples - 1
681 h = step
682 t_0 = samples[0]
683 t_q = samples[q]
685 # Define the output ndarray.
686 # Use an ndim=3 shape, such that the last two indices
687 # are the ones that may be involved in level 3 BLAS operations.
688 X_shape = (nsamples,) + B.shape
689 X = np.empty(X_shape, dtype=np.result_type(A.dtype, B.dtype, float))
690 t = t_q - t_0
691 A = A - mu * ident
692 A_1_norm = onenormest(A) if is_linear_operator else _exact_1_norm(A)
693 ell = 2
694 norm_info = LazyOperatorNormInfo(t*A, A_1_norm=t*A_1_norm, ell=ell)
695 if t*A_1_norm == 0:
696 m_star, s = 0, 1
697 else:
698 m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
700 # Compute the expm action up to the initial time point.
701 X[0] = _expm_multiply_simple_core(A, B, t_0, mu, m_star, s)
703 # Compute the expm action at the rest of the time points.
704 if q <= s:
705 if status_only:
706 return 0
707 else:
708 return _expm_multiply_interval_core_0(A, X,
709 h, mu, q, norm_info, tol, ell,n0)
710 elif not (q % s):
711 if status_only:
712 return 1
713 else:
714 return _expm_multiply_interval_core_1(A, X,
715 h, mu, m_star, s, q, tol)
716 elif (q % s):
717 if status_only:
718 return 2
719 else:
720 return _expm_multiply_interval_core_2(A, X,
721 h, mu, m_star, s, q, tol)
722 else:
723 raise Exception('internal error')
726def _expm_multiply_interval_core_0(A, X, h, mu, q, norm_info, tol, ell, n0):
727 """
728 A helper function, for the case q <= s.
729 """
731 # Compute the new values of m_star and s which should be applied
732 # over intervals of size t/q
733 if norm_info.onenorm() == 0:
734 m_star, s = 0, 1
735 else:
736 norm_info.set_scale(1./q)
737 m_star, s = _fragment_3_1(norm_info, n0, tol, ell=ell)
738 norm_info.set_scale(1)
740 for k in range(q):
741 X[k+1] = _expm_multiply_simple_core(A, X[k], h, mu, m_star, s)
742 return X, 0
745def _expm_multiply_interval_core_1(A, X, h, mu, m_star, s, q, tol):
746 """
747 A helper function, for the case q > s and q % s == 0.
748 """
749 d = q // s
750 input_shape = X.shape[1:]
751 K_shape = (m_star + 1, ) + input_shape
752 K = np.empty(K_shape, dtype=X.dtype)
753 for i in range(s):
754 Z = X[i*d]
755 K[0] = Z
756 high_p = 0
757 for k in range(1, d+1):
758 F = K[0]
759 c1 = _exact_inf_norm(F)
760 for p in range(1, m_star+1):
761 if p > high_p:
762 K[p] = h * A.dot(K[p-1]) / float(p)
763 coeff = float(pow(k, p))
764 F = F + coeff * K[p]
765 inf_norm_K_p_1 = _exact_inf_norm(K[p])
766 c2 = coeff * inf_norm_K_p_1
767 if c1 + c2 <= tol * _exact_inf_norm(F):
768 break
769 c1 = c2
770 X[k + i*d] = np.exp(k*h*mu) * F
771 return X, 1
774def _expm_multiply_interval_core_2(A, X, h, mu, m_star, s, q, tol):
775 """
776 A helper function, for the case q > s and q % s > 0.
777 """
778 d = q // s
779 j = q // d
780 r = q - d * j
781 input_shape = X.shape[1:]
782 K_shape = (m_star + 1, ) + input_shape
783 K = np.empty(K_shape, dtype=X.dtype)
784 for i in range(j + 1):
785 Z = X[i*d]
786 K[0] = Z
787 high_p = 0
788 if i < j:
789 effective_d = d
790 else:
791 effective_d = r
792 for k in range(1, effective_d+1):
793 F = K[0]
794 c1 = _exact_inf_norm(F)
795 for p in range(1, m_star+1):
796 if p == high_p + 1:
797 K[p] = h * A.dot(K[p-1]) / float(p)
798 high_p = p
799 coeff = float(pow(k, p))
800 F = F + coeff * K[p]
801 inf_norm_K_p_1 = _exact_inf_norm(K[p])
802 c2 = coeff * inf_norm_K_p_1
803 if c1 + c2 <= tol * _exact_inf_norm(F):
804 break
805 c1 = c2
806 X[k + i*d] = np.exp(k*h*mu) * F
807 return X, 2