Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/threadpoolctl.py: 33%

405 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-12-12 06:31 +0000

1"""threadpoolctl 

2 

3This module provides utilities to introspect native libraries that relies on 

4thread pools (notably BLAS and OpenMP implementations) and dynamically set the 

5maximal number of threads they can use. 

6""" 

7# License: BSD 3-Clause 

8 

9# The code to introspect dynamically loaded libraries on POSIX systems is 

10# adapted from code by Intel developer @anton-malakhov available at 

11# https://github.com/IntelPython/smp (Copyright (c) 2017, Intel Corporation) 

12# and also published under the BSD 3-Clause license 

13import os 

14import re 

15import sys 

16import ctypes 

17import textwrap 

18from typing import final 

19import warnings 

20from ctypes.util import find_library 

21from abc import ABC, abstractmethod 

22from functools import lru_cache 

23from contextlib import ContextDecorator 

24 

25__version__ = "3.2.0" 

26__all__ = [ 

27 "threadpool_limits", 

28 "threadpool_info", 

29 "ThreadpoolController", 

30 "LibController", 

31 "register", 

32] 

33 

34 

35# One can get runtime errors or even segfaults due to multiple OpenMP libraries 

36# loaded simultaneously which can happen easily in Python when importing and 

37# using compiled extensions built with different compilers and therefore 

38# different OpenMP runtimes in the same program. In particular libiomp (used by 

39# Intel ICC) and libomp used by clang/llvm tend to crash. This can happen for 

40# instance when calling BLAS inside a prange. Setting the following environment 

41# variable allows multiple OpenMP libraries to be loaded. It should not degrade 

42# performances since we manually take care of potential over-subscription 

43# performance issues, in sections of the code where nested OpenMP loops can 

44# happen, by dynamically reconfiguring the inner OpenMP runtime to temporarily 

45# disable it while under the scope of the outer OpenMP parallel section. 

46os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "True") 

47 

48# Structure to cast the info on dynamically loaded library. See 

49# https://linux.die.net/man/3/dl_iterate_phdr for more details. 

50_SYSTEM_UINT = ctypes.c_uint64 if sys.maxsize > 2**32 else ctypes.c_uint32 

51_SYSTEM_UINT_HALF = ctypes.c_uint32 if sys.maxsize > 2**32 else ctypes.c_uint16 

52 

53 

54class _dl_phdr_info(ctypes.Structure): 

55 _fields_ = [ 

56 ("dlpi_addr", _SYSTEM_UINT), # Base address of object 

57 ("dlpi_name", ctypes.c_char_p), # path to the library 

58 ("dlpi_phdr", ctypes.c_void_p), # pointer on dlpi_headers 

59 ("dlpi_phnum", _SYSTEM_UINT_HALF), # number of elements in dlpi_phdr 

60 ] 

61 

62 

63# The RTLD_NOLOAD flag for loading shared libraries is not defined on Windows. 

64try: 

65 _RTLD_NOLOAD = os.RTLD_NOLOAD 

66except AttributeError: 

67 _RTLD_NOLOAD = ctypes.DEFAULT_MODE 

68 

69 

70class LibController(ABC): 

71 """Abstract base class for the individual library controllers 

72 

73 A library controller must expose the following class attributes: 

74 - user_api : str 

75 Usually the name of the library or generic specification the library 

76 implements, e.g. "blas" is a specification with different implementations. 

77 - internal_api : str 

78 Usually the name of the library or concrete implementation of some 

79 specification, e.g. "openblas" is an implementation of the "blas" 

80 specification. 

81 - filename_prefixes : tuple 

82 Possible prefixes of the shared library's filename that allow to 

83 identify the library. e.g. "libopenblas" for libopenblas.so. 

84 

85 and implement the following methods: `get_num_threads`, `set_num_threads` and 

86 `get_version`. 

87 

88 Threadpoolctl loops through all the loaded shared libraries and tries to match 

89 the filename of each library with the `filename_prefixes`. If a match is found, a 

90 controller is instantiated and a handler to the library is stored in the `dynlib` 

91 attribute as a `ctypes.CDLL` object. It can be used to access the necessary symbols 

92 of the shared library to implement the above methods. 

93 

94 The following information will be exposed in the info dictionary: 

95 - user_api : standardized API, if any, or a copy of internal_api. 

96 - internal_api : implementation-specific API. 

97 - num_threads : the current thread limit. 

98 - prefix : prefix of the shared library's filename. 

99 - filepath : path to the loaded shared library. 

100 - version : version of the library (if available). 

101 

102 In addition, each library controller may expose internal API specific entries. They 

103 must be set as attributes in the `set_additional_attributes` method. 

104 """ 

105 

106 @final 

107 def __init__(self, *, filepath=None, prefix=None): 

108 """This is not meant to be overriden by subclasses.""" 

109 self.prefix = prefix 

110 self.filepath = filepath 

111 self.dynlib = ctypes.CDLL(filepath, mode=_RTLD_NOLOAD) 

112 self.version = self.get_version() 

113 self.set_additional_attributes() 

114 

115 @final 

116 def info(self): 

117 """Return relevant info wrapped in a dict 

118 

119 This is not meant to be overriden by subclasses. 

120 """ 

121 exposed_attrs = { 

122 "user_api": self.user_api, 

123 "internal_api": self.internal_api, 

124 "num_threads": self.num_threads, 

125 **vars(self), 

126 } 

127 exposed_attrs.pop("dynlib") 

128 return exposed_attrs 

129 

130 def set_additional_attributes(self): 

131 """Set additional attributes meant to be exposed in the info dict""" 

132 

133 @property 

134 def num_threads(self): 

135 """Exposes the current thread limit as a dynamic property 

136 

137 This is not meant to be used or overriden by subclasses. 

138 """ 

139 return self.get_num_threads() 

140 

141 @abstractmethod 

142 def get_num_threads(self): 

143 """Return the maximum number of threads available to use""" 

144 

145 @abstractmethod 

146 def set_num_threads(self, num_threads): 

147 """Set the maximum number of threads to use""" 

148 

149 @abstractmethod 

150 def get_version(self): 

151 """Return the version of the shared library""" 

152 

153 

154class OpenBLASController(LibController): 

155 """Controller class for OpenBLAS""" 

156 

157 user_api = "blas" 

158 internal_api = "openblas" 

159 filename_prefixes = ("libopenblas", "libblas") 

160 check_symbols = ("openblas_get_num_threads", "openblas_get_num_threads64_") 

161 

162 def set_additional_attributes(self): 

163 self.threading_layer = self._get_threading_layer() 

164 self.architecture = self._get_architecture() 

165 

166 def get_num_threads(self): 

167 get_func = getattr( 

168 self.dynlib, 

169 "openblas_get_num_threads", 

170 # Symbols differ when built for 64bit integers in Fortran 

171 getattr(self.dynlib, "openblas_get_num_threads64_", lambda: None), 

172 ) 

173 

174 return get_func() 

175 

176 def set_num_threads(self, num_threads): 

177 set_func = getattr( 

178 self.dynlib, 

179 "openblas_set_num_threads", 

180 # Symbols differ when built for 64bit integers in Fortran 

181 getattr( 

182 self.dynlib, "openblas_set_num_threads64_", lambda num_threads: None 

183 ), 

184 ) 

185 return set_func(num_threads) 

186 

187 def get_version(self): 

188 # None means OpenBLAS is not loaded or version < 0.3.4, since OpenBLAS 

189 # did not expose its version before that. 

190 get_config = getattr( 

191 self.dynlib, 

192 "openblas_get_config", 

193 getattr(self.dynlib, "openblas_get_config64_", None), 

194 ) 

195 if get_config is None: 

196 return None 

197 

198 get_config.restype = ctypes.c_char_p 

199 config = get_config().split() 

200 if config[0] == b"OpenBLAS": 

201 return config[1].decode("utf-8") 

202 return None 

203 

204 def _get_threading_layer(self): 

205 """Return the threading layer of OpenBLAS""" 

206 openblas_get_parallel = getattr( 

207 self.dynlib, 

208 "openblas_get_parallel", 

209 getattr(self.dynlib, "openblas_get_parallel64_", None), 

210 ) 

211 if openblas_get_parallel is None: 

212 return "unknown" 

213 threading_layer = openblas_get_parallel() 

214 if threading_layer == 2: 

215 return "openmp" 

216 elif threading_layer == 1: 

217 return "pthreads" 

218 return "disabled" 

219 

220 def _get_architecture(self): 

221 """Return the architecture detected by OpenBLAS""" 

222 get_corename = getattr( 

223 self.dynlib, 

224 "openblas_get_corename", 

225 getattr(self.dynlib, "openblas_get_corename64_", None), 

226 ) 

227 if get_corename is None: 

228 return None 

229 

230 get_corename.restype = ctypes.c_char_p 

231 return get_corename().decode("utf-8") 

232 

233 

234class BLISController(LibController): 

235 """Controller class for BLIS""" 

236 

237 user_api = "blas" 

238 internal_api = "blis" 

239 filename_prefixes = ("libblis", "libblas") 

240 check_symbols = ("bli_thread_get_num_threads",) 

241 

242 def set_additional_attributes(self): 

243 self.threading_layer = self._get_threading_layer() 

244 self.architecture = self._get_architecture() 

245 

246 def get_num_threads(self): 

247 get_func = getattr(self.dynlib, "bli_thread_get_num_threads", lambda: None) 

248 num_threads = get_func() 

249 # by default BLIS is single-threaded and get_num_threads 

250 # returns -1. We map it to 1 for consistency with other libraries. 

251 return 1 if num_threads == -1 else num_threads 

252 

253 def set_num_threads(self, num_threads): 

254 set_func = getattr( 

255 self.dynlib, "bli_thread_set_num_threads", lambda num_threads: None 

256 ) 

257 return set_func(num_threads) 

258 

259 def get_version(self): 

260 get_version_ = getattr(self.dynlib, "bli_info_get_version_str", None) 

261 if get_version_ is None: 

262 return None 

263 

264 get_version_.restype = ctypes.c_char_p 

265 return get_version_().decode("utf-8") 

266 

267 def _get_threading_layer(self): 

268 """Return the threading layer of BLIS""" 

269 if self.dynlib.bli_info_get_enable_openmp(): 

270 return "openmp" 

271 elif self.dynlib.bli_info_get_enable_pthreads(): 

272 return "pthreads" 

273 return "disabled" 

274 

275 def _get_architecture(self): 

276 """Return the architecture detected by BLIS""" 

277 bli_arch_query_id = getattr(self.dynlib, "bli_arch_query_id", None) 

278 bli_arch_string = getattr(self.dynlib, "bli_arch_string", None) 

279 if bli_arch_query_id is None or bli_arch_string is None: 

280 return None 

281 

282 # the true restype should be BLIS' arch_t (enum) but int should work 

283 # for us: 

284 bli_arch_query_id.restype = ctypes.c_int 

285 bli_arch_string.restype = ctypes.c_char_p 

286 return bli_arch_string(bli_arch_query_id()).decode("utf-8") 

287 

288 

289class MKLController(LibController): 

290 """Controller class for MKL""" 

291 

292 user_api = "blas" 

293 internal_api = "mkl" 

294 filename_prefixes = ("libmkl_rt", "mkl_rt", "libblas") 

295 check_symbols = ("MKL_Get_Max_Threads",) 

296 

297 def set_additional_attributes(self): 

298 self.threading_layer = self._get_threading_layer() 

299 

300 def get_num_threads(self): 

301 get_func = getattr(self.dynlib, "MKL_Get_Max_Threads", lambda: None) 

302 return get_func() 

303 

304 def set_num_threads(self, num_threads): 

305 set_func = getattr(self.dynlib, "MKL_Set_Num_Threads", lambda num_threads: None) 

306 return set_func(num_threads) 

307 

308 def get_version(self): 

309 if not hasattr(self.dynlib, "MKL_Get_Version_String"): 

310 return None 

311 

312 res = ctypes.create_string_buffer(200) 

313 self.dynlib.MKL_Get_Version_String(res, 200) 

314 

315 version = res.value.decode("utf-8") 

316 group = re.search(r"Version ([^ ]+) ", version) 

317 if group is not None: 

318 version = group.groups()[0] 

319 return version.strip() 

320 

321 def _get_threading_layer(self): 

322 """Return the threading layer of MKL""" 

323 # The function mkl_set_threading_layer returns the current threading 

324 # layer. Calling it with an invalid threading layer allows us to safely 

325 # get the threading layer 

326 set_threading_layer = getattr( 

327 self.dynlib, "MKL_Set_Threading_Layer", lambda layer: -1 

328 ) 

329 layer_map = { 

330 0: "intel", 

331 1: "sequential", 

332 2: "pgi", 

333 3: "gnu", 

334 4: "tbb", 

335 -1: "not specified", 

336 } 

337 return layer_map[set_threading_layer(-1)] 

338 

339 

340class OpenMPController(LibController): 

341 """Controller class for OpenMP""" 

342 

343 user_api = "openmp" 

344 internal_api = "openmp" 

345 filename_prefixes = ("libiomp", "libgomp", "libomp", "vcomp") 

346 

347 def get_num_threads(self): 

348 get_func = getattr(self.dynlib, "omp_get_max_threads", lambda: None) 

349 return get_func() 

350 

351 def set_num_threads(self, num_threads): 

352 set_func = getattr(self.dynlib, "omp_set_num_threads", lambda num_threads: None) 

353 return set_func(num_threads) 

354 

355 def get_version(self): 

356 # There is no way to get the version number programmatically in OpenMP. 

357 return None 

358 

359 

360# Controllers for the libraries that we'll look for in the loaded libraries. 

361# Third party libraries can register their own controllers. 

362_ALL_CONTROLLERS = [OpenBLASController, BLISController, MKLController, OpenMPController] 

363 

364# Helpers for the doc and test names 

365_ALL_USER_APIS = list(set(lib.user_api for lib in _ALL_CONTROLLERS)) 

366_ALL_INTERNAL_APIS = [lib.internal_api for lib in _ALL_CONTROLLERS] 

367_ALL_PREFIXES = list( 

368 set(prefix for lib in _ALL_CONTROLLERS for prefix in lib.filename_prefixes) 

369) 

370_ALL_BLAS_LIBRARIES = [ 

371 lib.internal_api for lib in _ALL_CONTROLLERS if lib.user_api == "blas" 

372] 

373_ALL_OPENMP_LIBRARIES = OpenMPController.filename_prefixes 

374 

375 

376def register(controller): 

377 """Register a new controller""" 

378 _ALL_CONTROLLERS.append(controller) 

379 _ALL_USER_APIS.append(controller.user_api) 

380 _ALL_INTERNAL_APIS.append(controller.internal_api) 

381 _ALL_PREFIXES.extend(controller.filename_prefixes) 

382 

383 

384def _format_docstring(*args, **kwargs): 

385 def decorator(o): 

386 if o.__doc__ is not None: 

387 o.__doc__ = o.__doc__.format(*args, **kwargs) 

388 return o 

389 

390 return decorator 

391 

392 

393@lru_cache(maxsize=10000) 

394def _realpath(filepath): 

395 """Small caching wrapper around os.path.realpath to limit system calls""" 

396 return os.path.realpath(filepath) 

397 

398 

399@_format_docstring(USER_APIS=list(_ALL_USER_APIS), INTERNAL_APIS=_ALL_INTERNAL_APIS) 

400def threadpool_info(): 

401 """Return the maximal number of threads for each detected library. 

402 

403 Return a list with all the supported libraries that have been found. Each 

404 library is represented by a dict with the following information: 

405 

406 - "user_api" : user API. Possible values are {USER_APIS}. 

407 - "internal_api": internal API. Possible values are {INTERNAL_APIS}. 

408 - "prefix" : filename prefix of the specific implementation. 

409 - "filepath": path to the loaded library. 

410 - "version": version of the library (if available). 

411 - "num_threads": the current thread limit. 

412 

413 In addition, each library may contain internal_api specific entries. 

414 """ 

415 return ThreadpoolController().info() 

416 

417 

418class _ThreadpoolLimiter: 

419 """The guts of ThreadpoolController.limit 

420 

421 Refer to the docstring of ThreadpoolController.limit for more details. 

422 

423 It will only act on the library controllers held by the provided `controller`. 

424 Using the default constructor sets the limits right away such that it can be used as 

425 a callable. Setting the limits can be delayed by using the `wrap` class method such 

426 that it can be used as a decorator. 

427 """ 

428 

429 def __init__(self, controller, *, limits=None, user_api=None): 

430 self._controller = controller 

431 self._limits, self._user_api, self._prefixes = self._check_params( 

432 limits, user_api 

433 ) 

434 self._original_info = self._controller.info() 

435 self._set_threadpool_limits() 

436 

437 def __enter__(self): 

438 return self 

439 

440 def __exit__(self, type, value, traceback): 

441 self.restore_original_limits() 

442 

443 @classmethod 

444 def wrap(cls, controller, *, limits=None, user_api=None): 

445 """Return an instance of this class that can be used as a decorator""" 

446 return _ThreadpoolLimiterDecorator( 

447 controller=controller, limits=limits, user_api=user_api 

448 ) 

449 

450 def restore_original_limits(self): 

451 """Set the limits back to their original values""" 

452 for lib_controller, original_info in zip( 

453 self._controller.lib_controllers, self._original_info 

454 ): 

455 lib_controller.set_num_threads(original_info["num_threads"]) 

456 

457 # Alias of `restore_original_limits` for backward compatibility 

458 unregister = restore_original_limits 

459 

460 def get_original_num_threads(self): 

461 """Original num_threads from before calling threadpool_limits 

462 

463 Return a dict `{user_api: num_threads}`. 

464 """ 

465 num_threads = {} 

466 warning_apis = [] 

467 

468 for user_api in self._user_api: 

469 limits = [ 

470 lib_info["num_threads"] 

471 for lib_info in self._original_info 

472 if lib_info["user_api"] == user_api 

473 ] 

474 limits = set(limits) 

475 n_limits = len(limits) 

476 

477 if n_limits == 1: 

478 limit = limits.pop() 

479 elif n_limits == 0: 

480 limit = None 

481 else: 

482 limit = min(limits) 

483 warning_apis.append(user_api) 

484 

485 num_threads[user_api] = limit 

486 

487 if warning_apis: 

488 warnings.warn( 

489 "Multiple value possible for following user apis: " 

490 + ", ".join(warning_apis) 

491 + ". Returning the minimum." 

492 ) 

493 

494 return num_threads 

495 

496 def _check_params(self, limits, user_api): 

497 """Suitable values for the _limits, _user_api and _prefixes attributes""" 

498 

499 if isinstance(limits, str) and limits == "sequential_blas_under_openmp": 

500 ( 

501 limits, 

502 user_api, 

503 ) = self._controller._get_params_for_sequential_blas_under_openmp().values() 

504 

505 if limits is None or isinstance(limits, int): 

506 if user_api is None: 

507 user_api = _ALL_USER_APIS 

508 elif user_api in _ALL_USER_APIS: 

509 user_api = [user_api] 

510 else: 

511 raise ValueError( 

512 f"user_api must be either in {_ALL_USER_APIS} or None. Got " 

513 f"{user_api} instead." 

514 ) 

515 

516 if limits is not None: 

517 limits = {api: limits for api in user_api} 

518 prefixes = [] 

519 else: 

520 if isinstance(limits, list): 

521 # This should be a list of dicts of library info, for 

522 # compatibility with the result from threadpool_info. 

523 limits = { 

524 lib_info["prefix"]: lib_info["num_threads"] for lib_info in limits 

525 } 

526 elif isinstance(limits, ThreadpoolController): 

527 # To set the limits from the library controllers of a 

528 # ThreadpoolController object. 

529 limits = { 

530 lib_controller.prefix: lib_controller.num_threads 

531 for lib_controller in limits.lib_controllers 

532 } 

533 

534 if not isinstance(limits, dict): 

535 raise TypeError( 

536 "limits must either be an int, a list, a dict, or " 

537 f"'sequential_blas_under_openmp'. Got {type(limits)} instead" 

538 ) 

539 

540 # With a dictionary, can set both specific limit for given 

541 # libraries and global limit for user_api. Fetch each separately. 

542 prefixes = [prefix for prefix in limits if prefix in _ALL_PREFIXES] 

543 user_api = [api for api in limits if api in _ALL_USER_APIS] 

544 

545 return limits, user_api, prefixes 

546 

547 def _set_threadpool_limits(self): 

548 """Change the maximal number of threads in selected thread pools. 

549 

550 Return a list with all the supported libraries that have been found 

551 matching `self._prefixes` and `self._user_api`. 

552 """ 

553 if self._limits is None: 

554 return 

555 

556 for lib_controller in self._controller.lib_controllers: 

557 # self._limits is a dict {key: num_threads} where key is either 

558 # a prefix or a user_api. If a library matches both, the limit 

559 # corresponding to the prefix is chosen. 

560 if lib_controller.prefix in self._limits: 

561 num_threads = self._limits[lib_controller.prefix] 

562 elif lib_controller.user_api in self._limits: 

563 num_threads = self._limits[lib_controller.user_api] 

564 else: 

565 continue 

566 

567 if num_threads is not None: 

568 lib_controller.set_num_threads(num_threads) 

569 

570 

571class _ThreadpoolLimiterDecorator(_ThreadpoolLimiter, ContextDecorator): 

572 """Same as _ThreadpoolLimiter but to be used as a decorator""" 

573 

574 def __init__(self, controller, *, limits=None, user_api=None): 

575 self._limits, self._user_api, self._prefixes = self._check_params( 

576 limits, user_api 

577 ) 

578 self._controller = controller 

579 

580 def __enter__(self): 

581 # we need to set the limits here and not in the __init__ because we want the 

582 # limits to be set when calling the decorated function, not when creating the 

583 # decorator. 

584 self._original_info = self._controller.info() 

585 self._set_threadpool_limits() 

586 return self 

587 

588 

589@_format_docstring( 

590 USER_APIS=", ".join(f'"{api}"' for api in _ALL_USER_APIS), 

591 BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 

592 OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 

593) 

594class threadpool_limits(_ThreadpoolLimiter): 

595 """Change the maximal number of threads that can be used in thread pools. 

596 

597 This object can be used either as a callable (the construction of this object 

598 limits the number of threads), as a context manager in a `with` block to 

599 automatically restore the original state of the controlled libraries when exiting 

600 the block, or as a decorator through its `wrap` method. 

601 

602 Set the maximal number of threads that can be used in thread pools used in 

603 the supported libraries to `limit`. This function works for libraries that 

604 are already loaded in the interpreter and can be changed dynamically. 

605 

606 This effect is global and impacts the whole Python process. There is no thread level 

607 isolation as these libraries do not offer thread-local APIs to configure the number 

608 of threads to use in nested parallel calls. 

609 

610 Parameters 

611 ---------- 

612 limits : int, dict, 'sequential_blas_under_openmp' or None (default=None) 

613 The maximal number of threads that can be used in thread pools 

614 

615 - If int, sets the maximum number of threads to `limits` for each 

616 library selected by `user_api`. 

617 

618 - If it is a dictionary `{{key: max_threads}}`, this function sets a 

619 custom maximum number of threads for each `key` which can be either a 

620 `user_api` or a `prefix` for a specific library. 

621 

622 - If 'sequential_blas_under_openmp', it will chose the appropriate `limits` 

623 and `user_api` parameters for the specific use case of sequential BLAS 

624 calls within an OpenMP parallel region. The `user_api` parameter is 

625 ignored. 

626 

627 - If None, this function does not do anything. 

628 

629 user_api : {USER_APIS} or None (default=None) 

630 APIs of libraries to limit. Used only if `limits` is an int. 

631 

632 - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 

633 

634 - If "openmp", it will only limit OpenMP supported libraries 

635 ({OPENMP_LIBS}). Note that it can affect the number of threads used 

636 by the BLAS libraries if they rely on OpenMP. 

637 

638 - If None, this function will apply to all supported libraries. 

639 """ 

640 

641 def __init__(self, limits=None, user_api=None): 

642 super().__init__(ThreadpoolController(), limits=limits, user_api=user_api) 

643 

644 @classmethod 

645 def wrap(cls, limits=None, user_api=None): 

646 return super().wrap(ThreadpoolController(), limits=limits, user_api=user_api) 

647 

648 

649class ThreadpoolController: 

650 """Collection of LibController objects for all loaded supported libraries 

651 

652 Attributes 

653 ---------- 

654 lib_controllers : list of `LibController` objects 

655 The list of library controllers of all loaded supported libraries. 

656 """ 

657 

658 # Cache for libc under POSIX and a few system libraries under Windows. 

659 # We use a class level cache instead of an instance level cache because 

660 # it's very unlikely that a shared library will be unloaded and reloaded 

661 # during the lifetime of a program. 

662 _system_libraries = dict() 

663 

664 def __init__(self): 

665 self.lib_controllers = [] 

666 self._load_libraries() 

667 self._warn_if_incompatible_openmp() 

668 

669 @classmethod 

670 def _from_controllers(cls, lib_controllers): 

671 new_controller = cls.__new__(cls) 

672 new_controller.lib_controllers = lib_controllers 

673 return new_controller 

674 

675 def info(self): 

676 """Return lib_controllers info as a list of dicts""" 

677 return [lib_controller.info() for lib_controller in self.lib_controllers] 

678 

679 def select(self, **kwargs): 

680 """Return a ThreadpoolController containing a subset of its current 

681 library controllers 

682 

683 It will select all libraries matching at least one pair (key, value) from kwargs 

684 where key is an entry of the library info dict (like "user_api", "internal_api", 

685 "prefix", ...) and value is the value or a list of acceptable values for that 

686 entry. 

687 

688 For instance, `ThreadpoolController().select(internal_api=["blis", "openblas"])` 

689 will select all library controllers whose internal_api is either "blis" or 

690 "openblas". 

691 """ 

692 for key, vals in kwargs.items(): 

693 kwargs[key] = [vals] if not isinstance(vals, list) else vals 

694 

695 lib_controllers = [ 

696 lib_controller 

697 for lib_controller in self.lib_controllers 

698 if any( 

699 getattr(lib_controller, key, None) in vals 

700 for key, vals in kwargs.items() 

701 ) 

702 ] 

703 

704 return ThreadpoolController._from_controllers(lib_controllers) 

705 

706 def _get_params_for_sequential_blas_under_openmp(self): 

707 """Return appropriate params to use for a sequential BLAS call in an OpenMP loop 

708 

709 This function takes into account the unexpected behavior of OpenBLAS with the 

710 OpenMP threading layer. 

711 """ 

712 if self.select( 

713 internal_api="openblas", threading_layer="openmp" 

714 ).lib_controllers: 

715 return {"limits": None, "user_api": None} 

716 return {"limits": 1, "user_api": "blas"} 

717 

718 @_format_docstring( 

719 USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), 

720 BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 

721 OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 

722 ) 

723 def limit(self, *, limits=None, user_api=None): 

724 """Change the maximal number of threads that can be used in thread pools. 

725 

726 This function returns an object that can be used either as a callable (the 

727 construction of this object limits the number of threads) or as a context 

728 manager, in a `with` block to automatically restore the original state of the 

729 controlled libraries when exiting the block. 

730 

731 Set the maximal number of threads that can be used in thread pools used in 

732 the supported libraries to `limits`. This function works for libraries that 

733 are already loaded in the interpreter and can be changed dynamically. 

734 

735 This effect is global and impacts the whole Python process. There is no thread 

736 level isolation as these libraries do not offer thread-local APIs to configure 

737 the number of threads to use in nested parallel calls. 

738 

739 Parameters 

740 ---------- 

741 limits : int, dict, 'sequential_blas_under_openmp' or None (default=None) 

742 The maximal number of threads that can be used in thread pools 

743 

744 - If int, sets the maximum number of threads to `limits` for each 

745 library selected by `user_api`. 

746 

747 - If it is a dictionary `{{key: max_threads}}`, this function sets a 

748 custom maximum number of threads for each `key` which can be either a 

749 `user_api` or a `prefix` for a specific library. 

750 

751 - If 'sequential_blas_under_openmp', it will chose the appropriate `limits` 

752 and `user_api` parameters for the specific use case of sequential BLAS 

753 calls within an OpenMP parallel region. The `user_api` parameter is 

754 ignored. 

755 

756 - If None, this function does not do anything. 

757 

758 user_api : {USER_APIS} or None (default=None) 

759 APIs of libraries to limit. Used only if `limits` is an int. 

760 

761 - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 

762 

763 - If "openmp", it will only limit OpenMP supported libraries 

764 ({OPENMP_LIBS}). Note that it can affect the number of threads used 

765 by the BLAS libraries if they rely on OpenMP. 

766 

767 - If None, this function will apply to all supported libraries. 

768 """ 

769 return _ThreadpoolLimiter(self, limits=limits, user_api=user_api) 

770 

771 @_format_docstring( 

772 USER_APIS=", ".join('"{}"'.format(api) for api in _ALL_USER_APIS), 

773 BLAS_LIBS=", ".join(_ALL_BLAS_LIBRARIES), 

774 OPENMP_LIBS=", ".join(_ALL_OPENMP_LIBRARIES), 

775 ) 

776 def wrap(self, *, limits=None, user_api=None): 

777 """Change the maximal number of threads that can be used in thread pools. 

778 

779 This function returns an object that can be used as a decorator. 

780 

781 Set the maximal number of threads that can be used in thread pools used in 

782 the supported libraries to `limits`. This function works for libraries that 

783 are already loaded in the interpreter and can be changed dynamically. 

784 

785 Parameters 

786 ---------- 

787 limits : int, dict or None (default=None) 

788 The maximal number of threads that can be used in thread pools 

789 

790 - If int, sets the maximum number of threads to `limits` for each 

791 library selected by `user_api`. 

792 

793 - If it is a dictionary `{{key: max_threads}}`, this function sets a 

794 custom maximum number of threads for each `key` which can be either a 

795 `user_api` or a `prefix` for a specific library. 

796 

797 - If None, this function does not do anything. 

798 

799 user_api : {USER_APIS} or None (default=None) 

800 APIs of libraries to limit. Used only if `limits` is an int. 

801 

802 - If "blas", it will only limit BLAS supported libraries ({BLAS_LIBS}). 

803 

804 - If "openmp", it will only limit OpenMP supported libraries 

805 ({OPENMP_LIBS}). Note that it can affect the number of threads used 

806 by the BLAS libraries if they rely on OpenMP. 

807 

808 - If None, this function will apply to all supported libraries. 

809 """ 

810 return _ThreadpoolLimiter.wrap(self, limits=limits, user_api=user_api) 

811 

812 def __len__(self): 

813 return len(self.lib_controllers) 

814 

815 def _load_libraries(self): 

816 """Loop through loaded shared libraries and store the supported ones""" 

817 if sys.platform == "darwin": 

818 self._find_libraries_with_dyld() 

819 elif sys.platform == "win32": 

820 self._find_libraries_with_enum_process_module_ex() 

821 else: 

822 self._find_libraries_with_dl_iterate_phdr() 

823 

824 def _find_libraries_with_dl_iterate_phdr(self): 

825 """Loop through loaded libraries and return binders on supported ones 

826 

827 This function is expected to work on POSIX system only. 

828 This code is adapted from code by Intel developer @anton-malakhov 

829 available at https://github.com/IntelPython/smp 

830 

831 Copyright (c) 2017, Intel Corporation published under the BSD 3-Clause 

832 license 

833 """ 

834 libc = self._get_libc() 

835 if not hasattr(libc, "dl_iterate_phdr"): # pragma: no cover 

836 return [] 

837 

838 # Callback function for `dl_iterate_phdr` which is called for every 

839 # library loaded in the current process until it returns 1. 

840 def match_library_callback(info, size, data): 

841 # Get the path of the current library 

842 filepath = info.contents.dlpi_name 

843 if filepath: 

844 filepath = filepath.decode("utf-8") 

845 

846 # Store the library controller if it is supported and selected 

847 self._make_controller_from_path(filepath) 

848 return 0 

849 

850 c_func_signature = ctypes.CFUNCTYPE( 

851 ctypes.c_int, # Return type 

852 ctypes.POINTER(_dl_phdr_info), 

853 ctypes.c_size_t, 

854 ctypes.c_char_p, 

855 ) 

856 c_match_library_callback = c_func_signature(match_library_callback) 

857 

858 data = ctypes.c_char_p(b"") 

859 libc.dl_iterate_phdr(c_match_library_callback, data) 

860 

861 def _find_libraries_with_dyld(self): 

862 """Loop through loaded libraries and return binders on supported ones 

863 

864 This function is expected to work on OSX system only 

865 """ 

866 libc = self._get_libc() 

867 if not hasattr(libc, "_dyld_image_count"): # pragma: no cover 

868 return [] 

869 

870 n_dyld = libc._dyld_image_count() 

871 libc._dyld_get_image_name.restype = ctypes.c_char_p 

872 

873 for i in range(n_dyld): 

874 filepath = ctypes.string_at(libc._dyld_get_image_name(i)) 

875 filepath = filepath.decode("utf-8") 

876 

877 # Store the library controller if it is supported and selected 

878 self._make_controller_from_path(filepath) 

879 

880 def _find_libraries_with_enum_process_module_ex(self): 

881 """Loop through loaded libraries and return binders on supported ones 

882 

883 This function is expected to work on windows system only. 

884 This code is adapted from code by Philipp Hagemeister @phihag available 

885 at https://stackoverflow.com/questions/17474574 

886 """ 

887 from ctypes.wintypes import DWORD, HMODULE, MAX_PATH 

888 

889 PROCESS_QUERY_INFORMATION = 0x0400 

890 PROCESS_VM_READ = 0x0010 

891 

892 LIST_LIBRARIES_ALL = 0x03 

893 

894 ps_api = self._get_windll("Psapi") 

895 kernel_32 = self._get_windll("kernel32") 

896 

897 h_process = kernel_32.OpenProcess( 

898 PROCESS_QUERY_INFORMATION | PROCESS_VM_READ, False, os.getpid() 

899 ) 

900 if not h_process: # pragma: no cover 

901 raise OSError(f"Could not open PID {os.getpid()}") 

902 

903 try: 

904 buf_count = 256 

905 needed = DWORD() 

906 # Grow the buffer until it becomes large enough to hold all the 

907 # module headers 

908 while True: 

909 buf = (HMODULE * buf_count)() 

910 buf_size = ctypes.sizeof(buf) 

911 if not ps_api.EnumProcessModulesEx( 

912 h_process, 

913 ctypes.byref(buf), 

914 buf_size, 

915 ctypes.byref(needed), 

916 LIST_LIBRARIES_ALL, 

917 ): 

918 raise OSError("EnumProcessModulesEx failed") 

919 if buf_size >= needed.value: 

920 break 

921 buf_count = needed.value // (buf_size // buf_count) 

922 

923 count = needed.value // (buf_size // buf_count) 

924 h_modules = map(HMODULE, buf[:count]) 

925 

926 # Loop through all the module headers and get the library path 

927 buf = ctypes.create_unicode_buffer(MAX_PATH) 

928 n_size = DWORD() 

929 for h_module in h_modules: 

930 # Get the path of the current module 

931 if not ps_api.GetModuleFileNameExW( 

932 h_process, h_module, ctypes.byref(buf), ctypes.byref(n_size) 

933 ): 

934 raise OSError("GetModuleFileNameEx failed") 

935 filepath = buf.value 

936 

937 # Store the library controller if it is supported and selected 

938 self._make_controller_from_path(filepath) 

939 finally: 

940 kernel_32.CloseHandle(h_process) 

941 

942 def _make_controller_from_path(self, filepath): 

943 """Store a library controller if it is supported and selected""" 

944 # Required to resolve symlinks 

945 filepath = _realpath(filepath) 

946 # `lower` required to take account of OpenMP dll case on Windows 

947 # (vcomp, VCOMP, Vcomp, ...) 

948 filename = os.path.basename(filepath).lower() 

949 

950 # Loop through supported libraries to find if this filename corresponds 

951 # to a supported one. 

952 for controller_class in _ALL_CONTROLLERS: 

953 # check if filename matches a supported prefix 

954 prefix = self._check_prefix(filename, controller_class.filename_prefixes) 

955 

956 # filename does not match any of the prefixes of the candidate 

957 # library. move to next library. 

958 if prefix is None: 

959 continue 

960 

961 # workaround for BLAS libraries packaged by conda-forge on windows, which 

962 # are all renamed "libblas.dll". We thus have to check to which BLAS 

963 # implementation it actually corresponds looking for implementation 

964 # specific symbols. 

965 if prefix == "libblas": 

966 if filename.endswith(".dll"): 

967 libblas = ctypes.CDLL(filepath, _RTLD_NOLOAD) 

968 if not any( 

969 hasattr(libblas, func) 

970 for func in controller_class.check_symbols 

971 ): 

972 continue 

973 else: 

974 # We ignore libblas on other platforms than windows because there 

975 # might be a libblas dso comming with openblas for instance that 

976 # can't be used to instantiate a pertinent LibController (many 

977 # symbols are missing) and would create confusion by making a 

978 # duplicate entry in threadpool_info. 

979 continue 

980 

981 # filename matches a prefix. Create and store the library 

982 # controller. 

983 

984 lib_controller = controller_class(filepath=filepath, prefix=prefix) 

985 self.lib_controllers.append(lib_controller) 

986 

987 def _check_prefix(self, library_basename, filename_prefixes): 

988 """Return the prefix library_basename starts with 

989 

990 Return None if none matches. 

991 """ 

992 for prefix in filename_prefixes: 

993 if library_basename.startswith(prefix): 

994 return prefix 

995 return None 

996 

997 def _warn_if_incompatible_openmp(self): 

998 """Raise a warning if llvm-OpenMP and intel-OpenMP are both loaded""" 

999 prefixes = [lib_controller.prefix for lib_controller in self.lib_controllers] 

1000 msg = textwrap.dedent(""" 

1001 Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at 

1002 the same time. Both libraries are known to be incompatible and this 

1003 can cause random crashes or deadlocks on Linux when loaded in the 

1004 same Python program. 

1005 Using threadpoolctl may cause crashes or deadlocks. For more 

1006 information and possible workarounds, please see 

1007 https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md 

1008 """) 

1009 if "libomp" in prefixes and "libiomp" in prefixes: 

1010 warnings.warn(msg, RuntimeWarning) 

1011 

1012 @classmethod 

1013 def _get_libc(cls): 

1014 """Load the lib-C for unix systems.""" 

1015 libc = cls._system_libraries.get("libc") 

1016 if libc is None: 

1017 libc_name = find_library("c") 

1018 if libc_name is None: # pragma: no cover 

1019 warnings.warn( 

1020 "libc not found. The ctypes module in Python" 

1021 f" {sys.version_info.major}.{sys.version_info.minor} is maybe" 

1022 " too old for this OS.", 

1023 RuntimeWarning, 

1024 ) 

1025 return None 

1026 libc = ctypes.CDLL(libc_name, mode=_RTLD_NOLOAD) 

1027 cls._system_libraries["libc"] = libc 

1028 return libc 

1029 

1030 @classmethod 

1031 def _get_windll(cls, dll_name): 

1032 """Load a windows DLL""" 

1033 dll = cls._system_libraries.get(dll_name) 

1034 if dll is None: 

1035 dll = ctypes.WinDLL(f"{dll_name}.dll") 

1036 cls._system_libraries[dll_name] = dll 

1037 return dll 

1038 

1039 

1040def _main(): 

1041 """Commandline interface to display thread-pool information and exit.""" 

1042 import argparse 

1043 import importlib 

1044 import json 

1045 import sys 

1046 

1047 parser = argparse.ArgumentParser( 

1048 usage="python -m threadpoolctl -i numpy scipy.linalg xgboost", 

1049 description="Display thread-pool information and exit.", 

1050 ) 

1051 parser.add_argument( 

1052 "-i", 

1053 "--import", 

1054 dest="modules", 

1055 nargs="*", 

1056 default=(), 

1057 help="Python modules to import before introspecting thread-pools.", 

1058 ) 

1059 parser.add_argument( 

1060 "-c", 

1061 "--command", 

1062 help="a Python statement to execute before introspecting thread-pools.", 

1063 ) 

1064 

1065 options = parser.parse_args(sys.argv[1:]) 

1066 for module in options.modules: 

1067 try: 

1068 importlib.import_module(module, package=None) 

1069 except ImportError: 

1070 print("WARNING: could not import", module, file=sys.stderr) 

1071 

1072 if options.command: 

1073 exec(options.command) 

1074 

1075 print(json.dumps(threadpool_info(), indent=2)) 

1076 

1077 

1078if __name__ == "__main__": 

1079 _main()