Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.9/dist-packages/numexpr/utils.py: 43%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

150 statements  

1################################################################### 

2# Numexpr - Fast numerical array expression evaluator for NumPy. 

3# 

4# License: MIT 

5# Author: See AUTHORS.txt 

6# 

7# See LICENSE.txt and LICENSES/*.txt for details about copyright and 

8# rights to use. 

9#################################################################### 

10 

11import logging 

12log = logging.getLogger(__name__) 

13 

14import os 

15import subprocess 

16import contextvars 

17 

18from numexpr.interpreter import _set_num_threads, _get_num_threads, MAX_THREADS 

19from numexpr import use_vml 

20from . import version 

21 

22if use_vml: 

23 from numexpr.interpreter import ( 

24 _get_vml_version, _set_vml_accuracy_mode, _set_vml_num_threads, 

25 _get_vml_num_threads) 

26 

27 

28def get_vml_version(): 

29 """ 

30 Get the VML/MKL library version. 

31 """ 

32 if use_vml: 

33 return _get_vml_version() 

34 else: 

35 return None 

36 

37 

38def set_vml_accuracy_mode(mode): 

39 """ 

40 Set the accuracy mode for VML operations. 

41 

42 The `mode` parameter can take the values: 

43 - 'high': high accuracy mode (HA), <1 least significant bit 

44 - 'low': low accuracy mode (LA), typically 1-2 least significant bits 

45 - 'fast': enhanced performance mode (EP) 

46 - None: mode settings are ignored 

47 

48 This call is equivalent to the `vmlSetMode()` in the VML library. 

49 See: 

50 

51 http://www.intel.com/software/products/mkl/docs/webhelp/vml/vml_DataTypesAccuracyModes.html 

52 

53 for more info on the accuracy modes. 

54 

55 Returns old accuracy settings. 

56 """ 

57 if use_vml: 

58 acc_dict = {None: 0, 'low': 1, 'high': 2, 'fast': 3} 

59 acc_reverse_dict = {1: 'low', 2: 'high', 3: 'fast'} 

60 if mode not in list(acc_dict.keys()): 

61 raise ValueError( 

62 "mode argument must be one of: None, 'high', 'low', 'fast'") 

63 retval = _set_vml_accuracy_mode(acc_dict.get(mode, 0)) 

64 return acc_reverse_dict.get(retval) 

65 else: 

66 return None 

67 

68 

69def set_vml_num_threads(nthreads): 

70 """ 

71 Suggests a maximum number of threads to be used in VML operations. 

72 

73 This function is equivalent to the call 

74 `mkl_domain_set_num_threads(nthreads, MKL_DOMAIN_VML)` in the MKL 

75 library. See: 

76 

77 http://www.intel.com/software/products/mkl/docs/webhelp/support/functn_mkl_domain_set_num_threads.html 

78 

79 for more info about it. 

80 """ 

81 if use_vml: 

82 _set_vml_num_threads(nthreads) 

83 pass 

84 

85def get_vml_num_threads(): 

86 """ 

87 Gets the maximum number of threads to be used in VML operations. 

88 

89 This function is equivalent to the call 

90 `mkl_domain_get_max_threads (MKL_DOMAIN_VML)` in the MKL 

91 library. See: 

92 

93 http://software.intel.com/en-us/node/522118 

94 

95 for more info about it. 

96 """ 

97 if use_vml: 

98 return _get_vml_num_threads() 

99 return None 

100 

101def set_num_threads(nthreads): 

102 """ 

103 Sets a number of threads to be used in operations. 

104 

105 DEPRECATED: returns the previous setting for the number of threads. 

106 

107 During initialization time NumExpr sets this number to the number 

108 of detected cores in the system (see `detect_number_of_cores()`). 

109 """ 

110 old_nthreads = _set_num_threads(nthreads) 

111 return old_nthreads 

112 

113def get_num_threads(): 

114 """ 

115 Gets the number of threads currently in use for operations. 

116 """ 

117 return _get_num_threads() 

118 

119def _init_num_threads(): 

120 """ 

121 Detects the environment variable 'NUMEXPR_MAX_THREADS' to set the threadpool  

122 size, and if necessary the slightly redundant 'NUMEXPR_NUM_THREADS' or  

123 'OMP_NUM_THREADS' env vars to set the initial number of threads used by  

124 the virtual machine. 

125 """ 

126 # Any platform-specific short-circuits 

127 if 'sparc' in version.platform_machine: 

128 log.warning('The number of threads have been set to 1 because problems related ' 

129 'to threading have been reported on some sparc machine. ' 

130 'The number of threads can be changed using the "set_num_threads" ' 

131 'function.') 

132 set_num_threads(1) 

133 return 1 

134 

135 env_configured = False 

136 n_cores = detect_number_of_cores() 

137 if ('NUMEXPR_MAX_THREADS' in os.environ and os.environ['NUMEXPR_MAX_THREADS'] != '' or 

138 'OMP_NUM_THREADS' in os.environ and os.environ['OMP_NUM_THREADS'] != ''): 

139 # The user has configured NumExpr in the expected way, so suppress logs. 

140 env_configured = True 

141 n_cores = MAX_THREADS 

142 else: 

143 # The use has not set 'NUMEXPR_MAX_THREADS', so likely they have not  

144 # configured NumExpr as desired, so we emit info logs. 

145 if n_cores > MAX_THREADS: 

146 log.info('Note: detected %d virtual cores but NumExpr set to maximum of %d, check "NUMEXPR_MAX_THREADS" environment variable.'%(n_cores, MAX_THREADS)) 

147 if n_cores > 16: 

148 # Back in 2019, 8 threads would be considered safe for performance. We are in 2024 now, so adjusting. 

149 log.info('Note: NumExpr detected %d cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 16.'%n_cores) 

150 n_cores = 16 

151 

152 # Now we check for 'NUMEXPR_NUM_THREADS' or 'OMP_NUM_THREADS' to set the  

153 # actual number of threads used. 

154 if 'NUMEXPR_NUM_THREADS' in os.environ and os.environ['NUMEXPR_NUM_THREADS'] != '': 

155 requested_threads = int(os.environ['NUMEXPR_NUM_THREADS']) 

156 elif 'OMP_NUM_THREADS' in os.environ and os.environ['OMP_NUM_THREADS'] != '': 

157 # Empty string is commonly used to unset the variable 

158 requested_threads = int(os.environ['OMP_NUM_THREADS']) 

159 else: 

160 requested_threads = n_cores 

161 if not env_configured: 

162 log.info('NumExpr defaulting to %d threads.'%n_cores) 

163 

164 # The C-extension function performs its own checks against `MAX_THREADS` 

165 set_num_threads(requested_threads) 

166 return requested_threads 

167 

168 

169def detect_number_of_cores(): 

170 """ 

171 Detects the number of cores on a system. Cribbed from pp. 

172 """ 

173 # Linux, Unix and MacOS: 

174 if hasattr(os, "sysconf"): 

175 if "SC_NPROCESSORS_ONLN" in os.sysconf_names: 

176 # Linux & Unix: 

177 ncpus = os.sysconf("SC_NPROCESSORS_ONLN") 

178 if isinstance(ncpus, int) and ncpus > 0: 

179 return ncpus 

180 else: # OSX: 

181 return int(subprocess.check_output(["sysctl", "-n", "hw.ncpu"])) 

182 # Windows: 

183 try: 

184 ncpus = int(os.environ.get("NUMBER_OF_PROCESSORS", "")) 

185 if ncpus > 0: 

186 return ncpus 

187 except ValueError: 

188 pass 

189 return 1 # Default 

190 

191 

192def detect_number_of_threads(): 

193 """ 

194 DEPRECATED: use `_init_num_threads` instead. 

195 If this is modified, please update the note in: https://github.com/pydata/numexpr/wiki/Numexpr-Users-Guide 

196 """ 

197 log.warning('Deprecated, use `init_num_threads` instead.') 

198 try: 

199 nthreads = int(os.environ.get('NUMEXPR_NUM_THREADS', '')) 

200 except ValueError: 

201 try: 

202 nthreads = int(os.environ.get('OMP_NUM_THREADS', '')) 

203 except ValueError: 

204 nthreads = detect_number_of_cores() 

205 

206 # Check that we don't surpass the MAX_THREADS in interpreter.cpp 

207 if nthreads > MAX_THREADS: 

208 nthreads = MAX_THREADS 

209 return nthreads 

210 

211 

212class CacheDict(dict): 

213 """ 

214 A dictionary that prevents itself from growing too much. 

215 """ 

216 

217 def __init__(self, maxentries): 

218 self.maxentries = maxentries 

219 super(CacheDict, self).__init__(self) 

220 

221 def __setitem__(self, key, value): 

222 # Protection against growing the cache too much 

223 if len(self) > self.maxentries: 

224 # Remove a 10% of (arbitrary) elements from the cache 

225 entries_to_remove = self.maxentries // 10 

226 for k in list(self.keys())[:entries_to_remove]: 

227 super(CacheDict, self).__delitem__(k) 

228 super(CacheDict, self).__setitem__(key, value) 

229 

230 

231class ContextDict: 

232 """ 

233 A context aware version dictionary 

234 """ 

235 def __init__(self): 

236 self._context_data = contextvars.ContextVar('context_data', default={}) 

237 

238 def set(self, key=None, value=None, **kwargs): 

239 data = self._context_data.get().copy() 

240 

241 if key is not None: 

242 data[key] = value 

243 

244 for k, v in kwargs.items(): 

245 data[k] = v 

246 

247 self._context_data.set(data) 

248 

249 def get(self, key, default=None): 

250 data = self._context_data.get() 

251 return data.get(key, default) 

252 

253 def delete(self, key): 

254 data = self._context_data.get().copy() 

255 if key in data: 

256 del data[key] 

257 self._context_data.set(data) 

258 

259 def clear(self): 

260 self._context_data.set({}) 

261 

262 def all(self): 

263 return self._context_data.get() 

264 

265 def update(self, *args, **kwargs): 

266 data = self._context_data.get().copy() 

267 

268 if args: 

269 if len(args) > 1: 

270 raise TypeError(f"update() takes at most 1 positional argument ({len(args)} given)") 

271 other = args[0] 

272 if isinstance(other, dict): 

273 data.update(other) 

274 else: 

275 for k, v in other: 

276 data[k] = v 

277 

278 data.update(kwargs) 

279 self._context_data.set(data) 

280 

281 def keys(self): 

282 return self._context_data.get().keys() 

283 

284 def values(self): 

285 return self._context_data.get().values() 

286 

287 def items(self): 

288 return self._context_data.get().items() 

289 

290 def __getitem__(self, key): 

291 return self.get(key) 

292 

293 def __setitem__(self, key, value): 

294 self.set(key, value) 

295 

296 def __delitem__(self, key): 

297 self.delete(key) 

298 

299 def __contains__(self, key): 

300 return key in self._context_data.get() 

301 

302 def __len__(self): 

303 return len(self._context_data.get()) 

304 

305 def __iter__(self): 

306 return iter(self._context_data.get()) 

307 

308 def __repr__(self): 

309 return repr(self._context_data.get())