Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/_lib/_finite_differences.py: 7%
46 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-12-12 06:31 +0000
1from numpy import arange, newaxis, hstack, prod, array
4def _central_diff_weights(Np, ndiv=1):
5 """
6 Return weights for an Np-point central derivative.
8 Assumes equally-spaced function points.
10 If weights are in the vector w, then
11 derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx)
13 Parameters
14 ----------
15 Np : int
16 Number of points for the central derivative.
17 ndiv : int, optional
18 Number of divisions. Default is 1.
20 Returns
21 -------
22 w : ndarray
23 Weights for an Np-point central derivative. Its size is `Np`.
25 Notes
26 -----
27 Can be inaccurate for a large number of points.
29 Examples
30 --------
31 We can calculate a derivative value of a function.
33 >>> def f(x):
34 ... return 2 * x**2 + 3
35 >>> x = 3.0 # derivative point
36 >>> h = 0.1 # differential step
37 >>> Np = 3 # point number for central derivative
38 >>> weights = _central_diff_weights(Np) # weights for first derivative
39 >>> vals = [f(x + (i - Np/2) * h) for i in range(Np)]
40 >>> sum(w * v for (w, v) in zip(weights, vals))/h
41 11.79999999999998
43 This value is close to the analytical solution:
44 f'(x) = 4x, so f'(3) = 12
46 References
47 ----------
48 .. [1] https://en.wikipedia.org/wiki/Finite_difference
50 """
51 if Np < ndiv + 1:
52 raise ValueError(
53 "Number of points must be at least the derivative order + 1."
54 )
55 if Np % 2 == 0:
56 raise ValueError("The number of points must be odd.")
57 from scipy import linalg
59 ho = Np >> 1
60 x = arange(-ho, ho + 1.0)
61 x = x[:, newaxis]
62 X = x**0.0
63 for k in range(1, Np):
64 X = hstack([X, x**k])
65 w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv]
66 return w
69def _derivative(func, x0, dx=1.0, n=1, args=(), order=3):
70 """
71 Find the nth derivative of a function at a point.
73 Given a function, use a central difference formula with spacing `dx` to
74 compute the nth derivative at `x0`.
76 Parameters
77 ----------
78 func : function
79 Input function.
80 x0 : float
81 The point at which the nth derivative is found.
82 dx : float, optional
83 Spacing.
84 n : int, optional
85 Order of the derivative. Default is 1.
86 args : tuple, optional
87 Arguments
88 order : int, optional
89 Number of points to use, must be odd.
91 Notes
92 -----
93 Decreasing the step size too small can result in round-off error.
95 Examples
96 --------
97 >>> def f(x):
98 ... return x**3 + x**2
99 >>> _derivative(f, 1.0, dx=1e-6)
100 4.9999999999217337
102 """
103 if order < n + 1:
104 raise ValueError(
105 "'order' (the number of points used to compute the derivative), "
106 "must be at least the derivative order 'n' + 1."
107 )
108 if order % 2 == 0:
109 raise ValueError(
110 "'order' (the number of points used to compute the derivative) "
111 "must be odd."
112 )
113 # pre-computed for n=1 and 2 and low-order for speed.
114 if n == 1:
115 if order == 3:
116 weights = array([-1, 0, 1]) / 2.0
117 elif order == 5:
118 weights = array([1, -8, 0, 8, -1]) / 12.0
119 elif order == 7:
120 weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0
121 elif order == 9:
122 weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0
123 else:
124 weights = _central_diff_weights(order, 1)
125 elif n == 2:
126 if order == 3:
127 weights = array([1, -2.0, 1])
128 elif order == 5:
129 weights = array([-1, 16, -30, 16, -1]) / 12.0
130 elif order == 7:
131 weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0
132 elif order == 9:
133 weights = (
134 array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9])
135 / 5040.0
136 )
137 else:
138 weights = _central_diff_weights(order, 2)
139 else:
140 weights = _central_diff_weights(order, n)
141 val = 0.0
142 ho = order >> 1
143 for k in range(order):
144 val += weights[k] * func(x0 + (k - ho) * dx, *args)
145 return val / prod((dx,) * n, axis=0)