Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/path_random.py: 22%
163 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Support for random optimizers, including the random-greedy path.
3"""
5import functools
6import heapq
7import math
8import numbers
9import time
10from collections import deque
11from random import choices as random_choices
12from random import seed as random_seed
13from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
15from . import helpers, paths
16from .typing import ArrayIndexType, ArrayType, PathType
18__all__ = ["RandomGreedy", "random_greedy", "random_greedy_128"]
21class RandomOptimizer(paths.PathOptimizer):
22 """Base class for running any random path finder that benefits
23 from repeated calling, possibly in a parallel fashion. Custom random
24 optimizers should subclass this, and the `setup` method should be
25 implemented with the following signature:
27 ```python
28 def setup(self, inputs, output, size_dict):
29 # custom preparation here ...
30 return trial_fn, trial_args
31 ```
33 Where `trial_fn` itself should have the signature::
35 ```python
36 def trial_fn(r, *trial_args):
37 # custom computation of path here
38 return ssa_path, cost, size
39 ```
41 Where `r` is the run number and could for example be used to seed a
42 random number generator. See `RandomGreedy` for an example.
45 **Parameters:**
47 - **max_repeats** - *(int, optional)* The maximum number of repeat trials to have.
48 - **max_time** - *(float, optional)* The maximum amount of time to run the algorithm for.
49 - **minimize** - *({'flops', 'size'}, optional)* Whether to favour paths that minimize the total estimated flop-count or
50 the size of the largest intermediate created.
51 - **parallel** - *({bool, int, or executor-pool like}, optional)* Whether to parallelize the random trials, by default `False`. If
52 `True`, use a `concurrent.futures.ProcessPoolExecutor` with the same
53 number of processes as cores. If an integer is specified, use that many
54 processes instead. Finally, you can supply a custom executor-pool which
55 should have an API matching that of the python 3 standard library
56 module `concurrent.futures`. Namely, a `submit` method that returns
57 `Future` objects, themselves with `result` and `cancel` methods.
58 - **pre_dispatch** - *(int, optional)* If running in parallel, how many jobs to pre-dispatch so as to avoid
59 submitting all jobs at once. Should also be more than twice the number
60 of workers to avoid under-subscription. Default: 128.
62 **Attributes:**
64 - **path** - *(list[tuple[int]])* The best path found so far.
65 - **costs** - *(list[int])* The list of each trial's costs found so far.
66 - **sizes** - *(list[int])* The list of each trial's largest intermediate size so far.
67 """
69 def __init__(
70 self,
71 max_repeats: int = 32,
72 max_time: Optional[float] = None,
73 minimize: str = "flops",
74 parallel: bool = False,
75 pre_dispatch: int = 128,
76 ):
78 if minimize not in ("flops", "size"):
79 raise ValueError("`minimize` should be one of {'flops', 'size'}.")
81 self.max_repeats = max_repeats
82 self.max_time = max_time
83 self.minimize = minimize
84 self.better = paths.get_better_fn(minimize)
85 self._parallel = False
86 self.parallel = parallel
87 self.pre_dispatch = pre_dispatch
89 self.costs: List[int] = []
90 self.sizes: List[int] = []
91 self.best: Dict[str, Any] = {"flops": float("inf"), "size": float("inf")}
93 self._repeats_start = 0
94 self._executor: Any
95 self._futures: Any
97 @property
98 def path(self) -> PathType:
99 """The best path found so far."""
100 return paths.ssa_to_linear(self.best["ssa_path"])
102 @property
103 def parallel(self) -> bool:
104 return self._parallel
106 @parallel.setter
107 def parallel(self, parallel: bool) -> None:
108 # shutdown any previous executor if we are managing it
109 if getattr(self, "_managing_executor", False):
110 self._executor.shutdown()
112 self._parallel = parallel
113 self._managing_executor = False
115 if parallel is False:
116 self._executor = None
117 return
119 if parallel is True:
120 from concurrent.futures import ProcessPoolExecutor
122 self._executor = ProcessPoolExecutor()
123 self._managing_executor = True
124 return
126 if isinstance(parallel, numbers.Number):
127 from concurrent.futures import ProcessPoolExecutor
129 self._executor = ProcessPoolExecutor(parallel)
130 self._managing_executor = True
131 return
133 # assume a pool-executor has been supplied
134 self._executor = parallel
136 def _gen_results_parallel(self, repeats: Iterable[int], trial_fn: Any, args: Any) -> Generator[Any, None, None]:
137 """Lazily generate results from an executor without submitting all jobs at once."""
138 self._futures = deque()
140 # the idea here is to submit at least ``pre_dispatch`` jobs *before* we
141 # yield any results, then do both in tandem, before draining the queue
142 for r in repeats:
143 if len(self._futures) < self.pre_dispatch:
144 self._futures.append(self._executor.submit(trial_fn, r, *args))
145 continue
146 yield self._futures.popleft().result()
148 while self._futures:
149 yield self._futures.popleft().result()
151 def _cancel_futures(self) -> None:
152 if self._executor is not None:
153 for f in self._futures:
154 f.cancel()
156 def setup(
157 self,
158 inputs: List[ArrayIndexType],
159 output: ArrayIndexType,
160 size_dict: Dict[str, int],
161 ) -> Tuple[Any, Any]:
162 raise NotImplementedError
164 def __call__(
165 self,
166 inputs: List[ArrayIndexType],
167 output: ArrayIndexType,
168 size_dict: Dict[str, int],
169 memory_limit: Optional[int] = None,
170 ) -> PathType:
171 self._check_args_against_first_call(inputs, output, size_dict)
173 # start a timer?
174 if self.max_time is not None:
175 t0 = time.time()
177 trial_fn, trial_args = self.setup(inputs, output, size_dict)
179 r_start = self._repeats_start + len(self.costs)
180 r_stop = r_start + self.max_repeats
181 repeats = range(r_start, r_stop)
183 # create the trials lazily
184 if self._executor is not None:
185 trials = self._gen_results_parallel(repeats, trial_fn, trial_args)
186 else:
187 trials = (trial_fn(r, *trial_args) for r in repeats)
189 # assess the trials
190 for ssa_path, cost, size in trials:
192 # keep track of all costs and sizes
193 self.costs.append(cost)
194 self.sizes.append(size)
196 # check if we have found a new best
197 found_new_best = self.better(cost, size, self.best["flops"], self.best["size"])
199 if found_new_best:
200 self.best["flops"] = cost
201 self.best["size"] = size
202 self.best["ssa_path"] = ssa_path
204 # check if we have run out of time
205 if (self.max_time is not None) and (time.time() > t0 + self.max_time):
206 break
208 self._cancel_futures()
209 return self.path
211 def __del__(self):
212 # if we created the parallel pool-executor, shut it down
213 if getattr(self, "_managing_executor", False):
214 self._executor.shutdown()
217def thermal_chooser(queue, remaining, nbranch=8, temperature=1, rel_temperature=True):
218 """A contraction 'chooser' that weights possible contractions using a
219 Boltzmann distribution. Explicitly, given costs `c_i` (with `c_0` the
220 smallest), the relative weights, `w_i`, are computed as:
222 $$w_i = exp( -(c_i - c_0) / temperature)$$
224 Additionally, if `rel_temperature` is set, scale `temperature` by
225 `abs(c_0)` to account for likely fluctuating cost magnitudes during the
226 course of a contraction.
228 **Parameters:**
230 - **queue** - *(list)* The heapified list of candidate contractions.
231 - **remaining** - *(dict[str, int])* Mapping of remaining inputs' indices to the ssa id.
232 - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be
233 proportional to `exp(-cost / temperature)`. Thus the larger
234 `temperature` is, the further random paths will stray from the normal
235 'greedy' path. Conversely, if set to zero, only paths with exactly the
236 same cost as the best at each step will be explored.
237 - **rel_temperature** - *(bool, optional)* Whether to normalize the `temperature` at each step to the scale of
238 the best cost. This is generally beneficial as the magnitude of costs
239 can vary significantly throughout a contraction.
240 - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step.
242 **Returns:**
244 - **cost**
245 - **k1**
246 - **k2**
247 - **k3**
248 """
249 n = 0
250 choices = []
251 while queue and n < nbranch:
252 cost, k1, k2, k12 = heapq.heappop(queue)
253 if k1 not in remaining or k2 not in remaining:
254 continue # candidate is obsolete
255 choices.append((cost, k1, k2, k12))
256 n += 1
258 if n == 0:
259 return None
260 if n == 1:
261 return choices[0]
263 costs = [choice[0][0] for choice in choices]
264 cmin = costs[0]
266 # adjust by the overall scale to account for fluctuating absolute costs
267 if rel_temperature:
268 temperature *= max(1, abs(cmin))
270 # compute relative probability for each potential contraction
271 if temperature == 0.0:
272 energies = [1 if c == cmin else 0 for c in costs]
273 else:
274 # shift by cmin for numerical reasons
275 energies = [math.exp(-(c - cmin) / temperature) for c in costs]
277 # randomly choose a contraction based on energies
278 (chosen,) = random_choices(range(n), weights=energies)
279 cost, k1, k2, k12 = choices.pop(chosen)
281 # put the other choice back in the heap
282 for other in choices:
283 heapq.heappush(queue, other)
285 return cost, k1, k2, k12
288def ssa_path_compute_cost(
289 ssa_path: PathType,
290 inputs: List[ArrayIndexType],
291 output: ArrayIndexType,
292 size_dict: Dict[str, int],
293) -> Tuple[int, int]:
294 """Compute the flops and max size of an ssa path."""
295 inputs = list(map(frozenset, inputs)) # type: ignore
296 output = frozenset(output)
297 remaining = set(range(len(inputs)))
298 total_cost = 0
299 max_size = 0
301 for i, j in ssa_path:
302 k12, flops12 = paths.calc_k12_flops(inputs, output, remaining, i, j, size_dict) # type: ignore
303 remaining.discard(i)
304 remaining.discard(j)
305 remaining.add(len(inputs))
306 inputs.append(k12)
307 total_cost += flops12
308 max_size = max(max_size, helpers.compute_size_by_dict(k12, size_dict))
310 return total_cost, max_size
313def _trial_greedy_ssa_path_and_cost(
314 r: int,
315 inputs: List[ArrayIndexType],
316 output: ArrayIndexType,
317 size_dict: Dict[str, int],
318 choose_fn: Any,
319 cost_fn: Any,
320) -> Tuple[PathType, int, int]:
321 """A single, repeatable, greedy trial run. **Returns:** ``ssa_path`` and cost."""
322 if r == 0:
323 # always start with the standard greedy approach
324 choose_fn = None
326 random_seed(r)
328 ssa_path = paths.ssa_greedy_optimize(inputs, output, size_dict, choose_fn, cost_fn)
329 cost, size = ssa_path_compute_cost(ssa_path, inputs, output, size_dict)
331 return ssa_path, cost, size
334class RandomGreedy(RandomOptimizer):
335 """
337 **Parameters:**
339 - **cost_fn** - *(callable, optional)* A function that returns a heuristic 'cost' of a potential contraction
340 with which to sort candidates. Should have signature
341 `cost_fn(size12, size1, size2, k12, k1, k2)`.
342 - **temperature** - *(float, optional)* When choosing a possible contraction, its relative probability will be
343 proportional to `exp(-cost / temperature)`. Thus the larger
344 `temperature` is, the further random paths will stray from the normal
345 'greedy' path. Conversely, if set to zero, only paths with exactly the
346 same cost as the best at each step will be explored.
347 - **rel_temperature** - *(bool, optional)* Whether to normalize the ``temperature`` at each step to the scale of
348 the best cost. This is generally beneficial as the magnitude of costs
349 can vary significantly throughout a contraction. If False, the
350 algorithm will end up branching when the absolute cost is low, but
351 stick to the 'greedy' path when the cost is high - this can also be
352 beneficial.
353 - **nbranch** - *(int, optional)* How many potential paths to calculate probability for and choose from at each step.
354 - **kwargs** - Supplied to RandomOptimizer.
355 """
357 def __init__(
358 self,
359 cost_fn: str = "memory-removed-jitter",
360 temperature: float = 1.0,
361 rel_temperature: bool = True,
362 nbranch: int = 8,
363 **kwargs: Any,
364 ):
365 self.cost_fn = cost_fn
366 self.temperature = temperature
367 self.rel_temperature = rel_temperature
368 self.nbranch = nbranch
369 super().__init__(**kwargs)
371 @property
372 def choose_fn(self) -> Any:
373 """The function that chooses which contraction to take - make this a
374 property so that ``temperature`` and ``nbranch`` etc. can be updated
375 between runs.
376 """
377 if self.nbranch == 1:
378 return None
380 return functools.partial(
381 thermal_chooser,
382 temperature=self.temperature,
383 nbranch=self.nbranch,
384 rel_temperature=self.rel_temperature,
385 )
387 def setup(
388 self,
389 inputs: List[ArrayIndexType],
390 output: ArrayIndexType,
391 size_dict: Dict[str, int],
392 ) -> Tuple[Any, Any]:
393 fn = _trial_greedy_ssa_path_and_cost
394 args = (inputs, output, size_dict, self.choose_fn, self.cost_fn)
395 return fn, args
398def random_greedy(
399 inputs: List[ArrayIndexType],
400 output: ArrayIndexType,
401 idx_dict: Dict[str, int],
402 memory_limit: Optional[int] = None,
403 **optimizer_kwargs: Any,
404) -> ArrayType:
405 """ """
406 optimizer = RandomGreedy(**optimizer_kwargs)
407 return optimizer(inputs, output, idx_dict, memory_limit)
410random_greedy_128 = functools.partial(random_greedy, max_repeats=128)