Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/path_random.py: 1%
165 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
1"""
2Support for random optimizers, including the random-greedy path.
3"""
5import functools
6import heapq
7import math
8import numbers
9import time
10from collections import deque
12from . import helpers, paths
14# random.choices was introduced in python 3.6
15try:
16 from random import choices as random_choices
17 from random import seed as random_seed
18except ImportError:
19 import numpy as np
21 def random_choices(population, weights):
22 norm = sum(weights)
23 return np.random.choice(population, p=[w / norm for w in weights], size=1)
25 random_seed = np.random.seed
27__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
30class RandomOptimizer(paths.PathOptimizer):
31 """Base class for running any random path finder that benefits
32 from repeated calling, possibly in a parallel fashion. Custom random
33 optimizers should subclass this, and the ``setup`` method should be
34 implemented with the following signature::
36 def setup(self, inputs, output, size_dict):
37 # custom preparation here ...
38 return trial_fn, trial_args
40 Where ``trial_fn`` itself should have the signature::
42 def trial_fn(r, *trial_args):
43 # custom computation of path here
44 return ssa_path, cost, size
46 Where ``r`` is the run number and could for example be used to seed a
47 random number generator. See ``RandomGreedy`` for an example.
50 Parameters
51 ----------
52 max_repeats : int, optional
53 The maximum number of repeat trials to have.
54 max_time : float, optional
55 The maximum amount of time to run the algorithm for.
56 minimize : {'flops', 'size'}, optional
57 Whether to favour paths that minimize the total estimated flop-count or
58 the size of the largest intermediate created.
59 parallel : {bool, int, or executor-pool like}, optional
60 Whether to parallelize the random trials, by default ``False``. If
61 ``True``, use a ``concurrent.futures.ProcessPoolExecutor`` with the same
62 number of processes as cores. If an integer is specified, use that many
63 processes instead. Finally, you can supply a custom executor-pool which
64 should have an API matching that of the python 3 standard library
65 module ``concurrent.futures``. Namely, a ``submit`` method that returns
66 ``Future`` objects, themselves with ``result`` and ``cancel`` methods.
67 pre_dispatch : int, optional
68 If running in parallel, how many jobs to pre-dispatch so as to avoid
69 submitting all jobs at once. Should also be more than twice the number
70 of workers to avoid under-subscription. Default: 128.
72 Attributes
73 ----------
74 path : list[tuple[int]]
75 The best path found so far.
76 costs : list[int]
77 The list of each trial's costs found so far.
78 sizes : list[int]
79 The list of each trial's largest intermediate size so far.
81 See Also
82 --------
83 RandomGreedy
84 """
85 def __init__(self, max_repeats=32, max_time=None, minimize='flops', parallel=False, pre_dispatch=128):
87 if minimize not in ('flops', 'size'):
88 raise ValueError("`minimize` should be one of {'flops', 'size'}.")
90 self.max_repeats = max_repeats
91 self.max_time = max_time
92 self.minimize = minimize
93 self.better = paths.get_better_fn(minimize)
94 self.parallel = parallel
95 self.pre_dispatch = pre_dispatch
97 self.costs = []
98 self.sizes = []
99 self.best = {'flops': float('inf'), 'size': float('inf')}
101 self._repeats_start = 0
103 @property
104 def path(self):
105 """The best path found so far.
106 """
107 return paths.ssa_to_linear(self.best['ssa_path'])
109 @property
110 def parallel(self):
111 return self._parallel
113 @parallel.setter
114 def parallel(self, parallel):
115 # shutdown any previous executor if we are managing it
116 if getattr(self, '_managing_executor', False):
117 self._executor.shutdown()
119 self._parallel = parallel
120 self._managing_executor = False
122 if parallel is False:
123 self._executor = None
124 return
126 if parallel is True:
127 from concurrent.futures import ProcessPoolExecutor
128 self._executor = ProcessPoolExecutor()
129 self._managing_executor = True
130 return
132 if isinstance(parallel, numbers.Number):
133 from concurrent.futures import ProcessPoolExecutor
134 self._executor = ProcessPoolExecutor(parallel)
135 self._managing_executor = True
136 return
138 # assume a pool-executor has been supplied
139 self._executor = parallel
141 def _gen_results_parallel(self, repeats, trial_fn, args):
142 """Lazily generate results from an executor without submitting all jobs at once.
143 """
144 self._futures = deque()
146 # the idea here is to submit at least ``pre_dispatch`` jobs *before* we
147 # yield any results, then do both in tandem, before draining the queue
148 for r in repeats:
149 if len(self._futures) < self.pre_dispatch:
150 self._futures.append(self._executor.submit(trial_fn, r, *args))
151 continue
152 yield self._futures.popleft().result()
154 while self._futures:
155 yield self._futures.popleft().result()
157 def _cancel_futures(self):
158 if self._executor is not None:
159 for f in self._futures:
160 f.cancel()
162 def setup(self, inputs, output, size_dict):
163 raise NotImplementedError
165 def __call__(self, inputs, output, size_dict, memory_limit):
166 self._check_args_against_first_call(inputs, output, size_dict)
168 # start a timer?
169 if self.max_time is not None:
170 t0 = time.time()
172 trial_fn, trial_args = self.setup(inputs, output, size_dict)
174 r_start = self._repeats_start + len(self.costs)
175 r_stop = r_start + self.max_repeats
176 repeats = range(r_start, r_stop)
178 # create the trials lazily
179 if self._executor is not None:
180 trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
181 else:
182 trials = (trial_fn(r, *trial_args) for r in repeats)
184 # assess the trials
185 for ssa_path, cost, size in trials:
187 # keep track of all costs and sizes
188 self.costs.append(cost)
189 self.sizes.append(size)
191 # check if we have found a new best
192 found_new_best = self.better(cost, size, self.best['flops'], self.best['size'])
194 if found_new_best:
195 self.best['flops'] = cost
196 self.best['size'] = size
197 self.best['ssa_path'] = ssa_path
199 # check if we have run out of time
200 if (self.max_time is not None) and (time.time() > t0 + self.max_time):
201 break
203 self._cancel_futures()
204 return self.path
206 def __del__(self):
207 # if we created the parallel pool-executor, shut it down
208 if getattr(self, '_managing_executor', False):
209 self._executor.shutdown()
212def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
213 """A contraction 'chooser' that weights possible contractions using a
214 Boltzmann distribution. Explicitly, given costs ``c_i`` (with ``c_0`` the
215 smallest), the relative weights, ``w_i``, are computed as:
217 w_i = exp( -(c_i - c_0) / temperature)
219 Additionally, if ``rel_temperature`` is set, scale ``temperature`` by
220 ``abs(c_0)`` to account for likely fluctuating cost magnitudes during the
221 course of a contraction.
223 Parameters
224 ----------
225 queue : list
226 The heapified list of candidate contractions.
227 remaining : dict[str, int]
228 Mapping of remaining inputs' indices to the ssa id.
229 temperature : float, optional
230 When choosing a possible contraction, its relative probability will be
231 proportional to ``exp(-cost / temperature)``. Thus the larger
232 ``temperature`` is, the further random paths will stray from the normal
233 'greedy' path. Conversely, if set to zero, only paths with exactly the
234 same cost as the best at each step will be explored.
235 rel_temperature : bool, optional
236 Whether to normalize the ``temperature`` at each step to the scale of
237 the best cost. This is generally beneficial as the magnitude of costs
238 can vary significantly throughout a contraction.
239 nbranch : int, optional
240 How many potential paths to calculate probability for and choose from
241 at each step.
243 Returns
244 -------
245 cost, k1, k2, k12
246 """
247 n = 0
248 choices = []
249 while queue and n < nbranch:
250 cost, k1, k2, k12 = heapq.heappop(queue)
251 if k1 not in remaining or k2 not in remaining:
252 continue # candidate is obsolete
253 choices.append((cost, k1, k2, k12))
254 n += 1
256 if n == 0:
257 return None
258 if n == 1:
259 return choices[0]
261 costs = [choice[0][0] for choice in choices]
262 cmin = costs[0]
264 # adjust by the overall scale to account for fluctuating absolute costs
265 if rel_temperature:
266 temperature *= max(1, abs(cmin))
268 # compute relative probability for each potential contraction
269 if temperature == 0.0:
270 energies = [1 if c == cmin else 0 for c in costs]
271 else:
272 # shift by cmin for numerical reasons
273 energies = [math.exp(-(c - cmin) / temperature) for c in costs]
275 # randomly choose a contraction based on energies
276 chosen, = random_choices(range(n), weights=energies)
277 cost, k1, k2, k12 = choices.pop(chosen)
279 # put the other choise back in the heap
280 for other in choices:
281 heapq.heappush(queue, other)
283 return cost, k1, k2, k12
286def ssa_path_compute_cost(ssa_path, inputs, output, size_dict):
287 """Compute the flops and max size of an ssa path.
288 """
289 inputs = list(map(frozenset, inputs))
290 output = frozenset(output)
291 remaining = set(range(len(inputs)))
292 total_cost = 0
293 max_size = 0
295 for i, j in ssa_path:
296 k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict)
297 remaining.discard(i)
298 remaining.discard(j)
299 remaining.add(len(inputs))
300 inputs.append(k12)
301 total_cost += flops12
302 max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))
304 return total_cost, max_size
307def _trial_greedy_ssa_path_and_cost(r, inputs, output, size_dict, choose_fn, cost_fn):
308 """A single, repeatable, greedy trial run. Returns ``ssa_path`` and cost.
309 """
310 if r == 0:
311 # always start with the standard greedy approach
312 choose_fn = None
314 random_seed(r)
316 ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
317 cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
319 return ssa_path, cost, size
322class RandomGreedy(RandomOptimizer):
323 """
325 Parameters
326 ----------
327 cost_fn : callable, optional
328 A function that returns a heuristic 'cost' of a potential contraction
329 with which to sort candidates. Should have signature
330 ``cost_fn(size12, size1, size2, k12, k1, k2)``.
331 temperature : float, optional
332 When choosing a possible contraction, its relative probability will be
333 proportional to ``exp(-cost / temperature)``. Thus the larger
334 ``temperature`` is, the further random paths will stray from the normal
335 'greedy' path. Conversely, if set to zero, only paths with exactly the
336 same cost as the best at each step will be explored.
337 rel_temperature : bool, optional
338 Whether to normalize the ``temperature`` at each step to the scale of
339 the best cost. This is generally beneficial as the magnitude of costs
340 can vary significantly throughout a contraction. If False, the
341 algorithm will end up branching when the absolute cost is low, but
342 stick to the 'greedy' path when the cost is high - this can also be
343 beneficial.
344 nbranch : int, optional
345 How many potential paths to calculate probability for and choose from
346 at each step.
347 kwargs
348 Supplied to RandomOptimizer.
350 See Also
351 --------
352 RandomOptimizer
353 """
354 def __init__(self, cost_fn='memory-removed-jitter', temperature=1.0, rel_temperature=True, nbranch=8, **kwargs):
355 self.cost_fn = cost_fn
356 self.temperature = temperature
357 self.rel_temperature = rel_temperature
358 self.nbranch = nbranch
359 super().__init__(**kwargs)
361 @property
362 def choose_fn(self):
363 """The function that chooses which contraction to take - make this a
364 property so that ``temperature`` and ``nbranch`` etc. can be updated
365 between runs.
366 """
367 if self.nbranch == 1:
368 return None
370 return functools.partial(thermal_chooser,
371 temperature=self.temperature,
372 nbranch=self.nbranch,
373 rel_temperature=self.rel_temperature)
375 def setup(self, inputs, output, size_dict):
376 fn = _trial_greedy_ssa_path_and_cost
377 args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
378 return fn, args
381def random_greedy(inputs, output, idx_dict, memory_limit=None, **optimizer_kwargs):
382 """
383 """
384 optimizer = RandomGreedy(**optimizer_kwargs)
385 return optimizer(inputs, output, idx_dict, memory_limit)
388random_greedy_128 = functools.partial(random_greedy, max_repeats=128)