Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/scipy/_lib/array_api_compat/common/_helpers.py: 15%

114 statements  

« prev     ^ index     » next       coverage.py v7.4.1, created at 2024-02-14 06:37 +0000

1""" 

2Various helper functions which are not part of the spec. 

3 

4Functions which start with an underscore are for internal use only but helpers 

5that are in __all__ are intended as additional helper functions for use by end 

6users of the compat library. 

7""" 

8from __future__ import annotations 

9 

10import sys 

11import math 

12 

13def _is_numpy_array(x): 

14 # Avoid importing NumPy if it isn't already 

15 if 'numpy' not in sys.modules: 

16 return False 

17 

18 import numpy as np 

19 

20 # TODO: Should we reject ndarray subclasses? 

21 return isinstance(x, (np.ndarray, np.generic)) 

22 

23def _is_cupy_array(x): 

24 # Avoid importing NumPy if it isn't already 

25 if 'cupy' not in sys.modules: 

26 return False 

27 

28 import cupy as cp 

29 

30 # TODO: Should we reject ndarray subclasses? 

31 return isinstance(x, (cp.ndarray, cp.generic)) 

32 

33def _is_torch_array(x): 

34 # Avoid importing torch if it isn't already 

35 if 'torch' not in sys.modules: 

36 return False 

37 

38 import torch 

39 

40 # TODO: Should we reject ndarray subclasses? 

41 return isinstance(x, torch.Tensor) 

42 

43def is_array_api_obj(x): 

44 """ 

45 Check if x is an array API compatible array object. 

46 """ 

47 return _is_numpy_array(x) \ 

48 or _is_cupy_array(x) \ 

49 or _is_torch_array(x) \ 

50 or hasattr(x, '__array_namespace__') 

51 

52def _check_api_version(api_version): 

53 if api_version is not None and api_version != '2021.12': 

54 raise ValueError("Only the 2021.12 version of the array API specification is currently supported") 

55 

56def array_namespace(*xs, api_version=None, _use_compat=True): 

57 """ 

58 Get the array API compatible namespace for the arrays `xs`. 

59 

60 `xs` should contain one or more arrays. 

61 

62 Typical usage is 

63 

64 def your_function(x, y): 

65 xp = array_api_compat.array_namespace(x, y) 

66 # Now use xp as the array library namespace 

67 return xp.mean(x, axis=0) + 2*xp.std(y, axis=0) 

68 

69 api_version should be the newest version of the spec that you need support 

70 for (currently the compat library wrapped APIs only support v2021.12). 

71 """ 

72 namespaces = set() 

73 for x in xs: 

74 if _is_numpy_array(x): 

75 _check_api_version(api_version) 

76 if _use_compat: 

77 from .. import numpy as numpy_namespace 

78 namespaces.add(numpy_namespace) 

79 else: 

80 import numpy as np 

81 namespaces.add(np) 

82 elif _is_cupy_array(x): 

83 _check_api_version(api_version) 

84 if _use_compat: 

85 from .. import cupy as cupy_namespace 

86 namespaces.add(cupy_namespace) 

87 else: 

88 import cupy as cp 

89 namespaces.add(cp) 

90 elif _is_torch_array(x): 

91 _check_api_version(api_version) 

92 if _use_compat: 

93 from .. import torch as torch_namespace 

94 namespaces.add(torch_namespace) 

95 else: 

96 import torch 

97 namespaces.add(torch) 

98 elif hasattr(x, '__array_namespace__'): 

99 namespaces.add(x.__array_namespace__(api_version=api_version)) 

100 else: 

101 # TODO: Support Python scalars? 

102 raise TypeError(f"{type(x).__name__} is not a supported array type") 

103 

104 if not namespaces: 

105 raise TypeError("Unrecognized array input") 

106 

107 if len(namespaces) != 1: 

108 raise TypeError(f"Multiple namespaces for array inputs: {namespaces}") 

109 

110 xp, = namespaces 

111 

112 return xp 

113 

114# backwards compatibility alias 

115get_namespace = array_namespace 

116 

117def _check_device(xp, device): 

118 if xp == sys.modules.get('numpy'): 

119 if device not in ["cpu", None]: 

120 raise ValueError(f"Unsupported device for NumPy: {device!r}") 

121 

122# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray 

123# or cupy.ndarray. They are not included in array objects of this library 

124# because this library just reuses the respective ndarray classes without 

125# wrapping or subclassing them. These helper functions can be used instead of 

126# the wrapper functions for libraries that need to support both NumPy/CuPy and 

127# other libraries that use devices. 

128def device(x: "Array", /) -> "Device": 

129 """ 

130 Hardware device the array data resides on. 

131 

132 Parameters 

133 ---------- 

134 x: array 

135 array instance from NumPy or an array API compatible library. 

136 

137 Returns 

138 ------- 

139 out: device 

140 a ``device`` object (see the "Device Support" section of the array API specification). 

141 """ 

142 if _is_numpy_array(x): 

143 return "cpu" 

144 return x.device 

145 

146# Based on cupy.array_api.Array.to_device 

147def _cupy_to_device(x, device, /, stream=None): 

148 import cupy as cp 

149 from cupy.cuda import Device as _Device 

150 from cupy.cuda import stream as stream_module 

151 from cupy_backends.cuda.api import runtime 

152 

153 if device == x.device: 

154 return x 

155 elif device == "cpu": 

156 # allowing us to use `to_device(x, "cpu")` 

157 # is useful for portable test swapping between 

158 # host and device backends 

159 return x.get() 

160 elif not isinstance(device, _Device): 

161 raise ValueError(f"Unsupported device {device!r}") 

162 else: 

163 # see cupy/cupy#5985 for the reason how we handle device/stream here 

164 prev_device = runtime.getDevice() 

165 prev_stream: stream_module.Stream = None 

166 if stream is not None: 

167 prev_stream = stream_module.get_current_stream() 

168 # stream can be an int as specified in __dlpack__, or a CuPy stream 

169 if isinstance(stream, int): 

170 stream = cp.cuda.ExternalStream(stream) 

171 elif isinstance(stream, cp.cuda.Stream): 

172 pass 

173 else: 

174 raise ValueError('the input stream is not recognized') 

175 stream.use() 

176 try: 

177 runtime.setDevice(device.id) 

178 arr = x.copy() 

179 finally: 

180 runtime.setDevice(prev_device) 

181 if stream is not None: 

182 prev_stream.use() 

183 return arr 

184 

185def _torch_to_device(x, device, /, stream=None): 

186 if stream is not None: 

187 raise NotImplementedError 

188 return x.to(device) 

189 

190def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": 

191 """ 

192 Copy the array from the device on which it currently resides to the specified ``device``. 

193 

194 Parameters 

195 ---------- 

196 x: array 

197 array instance from NumPy or an array API compatible library. 

198 device: device 

199 a ``device`` object (see the "Device Support" section of the array API specification). 

200 stream: Optional[Union[int, Any]] 

201 stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable. 

202 

203 Returns 

204 ------- 

205 out: array 

206 an array with the same data and data type as ``x`` and located on the specified ``device``. 

207 

208 .. note:: 

209 If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation. 

210 """ 

211 if _is_numpy_array(x): 

212 if stream is not None: 

213 raise ValueError("The stream argument to to_device() is not supported") 

214 if device == 'cpu': 

215 return x 

216 raise ValueError(f"Unsupported device {device!r}") 

217 elif _is_cupy_array(x): 

218 # cupy does not yet have to_device 

219 return _cupy_to_device(x, device, stream=stream) 

220 elif _is_torch_array(x): 

221 return _torch_to_device(x, device, stream=stream) 

222 return x.to_device(device, stream=stream) 

223 

224def size(x): 

225 """ 

226 Return the total number of elements of x 

227 """ 

228 if None in x.shape: 

229 return None 

230 return math.prod(x.shape) 

231 

232__all__ = ['is_array_api_obj', 'array_namespace', 'get_namespace', 'device', 'to_device', 'size']