Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/contract.py: 12%

323 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:41 +0000

1""" 

2Contains the primary optimization and contraction routines. 

3""" 

4 

5from collections import namedtuple 

6from decimal import Decimal 

7from functools import lru_cache 

8from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple 

9 

10from . import backends, blas, helpers, parser, paths, sharing 

11from .typing import ArrayIndexType, ArrayType, ContractionListType, PathType 

12 

13__all__ = [ 

14 "contract_path", 

15 "contract", 

16 "format_const_einsum_str", 

17 "ContractExpression", 

18 "shape_only", 

19] 

20 

21 

22class PathInfo: 

23 """A printable object to contain information about a contraction path. 

24 

25 **Attributes:** 

26 

27 - **naive_cost** - *(int)* The estimate FLOP cost of a naive einsum contraction. 

28 - **opt_cost** - *(int)* The estimate FLOP cost of this optimized contraction path. 

29 - **largest_intermediate** - *(int)* The number of elements in the largest intermediate array that will be produced during the contraction. 

30 """ 

31 

32 def __init__( 

33 self, 

34 contraction_list: ContractionListType, 

35 input_subscripts: str, 

36 output_subscript: str, 

37 indices: ArrayIndexType, 

38 path: PathType, 

39 scale_list: Sequence[int], 

40 naive_cost: int, 

41 opt_cost: int, 

42 size_list: Sequence[int], 

43 size_dict: Dict[str, int], 

44 ): 

45 self.contraction_list = contraction_list 

46 self.input_subscripts = input_subscripts 

47 self.output_subscript = output_subscript 

48 self.path = path 

49 self.indices = indices 

50 self.scale_list = scale_list 

51 self.naive_cost = Decimal(naive_cost) 

52 self.opt_cost = Decimal(opt_cost) 

53 self.speedup = self.naive_cost / self.opt_cost 

54 self.size_list = size_list 

55 self.size_dict = size_dict 

56 

57 self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(",")] 

58 self.eq = "{}->{}".format(input_subscripts, output_subscript) 

59 self.largest_intermediate = Decimal(max(size_list)) 

60 

61 def __repr__(self) -> str: 

62 # Return the path along with a nice string representation 

63 header = ("scaling", "BLAS", "current", "remaining") 

64 

65 path_print = [ 

66 " Complete contraction: {}\n".format(self.eq), 

67 " Naive scaling: {}\n".format(len(self.indices)), 

68 " Optimized scaling: {}\n".format(max(self.scale_list)), 

69 " Naive FLOP count: {:.3e}\n".format(self.naive_cost), 

70 " Optimized FLOP count: {:.3e}\n".format(self.opt_cost), 

71 " Theoretical speedup: {:.3e}\n".format(self.speedup), 

72 " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), 

73 "-" * 80 + "\n", 

74 "{:>6} {:>11} {:>22} {:>37}\n".format(*header), 

75 "-" * 80, 

76 ] 

77 

78 for n, contraction in enumerate(self.contraction_list): 

79 _, _, einsum_str, remaining, do_blas = contraction 

80 

81 if remaining is not None: 

82 remaining_str = ",".join(remaining) + "->" + self.output_subscript 

83 else: 

84 remaining_str = "..." 

85 size_remaining = max(0, 56 - max(22, len(einsum_str))) 

86 

87 path_run = ( 

88 self.scale_list[n], 

89 do_blas, 

90 einsum_str, 

91 remaining_str, 

92 size_remaining, 

93 ) 

94 path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run)) 

95 

96 return "".join(path_print) 

97 

98 

99def _choose_memory_arg(memory_limit: int, size_list: List[int]) -> Optional[int]: 

100 if memory_limit == "max_input": 

101 return max(size_list) 

102 

103 if memory_limit is None: 

104 return None 

105 

106 if memory_limit < 1: 

107 if memory_limit == -1: 

108 return None 

109 else: 

110 raise ValueError("Memory limit must be larger than 0, or -1") 

111 

112 return int(memory_limit) 

113 

114 

115_VALID_CONTRACT_KWARGS = { 

116 "optimize", 

117 "path", 

118 "memory_limit", 

119 "einsum_call", 

120 "use_blas", 

121 "shapes", 

122} 

123 

124 

125def contract_path(*operands_: Any, **kwargs: Any) -> Tuple[PathType, PathInfo]: 

126 """ 

127 Find a contraction order `path`, without performing the contraction. 

128 

129 **Parameters:** 

130 

131 - **subscripts** - *(str)* Specifies the subscripts for summation. 

132 - **\\*operands** - *(list of array_like)* these are the arrays for the operation. 

133 - **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates. 

134 - **optimize** - *(str, list or bool, optional (default: `auto`))* Choose the type of path. 

135 

136 - if a list is given uses this as the path. 

137 - `'optimal'` An algorithm that explores all possible ways of 

138 contracting the listed tensors. Scales factorially with the number of 

139 terms in the contraction. 

140 - `'dp'` A faster (but essentially optimal) algorithm that uses 

141 dynamic programming to exhaustively search all contraction paths 

142 without outer-products. 

143 - `'greedy'` An cheap algorithm that heuristically chooses the best 

144 pairwise contraction at each step. Scales linearly in the number of 

145 terms in the contraction. 

146 - `'random-greedy'` Run a randomized version of the greedy algorithm 

147 32 times and pick the best path. 

148 - `'random-greedy-128'` Run a randomized version of the greedy 

149 algorithm 128 times and pick the best path. 

150 - `'branch-all'` An algorithm like optimal but that restricts itself 

151 to searching 'likely' paths. Still scales factorially. 

152 - `'branch-2'` An even more restricted version of 'branch-all' that 

153 only searches the best two options at each step. Scales exponentially 

154 with the number of terms in the contraction. 

155 - `'auto'` Choose the best of the above algorithms whilst aiming to 

156 keep the path finding time below 1ms. 

157 - `'auto-hq'` Aim for a high quality contraction, choosing the best 

158 of the above algorithms whilst aiming to keep the path finding time 

159 below 1sec. 

160 

161 - **memory_limit** - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build. 

162 

163 - None or -1 means there is no limit 

164 - `max_input` means the limit is set as largest input tensor 

165 - a positive integer is taken as an explicit limit on the number of elements 

166 

167 The default is None. Note that imposing a limit can make contractions 

168 exponentially slower to perform. 

169 

170 - **shapes** - *(bool, optional)* Whether ``contract_path`` should assume arrays (the default) or array shapes have been supplied. 

171 

172 **Returns:** 

173 

174 - **path** - *(list of tuples)* The einsum path 

175 - **PathInfo** - *(str)* A printable object containing various information about the path found. 

176 

177 **Notes:** 

178 

179 The resulting path indicates which terms of the input contraction should be 

180 contracted first, the result of this contraction is then appended to the end of 

181 the contraction list. 

182 

183 **Examples:** 

184 

185 We can begin with a chain dot example. In this case, it is optimal to 

186 contract the b and c tensors represented by the first element of the path (1, 

187 2). The resulting tensor is added to the end of the contraction and the 

188 remaining contraction, `(0, 1)`, is then executed. 

189 

190 ```python 

191 a = np.random.rand(2, 2) 

192 b = np.random.rand(2, 5) 

193 c = np.random.rand(5, 2) 

194 path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c) 

195 print(path_info[0]) 

196 #> [(1, 2), (0, 1)] 

197 print(path_info[1]) 

198 #> Complete contraction: ij,jk,kl->il 

199 #> Naive scaling: 4 

200 #> Optimized scaling: 3 

201 #> Naive FLOP count: 1.600e+02 

202 #> Optimized FLOP count: 5.600e+01 

203 #> Theoretical speedup: 2.857 

204 #> Largest intermediate: 4.000e+00 elements 

205 #> ------------------------------------------------------------------------- 

206 #> scaling current remaining 

207 #> ------------------------------------------------------------------------- 

208 #> 3 kl,jk->jl ij,jl->il 

209 #> 3 jl,ij->il il->il 

210 ``` 

211 

212 A more complex index transformation example. 

213 

214 ```python 

215 I = np.random.rand(10, 10, 10, 10) 

216 C = np.random.rand(10, 10) 

217 path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C) 

218 

219 print(path_info[0]) 

220 #> [(0, 2), (0, 3), (0, 2), (0, 1)] 

221 print(path_info[1]) 

222 #> Complete contraction: ea,fb,abcd,gc,hd->efgh 

223 #> Naive scaling: 8 

224 #> Optimized scaling: 5 

225 #> Naive FLOP count: 8.000e+08 

226 #> Optimized FLOP count: 8.000e+05 

227 #> Theoretical speedup: 1000.000 

228 #> Largest intermediate: 1.000e+04 elements 

229 #> -------------------------------------------------------------------------- 

230 #> scaling current remaining 

231 #> -------------------------------------------------------------------------- 

232 #> 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

233 #> 5 bcde,fb->cdef gc,hd,cdef->efgh 

234 #> 5 cdef,gc->defg hd,defg->efgh 

235 #> 5 defg,hd->efgh efgh->efgh 

236 ``` 

237 """ 

238 

239 # Make sure all keywords are valid 

240 unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS 

241 if len(unknown_kwargs): 

242 raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs)) 

243 

244 path_type = kwargs.pop("optimize", "auto") 

245 

246 memory_limit = kwargs.pop("memory_limit", None) 

247 shapes = kwargs.pop("shapes", False) 

248 

249 # Hidden option, only einsum should call this 

250 einsum_call_arg = kwargs.pop("einsum_call", False) 

251 use_blas = kwargs.pop("use_blas", True) 

252 

253 # Python side parsing 

254 input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands_, shapes=shapes) 

255 

256 # Build a few useful list and sets 

257 input_list = input_subscripts.split(",") 

258 input_sets = [frozenset(x) for x in input_list] 

259 if shapes: 

260 input_shapes = operands 

261 else: 

262 input_shapes = [x.shape for x in operands] 

263 output_set = frozenset(output_subscript) 

264 indices = frozenset(input_subscripts.replace(",", "")) 

265 

266 # Get length of each unique dimension and ensure all dimensions are correct 

267 size_dict: Dict[str, int] = {} 

268 for tnum, term in enumerate(input_list): 

269 sh = input_shapes[tnum] 

270 

271 if len(sh) != len(term): 

272 raise ValueError( 

273 "Einstein sum subscript '{}' does not contain the " 

274 "correct number of indices for operand {}.".format(input_list[tnum], tnum) 

275 ) 

276 for cnum, char in enumerate(term): 

277 dim = int(sh[cnum]) 

278 

279 if char in size_dict: 

280 # For broadcasting cases we always want the largest dim size 

281 if size_dict[char] == 1: 

282 size_dict[char] = dim 

283 elif dim not in (1, size_dict[char]): 

284 raise ValueError( 

285 "Size of label '{}' for operand {} ({}) does not match previous " 

286 "terms ({}).".format(char, tnum, size_dict[char], dim) 

287 ) 

288 else: 

289 size_dict[char] = dim 

290 

291 # Compute size of each input array plus the output array 

292 size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]] 

293 memory_arg = _choose_memory_arg(memory_limit, size_list) 

294 

295 num_ops = len(input_list) 

296 

297 # Compute naive cost 

298 # This is not quite right, need to look into exactly how einsum does this 

299 # indices_in_input = input_subscripts.replace(',', '') 

300 

301 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

302 naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict) 

303 

304 # Compute the path 

305 if not isinstance(path_type, (str, paths.PathOptimizer)): 

306 # Custom path supplied 

307 path = path_type 

308 elif num_ops <= 2: 

309 # Nothing to be optimized 

310 path = [tuple(range(num_ops))] 

311 elif isinstance(path_type, paths.PathOptimizer): 

312 # Custom path optimizer supplied 

313 path = path_type(input_sets, output_set, size_dict, memory_arg) 

314 else: 

315 path_optimizer = paths.get_path_fn(path_type) 

316 path = path_optimizer(input_sets, output_set, size_dict, memory_arg) 

317 

318 cost_list = [] 

319 scale_list = [] 

320 size_list = [] 

321 contraction_list = [] 

322 

323 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

324 for cnum, contract_inds in enumerate(path): 

325 # Make sure we remove inds from right to left 

326 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 

327 

328 contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set) 

329 out_inds, input_sets, idx_removed, idx_contract = contract_tuple 

330 

331 # Compute cost, scale, and size 

332 cost = helpers.flop_count(idx_contract, bool(idx_removed), len(contract_inds), size_dict) 

333 cost_list.append(cost) 

334 scale_list.append(len(idx_contract)) 

335 size_list.append(helpers.compute_size_by_dict(out_inds, size_dict)) 

336 

337 tmp_inputs = [input_list.pop(x) for x in contract_inds] 

338 tmp_shapes = [input_shapes.pop(x) for x in contract_inds] 

339 

340 if use_blas: 

341 do_blas = blas.can_blas(tmp_inputs, "".join(out_inds), idx_removed, tmp_shapes) 

342 else: 

343 do_blas = False 

344 

345 # Last contraction 

346 if (cnum - len(path)) == -1: 

347 idx_result = output_subscript 

348 else: 

349 # use tensordot order to minimize transpositions 

350 all_input_inds = "".join(tmp_inputs) 

351 idx_result = "".join(sorted(out_inds, key=all_input_inds.find)) 

352 

353 shp_result = parser.find_output_shape(tmp_inputs, tmp_shapes, idx_result) 

354 

355 input_list.append(idx_result) 

356 input_shapes.append(shp_result) 

357 

358 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

359 

360 # for large expressions saving the remaining terms at each step can 

361 # incur a large memory footprint - and also be messy to print 

362 if len(input_list) <= 20: 

363 remaining: Optional[Tuple[str, ...]] = tuple(input_list) 

364 else: 

365 remaining = None 

366 

367 contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas) 

368 contraction_list.append(contraction) 

369 

370 opt_cost = sum(cost_list) 

371 

372 if einsum_call_arg: 

373 return operands, contraction_list # type: ignore 

374 

375 path_print = PathInfo( 

376 contraction_list, 

377 input_subscripts, 

378 output_subscript, 

379 indices, 

380 path, 

381 scale_list, 

382 naive_cost, 

383 opt_cost, 

384 size_list, 

385 size_dict, 

386 ) 

387 

388 return path, path_print 

389 

390 

391@sharing.einsum_cache_wrap 

392def _einsum(*operands, **kwargs): 

393 """Base einsum, but with pre-parse for valid characters if a string is given.""" 

394 fn = backends.get_func("einsum", kwargs.pop("backend", "numpy")) 

395 

396 if not isinstance(operands[0], str): 

397 return fn(*operands, **kwargs) 

398 

399 einsum_str, operands = operands[0], operands[1:] 

400 

401 # Do we need to temporarily map indices into [a-z,A-Z] range? 

402 if not parser.has_valid_einsum_chars_only(einsum_str): 

403 

404 # Explicitly find output str first so as to maintain order 

405 if "->" not in einsum_str: 

406 einsum_str += "->" + parser.find_output_str(einsum_str) 

407 

408 einsum_str = parser.convert_to_valid_einsum_chars(einsum_str) 

409 

410 return fn(einsum_str, *operands, **kwargs) 

411 

412 

413def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType: 

414 # most libraries implement a method version 

415 return x.transpose(axes) 

416 

417 

418@sharing.transpose_cache_wrap 

419def _transpose(x: ArrayType, axes: Tuple[int, ...], backend: str = "numpy") -> ArrayType: 

420 """Base transpose.""" 

421 fn = backends.get_func("transpose", backend, _default_transpose) 

422 return fn(x, axes) 

423 

424 

425@sharing.tensordot_cache_wrap 

426def _tensordot(x: ArrayType, y: ArrayType, axes: Tuple[int, ...], backend: str = "numpy") -> ArrayType: 

427 """Base tensordot.""" 

428 fn = backends.get_func("tensordot", backend) 

429 return fn(x, y, axes=axes) 

430 

431 

432# Rewrite einsum to handle different cases 

433def contract(*operands_: Any, **kwargs: Any) -> ArrayType: 

434 """ 

435 Evaluates the Einstein summation convention on the operands. A drop in 

436 replacement for NumPy's einsum function that optimizes the order of contraction 

437 to reduce overall scaling at the cost of several intermediate arrays. 

438 

439 **Parameters:** 

440 

441 - **subscripts** - *(str)* Specifies the subscripts for summation. 

442 - **\\*operands** - *(list of array_like)* hese are the arrays for the operation. 

443 - **out** - *(array_like)* A output array in which set the sresulting output. 

444 - **dtype** - *(str)* The dtype of the given contraction, see np.einsum. 

445 - **order** - *(str)* The order of the resulting contraction, see np.einsum. 

446 - **casting** - *(str)* The casting procedure for operations of different dtype, see np.einsum. 

447 - **use_blas** - *(bool)* Do you use BLAS for valid operations, may use extra memory for more intermediates. 

448 - **optimize** - *(str, list or bool, optional (default: ``auto``))* Choose the type of path. 

449 

450 - if a list is given uses this as the path. 

451 - `'optimal'` An algorithm that explores all possible ways of 

452 contracting the listed tensors. Scales factorially with the number of 

453 terms in the contraction. 

454 - `'dp'` A faster (but essentially optimal) algorithm that uses 

455 dynamic programming to exhaustively search all contraction paths 

456 without outer-products. 

457 - `'greedy'` An cheap algorithm that heuristically chooses the best 

458 pairwise contraction at each step. Scales linearly in the number of 

459 terms in the contraction. 

460 - `'random-greedy'` Run a randomized version of the greedy algorithm 

461 32 times and pick the best path. 

462 - `'random-greedy-128'` Run a randomized version of the greedy 

463 algorithm 128 times and pick the best path. 

464 - `'branch-all'` An algorithm like optimal but that restricts itself 

465 to searching 'likely' paths. Still scales factorially. 

466 - `'branch-2'` An even more restricted version of 'branch-all' that 

467 only searches the best two options at each step. Scales exponentially 

468 with the number of terms in the contraction. 

469 - `'auto'` Choose the best of the above algorithms whilst aiming to 

470 keep the path finding time below 1ms. 

471 - `'auto-hq'` Aim for a high quality contraction, choosing the best 

472 of the above algorithms whilst aiming to keep the path finding time 

473 below 1sec. 

474 

475 - **memory_limit** - *({None, int, 'max_input'} (default: `None`))* - Give the upper bound of the largest intermediate tensor contract will build. 

476 - None or -1 means there is no limit 

477 - `max_input` means the limit is set as largest input tensor 

478 - a positive integer is taken as an explicit limit on the number of elements 

479 

480 The default is None. Note that imposing a limit can make contractions 

481 exponentially slower to perform. 

482 

483 - **backend** - *(str, optional (default: ``auto``))* Which library to use to perform the required ``tensordot``, ``transpose`` 

484 and ``einsum`` calls. Should match the types of arrays supplied, See 

485 :func:`contract_expression` for generating expressions which convert 

486 numpy arrays to and from the backend library automatically. 

487 

488 **Returns:** 

489 

490 - **out** - *(array_like)* The result of the einsum expression. 

491 

492 **Notes:** 

493 

494 This function should produce a result identical to that of NumPy's einsum 

495 function. The primary difference is ``contract`` will attempt to form 

496 intermediates which reduce the overall scaling of the given einsum contraction. 

497 By default the worst intermediate formed will be equal to that of the largest 

498 input array. For large einsum expressions with many input arrays this can 

499 provide arbitrarily large (1000 fold+) speed improvements. 

500 

501 For contractions with just two tensors this function will attempt to use 

502 NumPy's built-in BLAS functionality to ensure that the given operation is 

503 performed optimally. When NumPy is linked to a threaded BLAS, potential 

504 speedups are on the order of 20-100 for a six core machine. 

505 """ 

506 optimize_arg = kwargs.pop("optimize", True) 

507 if optimize_arg is True: 

508 optimize_arg = "auto" 

509 

510 valid_einsum_kwargs = ["out", "dtype", "order", "casting"] 

511 einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs} 

512 

513 # If no optimization, run pure einsum 

514 if optimize_arg is False: 

515 return _einsum(*operands_, **einsum_kwargs) 

516 

517 # Grab non-einsum kwargs 

518 use_blas = kwargs.pop("use_blas", True) 

519 memory_limit = kwargs.pop("memory_limit", None) 

520 backend = kwargs.pop("backend", "auto") 

521 gen_expression = kwargs.pop("_gen_expression", False) 

522 constants_dict = kwargs.pop("_constants_dict", {}) 

523 

524 # Make sure remaining keywords are valid for einsum 

525 unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs] 

526 if len(unknown_kwargs): 

527 raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs)) 

528 

529 if gen_expression: 

530 full_str = operands_[0] 

531 

532 # Build the contraction list and operand 

533 operands: Sequence[ArrayType] 

534 contraction_list: ContractionListType 

535 operands, contraction_list = contract_path( # type: ignore 

536 *operands_, optimize=optimize_arg, memory_limit=memory_limit, einsum_call=True, use_blas=use_blas 

537 ) 

538 

539 # check if performing contraction or just building expression 

540 if gen_expression: 

541 return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs) 

542 

543 return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs) 

544 

545 

546@lru_cache(None) 

547def _infer_backend_class_cached(cls: type) -> str: 

548 return cls.__module__.split(".")[0] 

549 

550 

551def infer_backend(x: Any) -> str: 

552 return _infer_backend_class_cached(x.__class__) 

553 

554 

555def parse_backend(arrays: Sequence[ArrayType], backend: Optional[str]) -> str: 

556 """Find out what backend we should use, dipatching based on the first 

557 array if ``backend='auto'`` is specified. 

558 """ 

559 if (backend != "auto") and (backend is not None): 

560 return backend 

561 backend = infer_backend(arrays[0]) 

562 

563 # some arrays will be defined in modules that don't implement tensordot 

564 # etc. so instead default to numpy 

565 if not backends.has_tensordot(backend): 

566 return "numpy" 

567 

568 return backend 

569 

570 

571def _core_contract( 

572 operands_: Sequence[ArrayType], 

573 contraction_list: ContractionListType, 

574 backend: Optional[str] = "auto", 

575 evaluate_constants: bool = False, 

576 **einsum_kwargs: Any, 

577) -> ArrayType: 

578 """Inner loop used to perform an actual contraction given the output 

579 from a ``contract_path(..., einsum_call=True)`` call. 

580 """ 

581 

582 # Special handling if out is specified 

583 out_array = einsum_kwargs.pop("out", None) 

584 specified_out = out_array is not None 

585 

586 operands = list(operands_) 

587 backend = parse_backend(operands, backend) 

588 

589 # try and do as much as possible without einsum if not available 

590 no_einsum = not backends.has_einsum(backend) 

591 

592 # Start contraction loop 

593 for num, contraction in enumerate(contraction_list): 

594 inds, idx_rm, einsum_str, _, blas_flag = contraction 

595 

596 # check if we are performing the pre-pass of an expression with constants, 

597 # if so, break out upon finding first non-constant (None) operand 

598 if evaluate_constants and any(operands[x] is None for x in inds): 

599 return operands, contraction_list[num:] 

600 

601 tmp_operands = [operands.pop(x) for x in inds] 

602 

603 # Do we need to deal with the output? 

604 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

605 

606 # Call tensordot (check if should prefer einsum, but only if available) 

607 if blas_flag and ("EINSUM" not in blas_flag or no_einsum): # type: ignore 

608 

609 # Checks have already been handled 

610 input_str, results_index = einsum_str.split("->") 

611 input_left, input_right = input_str.split(",") 

612 

613 tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm) 

614 

615 if idx_rm: 

616 # Find indices to contract over 

617 left_pos, right_pos = [], [] 

618 for s in idx_rm: 

619 left_pos.append(input_left.find(s)) 

620 right_pos.append(input_right.find(s)) 

621 

622 # Construct the axes tuples in a canonical order 

623 axes = tuple(zip(*sorted(zip(left_pos, right_pos)))) 

624 else: 

625 # Ensure axes is always pair of tuples 

626 axes = ((), ()) 

627 

628 # Contract! 

629 new_view = _tensordot(*tmp_operands, axes=axes, backend=backend) 

630 

631 # Build a new view if needed 

632 if (tensor_result != results_index) or handle_out: 

633 

634 transpose = tuple(map(tensor_result.index, results_index)) 

635 new_view = _transpose(new_view, axes=transpose, backend=backend) 

636 

637 if handle_out: 

638 out_array[:] = new_view 

639 

640 # Call einsum 

641 else: 

642 # If out was specified 

643 if handle_out: 

644 einsum_kwargs["out"] = out_array 

645 

646 # Do the contraction 

647 new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs) 

648 

649 # Append new items and dereference what we can 

650 operands.append(new_view) 

651 del tmp_operands, new_view 

652 

653 if specified_out: 

654 return out_array 

655 else: 

656 return operands[0] 

657 

658 

659def format_const_einsum_str(einsum_str: str, constants: Iterable[int]) -> str: 

660 """Add brackets to the constant terms in ``einsum_str``. For example: 

661 

662 >>> format_const_einsum_str('ab,bc,cd->ad', [0, 2]) 

663 'bc,[ab,cd]->ad' 

664 

665 No-op if there are no constants. 

666 """ 

667 if not constants: 

668 return einsum_str 

669 

670 if "->" in einsum_str: 

671 lhs, rhs = einsum_str.split("->") 

672 arrow = "->" 

673 else: 

674 lhs, rhs, arrow = einsum_str, "", "" 

675 

676 wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(","))] 

677 

678 formatted_einsum_str = "{}{}{}".format(",".join(wrapped_terms), arrow, rhs) 

679 

680 # merge adjacent constants 

681 formatted_einsum_str = formatted_einsum_str.replace("],[", ",") 

682 return formatted_einsum_str 

683 

684 

685class ContractExpression: 

686 """Helper class for storing an explicit ``contraction_list`` which can 

687 then be repeatedly called solely with the array arguments. 

688 """ 

689 

690 def __init__( 

691 self, 

692 contraction: str, 

693 contraction_list: ContractionListType, 

694 constants_dict: Dict[int, ArrayType], 

695 **einsum_kwargs: Any, 

696 ): 

697 self.contraction_list = contraction_list 

698 self.einsum_kwargs = einsum_kwargs 

699 self.contraction = format_const_einsum_str(contraction, constants_dict.keys()) 

700 

701 # need to know _full_num_args to parse constants with, and num_args to call with 

702 self._full_num_args = contraction.count(",") + 1 

703 self.num_args = self._full_num_args - len(constants_dict) 

704 

705 # likewise need to know full contraction list 

706 self._full_contraction_list = contraction_list 

707 

708 self._constants_dict = constants_dict 

709 self._evaluated_constants: Dict[str, Any] = {} 

710 self._backend_expressions: Dict[str, Any] = {} 

711 

712 def evaluate_constants(self, backend: Optional[str] = "auto") -> None: 

713 """Convert any constant operands to the correct backend form, and 

714 perform as many contractions as possible to create a new list of 

715 operands, stored in ``self._evaluated_constants[backend]``. This also 

716 makes sure ``self.contraction_list`` only contains the remaining, 

717 non-const operations. 

718 """ 

719 # prepare a list of operands, with `None` for non-consts 

720 tmp_const_ops = [self._constants_dict.get(i, None) for i in range(self._full_num_args)] 

721 backend = parse_backend(tmp_const_ops, backend) 

722 

723 # get the new list of operands with constant operations performed, and remaining contractions 

724 try: 

725 new_ops, new_contraction_list = backends.evaluate_constants(backend, tmp_const_ops, self) 

726 except KeyError: 

727 new_ops, new_contraction_list = self(*tmp_const_ops, backend=backend, evaluate_constants=True) 

728 

729 self._evaluated_constants[backend] = new_ops 

730 self.contraction_list = new_contraction_list 

731 

732 def _get_evaluated_constants(self, backend: str) -> List[Optional[ArrayType]]: 

733 """Retrieve or generate the cached list of constant operators (mixed 

734 in with None representing non-consts) and the remaining contraction 

735 list. 

736 """ 

737 try: 

738 return self._evaluated_constants[backend] 

739 except KeyError: 

740 self.evaluate_constants(backend) 

741 return self._evaluated_constants[backend] 

742 

743 def _get_backend_expression(self, arrays: Sequence[ArrayType], backend: str) -> Any: 

744 try: 

745 return self._backend_expressions[backend] 

746 except KeyError: 

747 fn = backends.build_expression(backend, arrays, self) 

748 self._backend_expressions[backend] = fn 

749 return fn 

750 

751 def _contract( 

752 self, 

753 arrays: Sequence[ArrayType], 

754 out: Optional[ArrayType] = None, 

755 backend: Optional[str] = "auto", 

756 evaluate_constants: bool = False, 

757 ) -> ArrayType: 

758 """The normal, core contraction.""" 

759 contraction_list = self._full_contraction_list if evaluate_constants else self.contraction_list 

760 

761 return _core_contract( 

762 list(arrays), 

763 contraction_list, 

764 out=out, 

765 backend=backend, 

766 evaluate_constants=evaluate_constants, 

767 **self.einsum_kwargs, 

768 ) 

769 

770 def _contract_with_conversion( 

771 self, 

772 arrays: Sequence[ArrayType], 

773 out: Optional[ArrayType], 

774 backend: str, 

775 evaluate_constants: bool = False, 

776 ) -> ArrayType: 

777 """Special contraction, i.e., contraction with a different backend 

778 but converting to and from that backend. Retrieves or generates a 

779 cached expression using ``arrays`` as templates, then calls it 

780 with ``arrays``. 

781 

782 If ``evaluate_constants=True``, perform a partial contraction that 

783 prepares the constant tensors and operations with the right backend. 

784 """ 

785 # convert consts to correct type & find reduced contraction list 

786 if evaluate_constants: 

787 return backends.evaluate_constants(backend, arrays, self) 

788 

789 result = self._get_backend_expression(arrays, backend)(*arrays) 

790 

791 if out is not None: 

792 out[()] = result 

793 return out 

794 

795 return result 

796 

797 def __call__(self, *arrays: ArrayType, **kwargs: Any) -> ArrayType: 

798 """Evaluate this expression with a set of arrays. 

799 

800 Parameters 

801 ---------- 

802 arrays : seq of array 

803 The arrays to supply as input to the expression. 

804 out : array, optional (default: ``None``) 

805 If specified, output the result into this array. 

806 backend : str, optional (default: ``numpy``) 

807 Perform the contraction with this backend library. If numpy arrays 

808 are supplied then try to convert them to and from the correct 

809 backend array type. 

810 """ 

811 out = kwargs.pop("out", None) 

812 backend = parse_backend(arrays, kwargs.pop("backend", "auto")) 

813 evaluate_constants = kwargs.pop("evaluate_constants", False) 

814 

815 if kwargs: 

816 raise ValueError( 

817 "The only valid keyword arguments to a `ContractExpression` " 

818 "call are `out=` or `backend=`. Got: {}.".format(kwargs) 

819 ) 

820 

821 correct_num_args = self._full_num_args if evaluate_constants else self.num_args 

822 

823 if len(arrays) != correct_num_args: 

824 raise ValueError( 

825 "This `ContractExpression` takes exactly {} array arguments " 

826 "but received {}.".format(self.num_args, len(arrays)) 

827 ) 

828 

829 if self._constants_dict and not evaluate_constants: 

830 # fill in the missing non-constant terms with newly supplied arrays 

831 ops_var, ops_const = iter(arrays), self._get_evaluated_constants(backend) 

832 ops: Sequence[ArrayType] = [next(ops_var) if op is None else op for op in ops_const] 

833 else: 

834 ops = arrays 

835 

836 try: 

837 # Check if the backend requires special preparation / calling 

838 # but also ignore non-numpy arrays -> assume user wants same type back 

839 if backends.has_backend(backend) and all(infer_backend(x) == "numpy" for x in arrays): 

840 return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants) 

841 

842 return self._contract(ops, out, backend, evaluate_constants=evaluate_constants) 

843 

844 except ValueError as err: 

845 original_msg = str(err.args) if err.args else "" 

846 msg = ( 

847 "Internal error while evaluating `ContractExpression`. Note that few checks are performed" 

848 " - the number and rank of the array arguments must match the original expression. " 

849 "The internal error was: '{}'".format(original_msg), 

850 ) 

851 err.args = msg 

852 raise 

853 

854 def __repr__(self) -> str: 

855 if self._constants_dict: 

856 constants_repr = ", constants={}".format(sorted(self._constants_dict)) 

857 else: 

858 constants_repr = "" 

859 return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr) 

860 

861 def __str__(self) -> str: 

862 s = [self.__repr__()] 

863 for i, c in enumerate(self.contraction_list): 

864 s.append("\n {}. ".format(i + 1)) 

865 s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else "")) 

866 if self.einsum_kwargs: 

867 s.append("\neinsum_kwargs={}".format(self.einsum_kwargs)) 

868 return "".join(s) 

869 

870 

871Shaped = namedtuple("Shaped", ["shape"]) 

872 

873 

874def shape_only(shape: PathType) -> Shaped: 

875 """Dummy ``numpy.ndarray`` which has a shape only - for generating 

876 contract expressions. 

877 """ 

878 return Shaped(shape) 

879 

880 

881def contract_expression(subscripts: str, *shapes: PathType, **kwargs: Any) -> Any: 

882 """Generate a reusable expression for a given contraction with 

883 specific shapes, which can, for example, be cached. 

884 

885 **Parameters:** 

886 

887 - **subscripts** - *(str)* Specifies the subscripts for summation. 

888 - **shapes** - *(sequence of integer tuples)* Shapes of the arrays to optimize the contraction for. 

889 - **constants** - *(sequence of int, optional)* The indices of any constant arguments in `shapes`, in which case the 

890 actual array should be supplied at that position rather than just a 

891 shape. If these are specified, then constant parts of the contraction 

892 between calls will be reused. Additionally, if a GPU-enabled backend is 

893 used for example, then the constant tensors will be kept on the GPU, 

894 minimizing transfers. 

895 - **kwargs** - Passed on to `contract_path` or `einsum`. See `contract`. 

896 

897 **Returns:** 

898 

899 - **expr** - *(ContractExpression)* Callable with signature `expr(*arrays, out=None, backend='numpy')` where the array's shapes should match `shapes`. 

900 

901 **Notes:** 

902 

903 - The `out` keyword argument should be supplied to the generated expression 

904 rather than this function. 

905 - The `backend` keyword argument should also be supplied to the generated 

906 expression. If numpy arrays are supplied, if possible they will be 

907 converted to and back from the correct backend array type. 

908 - The generated expression will work with any arrays which have 

909 the same rank (number of dimensions) as the original shapes, however, if 

910 the actual sizes are different, the expression may no longer be optimal. 

911 - Constant operations will be computed upon the first call with a particular 

912 backend, then subsequently reused. 

913 

914 **Examples:** 

915 

916 Basic usage: 

917 

918 ```python 

919 expr = contract_expression("ab,bc->ac", (3, 4), (4, 5)) 

920 a, b = np.random.rand(3, 4), np.random.rand(4, 5) 

921 c = expr(a, b) 

922 np.allclose(c, a @ b) 

923 #> True 

924 ``` 

925 

926 Supply `a` as a constant: 

927 

928 ```python 

929 expr = contract_expression("ab,bc->ac", a, (4, 5), constants=[0]) 

930 expr 

931 #> <ContractExpression('[ab],bc->ac', constants=[0])> 

932 

933 c = expr(b) 

934 np.allclose(c, a @ b) 

935 #> True 

936 ``` 

937 

938 """ 

939 if not kwargs.get("optimize", True): 

940 raise ValueError("Can only generate expressions for optimized contractions.") 

941 

942 for arg in ("out", "backend"): 

943 if kwargs.get(arg, None) is not None: 

944 raise ValueError( 

945 "'{}' should only be specified when calling a " 

946 "`ContractExpression`, not when building it.".format(arg) 

947 ) 

948 

949 if not isinstance(subscripts, str): 

950 subscripts, shapes = parser.convert_interleaved_input((subscripts,) + shapes) 

951 

952 kwargs["_gen_expression"] = True 

953 

954 # build dict of constant indices mapped to arrays 

955 constants = kwargs.pop("constants", ()) 

956 constants_dict = {i: shapes[i] for i in constants} 

957 kwargs["_constants_dict"] = constants_dict 

958 

959 # apart from constant arguments, make dummy arrays 

960 dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)] 

961 

962 return contract(subscripts, *dummy_arrays, **kwargs)