Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/contract.py: 12%
323 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:41 +0000
1"""
2Contains the primary optimization and contraction routines.
3"""
5from collections import namedtuple
6from decimal import Decimal
7from functools import lru_cache
8from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple
10from . import backends, blas, helpers, parser, paths, sharing
11from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType
13__all__ = [
14 "contract_path",
15 "contract",
16 "format_const_einsum_str",
17 "ContractExpression",
18 "shape_only",
19]
22class PathInfo:
23 """A printable object to contain information about a contraction path.
25 **Attributes:**
27 - **naive_cost** - *(int)* The estimate FLOP cost of a naive einsum contraction.
28 - **opt_cost** - *(int)* The estimate FLOP cost of this optimized contraction path.
29 - **largest_intermediate** - *(int)* The number of elements in the largest intermediate array that will be produced during the contraction.
30 """
32 def __init__(
33 self,
34 contraction_list: ContractionListType,
35 input_subscripts: str,
36 output_subscript: str,
37 indices: ArrayIndexType,
38 path: PathType,
39 scale_list: Sequence[int],
40 naive_cost: int,
41 opt_cost: int,
42 size_list: Sequence[int],
43 size_dict: Dict[str, int],
44 ):
45 self.contraction_list = contraction_list
46 self.input_subscripts = input_subscripts
47 self.output_subscript = output_subscript
48 self.path = path
49 self.indices = indices
50 self.scale_list = scale_list
51 self.naive_cost = Decimal(naive_cost)
52 self.opt_cost = Decimal(opt_cost)
53 self.speedup = self.naive_cost / self.opt_cost
54 self.size_list = size_list
55 self.size_dict = size_dict
57 self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(",")]
58 self.eq = "{}->{}".format(input_subscripts, output_subscript)
59 self.largest_intermediate = Decimal(max(size_list))
61 def __repr__(self) -> str:
62 # Return the path along with a nice string representation
63 header = ("scaling", "BLAS", "current", "remaining")
65 path_print = [
66 " Complete contraction: {}\n".format(self.eq),
67 " Naive scaling: {}\n".format(len(self.indices)),
68 " Optimized scaling: {}\n".format(max(self.scale_list)),
69 " Naive FLOP count: {:.3e}\n".format(self.naive_cost),
70 " Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
71 " Theoretical speedup: {:.3e}\n".format(self.speedup),
72 " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate),
73 "-" * 80 + "\n",
74 "{:>6} {:>11} {:>22} {:>37}\n".format(*header),
75 "-" * 80,
76 ]
78 for n, contraction in enumerate(self.contraction_list):
79 _, _, einsum_str, remaining, do_blas = contraction
81 if remaining is not None:
82 remaining_str = ",".join(remaining) + "->" + self.output_subscript
83 else:
84 remaining_str = "..."
85 size_remaining = max(0, 56 - max(22, len(einsum_str)))
87 path_run = (
88 self.scale_list[n],
89 do_blas,
90 einsum_str,
91 remaining_str,
92 size_remaining,
93 )
94 path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run))
96 return "".join(path_print)
99def _choose_memory_arg(memory_limit: int, size_list: List[int]) -> Optional[int]:
100 if memory_limit == "max_input":
101 return max(size_list)
103 if memory_limit is None:
104 return None
106 if memory_limit < 1:
107 if memory_limit == -1:
108 return None
109 else:
110 raise ValueError("Memory limit must be larger than 0, or -1")
112 return int(memory_limit)
115_VALID_CONTRACT_KWARGS = {
116 "optimize",
117 "path",
118 "memory_limit",
119 "einsum_call",
120 "use_blas",
121 "shapes",
122}
125def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]:
126 """
127 Find a contraction order `path`, without performing the contraction.
129 **Parameters:**
131 - **subscripts** - *(str)* Specifies the subscripts for summation.
132 - **\\*operands** - *(list of array_like)* these are the arrays for the operation.
133 - **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates.
134 - **optimize** - *(str, list or bool, optional (default: `auto`))* Choose the type of path.
136 - if a list is given uses this as the path.
137 - `'optimal'` An algorithm that explores all possible ways of
138 contracting the listed tensors. Scales factorially with the number of
139 terms in the contraction.
140 - `'dp'` A faster (but essentially optimal) algorithm that uses
141 dynamic programming to exhaustively search all contraction paths
142 without outer-products.
143 - `'greedy'` An cheap algorithm that heuristically chooses the best
144 pairwise contraction at each step. Scales linearly in the number of
145 terms in the contraction.
146 - `'random-greedy'` Run a randomized version of the greedy algorithm
147 32 times and pick the best path.
148 - `'random-greedy-128'` Run a randomized version of the greedy
149 algorithm 128 times and pick the best path.
150 - `'branch-all'` An algorithm like optimal but that restricts itself
151 to searching 'likely' paths. Still scales factorially.
152 - `'branch-2'` An even more restricted version of 'branch-all' that
153 only searches the best two options at each step. Scales exponentially
154 with the number of terms in the contraction.
155 - `'auto'` Choose the best of the above algorithms whilst aiming to
156 keep the path finding time below 1ms.
157 - `'auto-hq'` Aim for a high quality contraction, choosing the best
158 of the above algorithms whilst aiming to keep the path finding time
159 below 1sec.
161 - **memory_limit** - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build.
163 - None or -1 means there is no limit
164 - `max_input` means the limit is set as largest input tensor
165 - a positive integer is taken as an explicit limit on the number of elements
167 The default is None. Note that imposing a limit can make contractions
168 exponentially slower to perform.
170 - **shapes** - *(bool, optional)* Whether ``contract_path`` should assume arrays (the default) or array shapes have been supplied.
172 **Returns:**
174 - **path** - *(list of tuples)* The einsum path
175 - **PathInfo** - *(str)* A printable object containing various information about the path found.
177 **Notes:**
179 The resulting path indicates which terms of the input contraction should be
180 contracted first, the result of this contraction is then appended to the end of
181 the contraction list.
183 **Examples:**
185 We can begin with a chain dot example. In this case, it is optimal to
186 contract the b and c tensors represented by the first element of the path (1,
187 2). The resulting tensor is added to the end of the contraction and the
188 remaining contraction, `(0, 1)`, is then executed.
190 ```python
191 a = np.random.rand(2, 2)
192 b = np.random.rand(2, 5)
193 c = np.random.rand(5, 2)
194 path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c)
195 print(path_info[0])
196 #> [(1, 2), (0, 1)]
197 print(path_info[1])
198 #> Complete contraction: ij,jk,kl->il
199 #> Naive scaling: 4
200 #> Optimized scaling: 3
201 #> Naive FLOP count: 1.600e+02
202 #> Optimized FLOP count: 5.600e+01
203 #> Theoretical speedup: 2.857
204 #> Largest intermediate: 4.000e+00 elements
205 #> -------------------------------------------------------------------------
206 #> scaling current remaining
207 #> -------------------------------------------------------------------------
208 #> 3 kl,jk->jl ij,jl->il
209 #> 3 jl,ij->il il->il
210 ```
212 A more complex index transformation example.
214 ```python
215 I = np.random.rand(10, 10, 10, 10)
216 C = np.random.rand(10, 10)
217 path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C)
219 print(path_info[0])
220 #> [(0, 2), (0, 3), (0, 2), (0, 1)]
221 print(path_info[1])
222 #> Complete contraction: ea,fb,abcd,gc,hd->efgh
223 #> Naive scaling: 8
224 #> Optimized scaling: 5
225 #> Naive FLOP count: 8.000e+08
226 #> Optimized FLOP count: 8.000e+05
227 #> Theoretical speedup: 1000.000
228 #> Largest intermediate: 1.000e+04 elements
229 #> --------------------------------------------------------------------------
230 #> scaling current remaining
231 #> --------------------------------------------------------------------------
232 #> 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
233 #> 5 bcde,fb->cdef gc,hd,cdef->efgh
234 #> 5 cdef,gc->defg hd,defg->efgh
235 #> 5 defg,hd->efgh efgh->efgh
236 ```
237 """
239 # Make sure all keywords are valid
240 unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS
241 if len(unknown_kwargs):
242 raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs))
244 path_type = kwargs.pop("optimize", "auto")
246 memory_limit = kwargs.pop("memory_limit", None)
247 shapes = kwargs.pop("shapes", False)
249 # Hidden option, only einsum should call this
250 einsum_call_arg = kwargs.pop("einsum_call", False)
251 use_blas = kwargs.pop("use_blas", True)
253 # Python side parsing
254 input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands_, shapes=shapes)
256 # Build a few useful list and sets
257 input_list = input_subscripts.split(",")
258 input_sets = [frozenset(x) for x in input_list]
259 if shapes:
260 input_shapes = operands
261 else:
262 input_shapes = [x.shape for x in operands]
263 output_set = frozenset(output_subscript)
264 indices = frozenset(input_subscripts.replace(",", ""))
266 # Get length of each unique dimension and ensure all dimensions are correct
267 size_dict: Dict[str, int] = {}
268 for tnum, term in enumerate(input_list):
269 sh = input_shapes[tnum]
271 if len(sh) != len(term):
272 raise ValueError(
273 "Einstein sum subscript '{}' does not contain the "
274 "correct number of indices for operand {}.".format(input_list[tnum], tnum)
275 )
276 for cnum, char in enumerate(term):
277 dim = int(sh[cnum])
279 if char in size_dict:
280 # For broadcasting cases we always want the largest dim size
281 if size_dict[char] == 1:
282 size_dict[char] = dim
283 elif dim not in (1, size_dict[char]):
284 raise ValueError(
285 "Size of label '{}' for operand {} ({}) does not match previous "
286 "terms ({}).".format(char, tnum, size_dict[char], dim)
287 )
288 else:
289 size_dict[char] = dim
291 # Compute size of each input array plus the output array
292 size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]]
293 memory_arg = _choose_memory_arg(memory_limit, size_list)
295 num_ops = len(input_list)
297 # Compute naive cost
298 # This is not quite right, need to look into exactly how einsum does this
299 # indices_in_input = input_subscripts.replace(',', '')
301 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
302 naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)
304 # Compute the path
305 if not isinstance(path_type, (str, paths.PathOptimizer)):
306 # Custom path supplied
307 path = path_type
308 elif num_ops <= 2:
309 # Nothing to be optimized
310 path = [tuple(range(num_ops))]
311 elif isinstance(path_type, paths.PathOptimizer):
312 # Custom path optimizer supplied
313 path = path_type(input_sets, output_set, size_dict, memory_arg)
314 else:
315 path_optimizer = paths.get_path_fn(path_type)
316 path = path_optimizer(input_sets, output_set, size_dict, memory_arg)
318 cost_list = []
319 scale_list = []
320 size_list = []
321 contraction_list = []
323 # Build contraction tuple (positions, gemm, einsum_str, remaining)
324 for cnum, contract_inds in enumerate(path):
325 # Make sure we remove inds from right to left
326 contract_inds = tuple(sorted(list(contract_inds), reverse=True))
328 contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
329 out_inds, input_sets, idx_removed, idx_contract = contract_tuple
331 # Compute cost, scale, and size
332 cost = helpers.flop_count(idx_contract, bool(idx_removed), len(contract_inds), size_dict)
333 cost_list.append(cost)
334 scale_list.append(len(idx_contract))
335 size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))
337 tmp_inputs = [input_list.pop(x) for x in contract_inds]
338 tmp_shapes = [input_shapes.pop(x) for x in contract_inds]
340 if use_blas:
341 do_blas = blas.can_blas(tmp_inputs, "".join(out_inds), idx_removed, tmp_shapes)
342 else:
343 do_blas = False
345 # Last contraction
346 if (cnum - len(path)) == -1:
347 idx_result = output_subscript
348 else:
349 # use tensordot order to minimize transpositions
350 all_input_inds = "".join(tmp_inputs)
351 idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
353 shp_result = parser.find_output_shape(tmp_inputs, tmp_shapes, idx_result)
355 input_list.append(idx_result)
356 input_shapes.append(shp_result)
358 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
360 # for large expressions saving the remaining terms at each step can
361 # incur a large memory footprint - and also be messy to print
362 if len(input_list) <= 20:
363 remaining: Optional[Tuple[str, ...]] = tuple(input_list)
364 else:
365 remaining = None
367 contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas)
368 contraction_list.append(contraction)
370 opt_cost = sum(cost_list)
372 if einsum_call_arg:
373 return operands, contraction_list # type: ignore
375 path_print = PathInfo(
376 contraction_list,
377 input_subscripts,
378 output_subscript,
379 indices,
380 path,
381 scale_list,
382 naive_cost,
383 opt_cost,
384 size_list,
385 size_dict,
386 )
388 return path, path_print
391@sharing.einsum_cache_wrap
392def _einsum(*operands, **kwargs):
393 """Base einsum, but with pre-parse for valid characters if a string is given."""
394 fn = backends.get_func("einsum", kwargs.pop("backend", "numpy"))
396 if not isinstance(operands[0], str):
397 return fn(*operands, **kwargs)
399 einsum_str, operands = operands[0], operands[1:]
401 # Do we need to temporarily map indices into [a-z,A-Z] range?
402 if not parser.has_valid_einsum_chars_only(einsum_str):
404 # Explicitly find output str first so as to maintain order
405 if "->" not in einsum_str:
406 einsum_str += "->" + parser.find_output_str(einsum_str)
408 einsum_str = parser.convert_to_valid_einsum_chars(einsum_str)
410 return fn(einsum_str, *operands, **kwargs)
413def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType:
414 # most libraries implement a method version
415 return x.transpose(axes)
418@sharing.transpose_cache_wrap
419def _transpose(x: ArrayType, axes: Tuple[int, ...], backend: str = "numpy") -> ArrayType:
420 """Base transpose."""
421 fn = backends.get_func("transpose", backend, _default_transpose)
422 return fn(x, axes)
425@sharing.tensordot_cache_wrap
426def _tensordot(x: ArrayType, y: ArrayType, axes: Tuple[int, ...], backend: str = "numpy") -> ArrayType:
427 """Base tensordot."""
428 fn = backends.get_func("tensordot", backend)
429 return fn(x, y, axes=axes)
432# Rewrite einsum to handle different cases
433def contract(*operands_: Any, **kwargs: Any) -> ArrayType:
434 """
435 Evaluates the Einstein summation convention on the operands. A drop in
436 replacement for NumPy's einsum function that optimizes the order of contraction
437 to reduce overall scaling at the cost of several intermediate arrays.
439 **Parameters:**
441 - **subscripts** - *(str)* Specifies the subscripts for summation.
442 - **\\*operands** - *(list of array_like)* hese are the arrays for the operation.
443 - **out** - *(array_like)* A output array in which set the sresulting output.
444 - **dtype** - *(str)* The dtype of the given contraction, see np.einsum.
445 - **order** - *(str)* The order of the resulting contraction, see np.einsum.
446 - **casting** - *(str)* The casting procedure for operations of different dtype, see np.einsum.
447 - **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates.
448 - **optimize** - *(str, list or bool, optional (default: ``auto``))* Choose the type of path.
450 - if a list is given uses this as the path.
451 - `'optimal'` An algorithm that explores all possible ways of
452 contracting the listed tensors. Scales factorially with the number of
453 terms in the contraction.
454 - `'dp'` A faster (but essentially optimal) algorithm that uses
455 dynamic programming to exhaustively search all contraction paths
456 without outer-products.
457 - `'greedy'` An cheap algorithm that heuristically chooses the best
458 pairwise contraction at each step. Scales linearly in the number of
459 terms in the contraction.
460 - `'random-greedy'` Run a randomized version of the greedy algorithm
461 32 times and pick the best path.
462 - `'random-greedy-128'` Run a randomized version of the greedy
463 algorithm 128 times and pick the best path.
464 - `'branch-all'` An algorithm like optimal but that restricts itself
465 to searching 'likely' paths. Still scales factorially.
466 - `'branch-2'` An even more restricted version of 'branch-all' that
467 only searches the best two options at each step. Scales exponentially
468 with the number of terms in the contraction.
469 - `'auto'` Choose the best of the above algorithms whilst aiming to
470 keep the path finding time below 1ms.
471 - `'auto-hq'` Aim for a high quality contraction, choosing the best
472 of the above algorithms whilst aiming to keep the path finding time
473 below 1sec.
475 - **memory_limit** - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build.
476 - None or -1 means there is no limit
477 - `max_input` means the limit is set as largest input tensor
478 - a positive integer is taken as an explicit limit on the number of elements
480 The default is None. Note that imposing a limit can make contractions
481 exponentially slower to perform.
483 - **backend** - *(str, optional (default: ``auto``))* Which library to use to perform the required ``tensordot``, ``transpose``
484 and ``einsum`` calls. Should match the types of arrays supplied, See
485 :func:`contract_expression` for generating expressions which convert
486 numpy arrays to and from the backend library automatically.
488 **Returns:**
490 - **out** - *(array_like)* The result of the einsum expression.
492 **Notes:**
494 This function should produce a result identical to that of NumPy's einsum
495 function. The primary difference is ``contract`` will attempt to form
496 intermediates which reduce the overall scaling of the given einsum contraction.
497 By default the worst intermediate formed will be equal to that of the largest
498 input array. For large einsum expressions with many input arrays this can
499 provide arbitrarily large (1000 fold+) speed improvements.
501 For contractions with just two tensors this function will attempt to use
502 NumPy's built-in BLAS functionality to ensure that the given operation is
503 performed optimally. When NumPy is linked to a threaded BLAS, potential
504 speedups are on the order of 20-100 for a six core machine.
505 """
506 optimize_arg = kwargs.pop("optimize", True)
507 if optimize_arg is True:
508 optimize_arg = "auto"
510 valid_einsum_kwargs = ["out", "dtype", "order", "casting"]
511 einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs}
513 # If no optimization, run pure einsum
514 if optimize_arg is False:
515 return _einsum(*operands_, **einsum_kwargs)
517 # Grab non-einsum kwargs
518 use_blas = kwargs.pop("use_blas", True)
519 memory_limit = kwargs.pop("memory_limit", None)
520 backend = kwargs.pop("backend", "auto")
521 gen_expression = kwargs.pop("_gen_expression", False)
522 constants_dict = kwargs.pop("_constants_dict", {})
524 # Make sure remaining keywords are valid for einsum
525 unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs]
526 if len(unknown_kwargs):
527 raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs))
529 if gen_expression:
530 full_str = operands_[0]
532 # Build the contraction list and operand
533 operands: Sequence[ArrayType]
534 contraction_list: ContractionListType
535 operands, contraction_list = contract_path( # type: ignore
536 *operands_, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas
537 )
539 # check if performing contraction or just building expression
540 if gen_expression:
541 return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs)
543 return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
546@lru_cache(None)
547def _infer_backend_class_cached(cls: type) -> str:
548 return cls.__module__.split(".")[0]
551def infer_backend(x: Any) -> str:
552 return _infer_backend_class_cached(x.__class__)
555def parse_backend(arrays: Sequence[ArrayType], backend: Optional[str]) -> str:
556 """Find out what backend we should use, dipatching based on the first
557 array if ``backend='auto'`` is specified.
558 """
559 if (backend != "auto") and (backend is not None):
560 return backend
561 backend = infer_backend(arrays[0])
563 # some arrays will be defined in modules that don't implement tensordot
564 # etc. so instead default to numpy
565 if not backends.has_tensordot(backend):
566 return "numpy"
568 return backend
571def _core_contract(
572 operands_: Sequence[ArrayType],
573 contraction_list: ContractionListType,
574 backend: Optional[str] = "auto",
575 evaluate_constants: bool = False,
576 **einsum_kwargs: Any,
577) -> ArrayType:
578 """Inner loop used to perform an actual contraction given the output
579 from a ``contract_path(..., einsum_call=True)`` call.
580 """
582 # Special handling if out is specified
583 out_array = einsum_kwargs.pop("out", None)
584 specified_out = out_array is not None
586 operands = list(operands_)
587 backend = parse_backend(operands, backend)
589 # try and do as much as possible without einsum if not available
590 no_einsum = not backends.has_einsum(backend)
592 # Start contraction loop
593 for num, contraction in enumerate(contraction_list):
594 inds, idx_rm, einsum_str, _, blas_flag = contraction
596 # check if we are performing the pre-pass of an expression with constants,
597 # if so, break out upon finding first non-constant (None) operand
598 if evaluate_constants and any(operands[x] is None for x in inds):
599 return operands, contraction_list[num:]
601 tmp_operands = [operands.pop(x) for x in inds]
603 # Do we need to deal with the output?
604 handle_out = specified_out and ((num + 1) == len(contraction_list))
606 # Call tensordot (check if should prefer einsum, but only if available)
607 if blas_flag and ("EINSUM" not in blas_flag or no_einsum): # type: ignore
609 # Checks have already been handled
610 input_str, results_index = einsum_str.split("->")
611 input_left, input_right = input_str.split(",")
613 tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)
615 if idx_rm:
616 # Find indices to contract over
617 left_pos, right_pos = [], []
618 for s in idx_rm:
619 left_pos.append(input_left.find(s))
620 right_pos.append(input_right.find(s))
622 # Construct the axes tuples in a canonical order
623 axes = tuple(zip(*sorted(zip(left_pos, right_pos))))
624 else:
625 # Ensure axes is always pair of tuples
626 axes = ((), ())
628 # Contract!
629 new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)
631 # Build a new view if needed
632 if (tensor_result != results_index) or handle_out:
634 transpose = tuple(map(tensor_result.index, results_index))
635 new_view = _transpose(new_view, axes=transpose, backend=backend)
637 if handle_out:
638 out_array[:] = new_view
640 # Call einsum
641 else:
642 # If out was specified
643 if handle_out:
644 einsum_kwargs["out"] = out_array
646 # Do the contraction
647 new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
649 # Append new items and dereference what we can
650 operands.append(new_view)
651 del tmp_operands, new_view
653 if specified_out:
654 return out_array
655 else:
656 return operands[0]
659def format_const_einsum_str(einsum_str: str, constants: Iterable[int]) -> str:
660 """Add brackets to the constant terms in ``einsum_str``. For example:
662 >>> format_const_einsum_str('ab,bc,cd->ad', [0, 2])
663 'bc,[ab,cd]->ad'
665 No-op if there are no constants.
666 """
667 if not constants:
668 return einsum_str
670 if "->" in einsum_str:
671 lhs, rhs = einsum_str.split("->")
672 arrow = "->"
673 else:
674 lhs, rhs, arrow = einsum_str, "", ""
676 wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(","))]
678 formatted_einsum_str = "{}{}{}".format(",".join(wrapped_terms), arrow, rhs)
680 # merge adjacent constants
681 formatted_einsum_str = formatted_einsum_str.replace("],[", ",")
682 return formatted_einsum_str
685class ContractExpression:
686 """Helper class for storing an explicit ``contraction_list`` which can
687 then be repeatedly called solely with the array arguments.
688 """
690 def __init__(
691 self,
692 contraction: str,
693 contraction_list: ContractionListType,
694 constants_dict: Dict[int, ArrayType],
695 **einsum_kwargs: Any,
696 ):
697 self.contraction_list = contraction_list
698 self.einsum_kwargs = einsum_kwargs
699 self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
701 # need to know _full_num_args to parse constants with, and num_args to call with
702 self._full_num_args = contraction.count(",") + 1
703 self.num_args = self._full_num_args - len(constants_dict)
705 # likewise need to know full contraction list
706 self._full_contraction_list = contraction_list
708 self._constants_dict = constants_dict
709 self._evaluated_constants: Dict[str, Any] = {}
710 self._backend_expressions: Dict[str, Any] = {}
712 def evaluate_constants(self, backend: Optional[str] = "auto") -> None:
713 """Convert any constant operands to the correct backend form, and
714 perform as many contractions as possible to create a new list of
715 operands, stored in ``self._evaluated_constants[backend]``. This also
716 makes sure ``self.contraction_list`` only contains the remaining,
717 non-const operations.
718 """
719 # prepare a list of operands, with `None` for non-consts
720 tmp_const_ops = [self._constants_dict.get(i, None) for i in range(self._full_num_args)]
721 backend = parse_backend(tmp_const_ops, backend)
723 # get the new list of operands with constant operations performed, and remaining contractions
724 try:
725 new_ops, new_contraction_list = backends.evaluate_constants(backend, tmp_const_ops, self)
726 except KeyError:
727 new_ops, new_contraction_list = self(*tmp_const_ops, backend=backend, evaluate_constants=True)
729 self._evaluated_constants[backend] = new_ops
730 self.contraction_list = new_contraction_list
732 def _get_evaluated_constants(self, backend: str) -> List[Optional[ArrayType]]:
733 """Retrieve or generate the cached list of constant operators (mixed
734 in with None representing non-consts) and the remaining contraction
735 list.
736 """
737 try:
738 return self._evaluated_constants[backend]
739 except KeyError:
740 self.evaluate_constants(backend)
741 return self._evaluated_constants[backend]
743 def _get_backend_expression(self, arrays: Sequence[ArrayType], backend: str) -> Any:
744 try:
745 return self._backend_expressions[backend]
746 except KeyError:
747 fn = backends.build_expression(backend, arrays, self)
748 self._backend_expressions[backend] = fn
749 return fn
751 def _contract(
752 self,
753 arrays: Sequence[ArrayType],
754 out: Optional[ArrayType] = None,
755 backend: Optional[str] = "auto",
756 evaluate_constants: bool = False,
757 ) -> ArrayType:
758 """The normal, core contraction."""
759 contraction_list = self._full_contraction_list if evaluate_constants else self.contraction_list
761 return _core_contract(
762 list(arrays),
763 contraction_list,
764 out=out,
765 backend=backend,
766 evaluate_constants=evaluate_constants,
767 **self.einsum_kwargs,
768 )
770 def _contract_with_conversion(
771 self,
772 arrays: Sequence[ArrayType],
773 out: Optional[ArrayType],
774 backend: str,
775 evaluate_constants: bool = False,
776 ) -> ArrayType:
777 """Special contraction, i.e., contraction with a different backend
778 but converting to and from that backend. Retrieves or generates a
779 cached expression using ``arrays`` as templates, then calls it
780 with ``arrays``.
782 If ``evaluate_constants=True``, perform a partial contraction that
783 prepares the constant tensors and operations with the right backend.
784 """
785 # convert consts to correct type & find reduced contraction list
786 if evaluate_constants:
787 return backends.evaluate_constants(backend, arrays, self)
789 result = self._get_backend_expression(arrays, backend)(*arrays)
791 if out is not None:
792 out[()] = result
793 return out
795 return result
797 def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType:
798 """Evaluate this expression with a set of arrays.
800 Parameters
801 ----------
802 arrays : seq of array
803 The arrays to supply as input to the expression.
804 out : array, optional (default: ``None``)
805 If specified, output the result into this array.
806 backend : str, optional (default: ``numpy``)
807 Perform the contraction with this backend library. If numpy arrays
808 are supplied then try to convert them to and from the correct
809 backend array type.
810 """
811 out = kwargs.pop("out", None)
812 backend = parse_backend(arrays, kwargs.pop("backend", "auto"))
813 evaluate_constants = kwargs.pop("evaluate_constants", False)
815 if kwargs:
816 raise ValueError(
817 "The only valid keyword arguments to a `ContractExpression` "
818 "call are `out=` or `backend=`. Got: {}.".format(kwargs)
819 )
821 correct_num_args = self._full_num_args if evaluate_constants else self.num_args
823 if len(arrays) != correct_num_args:
824 raise ValueError(
825 "This `ContractExpression` takes exactly {} array arguments "
826 "but received {}.".format(self.num_args, len(arrays))
827 )
829 if self._constants_dict and not evaluate_constants:
830 # fill in the missing non-constant terms with newly supplied arrays
831 ops_var, ops_const = iter(arrays), self._get_evaluated_constants(backend)
832 ops: Sequence[ArrayType] = [next(ops_var) if op is None else op for op in ops_const]
833 else:
834 ops = arrays
836 try:
837 # Check if the backend requires special preparation / calling
838 # but also ignore non-numpy arrays -> assume user wants same type back
839 if backends.has_backend(backend) and all(infer_backend(x) == "numpy" for x in arrays):
840 return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)
842 return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
844 except ValueError as err:
845 original_msg = str(err.args) if err.args else ""
846 msg = (
847 "Internal error while evaluating `ContractExpression`. Note that few checks are performed"
848 " - the number and rank of the array arguments must match the original expression. "
849 "The internal error was: '{}'".format(original_msg),
850 )
851 err.args = msg
852 raise
854 def __repr__(self) -> str:
855 if self._constants_dict:
856 constants_repr = ", constants={}".format(sorted(self._constants_dict))
857 else:
858 constants_repr = ""
859 return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr)
861 def __str__(self) -> str:
862 s = [self.__repr__()]
863 for i, c in enumerate(self.contraction_list):
864 s.append("\n {}. ".format(i + 1))
865 s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else ""))
866 if self.einsum_kwargs:
867 s.append("\neinsum_kwargs={}".format(self.einsum_kwargs))
868 return "".join(s)
871Shaped = namedtuple("Shaped", ["shape"])
874def shape_only(shape: PathType) -> Shaped:
875 """Dummy ``numpy.ndarray`` which has a shape only - for generating
876 contract expressions.
877 """
878 return Shaped(shape)
881def contract_expression(subscripts: str, *shapes: PathType, **kwargs: Any) -> Any:
882 """Generate a reusable expression for a given contraction with
883 specific shapes, which can, for example, be cached.
885 **Parameters:**
887 - **subscripts** - *(str)* Specifies the subscripts for summation.
888 - **shapes** - *(sequence of integer tuples)* Shapes of the arrays to optimize the contraction for.
889 - **constants** - *(sequence of int, optional)* The indices of any constant arguments in `shapes`, in which case the
890 actual array should be supplied at that position rather than just a
891 shape. If these are specified, then constant parts of the contraction
892 between calls will be reused. Additionally, if a GPU-enabled backend is
893 used for example, then the constant tensors will be kept on the GPU,
894 minimizing transfers.
895 - **kwargs** - Passed on to `contract_path` or `einsum`. See `contract`.
897 **Returns:**
899 - **expr** - *(ContractExpression)* Callable with signature `expr(*arrays, out=None, backend='numpy')` where the array's shapes should match `shapes`.
901 **Notes:**
903 - The `out` keyword argument should be supplied to the generated expression
904 rather than this function.
905 - The `backend` keyword argument should also be supplied to the generated
906 expression. If numpy arrays are supplied, if possible they will be
907 converted to and back from the correct backend array type.
908 - The generated expression will work with any arrays which have
909 the same rank (number of dimensions) as the original shapes, however, if
910 the actual sizes are different, the expression may no longer be optimal.
911 - Constant operations will be computed upon the first call with a particular
912 backend, then subsequently reused.
914 **Examples:**
916 Basic usage:
918 ```python
919 expr = contract_expression("ab,bc->ac", (3, 4), (4, 5))
920 a, b = np.random.rand(3, 4), np.random.rand(4, 5)
921 c = expr(a, b)
922 np.allclose(c, a @ b)
923 #> True
924 ```
926 Supply `a` as a constant:
928 ```python
929 expr = contract_expression("ab,bc->ac", a, (4, 5), constants=[0])
930 expr
931 #> <ContractExpression('[ab],bc->ac', constants=[0])>
933 c = expr(b)
934 np.allclose(c, a @ b)
935 #> True
936 ```
938 """
939 if not kwargs.get("optimize", True):
940 raise ValueError("Can only generate expressions for optimized contractions.")
942 for arg in ("out", "backend"):
943 if kwargs.get(arg, None) is not None:
944 raise ValueError(
945 "'{}' should only be specified when calling a "
946 "`ContractExpression`, not when building it.".format(arg)
947 )
949 if not isinstance(subscripts, str):
950 subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes)
952 kwargs["_gen_expression"] = True
954 # build dict of constant indices mapped to arrays
955 constants = kwargs.pop("constants", ())
956 constants_dict = {i: shapes[i] for i in constants}
957 kwargs["_constants_dict"] = constants_dict
959 # apart from constant arguments, make dummy arrays
960 dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)]
962 return contract(subscripts, *dummy_arrays, **kwargs)