Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/path_random.py: 22%

163 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Support for random optimizers, including the random-greedy path. 

3""" 

4 

5import functools 

6import heapq 

7import math 

8import numbers 

9import time 

10from collections import deque 

11from random import choices as random_choices 

12from random import seed as random_seed 

13from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple 

14 

15from . import helpers, paths 

16from .typing import ArrayIndexType, ArrayType, PathType 

17 

18__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"] 

19 

20 

21class RandomOptimizer(paths.PathOptimizer): 

22 """Base class for running any random path finder that benefits 

23 from repeated calling, possibly in a parallel fashion. Custom random 

24 optimizers should subclass this, and the `setup` method should be 

25 implemented with the following signature: 

26 

27 ```python 

28 def setup(self, inputs, output, size_dict): 

29 # custom preparation here ... 

30 return trial_fn, trial_args 

31 ``` 

32 

33 Where `trial_fn` itself should have the signature:: 

34 

35 ```python 

36 def trial_fn(r, *trial_args): 

37 # custom computation of path here 

38 return ssa_path, cost, size 

39 ``` 

40 

41 Where `r` is the run number and could for example be used to seed a 

42 random number generator. See `RandomGreedy` for an example. 

43 

44 

45 **Parameters:** 

46 

47 - **max_repeats** - *(int, optional)* The maximum number of repeat trials to have. 

48 - **max_time** - *(float, optional)* The maximum amount of time to run the algorithm for. 

49 - **minimize** - *({'flops', 'size'}, optional)* Whether to favour paths that minimize the total estimated flop-count or 

50 the size of the largest intermediate created. 

51 - **parallel** - *({bool, int, or executor-pool like}, optional)* Whether to parallelize the random trials, by default `False`. If 

52 `True`, use a `concurrent.futures.ProcessPoolExecutor` with the same 

53 number of processes as cores. If an integer is specified, use that many 

54 processes instead. Finally, you can supply a custom executor-pool which 

55 should have an API matching that of the python 3 standard library 

56 module `concurrent.futures`. Namely, a `submit` method that returns 

57 `Future` objects, themselves with `result` and `cancel` methods. 

58 - **pre_dispatch** - *(int, optional)* If running in parallel, how many jobs to pre-dispatch so as to avoid 

59 submitting all jobs at once. Should also be more than twice the number 

60 of workers to avoid under-subscription. Default: 128. 

61 

62 **Attributes:** 

63 

64 - **path** - *(list[tuple[int]])* The best path found so far. 

65 - **costs** - *(list[int])* The list of each trial's costs found so far. 

66 - **sizes** - *(list[int])* The list of each trial's largest intermediate size so far. 

67 """ 

68 

69 def __init__( 

70 self, 

71 max_repeats: int = 32, 

72 max_time: Optional[float] = None, 

73 minimize: str = "flops", 

74 parallel: bool = False, 

75 pre_dispatch: int = 128, 

76 ): 

77 

78 if minimize not in ("flops", "size"): 

79 raise ValueError("`minimize` should be one of {'flops', 'size'}.") 

80 

81 self.max_repeats = max_repeats 

82 self.max_time = max_time 

83 self.minimize = minimize 

84 self.better = paths.get_better_fn(minimize) 

85 self._parallel = False 

86 self.parallel = parallel 

87 self.pre_dispatch = pre_dispatch 

88 

89 self.costs: List[int] = [] 

90 self.sizes: List[int] = [] 

91 self.best: Dict[str, Any] = {"flops": float("inf"), "size": float("inf")} 

92 

93 self._repeats_start = 0 

94 self._executor: Any 

95 self._futures: Any 

96 

97 @property 

98 def path(self) -> PathType: 

99 """The best path found so far.""" 

100 return paths.ssa_to_linear(self.best["ssa_path"]) 

101 

102 @property 

103 def parallel(self) -> bool: 

104 return self._parallel 

105 

106 @parallel.setter 

107 def parallel(self, parallel: bool) -> None: 

108 # shutdown any previous executor if we are managing it 

109 if getattr(self, "_managing_executor", False): 

110 self._executor.shutdown() 

111 

112 self._parallel = parallel 

113 self._managing_executor = False 

114 

115 if parallel is False: 

116 self._executor = None 

117 return 

118 

119 if parallel is True: 

120 from concurrent.futures import ProcessPoolExecutor 

121 

122 self._executor = ProcessPoolExecutor() 

123 self._managing_executor = True 

124 return 

125 

126 if isinstance(parallel, numbers.Number): 

127 from concurrent.futures import ProcessPoolExecutor 

128 

129 self._executor = ProcessPoolExecutor(parallel) 

130 self._managing_executor = True 

131 return 

132 

133 # assume a pool-executor has been supplied 

134 self._executor = parallel 

135 

136 def _gen_results_parallel(self, repeats: Iterable[int], trial_fn: Any, args: Any) -> Generator[Any, None, None]: 

137 """Lazily generate results from an executor without submitting all jobs at once.""" 

138 self._futures = deque() 

139 

140 # the idea here is to submit at least ``pre_dispatch`` jobs *before* we 

141 # yield any results, then do both in tandem, before draining the queue 

142 for r in repeats: 

143 if len(self._futures) < self.pre_dispatch: 

144 self._futures.append(self._executor.submit(trial_fn, r, *args)) 

145 continue 

146 yield self._futures.popleft().result() 

147 

148 while self._futures: 

149 yield self._futures.popleft().result() 

150 

151 def _cancel_futures(self) -> None: 

152 if self._executor is not None: 

153 for f in self._futures: 

154 f.cancel() 

155 

156 def setup( 

157 self, 

158 inputs: List[ArrayIndexType], 

159 output: ArrayIndexType, 

160 size_dict: Dict[str, int], 

161 ) -> Tuple[Any, Any]: 

162 raise NotImplementedError 

163 

164 def __call__( 

165 self, 

166 inputs: List[ArrayIndexType], 

167 output: ArrayIndexType, 

168 size_dict: Dict[str, int], 

169 memory_limit: Optional[int] = None, 

170 ) -> PathType: 

171 self._check_args_against_first_call(inputs, output, size_dict) 

172 

173 # start a timer? 

174 if self.max_time is not None: 

175 t0 = time.time() 

176 

177 trial_fn, trial_args = self.setup(inputs, output, size_dict) 

178 

179 r_start = self._repeats_start + len(self.costs) 

180 r_stop = r_start + self.max_repeats 

181 repeats = range(r_start, r_stop) 

182 

183 # create the trials lazily 

184 if self._executor is not None: 

185 trials = self._gen_results_parallel(repeats, trial_fn, trial_args) 

186 else: 

187 trials = (trial_fn(r, *trial_args) for r in repeats) 

188 

189 # assess the trials 

190 for ssa_path, cost, size in trials: 

191 

192 # keep track of all costs and sizes 

193 self.costs.append(cost) 

194 self.sizes.append(size) 

195 

196 # check if we have found a new best 

197 found_new_best = self.better(cost, size, self.best["flops"], self.best["size"]) 

198 

199 if found_new_best: 

200 self.best["flops"] = cost 

201 self.best["size"] = size 

202 self.best["ssa_path"] = ssa_path 

203 

204 # check if we have run out of time 

205 if (self.max_time is not None) and (time.time() > t0 + self.max_time): 

206 break 

207 

208 self._cancel_futures() 

209 return self.path 

210 

211 def __del__(self): 

212 # if we created the parallel pool-executor, shut it down 

213 if getattr(self, "_managing_executor", False): 

214 self._executor.shutdown() 

215 

216 

217def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True): 

218 """A contraction 'chooser' that weights possible contractions using a 

219 Boltzmann distribution. Explicitly, given costs `c_i` (with `c_0` the 

220 smallest), the relative weights, `w_i`, are computed as: 

221 

222 $$w_i = exp( -(c_i - c_0) / temperature)$$ 

223 

224 Additionally, if `rel_temperature` is set, scale `temperature` by 

225 `abs(c_0)` to account for likely fluctuating cost magnitudes during the 

226 course of a contraction. 

227 

228 **Parameters:** 

229 

230 - **queue** - *(list)* The heapified list of candidate contractions. 

231 - **remaining** - *(dict[str, int])* Mapping of remaining inputs' indices to the ssa id. 

232 - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be 

233 proportional to `exp(-cost / temperature)`. Thus the larger 

234 `temperature` is, the further random paths will stray from the normal 

235 'greedy' path. Conversely, if set to zero, only paths with exactly the 

236 same cost as the best at each step will be explored. 

237 - **rel_temperature** - *(bool, optional)* Whether to normalize the `temperature` at each step to the scale of 

238 the best cost. This is generally beneficial as the magnitude of costs 

239 can vary significantly throughout a contraction. 

240 - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step. 

241 

242 **Returns:** 

243 

244 - **cost** 

245 - **k1** 

246 - **k2** 

247 - **k3** 

248 """ 

249 n = 0 

250 choices = [] 

251 while queue and n < nbranch: 

252 cost, k1, k2, k12 = heapq.heappop(queue) 

253 if k1 not in remaining or k2 not in remaining: 

254 continue # candidate is obsolete 

255 choices.append((cost, k1, k2, k12)) 

256 n += 1 

257 

258 if n == 0: 

259 return None 

260 if n == 1: 

261 return choices[0] 

262 

263 costs = [choice[0][0] for choice in choices] 

264 cmin = costs[0] 

265 

266 # adjust by the overall scale to account for fluctuating absolute costs 

267 if rel_temperature: 

268 temperature *= max(1, abs(cmin)) 

269 

270 # compute relative probability for each potential contraction 

271 if temperature == 0.0: 

272 energies = [1 if c == cmin else 0 for c in costs] 

273 else: 

274 # shift by cmin for numerical reasons 

275 energies = [math.exp(-(c - cmin) / temperature) for c in costs] 

276 

277 # randomly choose a contraction based on energies 

278 (chosen,) = random_choices(range(n), weights=energies) 

279 cost, k1, k2, k12 = choices.pop(chosen) 

280 

281 # put the other choice back in the heap 

282 for other in choices: 

283 heapq.heappush(queue, other) 

284 

285 return cost, k1, k2, k12 

286 

287 

288def ssa_path_compute_cost( 

289 ssa_path: PathType, 

290 inputs: List[ArrayIndexType], 

291 output: ArrayIndexType, 

292 size_dict: Dict[str, int], 

293) -> Tuple[int, int]: 

294 """Compute the flops and max size of an ssa path.""" 

295 inputs = list(map(frozenset, inputs)) # type: ignore 

296 output = frozenset(output) 

297 remaining = set(range(len(inputs))) 

298 total_cost = 0 

299 max_size = 0 

300 

301 for i, j in ssa_path: 

302 k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict) # type: ignore 

303 remaining.discard(i) 

304 remaining.discard(j) 

305 remaining.add(len(inputs)) 

306 inputs.append(k12) 

307 total_cost += flops12 

308 max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict)) 

309 

310 return total_cost, max_size 

311 

312 

313def _trial_greedy_ssa_path_and_cost( 

314 r: int, 

315 inputs: List[ArrayIndexType], 

316 output: ArrayIndexType, 

317 size_dict: Dict[str, int], 

318 choose_fn: Any, 

319 cost_fn: Any, 

320) -> Tuple[PathType, int, int]: 

321 """A single, repeatable, greedy trial run. **Returns:** ``ssa_path`` and cost.""" 

322 if r == 0: 

323 # always start with the standard greedy approach 

324 choose_fn = None 

325 

326 random_seed(r) 

327 

328 ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn) 

329 cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict) 

330 

331 return ssa_path, cost, size 

332 

333 

334class RandomGreedy(RandomOptimizer): 

335 """ 

336 

337 **Parameters:** 

338 

339 - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction 

340 with which to sort candidates. Should have signature 

341 `cost_fn(size12, size1, size2, k12, k1, k2)`. 

342 - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be 

343 proportional to `exp(-cost / temperature)`. Thus the larger 

344 `temperature` is, the further random paths will stray from the normal 

345 'greedy' path. Conversely, if set to zero, only paths with exactly the 

346 same cost as the best at each step will be explored. 

347 - **rel_temperature** - *(bool, optional)* Whether to normalize the ``temperature`` at each step to the scale of 

348 the best cost. This is generally beneficial as the magnitude of costs 

349 can vary significantly throughout a contraction. If False, the 

350 algorithm will end up branching when the absolute cost is low, but 

351 stick to the 'greedy' path when the cost is high - this can also be 

352 beneficial. 

353 - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step. 

354 - **kwargs** - Supplied to RandomOptimizer. 

355 """ 

356 

357 def __init__( 

358 self, 

359 cost_fn: str = "memory-removed-jitter", 

360 temperature: float = 1.0, 

361 rel_temperature: bool = True, 

362 nbranch: int = 8, 

363 **kwargs: Any, 

364 ): 

365 self.cost_fn = cost_fn 

366 self.temperature = temperature 

367 self.rel_temperature = rel_temperature 

368 self.nbranch = nbranch 

369 super().__init__(**kwargs) 

370 

371 @property 

372 def choose_fn(self) -> Any: 

373 """The function that chooses which contraction to take - make this a 

374 property so that ``temperature`` and ``nbranch`` etc. can be updated 

375 between runs. 

376 """ 

377 if self.nbranch == 1: 

378 return None 

379 

380 return functools.partial( 

381 thermal_chooser, 

382 temperature=self.temperature, 

383 nbranch=self.nbranch, 

384 rel_temperature=self.rel_temperature, 

385 ) 

386 

387 def setup( 

388 self, 

389 inputs: List[ArrayIndexType], 

390 output: ArrayIndexType, 

391 size_dict: Dict[str, int], 

392 ) -> Tuple[Any, Any]: 

393 fn = _trial_greedy_ssa_path_and_cost 

394 args = (inputs, output, size_dict, self.choose_fn, self.cost_fn) 

395 return fn, args 

396 

397 

398def random_greedy( 

399 inputs: List[ArrayIndexType], 

400 output: ArrayIndexType, 

401 idx_dict: Dict[str, int], 

402 memory_limit: Optional[int] = None, 

403 **optimizer_kwargs: Any, 

404) -> ArrayType: 

405 """ """ 

406 optimizer = RandomGreedy(**optimizer_kwargs) 

407 return optimizer(inputs, output, idx_dict, memory_limit) 

408 

409 

410random_greedy_128 = functools.partial(random_greedy, max_repeats=128)