Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/scipy/_lib/_finite_differences.py: 7%

46 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1from numpy import arange, newaxis, hstack, prod, array 

2 

3 

4def _central_diff_weights(Np, ndiv=1): 

5 """ 

6 Return weights for an Np-point central derivative. 

7 

8 Assumes equally-spaced function points. 

9 

10 If weights are in the vector w, then 

11 derivative is w[0] * f(x-ho*dx) + ... + w[-1] * f(x+h0*dx) 

12 

13 Parameters 

14 ---------- 

15 Np : int 

16 Number of points for the central derivative. 

17 ndiv : int, optional 

18 Number of divisions. Default is 1. 

19 

20 Returns 

21 ------- 

22 w : ndarray 

23 Weights for an Np-point central derivative. Its size is `Np`. 

24 

25 Notes 

26 ----- 

27 Can be inaccurate for a large number of points. 

28 

29 Examples 

30 -------- 

31 We can calculate a derivative value of a function. 

32 

33 >>> def f(x): 

34 ... return 2 * x**2 + 3 

35 >>> x = 3.0 # derivative point 

36 >>> h = 0.1 # differential step 

37 >>> Np = 3 # point number for central derivative 

38 >>> weights = _central_diff_weights(Np) # weights for first derivative 

39 >>> vals = [f(x + (i - Np/2) * h) for i in range(Np)] 

40 >>> sum(w * v for (w, v) in zip(weights, vals))/h 

41 11.79999999999998 

42 

43 This value is close to the analytical solution: 

44 f'(x) = 4x, so f'(3) = 12 

45 

46 References 

47 ---------- 

48 .. [1] https://en.wikipedia.org/wiki/Finite_difference 

49 

50 """ 

51 if Np < ndiv + 1: 

52 raise ValueError( 

53 "Number of points must be at least the derivative order + 1." 

54 ) 

55 if Np % 2 == 0: 

56 raise ValueError("The number of points must be odd.") 

57 from scipy import linalg 

58 

59 ho = Np >> 1 

60 x = arange(-ho, ho + 1.0) 

61 x = x[:, newaxis] 

62 X = x**0.0 

63 for k in range(1, Np): 

64 X = hstack([X, x**k]) 

65 w = prod(arange(1, ndiv + 1), axis=0) * linalg.inv(X)[ndiv] 

66 return w 

67 

68 

69def _derivative(func, x0, dx=1.0, n=1, args=(), order=3): 

70 """ 

71 Find the nth derivative of a function at a point. 

72 

73 Given a function, use a central difference formula with spacing `dx` to 

74 compute the nth derivative at `x0`. 

75 

76 Parameters 

77 ---------- 

78 func : function 

79 Input function. 

80 x0 : float 

81 The point at which the nth derivative is found. 

82 dx : float, optional 

83 Spacing. 

84 n : int, optional 

85 Order of the derivative. Default is 1. 

86 args : tuple, optional 

87 Arguments 

88 order : int, optional 

89 Number of points to use, must be odd. 

90 

91 Notes 

92 ----- 

93 Decreasing the step size too small can result in round-off error. 

94 

95 Examples 

96 -------- 

97 >>> def f(x): 

98 ... return x**3 + x**2 

99 >>> _derivative(f, 1.0, dx=1e-6) 

100 4.9999999999217337 

101 

102 """ 

103 if order < n + 1: 

104 raise ValueError( 

105 "'order' (the number of points used to compute the derivative), " 

106 "must be at least the derivative order 'n' + 1." 

107 ) 

108 if order % 2 == 0: 

109 raise ValueError( 

110 "'order' (the number of points used to compute the derivative) " 

111 "must be odd." 

112 ) 

113 # pre-computed for n=1 and 2 and low-order for speed. 

114 if n == 1: 

115 if order == 3: 

116 weights = array([-1, 0, 1]) / 2.0 

117 elif order == 5: 

118 weights = array([1, -8, 0, 8, -1]) / 12.0 

119 elif order == 7: 

120 weights = array([-1, 9, -45, 0, 45, -9, 1]) / 60.0 

121 elif order == 9: 

122 weights = array([3, -32, 168, -672, 0, 672, -168, 32, -3]) / 840.0 

123 else: 

124 weights = _central_diff_weights(order, 1) 

125 elif n == 2: 

126 if order == 3: 

127 weights = array([1, -2.0, 1]) 

128 elif order == 5: 

129 weights = array([-1, 16, -30, 16, -1]) / 12.0 

130 elif order == 7: 

131 weights = array([2, -27, 270, -490, 270, -27, 2]) / 180.0 

132 elif order == 9: 

133 weights = ( 

134 array([-9, 128, -1008, 8064, -14350, 8064, -1008, 128, -9]) 

135 / 5040.0 

136 ) 

137 else: 

138 weights = _central_diff_weights(order, 2) 

139 else: 

140 weights = _central_diff_weights(order, n) 

141 val = 0.0 

142 ho = order >> 1 

143 for k in range(order): 

144 val += weights[k] * func(x0 + (k - ho) * dx, *args) 

145 return val / prod((dx,) * n, axis=0)