Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/array_api_compat/common/_helpers.py: 15%

116 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-23 06:43 +0000

1""" 

2Various helper functions which are not part of the spec. 

3 

4Functions which start with an underscore are for internal use only but helpers 

5that are in __all__ are intended as additional helper functions for use by end 

6users of the compat library. 

7""" 

8from __future__ import annotations 

9 

10import sys 

11import math 

12 

13def _is_numpy_array(x): 

14 # Avoid importing NumPy if it isn't already 

15 if 'numpy' not in sys.modules: 

16 return False 

17 

18 import numpy as np 

19 

20 # TODO: Should we reject ndarray subclasses? 

21 return isinstance(x, (np.ndarray, np.generic)) 

22 

23def _is_cupy_array(x): 

24 # Avoid importing NumPy if it isn't already 

25 if 'cupy' not in sys.modules: 

26 return False 

27 

28 import cupy as cp 

29 

30 # TODO: Should we reject ndarray subclasses? 

31 return isinstance(x, (cp.ndarray, cp.generic)) 

32 

33def _is_torch_array(x): 

34 # Avoid importing torch if it isn't already 

35 if 'torch' not in sys.modules: 

36 return False 

37 

38 import torch 

39 

40 # TODO: Should we reject ndarray subclasses? 

41 return isinstance(x, torch.Tensor) 

42 

43def is_array_api_obj(x): 

44 """ 

45 Check if x is an array API compatible array object. 

46 """ 

47 return _is_numpy_array(x) \ 

48 or _is_cupy_array(x) \ 

49 or _is_torch_array(x) \ 

50 or hasattr(x, '__array_namespace__') 

51 

52def _check_api_version(api_version): 

53 if api_version is not None and api_version != '2021.12': 

54 raise ValueError("Only the 2021.12 version of the array API specification is currently supported") 

55 

56def array_namespace(*xs, api_version=None, _use_compat=True): 

57 """ 

58 Get the array API compatible namespace for the arrays `xs`. 

59 

60 `xs` should contain one or more arrays. 

61 

62 Typical usage is 

63 

64 def your_function(x, y): 

65 xp = array_api_compat.array_namespace(x, y) 

66 # Now use xp as the array library namespace 

67 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) 

68 

69 api_version should be the newest version of the spec that you need support 

70 for (currently the compat library wrapped APIs only support v2021.12). 

71 """ 

72 namespaces = set() 

73 for x in xs: 

74 if isinstance(x, (tuple, list)): 

75 namespaces.add(array_namespace(*x, _use_compat=_use_compat)) 

76 elif hasattr(x, '__array_namespace__'): 

77 namespaces.add(x.__array_namespace__(api_version=api_version)) 

78 elif _is_numpy_array(x): 

79 _check_api_version(api_version) 

80 if _use_compat: 

81 from .. import numpy as numpy_namespace 

82 namespaces.add(numpy_namespace) 

83 else: 

84 import numpy as np 

85 namespaces.add(np) 

86 elif _is_cupy_array(x): 

87 _check_api_version(api_version) 

88 if _use_compat: 

89 from .. import cupy as cupy_namespace 

90 namespaces.add(cupy_namespace) 

91 else: 

92 import cupy as cp 

93 namespaces.add(cp) 

94 elif _is_torch_array(x): 

95 _check_api_version(api_version) 

96 if _use_compat: 

97 from .. import torch as torch_namespace 

98 namespaces.add(torch_namespace) 

99 else: 

100 import torch 

101 namespaces.add(torch) 

102 else: 

103 # TODO: Support Python scalars? 

104 raise TypeError("The input is not a supported array type") 

105 

106 if not namespaces: 

107 raise TypeError("Unrecognized array input") 

108 

109 if len(namespaces) != 1: 

110 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") 

111 

112 xp, = namespaces 

113 

114 return xp 

115 

116# backwards compatibility alias 

117get_namespace = array_namespace 

118 

119def _check_device(xp, device): 

120 if xp == sys.modules.get('numpy'): 

121 if device not in ["cpu", None]: 

122 raise ValueError(f"Unsupported device for NumPy: {device!r}") 

123 

124# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray 

125# or cupy.ndarray. They are not included in array objects of this library 

126# because this library just reuses the respective ndarray classes without 

127# wrapping or subclassing them. These helper functions can be used instead of 

128# the wrapper functions for libraries that need to support both NumPy/CuPy and 

129# other libraries that use devices. 

130def device(x: "Array", /) -> "Device": 

131 """ 

132 Hardware device the array data resides on. 

133 

134 Parameters 

135 ---------- 

136 x: array 

137 array instance from NumPy or an array API compatible library. 

138 

139 Returns 

140 ------- 

141 out: device 

142 a ``device`` object (see the "Device Support" section of the array API specification). 

143 """ 

144 if _is_numpy_array(x): 

145 return "cpu" 

146 return x.device 

147 

148# Based on cupy.array_api.Array.to_device 

149def _cupy_to_device(x, device, /, stream=None): 

150 import cupy as cp 

151 from cupy.cuda import Device as _Device 

152 from cupy.cuda import stream as stream_module 

153 from cupy_backends.cuda.api import runtime 

154 

155 if device == x.device: 

156 return x 

157 elif device == "cpu": 

158 # allowing us to use `to_device(x, "cpu")` 

159 # is useful for portable test swapping between 

160 # host and device backends 

161 return x.get() 

162 elif not isinstance(device, _Device): 

163 raise ValueError(f"Unsupported device {device!r}") 

164 else: 

165 # see cupy/cupy#5985 for the reason how we handle device/stream here 

166 prev_device = runtime.getDevice() 

167 prev_stream: stream_module.Stream = None 

168 if stream is not None: 

169 prev_stream = stream_module.get_current_stream() 

170 # stream can be an int as specified in __dlpack__, or a CuPy stream 

171 if isinstance(stream, int): 

172 stream = cp.cuda.ExternalStream(stream) 

173 elif isinstance(stream, cp.cuda.Stream): 

174 pass 

175 else: 

176 raise ValueError('the input stream is not recognized') 

177 stream.use() 

178 try: 

179 runtime.setDevice(device.id) 

180 arr = x.copy() 

181 finally: 

182 runtime.setDevice(prev_device) 

183 if stream is not None: 

184 prev_stream.use() 

185 return arr 

186 

187def _torch_to_device(x, device, /, stream=None): 

188 if stream is not None: 

189 raise NotImplementedError 

190 return x.to(device) 

191 

192def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": 

193 """ 

194 Copy the array from the device on which it currently resides to the specified ``device``. 

195 

196 Parameters 

197 ---------- 

198 x: array 

199 array instance from NumPy or an array API compatible library. 

200 device: device 

201 a ``device`` object (see the "Device Support" section of the array API specification). 

202 stream: Optional[Union[int, Any]] 

203 stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. 

204 

205 Returns 

206 ------- 

207 out: array 

208 an array with the same data and data type as ``x`` and located on the specified ``device``. 

209 

210 .. note:: 

211 If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. 

212 """ 

213 if _is_numpy_array(x): 

214 if stream is not None: 

215 raise ValueError("The stream argument to to_device() is not supported") 

216 if device == 'cpu': 

217 return x 

218 raise ValueError(f"Unsupported device {device!r}") 

219 elif _is_cupy_array(x): 

220 # cupy does not yet have to_device 

221 return _cupy_to_device(x, device, stream=stream) 

222 elif _is_torch_array(x): 

223 return _torch_to_device(x, device, stream=stream) 

224 return x.to_device(device, stream=stream) 

225 

226def size(x): 

227 """ 

228 Return the total number of elements of x 

229 """ 

230 if None in x.shape: 

231 return None 

232 return math.prod(x.shape) 

233 

234__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']