Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/contract.py: 11%
315 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""
2Contains the primary optimization and contraction routines.
3"""
5from collections import namedtuple
6from decimal import Decimal
8import numpy as np
10from . import backends, blas, helpers, parser, paths, sharing
12__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only"]
15class PathInfo(object):
16 """A printable object to contain information about a contraction path.
18 Attributes
19 ----------
20 naive_cost : int
21 The estimate FLOP cost of a naive einsum contraction.
22 opt_cost : int
23 The estimate FLOP cost of this optimized contraction path.
24 largest_intermediate : int
25 The number of elements in the largest intermediate array that will be
26 produced during the contraction.
27 """
28 def __init__(self, contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost,
29 opt_cost, size_list, size_dict):
30 self.contraction_list = contraction_list
31 self.input_subscripts = input_subscripts
32 self.output_subscript = output_subscript
33 self.path = path
34 self.indices = indices
35 self.scale_list = scale_list
36 self.naive_cost = Decimal(naive_cost)
37 self.opt_cost = Decimal(opt_cost)
38 self.speedup = self.naive_cost / self.opt_cost
39 self.size_list = size_list
40 self.size_dict = size_dict
42 self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(',')]
43 self.eq = "{}->{}".format(input_subscripts, output_subscript)
44 self.largest_intermediate = Decimal(max(size_list))
46 def __repr__(self):
47 # Return the path along with a nice string representation
48 header = ("scaling", "BLAS", "current", "remaining")
50 path_print = [
51 " Complete contraction: {}\n".format(self.eq), " Naive scaling: {}\n".format(len(self.indices)),
52 " Optimized scaling: {}\n".format(max(self.scale_list)), " Naive FLOP count: {:.3e}\n".format(
53 self.naive_cost), " Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
54 " Theoretical speedup: {:.3e}\n".format(self.speedup),
55 " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), "-" * 80 + "\n",
56 "{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80
57 ]
59 for n, contraction in enumerate(self.contraction_list):
60 inds, idx_rm, einsum_str, remaining, do_blas = contraction
62 if remaining is not None:
63 remaining_str = ",".join(remaining) + "->" + self.output_subscript
64 else:
65 remaining_str = "..."
66 size_remaining = max(0, 56 - max(22, len(einsum_str)))
68 path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str, size_remaining)
69 path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run))
71 return "".join(path_print)
74def _choose_memory_arg(memory_limit, size_list):
75 if memory_limit == 'max_input':
76 return max(size_list)
78 if memory_limit is None:
79 return None
81 if memory_limit < 1:
82 if memory_limit == -1:
83 return None
84 else:
85 raise ValueError("Memory limit must be larger than 0, or -1")
87 return int(memory_limit)
90_VALID_CONTRACT_KWARGS = {'optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas', 'shapes'}
93def contract_path(*operands, **kwargs):
94 """
95 Find a contraction order 'path', without performing the contraction.
97 Parameters
98 ----------
99 subscripts : str
100 Specifies the subscripts for summation.
101 *operands : list of array_like
102 These are the arrays for the operation.
103 optimize : str, list or bool, optional (default: ``auto``)
104 Choose the type of path.
106 - if a list is given uses this as the path.
107 - ``'optimal'`` An algorithm that explores all possible ways of
108 contracting the listed tensors. Scales factorially with the number of
109 terms in the contraction.
110 - ``'branch-all'`` An algorithm like optimal but that restricts itself
111 to searching 'likely' paths. Still scales factorially.
112 - ``'branch-2'`` An even more restricted version of 'branch-all' that
113 only searches the best two options at each step. Scales exponentially
114 with the number of terms in the contraction.
115 - ``'greedy'`` An algorithm that heuristically chooses the best pair
116 contraction at each step.
117 - ``'auto'`` Choose the best of the above algorithms whilst aiming to
118 keep the path finding time below 1ms.
120 use_blas : bool
121 Use BLAS functions or not
122 memory_limit : int, optional (default: None)
123 Maximum number of elements allowed in intermediate arrays.
124 shapes : bool, optional
125 Whether ``contract_path`` should assume arrays (the default) or array
126 shapes have been supplied.
128 Returns
129 -------
130 path : list of tuples
131 The einsum path
132 PathInfo : str
133 A printable object containing various information about the path found.
135 Notes
136 -----
137 The resulting path indicates which terms of the input contraction should be
138 contracted first, the result of this contraction is then appended to the end of
139 the contraction list.
141 Examples
142 --------
144 We can begin with a chain dot example. In this case, it is optimal to
145 contract the b and c tensors represented by the first element of the path (1,
146 2). The resulting tensor is added to the end of the contraction and the
147 remaining contraction, ``(0, 1)``, is then executed.
149 >>> a = np.random.rand(2, 2)
150 >>> b = np.random.rand(2, 5)
151 >>> c = np.random.rand(5, 2)
152 >>> path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c)
153 >>> print(path_info[0])
154 [(1, 2), (0, 1)]
155 >>> print(path_info[1])
156 Complete contraction: ij,jk,kl->il
157 Naive scaling: 4
158 Optimized scaling: 3
159 Naive FLOP count: 1.600e+02
160 Optimized FLOP count: 5.600e+01
161 Theoretical speedup: 2.857
162 Largest intermediate: 4.000e+00 elements
163 -------------------------------------------------------------------------
164 scaling current remaining
165 -------------------------------------------------------------------------
166 3 kl,jk->jl ij,jl->il
167 3 jl,ij->il il->il
170 A more complex index transformation example.
172 >>> I = np.random.rand(10, 10, 10, 10)
173 >>> C = np.random.rand(10, 10)
174 >>> path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C)
176 >>> print(path_info[0])
177 [(0, 2), (0, 3), (0, 2), (0, 1)]
178 >>> print(path_info[1])
179 Complete contraction: ea,fb,abcd,gc,hd->efgh
180 Naive scaling: 8
181 Optimized scaling: 5
182 Naive FLOP count: 8.000e+08
183 Optimized FLOP count: 8.000e+05
184 Theoretical speedup: 1000.000
185 Largest intermediate: 1.000e+04 elements
186 --------------------------------------------------------------------------
187 scaling current remaining
188 --------------------------------------------------------------------------
189 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
190 5 bcde,fb->cdef gc,hd,cdef->efgh
191 5 cdef,gc->defg hd,defg->efgh
192 5 defg,hd->efgh efgh->efgh
193 """
195 # Make sure all keywords are valid
196 unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS
197 if len(unknown_kwargs):
198 raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs))
200 path_type = kwargs.pop('optimize', 'auto')
202 memory_limit = kwargs.pop('memory_limit', None)
203 shapes = kwargs.pop('shapes', False)
205 # Hidden option, only einsum should call this
206 einsum_call_arg = kwargs.pop("einsum_call", False)
207 use_blas = kwargs.pop('use_blas', True)
209 # Python side parsing
210 input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands)
212 # Build a few useful list and sets
213 input_list = input_subscripts.split(',')
214 input_sets = [set(x) for x in input_list]
215 if shapes:
216 input_shps = operands
217 else:
218 input_shps = [x.shape for x in operands]
219 output_set = set(output_subscript)
220 indices = set(input_subscripts.replace(',', ''))
222 # Get length of each unique dimension and ensure all dimensions are correct
223 size_dict = {}
224 for tnum, term in enumerate(input_list):
225 sh = input_shps[tnum]
227 if len(sh) != len(term):
228 raise ValueError("Einstein sum subscript '{}' does not contain the "
229 "correct number of indices for operand {}.".format(input_list[tnum], tnum))
230 for cnum, char in enumerate(term):
231 dim = int(sh[cnum])
233 if char in size_dict:
234 # For broadcasting cases we always want the largest dim size
235 if size_dict[char] == 1:
236 size_dict[char] = dim
237 elif dim not in (1, size_dict[char]):
238 raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
239 "terms ({}).".format(char, tnum, size_dict[char], dim))
240 else:
241 size_dict[char] = dim
243 # Compute size of each input array plus the output array
244 size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]]
245 memory_arg = _choose_memory_arg(memory_limit, size_list)
247 num_ops = len(input_list)
249 # Compute naive cost
250 # This isnt quite right, need to look into exactly how einsum does this
251 # indices_in_input = input_subscripts.replace(',', '')
253 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
254 naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)
256 # Compute the path
257 if not isinstance(path_type, (str, paths.PathOptimizer)):
258 # Custom path supplied
259 path = path_type
260 elif num_ops <= 2:
261 # Nothing to be optimized
262 path = [tuple(range(num_ops))]
263 elif isinstance(path_type, paths.PathOptimizer):
264 # Custom path optimizer supplied
265 path = path_type(input_sets, output_set, size_dict, memory_arg)
266 else:
267 path_optimizer = paths.get_path_fn(path_type)
268 path = path_optimizer(input_sets, output_set, size_dict, memory_arg)
270 cost_list = []
271 scale_list = []
272 size_list = []
273 contraction_list = []
275 # Build contraction tuple (positions, gemm, einsum_str, remaining)
276 for cnum, contract_inds in enumerate(path):
277 # Make sure we remove inds from right to left
278 contract_inds = tuple(sorted(list(contract_inds), reverse=True))
280 contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
281 out_inds, input_sets, idx_removed, idx_contract = contract_tuple
283 # Compute cost, scale, and size
284 cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
285 cost_list.append(cost)
286 scale_list.append(len(idx_contract))
287 size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))
289 tmp_inputs = [input_list.pop(x) for x in contract_inds]
290 tmp_shapes = [input_shps.pop(x) for x in contract_inds]
292 if use_blas:
293 do_blas = blas.can_blas(tmp_inputs, out_inds, idx_removed, tmp_shapes)
294 else:
295 do_blas = False
297 # Last contraction
298 if (cnum - len(path)) == -1:
299 idx_result = output_subscript
300 else:
301 # use tensordot order to minimize transpositions
302 all_input_inds = "".join(tmp_inputs)
303 idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
305 shp_result = parser.find_output_shape(tmp_inputs, tmp_shapes, idx_result)
307 input_list.append(idx_result)
308 input_shps.append(shp_result)
310 einsum_str = ",".join(tmp_inputs) + "->" + idx_result
312 # for large expressions saving the remaining terms at each step can
313 # incur a large memory footprint - and also be messy to print
314 if len(input_list) <= 20:
315 remaining = tuple(input_list)
316 else:
317 remaining = None
319 contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas)
320 contraction_list.append(contraction)
322 opt_cost = sum(cost_list)
324 if einsum_call_arg:
325 return operands, contraction_list
327 path_print = PathInfo(contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost,
328 opt_cost, size_list, size_dict)
330 return path, path_print
333@sharing.einsum_cache_wrap
334def _einsum(*operands, **kwargs):
335 """Base einsum, but with pre-parse for valid characters if a string is given.
336 """
337 fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
339 if not isinstance(operands[0], str):
340 return fn(*operands, **kwargs)
342 einsum_str, operands = operands[0], operands[1:]
344 # Do we need to temporarily map indices into [a-z,A-Z] range?
345 if not parser.has_valid_einsum_chars_only(einsum_str):
347 # Explicitly find output str first so as to maintain order
348 if '->' not in einsum_str:
349 einsum_str += '->' + parser.find_output_str(einsum_str)
351 einsum_str = parser.convert_to_valid_einsum_chars(einsum_str)
353 return fn(einsum_str, *operands, **kwargs)
356def _default_transpose(x, axes):
357 # most libraries implement a method version
358 return x.transpose(axes)
361@sharing.transpose_cache_wrap
362def _transpose(x, axes, backend='numpy'):
363 """Base transpose.
364 """
365 fn = backends.get_func('transpose', backend, _default_transpose)
366 return fn(x, axes)
369@sharing.tensordot_cache_wrap
370def _tensordot(x, y, axes, backend='numpy'):
371 """Base tensordot.
372 """
373 fn = backends.get_func('tensordot', backend)
374 return fn(x, y, axes=axes)
377# Rewrite einsum to handle different cases
378def contract(*operands, **kwargs):
379 """
380 contract(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', use_blas=True, optimize=True, memory_limit=None, backend='numpy')
382 Evaluates the Einstein summation convention on the operands. A drop in
383 replacement for NumPy's einsum function that optimizes the order of contraction
384 to reduce overall scaling at the cost of several intermediate arrays.
386 Parameters
387 ----------
388 subscripts : str
389 Specifies the subscripts for summation.
390 *operands : list of array_like
391 These are the arrays for the operation.
392 out : array_like
393 A output array in which set the resulting output.
394 dtype : str
395 The dtype of the given contraction, see np.einsum.
396 order : str
397 The order of the resulting contraction, see np.einsum.
398 casting : str
399 The casting procedure for operations of different dtype, see np.einsum.
400 use_blas : bool
401 Do you use BLAS for valid operations, may use extra memory for more intermediates.
402 optimize : str, list or bool, optional (default: ``auto``)
403 Choose the type of path.
405 - if a list is given uses this as the path.
406 - ``'optimal'`` An algorithm that explores all possible ways of
407 contracting the listed tensors. Scales factorially with the number of
408 terms in the contraction.
409 - ``'dp'`` A faster (but essentially optimal) algorithm that uses
410 dynamic programming to exhaustively search all contraction paths
411 without outer-products.
412 - ``'greedy'`` An cheap algorithm that heuristically chooses the best
413 pairwise contraction at each step. Scales linearly in the number of
414 terms in the contraction.
415 - ``'random-greedy'`` Run a randomized version of the greedy algorithm
416 32 times and pick the best path.
417 - ``'random-greedy-128'`` Run a randomized version of the greedy
418 algorithm 128 times and pick the best path.
419 - ``'branch-all'`` An algorithm like optimal but that restricts itself
420 to searching 'likely' paths. Still scales factorially.
421 - ``'branch-2'`` An even more restricted version of 'branch-all' that
422 only searches the best two options at each step. Scales exponentially
423 with the number of terms in the contraction.
424 - ``'auto'`` Choose the best of the above algorithms whilst aiming to
425 keep the path finding time below 1ms.
426 - ``'auto-hq'`` Aim for a high quality contraction, choosing the best
427 of the above algorithms whilst aiming to keep the path finding time
428 below 1sec.
430 memory_limit : {None, int, 'max_input'} (default: None)
431 Give the upper bound of the largest intermediate tensor contract will build.
433 - None or -1 means there is no limit
434 - 'max_input' means the limit is set as largest input tensor
435 - a positive integer is taken as an explicit limit on the number of elements
437 The default is None. Note that imposing a limit can make contractions
438 exponentially slower to perform.
439 backend : str, optional (default: ``auto``)
440 Which library to use to perform the required ``tensordot``, ``transpose``
441 and ``einsum`` calls. Should match the types of arrays supplied, See
442 :func:`contract_expression` for generating expressions which convert
443 numpy arrays to and from the backend library automatically.
445 Returns
446 -------
447 out : array_like
448 The result of the einsum expression.
450 Notes
451 -----
452 This function should produce a result identical to that of NumPy's einsum
453 function. The primary difference is ``contract`` will attempt to form
454 intermediates which reduce the overall scaling of the given einsum contraction.
455 By default the worst intermediate formed will be equal to that of the largest
456 input array. For large einsum expressions with many input arrays this can
457 provide arbitrarily large (1000 fold+) speed improvements.
459 For contractions with just two tensors this function will attempt to use
460 NumPy's built-in BLAS functionality to ensure that the given operation is
461 preformed optimally. When NumPy is linked to a threaded BLAS, potential
462 speedups are on the order of 20-100 for a six core machine.
464 Examples
465 --------
467 See :func:`opt_einsum.contract_path` or :func:`numpy.einsum`
469 """
470 optimize_arg = kwargs.pop('optimize', True)
471 if optimize_arg is True:
472 optimize_arg = 'auto'
474 valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
475 einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs}
477 # If no optimization, run pure einsum
478 if optimize_arg is False:
479 return _einsum(*operands, **einsum_kwargs)
481 # Grab non-einsum kwargs
482 use_blas = kwargs.pop('use_blas', True)
483 memory_limit = kwargs.pop('memory_limit', None)
484 backend = kwargs.pop('backend', 'auto')
485 gen_expression = kwargs.pop('_gen_expression', False)
486 constants_dict = kwargs.pop('_constants_dict', {})
488 # Make sure remaining keywords are valid for einsum
489 unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs]
490 if len(unknown_kwargs):
491 raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs))
493 if gen_expression:
494 full_str = operands[0]
496 # Build the contraction list and operand
497 operands, contraction_list = contract_path(*operands,
498 optimize=optimize_arg,
499 memory_limit=memory_limit,
500 einsum_call=True,
501 use_blas=use_blas)
503 # check if performing contraction or just building expression
504 if gen_expression:
505 return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs)
507 return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
510def infer_backend(x):
511 return x.__class__.__module__.split('.')[0]
514def parse_backend(arrays, backend):
515 """Find out what backend we should use, dipatching based on the first
516 array if ``backend='auto'`` is specified.
517 """
518 if backend != 'auto':
519 return backend
520 backend = infer_backend(arrays[0])
522 # some arrays will be defined in modules that don't implement tensordot
523 # etc. so instead default to numpy
524 if not backends.has_tensordot(backend):
525 return 'numpy'
527 return backend
530def _core_contract(operands, contraction_list, backend='auto', evaluate_constants=False, **einsum_kwargs):
531 """Inner loop used to perform an actual contraction given the output
532 from a ``contract_path(..., einsum_call=True)`` call.
533 """
535 # Special handling if out is specified
536 out_array = einsum_kwargs.pop('out', None)
537 specified_out = out_array is not None
538 backend = parse_backend(operands, backend)
540 # try and do as much as possible without einsum if not available
541 no_einsum = not backends.has_einsum(backend)
543 # Start contraction loop
544 for num, contraction in enumerate(contraction_list):
545 inds, idx_rm, einsum_str, _, blas_flag = contraction
547 # check if we are performing the pre-pass of an expression with constants,
548 # if so, break out upon finding first non-constant (None) operand
549 if evaluate_constants and any(operands[x] is None for x in inds):
550 return operands, contraction_list[num:]
552 tmp_operands = [operands.pop(x) for x in inds]
554 # Do we need to deal with the output?
555 handle_out = specified_out and ((num + 1) == len(contraction_list))
557 # Call tensordot (check if should prefer einsum, but only if available)
558 if blas_flag and ('EINSUM' not in blas_flag or no_einsum):
560 # Checks have already been handled
561 input_str, results_index = einsum_str.split('->')
562 input_left, input_right = input_str.split(',')
564 tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)
566 # Find indices to contract over
567 left_pos, right_pos = [], []
568 for s in idx_rm:
569 left_pos.append(input_left.find(s))
570 right_pos.append(input_right.find(s))
572 # Contract!
573 new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
575 # Build a new view if needed
576 if (tensor_result != results_index) or handle_out:
578 transpose = tuple(map(tensor_result.index, results_index))
579 new_view = _transpose(new_view, axes=transpose, backend=backend)
581 if handle_out:
582 out_array[:] = new_view
584 # Call einsum
585 else:
586 # If out was specified
587 if handle_out:
588 einsum_kwargs["out"] = out_array
590 # Do the contraction
591 new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
593 # Append new items and dereference what we can
594 operands.append(new_view)
595 del tmp_operands, new_view
597 if specified_out:
598 return out_array
599 else:
600 return operands[0]
603def format_const_einsum_str(einsum_str, constants):
604 """Add brackets to the constant terms in ``einsum_str``. For example:
606 >>> format_const_einsum_str('ab,bc,cd->ad', [0, 2])
607 'bc,[ab,cd]->ad'
609 No-op if there are no constants.
610 """
611 if not constants:
612 return einsum_str
614 if "->" in einsum_str:
615 lhs, rhs = einsum_str.split('->')
616 arrow = "->"
617 else:
618 lhs, rhs, arrow = einsum_str, "", ""
620 wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(','))]
622 formatted_einsum_str = "{}{}{}".format(','.join(wrapped_terms), arrow, rhs)
624 # merge adjacent constants
625 formatted_einsum_str = formatted_einsum_str.replace("],[", ',')
626 return formatted_einsum_str
629class ContractExpression:
630 """Helper class for storing an explicit ``contraction_list`` which can
631 then be repeatedly called solely with the array arguments.
632 """
633 def __init__(self, contraction, contraction_list, constants_dict, **einsum_kwargs):
634 self.contraction_list = contraction_list
635 self.einsum_kwargs = einsum_kwargs
636 self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
638 # need to know _full_num_args to parse constants with, and num_args to call with
639 self._full_num_args = contraction.count(',') + 1
640 self.num_args = self._full_num_args - len(constants_dict)
642 # likewise need to know full contraction list
643 self._full_contraction_list = contraction_list
645 self._constants_dict = constants_dict
646 self._evaluated_constants = {}
647 self._backend_expressions = {}
649 def evaluate_constants(self, backend='auto'):
650 """Convert any constant operands to the correct backend form, and
651 perform as many contractions as possible to create a new list of
652 operands, stored in ``self._evaluated_constants[backend]``. This also
653 makes sure ``self.contraction_list`` only contains the remaining,
654 non-const operations.
655 """
656 # prepare a list of operands, with `None` for non-consts
657 tmp_const_ops = [self._constants_dict.get(i, None) for i in range(self._full_num_args)]
658 backend = parse_backend(tmp_const_ops, backend)
660 # get the new list of operands with constant operations performed, and remaining contractions
661 try:
662 new_ops, new_contraction_list = backends.evaluate_constants(backend, tmp_const_ops, self)
663 except KeyError:
664 new_ops, new_contraction_list = self(*tmp_const_ops, backend=backend, evaluate_constants=True)
666 self._evaluated_constants[backend] = new_ops
667 self.contraction_list = new_contraction_list
669 def _get_evaluated_constants(self, backend):
670 """Retrieve or generate the cached list of constant operators (mixed
671 in with None representing non-consts) and the remaining contraction
672 list.
673 """
674 try:
675 return self._evaluated_constants[backend]
676 except KeyError:
677 self.evaluate_constants(backend)
678 return self._evaluated_constants[backend]
680 def _get_backend_expression(self, arrays, backend):
681 try:
682 return self._backend_expressions[backend]
683 except KeyError:
684 fn = backends.build_expression(backend, arrays, self)
685 self._backend_expressions[backend] = fn
686 return fn
688 def _contract(self, arrays, out=None, backend='auto', evaluate_constants=False):
689 """The normal, core contraction.
690 """
691 contraction_list = self._full_contraction_list if evaluate_constants else self.contraction_list
693 return _core_contract(list(arrays),
694 contraction_list,
695 out=out,
696 backend=backend,
697 evaluate_constants=evaluate_constants,
698 **self.einsum_kwargs)
700 def _contract_with_conversion(self, arrays, out, backend, evaluate_constants=False):
701 """Special contraction, i.e., contraction with a different backend
702 but converting to and from that backend. Retrieves or generates a
703 cached expression using ``arrays`` as templates, then calls it
704 with ``arrays``.
706 If ``evaluate_constants=True``, perform a partial contraction that
707 prepares the constant tensors and operations with the right backend.
708 """
709 # convert consts to correct type & find reduced contraction list
710 if evaluate_constants:
711 return backends.evaluate_constants(backend, arrays, self)
713 result = self._get_backend_expression(arrays, backend)(*arrays)
715 if out is not None:
716 out[()] = result
717 return out
719 return result
721 def __call__(self, *arrays, **kwargs):
722 """Evaluate this expression with a set of arrays.
724 Parameters
725 ----------
726 arrays : seq of array
727 The arrays to supply as input to the expression.
728 out : array, optional (default: ``None``)
729 If specified, output the result into this array.
730 backend : str, optional (default: ``numpy``)
731 Perform the contraction with this backend library. If numpy arrays
732 are supplied then try to convert them to and from the correct
733 backend array type.
734 """
735 out = kwargs.pop('out', None)
736 backend = kwargs.pop('backend', 'auto')
737 backend = parse_backend(arrays, backend)
738 evaluate_constants = kwargs.pop('evaluate_constants', False)
740 if kwargs:
741 raise ValueError("The only valid keyword arguments to a `ContractExpression` "
742 "call are `out=` or `backend=`. Got: {}.".format(kwargs))
744 correct_num_args = self._full_num_args if evaluate_constants else self.num_args
746 if len(arrays) != correct_num_args:
747 raise ValueError("This `ContractExpression` takes exactly {} array arguments "
748 "but received {}.".format(self.num_args, len(arrays)))
750 if self._constants_dict and not evaluate_constants:
751 # fill in the missing non-constant terms with newly supplied arrays
752 ops_var, ops_const = iter(arrays), self._get_evaluated_constants(backend)
753 ops = [next(ops_var) if op is None else op for op in ops_const]
754 else:
755 ops = arrays
757 try:
758 # Check if the backend requires special preparation / calling
759 # but also ignore non-numpy arrays -> assume user wants same type back
760 if backends.has_backend(backend) and all(isinstance(x, np.ndarray) for x in arrays):
761 return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)
763 return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
765 except ValueError as err:
766 original_msg = str(err.args) if err.args else ""
767 msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed"
768 " - the number and rank of the array arguments must match the original expression. "
769 "The internal error was: '{}'".format(original_msg), )
770 err.args = msg
771 raise
773 def __repr__(self):
774 if self._constants_dict:
775 constants_repr = ", constants={}".format(sorted(self._constants_dict))
776 else:
777 constants_repr = ""
778 return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr)
780 def __str__(self):
781 s = [self.__repr__()]
782 for i, c in enumerate(self.contraction_list):
783 s.append("\n {}. ".format(i + 1))
784 s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else ""))
785 if self.einsum_kwargs:
786 s.append("\neinsum_kwargs={}".format(self.einsum_kwargs))
787 return "".join(s)
790Shaped = namedtuple('Shaped', ['shape'])
793def shape_only(shape):
794 """Dummy ``numpy.ndarray`` which has a shape only - for generating
795 contract expressions.
796 """
797 return Shaped(shape)
800def contract_expression(subscripts, *shapes, **kwargs):
801 """Generate a reusable expression for a given contraction with
802 specific shapes, which can, for example, be cached.
804 Parameters
805 ----------
806 subscripts : str
807 Specifies the subscripts for summation.
808 shapes : sequence of integer tuples
809 Shapes of the arrays to optimize the contraction for.
810 constants : sequence of int, optional
811 The indices of any constant arguments in ``shapes``, in which case the
812 actual array should be supplied at that position rather than just a
813 shape. If these are specified, then constant parts of the contraction
814 between calls will be reused. Additionally, if a GPU-enabled backend is
815 used for example, then the constant tensors will be kept on the GPU,
816 minimizing transfers.
817 kwargs :
818 Passed on to ``contract_path`` or ``einsum``. See ``contract``.
820 Returns
821 -------
822 expr : ContractExpression
823 Callable with signature ``expr(*arrays, out=None, backend='numpy')``
824 where the array's shapes should match ``shapes``.
826 Notes
827 -----
828 - The `out` keyword argument should be supplied to the generated expression
829 rather than this function.
830 - The `backend` keyword argument should also be supplied to the generated
831 expression. If numpy arrays are supplied, if possible they will be
832 converted to and back from the correct backend array type.
833 - The generated expression will work with any arrays which have
834 the same rank (number of dimensions) as the original shapes, however, if
835 the actual sizes are different, the expression may no longer be optimal.
836 - Constant operations will be computed upon the first call with a particular
837 backend, then subsequently reused.
839 Examples
840 --------
842 Basic usage:
844 >>> expr = contract_expression("ab,bc->ac", (3, 4), (4, 5))
845 >>> a, b = np.random.rand(3, 4), np.random.rand(4, 5)
846 >>> c = expr(a, b)
847 >>> np.allclose(c, a @ b)
848 True
850 Supply ``a`` as a constant:
852 >>> expr = contract_expression("ab,bc->ac", a, (4, 5), constants=[0])
853 >>> expr
854 <ContractExpression('[ab],bc->ac', constants=[0])>
856 >>> c = expr(b)
857 >>> np.allclose(c, a @ b)
858 True
860 """
861 if not kwargs.get('optimize', True):
862 raise ValueError("Can only generate expressions for optimized contractions.")
864 for arg in ('out', 'backend'):
865 if kwargs.get(arg, None) is not None:
866 raise ValueError("'{}' should only be specified when calling a "
867 "`ContractExpression`, not when building it.".format(arg))
869 if not isinstance(subscripts, str):
870 subscripts, shapes = parser.convert_interleaved_input((subscripts, ) + shapes)
872 kwargs['_gen_expression'] = True
874 # build dict of constant indices mapped to arrays
875 constants = kwargs.pop('constants', ())
876 constants_dict = {i: shapes[i] for i in constants}
877 kwargs['_constants_dict'] = constants_dict
879 # apart from constant arguments, make dummy arrays
880 dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)]
882 return contract(subscripts, *dummy_arrays, **kwargs)