Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/path_random.py: 1%

165 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-10-05 06:32 +0000

1""" 

2Support for random optimizers, including the random-greedy path. 

3""" 

4 

5import functools 

6import heapq 

7import math 

8import numbers 

9import time 

10from collections import deque 

11 

12from . import helpers, paths 

13 

14# random.choices was introduced in python 3.6 

15try: 

16 from random import choices as random_choices 

17 from random import seed as random_seed 

18except ImportError: 

19 import numpy as np 

20 

21 def random_choices(population, weights): 

22 norm = sum(weights) 

23 return np.random.choice(population, p=[w / norm for w in weights], size=1) 

24 

25 random_seed = np.random.seed 

26 

27__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"] 

28 

29 

30class RandomOptimizer(paths.PathOptimizer): 

31 """Base class for running any random path finder that benefits 

32 from repeated calling, possibly in a parallel fashion. Custom random 

33 optimizers should subclass this, and the ``setup`` method should be 

34 implemented with the following signature:: 

35 

36 def setup(self, inputs, output, size_dict): 

37 # custom preparation here ... 

38 return trial_fn, trial_args 

39 

40 Where ``trial_fn`` itself should have the signature:: 

41 

42 def trial_fn(r, *trial_args): 

43 # custom computation of path here 

44 return ssa_path, cost, size 

45 

46 Where ``r`` is the run number and could for example be used to seed a 

47 random number generator. See ``RandomGreedy`` for an example. 

48 

49 

50 Parameters 

51 ---------- 

52 max_repeats : int, optional 

53 The maximum number of repeat trials to have. 

54 max_time : float, optional 

55 The maximum amount of time to run the algorithm for. 

56 minimize : {'flops', 'size'}, optional 

57 Whether to favour paths that minimize the total estimated flop-count or 

58 the size of the largest intermediate created. 

59 parallel : {bool, int, or executor-pool like}, optional 

60 Whether to parallelize the random trials, by default ``False``. If 

61 ``True``, use a ``concurrent.futures.ProcessPoolExecutor`` with the same 

62 number of processes as cores. If an integer is specified, use that many 

63 processes instead. Finally, you can supply a custom executor-pool which 

64 should have an API matching that of the python 3 standard library 

65 module ``concurrent.futures``. Namely, a ``submit`` method that returns 

66 ``Future`` objects, themselves with ``result`` and ``cancel`` methods. 

67 pre_dispatch : int, optional 

68 If running in parallel, how many jobs to pre-dispatch so as to avoid 

69 submitting all jobs at once. Should also be more than twice the number 

70 of workers to avoid under-subscription. Default: 128. 

71 

72 Attributes 

73 ---------- 

74 path : list[tuple[int]] 

75 The best path found so far. 

76 costs : list[int] 

77 The list of each trial's costs found so far. 

78 sizes : list[int] 

79 The list of each trial's largest intermediate size so far. 

80 

81 See Also 

82 -------- 

83 RandomGreedy 

84 """ 

85 def __init__(self, max_repeats=32, max_time=None, minimize='flops', parallel=False, pre_dispatch=128): 

86 

87 if minimize not in ('flops', 'size'): 

88 raise ValueError("`minimize` should be one of {'flops', 'size'}.") 

89 

90 self.max_repeats = max_repeats 

91 self.max_time = max_time 

92 self.minimize = minimize 

93 self.better = paths.get_better_fn(minimize) 

94 self.parallel = parallel 

95 self.pre_dispatch = pre_dispatch 

96 

97 self.costs = [] 

98 self.sizes = [] 

99 self.best = {'flops': float('inf'), 'size': float('inf')} 

100 

101 self._repeats_start = 0 

102 

103 @property 

104 def path(self): 

105 """The best path found so far. 

106 """ 

107 return paths.ssa_to_linear(self.best['ssa_path']) 

108 

109 @property 

110 def parallel(self): 

111 return self._parallel 

112 

113 @parallel.setter 

114 def parallel(self, parallel): 

115 # shutdown any previous executor if we are managing it 

116 if getattr(self, '_managing_executor', False): 

117 self._executor.shutdown() 

118 

119 self._parallel = parallel 

120 self._managing_executor = False 

121 

122 if parallel is False: 

123 self._executor = None 

124 return 

125 

126 if parallel is True: 

127 from concurrent.futures import ProcessPoolExecutor 

128 self._executor = ProcessPoolExecutor() 

129 self._managing_executor = True 

130 return 

131 

132 if isinstance(parallel, numbers.Number): 

133 from concurrent.futures import ProcessPoolExecutor 

134 self._executor = ProcessPoolExecutor(parallel) 

135 self._managing_executor = True 

136 return 

137 

138 # assume a pool-executor has been supplied 

139 self._executor = parallel 

140 

141 def _gen_results_parallel(self, repeats, trial_fn, args): 

142 """Lazily generate results from an executor without submitting all jobs at once. 

143 """ 

144 self._futures = deque() 

145 

146 # the idea here is to submit at least ``pre_dispatch`` jobs *before* we 

147 # yield any results, then do both in tandem, before draining the queue 

148 for r in repeats: 

149 if len(self._futures) < self.pre_dispatch: 

150 self._futures.append(self._executor.submit(trial_fn, r, *args)) 

151 continue 

152 yield self._futures.popleft().result() 

153 

154 while self._futures: 

155 yield self._futures.popleft().result() 

156 

157 def _cancel_futures(self): 

158 if self._executor is not None: 

159 for f in self._futures: 

160 f.cancel() 

161 

162 def setup(self, inputs, output, size_dict): 

163 raise NotImplementedError 

164 

165 def __call__(self, inputs, output, size_dict, memory_limit): 

166 self._check_args_against_first_call(inputs, output, size_dict) 

167 

168 # start a timer? 

169 if self.max_time is not None: 

170 t0 = time.time() 

171 

172 trial_fn, trial_args = self.setup(inputs, output, size_dict) 

173 

174 r_start = self._repeats_start + len(self.costs) 

175 r_stop = r_start + self.max_repeats 

176 repeats = range(r_start, r_stop) 

177 

178 # create the trials lazily 

179 if self._executor is not None: 

180 trials = self._gen_results_parallel(repeats, trial_fn, trial_args) 

181 else: 

182 trials = (trial_fn(r, *trial_args) for r in repeats) 

183 

184 # assess the trials 

185 for ssa_path, cost, size in trials: 

186 

187 # keep track of all costs and sizes 

188 self.costs.append(cost) 

189 self.sizes.append(size) 

190 

191 # check if we have found a new best 

192 found_new_best = self.better(cost, size, self.best['flops'], self.best['size']) 

193 

194 if found_new_best: 

195 self.best['flops'] = cost 

196 self.best['size'] = size 

197 self.best['ssa_path'] = ssa_path 

198 

199 # check if we have run out of time 

200 if (self.max_time is not None) and (time.time() > t0 + self.max_time): 

201 break 

202 

203 self._cancel_futures() 

204 return self.path 

205 

206 def __del__(self): 

207 # if we created the parallel pool-executor, shut it down 

208 if getattr(self, '_managing_executor', False): 

209 self._executor.shutdown() 

210 

211 

212def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True): 

213 """A contraction 'chooser' that weights possible contractions using a 

214 Boltzmann distribution. Explicitly, given costs ``c_i`` (with ``c_0`` the 

215 smallest), the relative weights, ``w_i``, are computed as: 

216 

217 w_i = exp( -(c_i - c_0) / temperature) 

218 

219 Additionally, if ``rel_temperature`` is set, scale ``temperature`` by 

220 ``abs(c_0)`` to account for likely fluctuating cost magnitudes during the 

221 course of a contraction. 

222 

223 Parameters 

224 ---------- 

225 queue : list 

226 The heapified list of candidate contractions. 

227 remaining : dict[str, int] 

228 Mapping of remaining inputs' indices to the ssa id. 

229 temperature : float, optional 

230 When choosing a possible contraction, its relative probability will be 

231 proportional to ``exp(-cost / temperature)``. Thus the larger 

232 ``temperature`` is, the further random paths will stray from the normal 

233 'greedy' path. Conversely, if set to zero, only paths with exactly the 

234 same cost as the best at each step will be explored. 

235 rel_temperature : bool, optional 

236 Whether to normalize the ``temperature`` at each step to the scale of 

237 the best cost. This is generally beneficial as the magnitude of costs 

238 can vary significantly throughout a contraction. 

239 nbranch : int, optional 

240 How many potential paths to calculate probability for and choose from 

241 at each step. 

242 

243 Returns 

244 ------- 

245 cost, k1, k2, k12 

246 """ 

247 n = 0 

248 choices = [] 

249 while queue and n < nbranch: 

250 cost, k1, k2, k12 = heapq.heappop(queue) 

251 if k1 not in remaining or k2 not in remaining: 

252 continue # candidate is obsolete 

253 choices.append((cost, k1, k2, k12)) 

254 n += 1 

255 

256 if n == 0: 

257 return None 

258 if n == 1: 

259 return choices[0] 

260 

261 costs = [choice[0][0] for choice in choices] 

262 cmin = costs[0] 

263 

264 # adjust by the overall scale to account for fluctuating absolute costs 

265 if rel_temperature: 

266 temperature *= max(1, abs(cmin)) 

267 

268 # compute relative probability for each potential contraction 

269 if temperature == 0.0: 

270 energies = [1 if c == cmin else 0 for c in costs] 

271 else: 

272 # shift by cmin for numerical reasons 

273 energies = [math.exp(-(c - cmin) / temperature) for c in costs] 

274 

275 # randomly choose a contraction based on energies 

276 chosen, = random_choices(range(n), weights=energies) 

277 cost, k1, k2, k12 = choices.pop(chosen) 

278 

279 # put the other choise back in the heap 

280 for other in choices: 

281 heapq.heappush(queue, other) 

282 

283 return cost, k1, k2, k12 

284 

285 

286def ssa_path_compute_cost(ssa_path, inputs, output, size_dict): 

287 """Compute the flops and max size of an ssa path. 

288 """ 

289 inputs = list(map(frozenset, inputs)) 

290 output = frozenset(output) 

291 remaining = set(range(len(inputs))) 

292 total_cost = 0 

293 max_size = 0 

294 

295 for i, j in ssa_path: 

296 k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict) 

297 remaining.discard(i) 

298 remaining.discard(j) 

299 remaining.add(len(inputs)) 

300 inputs.append(k12) 

301 total_cost += flops12 

302 max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict)) 

303 

304 return total_cost, max_size 

305 

306 

307def _trial_greedy_ssa_path_and_cost(r, inputs, output, size_dict, choose_fn, cost_fn): 

308 """A single, repeatable, greedy trial run. Returns ``ssa_path`` and cost. 

309 """ 

310 if r == 0: 

311 # always start with the standard greedy approach 

312 choose_fn = None 

313 

314 random_seed(r) 

315 

316 ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn) 

317 cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict) 

318 

319 return ssa_path, cost, size 

320 

321 

322class RandomGreedy(RandomOptimizer): 

323 """ 

324 

325 Parameters 

326 ---------- 

327 cost_fn : callable, optional 

328 A function that returns a heuristic 'cost' of a potential contraction 

329 with which to sort candidates. Should have signature 

330 ``cost_fn(size12, size1, size2, k12, k1, k2)``. 

331 temperature : float, optional 

332 When choosing a possible contraction, its relative probability will be 

333 proportional to ``exp(-cost / temperature)``. Thus the larger 

334 ``temperature`` is, the further random paths will stray from the normal 

335 'greedy' path. Conversely, if set to zero, only paths with exactly the 

336 same cost as the best at each step will be explored. 

337 rel_temperature : bool, optional 

338 Whether to normalize the ``temperature`` at each step to the scale of 

339 the best cost. This is generally beneficial as the magnitude of costs 

340 can vary significantly throughout a contraction. If False, the 

341 algorithm will end up branching when the absolute cost is low, but 

342 stick to the 'greedy' path when the cost is high - this can also be 

343 beneficial. 

344 nbranch : int, optional 

345 How many potential paths to calculate probability for and choose from 

346 at each step. 

347 kwargs 

348 Supplied to RandomOptimizer. 

349 

350 See Also 

351 -------- 

352 RandomOptimizer 

353 """ 

354 def __init__(self, cost_fn='memory-removed-jitter', temperature=1.0, rel_temperature=True, nbranch=8, **kwargs): 

355 self.cost_fn = cost_fn 

356 self.temperature = temperature 

357 self.rel_temperature = rel_temperature 

358 self.nbranch = nbranch 

359 super().__init__(**kwargs) 

360 

361 @property 

362 def choose_fn(self): 

363 """The function that chooses which contraction to take - make this a 

364 property so that ``temperature`` and ``nbranch`` etc. can be updated 

365 between runs. 

366 """ 

367 if self.nbranch == 1: 

368 return None 

369 

370 return functools.partial(thermal_chooser, 

371 temperature=self.temperature, 

372 nbranch=self.nbranch, 

373 rel_temperature=self.rel_temperature) 

374 

375 def setup(self, inputs, output, size_dict): 

376 fn = _trial_greedy_ssa_path_and_cost 

377 args = (inputs, output, size_dict, self.choose_fn, self.cost_fn) 

378 return fn, args 

379 

380 

381def random_greedy(inputs, output, idx_dict, memory_limit=None, **optimizer_kwargs): 

382 """ 

383 """ 

384 optimizer = RandomGreedy(**optimizer_kwargs) 

385 return optimizer(inputs, output, idx_dict, memory_limit) 

386 

387 

388random_greedy_128 = functools.partial(random_greedy, max_repeats=128)