Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/opt_einsum/contract.py: 11%

315 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1""" 

2Contains the primary optimization and contraction routines. 

3""" 

4 

5from collections import namedtuple 

6from decimal import Decimal 

7 

8import numpy as np 

9 

10from . import backends, blas, helpers, parser, paths, sharing 

11 

12__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only"] 

13 

14 

15class PathInfo(object): 

16 """A printable object to contain information about a contraction path. 

17 

18 Attributes 

19 ---------- 

20 naive_cost : int 

21 The estimate FLOP cost of a naive einsum contraction. 

22 opt_cost : int 

23 The estimate FLOP cost of this optimized contraction path. 

24 largest_intermediate : int 

25 The number of elements in the largest intermediate array that will be 

26 produced during the contraction. 

27 """ 

28 def __init__(self, contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost, 

29 opt_cost, size_list, size_dict): 

30 self.contraction_list = contraction_list 

31 self.input_subscripts = input_subscripts 

32 self.output_subscript = output_subscript 

33 self.path = path 

34 self.indices = indices 

35 self.scale_list = scale_list 

36 self.naive_cost = Decimal(naive_cost) 

37 self.opt_cost = Decimal(opt_cost) 

38 self.speedup = self.naive_cost / self.opt_cost 

39 self.size_list = size_list 

40 self.size_dict = size_dict 

41 

42 self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(',')] 

43 self.eq = "{}->{}".format(input_subscripts, output_subscript) 

44 self.largest_intermediate = Decimal(max(size_list)) 

45 

46 def __repr__(self): 

47 # Return the path along with a nice string representation 

48 header = ("scaling", "BLAS", "current", "remaining") 

49 

50 path_print = [ 

51 " Complete contraction: {}\n".format(self.eq), " Naive scaling: {}\n".format(len(self.indices)), 

52 " Optimized scaling: {}\n".format(max(self.scale_list)), " Naive FLOP count: {:.3e}\n".format( 

53 self.naive_cost), " Optimized FLOP count: {:.3e}\n".format(self.opt_cost), 

54 " Theoretical speedup: {:.3e}\n".format(self.speedup), 

55 " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), "-" * 80 + "\n", 

56 "{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80 

57 ] 

58 

59 for n, contraction in enumerate(self.contraction_list): 

60 inds, idx_rm, einsum_str, remaining, do_blas = contraction 

61 

62 if remaining is not None: 

63 remaining_str = ",".join(remaining) + "->" + self.output_subscript 

64 else: 

65 remaining_str = "..." 

66 size_remaining = max(0, 56 - max(22, len(einsum_str))) 

67 

68 path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str, size_remaining) 

69 path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run)) 

70 

71 return "".join(path_print) 

72 

73 

74def _choose_memory_arg(memory_limit, size_list): 

75 if memory_limit == 'max_input': 

76 return max(size_list) 

77 

78 if memory_limit is None: 

79 return None 

80 

81 if memory_limit < 1: 

82 if memory_limit == -1: 

83 return None 

84 else: 

85 raise ValueError("Memory limit must be larger than 0, or -1") 

86 

87 return int(memory_limit) 

88 

89 

90_VALID_CONTRACT_KWARGS = {'optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas', 'shapes'} 

91 

92 

93def contract_path(*operands, **kwargs): 

94 """ 

95 Find a contraction order 'path', without performing the contraction. 

96 

97 Parameters 

98 ---------- 

99 subscripts : str 

100 Specifies the subscripts for summation. 

101 *operands : list of array_like 

102 These are the arrays for the operation. 

103 optimize : str, list or bool, optional (default: ``auto``) 

104 Choose the type of path. 

105 

106 - if a list is given uses this as the path. 

107 - ``'optimal'`` An algorithm that explores all possible ways of 

108 contracting the listed tensors. Scales factorially with the number of 

109 terms in the contraction. 

110 - ``'branch-all'`` An algorithm like optimal but that restricts itself 

111 to searching 'likely' paths. Still scales factorially. 

112 - ``'branch-2'`` An even more restricted version of 'branch-all' that 

113 only searches the best two options at each step. Scales exponentially 

114 with the number of terms in the contraction. 

115 - ``'greedy'`` An algorithm that heuristically chooses the best pair 

116 contraction at each step. 

117 - ``'auto'`` Choose the best of the above algorithms whilst aiming to 

118 keep the path finding time below 1ms. 

119 

120 use_blas : bool 

121 Use BLAS functions or not 

122 memory_limit : int, optional (default: None) 

123 Maximum number of elements allowed in intermediate arrays. 

124 shapes : bool, optional 

125 Whether ``contract_path`` should assume arrays (the default) or array 

126 shapes have been supplied. 

127 

128 Returns 

129 ------- 

130 path : list of tuples 

131 The einsum path 

132 PathInfo : str 

133 A printable object containing various information about the path found. 

134 

135 Notes 

136 ----- 

137 The resulting path indicates which terms of the input contraction should be 

138 contracted first, the result of this contraction is then appended to the end of 

139 the contraction list. 

140 

141 Examples 

142 -------- 

143 

144 We can begin with a chain dot example. In this case, it is optimal to 

145 contract the b and c tensors represented by the first element of the path (1, 

146 2). The resulting tensor is added to the end of the contraction and the 

147 remaining contraction, ``(0, 1)``, is then executed. 

148 

149 >>> a = np.random.rand(2, 2) 

150 >>> b = np.random.rand(2, 5) 

151 >>> c = np.random.rand(5, 2) 

152 >>> path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c) 

153 >>> print(path_info[0]) 

154 [(1, 2), (0, 1)] 

155 >>> print(path_info[1]) 

156 Complete contraction: ij,jk,kl->il 

157 Naive scaling: 4 

158 Optimized scaling: 3 

159 Naive FLOP count: 1.600e+02 

160 Optimized FLOP count: 5.600e+01 

161 Theoretical speedup: 2.857 

162 Largest intermediate: 4.000e+00 elements 

163 ------------------------------------------------------------------------- 

164 scaling current remaining 

165 ------------------------------------------------------------------------- 

166 3 kl,jk->jl ij,jl->il 

167 3 jl,ij->il il->il 

168 

169 

170 A more complex index transformation example. 

171 

172 >>> I = np.random.rand(10, 10, 10, 10) 

173 >>> C = np.random.rand(10, 10) 

174 >>> path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C) 

175 

176 >>> print(path_info[0]) 

177 [(0, 2), (0, 3), (0, 2), (0, 1)] 

178 >>> print(path_info[1]) 

179 Complete contraction: ea,fb,abcd,gc,hd->efgh 

180 Naive scaling: 8 

181 Optimized scaling: 5 

182 Naive FLOP count: 8.000e+08 

183 Optimized FLOP count: 8.000e+05 

184 Theoretical speedup: 1000.000 

185 Largest intermediate: 1.000e+04 elements 

186 -------------------------------------------------------------------------- 

187 scaling current remaining 

188 -------------------------------------------------------------------------- 

189 5 abcd,ea->bcde fb,gc,hd,bcde->efgh 

190 5 bcde,fb->cdef gc,hd,cdef->efgh 

191 5 cdef,gc->defg hd,defg->efgh 

192 5 defg,hd->efgh efgh->efgh 

193 """ 

194 

195 # Make sure all keywords are valid 

196 unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS 

197 if len(unknown_kwargs): 

198 raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs)) 

199 

200 path_type = kwargs.pop('optimize', 'auto') 

201 

202 memory_limit = kwargs.pop('memory_limit', None) 

203 shapes = kwargs.pop('shapes', False) 

204 

205 # Hidden option, only einsum should call this 

206 einsum_call_arg = kwargs.pop("einsum_call", False) 

207 use_blas = kwargs.pop('use_blas', True) 

208 

209 # Python side parsing 

210 input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands) 

211 

212 # Build a few useful list and sets 

213 input_list = input_subscripts.split(',') 

214 input_sets = [set(x) for x in input_list] 

215 if shapes: 

216 input_shps = operands 

217 else: 

218 input_shps = [x.shape for x in operands] 

219 output_set = set(output_subscript) 

220 indices = set(input_subscripts.replace(',', '')) 

221 

222 # Get length of each unique dimension and ensure all dimensions are correct 

223 size_dict = {} 

224 for tnum, term in enumerate(input_list): 

225 sh = input_shps[tnum] 

226 

227 if len(sh) != len(term): 

228 raise ValueError("Einstein sum subscript '{}' does not contain the " 

229 "correct number of indices for operand {}.".format(input_list[tnum], tnum)) 

230 for cnum, char in enumerate(term): 

231 dim = int(sh[cnum]) 

232 

233 if char in size_dict: 

234 # For broadcasting cases we always want the largest dim size 

235 if size_dict[char] == 1: 

236 size_dict[char] = dim 

237 elif dim not in (1, size_dict[char]): 

238 raise ValueError("Size of label '{}' for operand {} ({}) does not match previous " 

239 "terms ({}).".format(char, tnum, size_dict[char], dim)) 

240 else: 

241 size_dict[char] = dim 

242 

243 # Compute size of each input array plus the output array 

244 size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]] 

245 memory_arg = _choose_memory_arg(memory_limit, size_list) 

246 

247 num_ops = len(input_list) 

248 

249 # Compute naive cost 

250 # This isnt quite right, need to look into exactly how einsum does this 

251 # indices_in_input = input_subscripts.replace(',', '') 

252 

253 inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0 

254 naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict) 

255 

256 # Compute the path 

257 if not isinstance(path_type, (str, paths.PathOptimizer)): 

258 # Custom path supplied 

259 path = path_type 

260 elif num_ops <= 2: 

261 # Nothing to be optimized 

262 path = [tuple(range(num_ops))] 

263 elif isinstance(path_type, paths.PathOptimizer): 

264 # Custom path optimizer supplied 

265 path = path_type(input_sets, output_set, size_dict, memory_arg) 

266 else: 

267 path_optimizer = paths.get_path_fn(path_type) 

268 path = path_optimizer(input_sets, output_set, size_dict, memory_arg) 

269 

270 cost_list = [] 

271 scale_list = [] 

272 size_list = [] 

273 contraction_list = [] 

274 

275 # Build contraction tuple (positions, gemm, einsum_str, remaining) 

276 for cnum, contract_inds in enumerate(path): 

277 # Make sure we remove inds from right to left 

278 contract_inds = tuple(sorted(list(contract_inds), reverse=True)) 

279 

280 contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set) 

281 out_inds, input_sets, idx_removed, idx_contract = contract_tuple 

282 

283 # Compute cost, scale, and size 

284 cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict) 

285 cost_list.append(cost) 

286 scale_list.append(len(idx_contract)) 

287 size_list.append(helpers.compute_size_by_dict(out_inds, size_dict)) 

288 

289 tmp_inputs = [input_list.pop(x) for x in contract_inds] 

290 tmp_shapes = [input_shps.pop(x) for x in contract_inds] 

291 

292 if use_blas: 

293 do_blas = blas.can_blas(tmp_inputs, out_inds, idx_removed, tmp_shapes) 

294 else: 

295 do_blas = False 

296 

297 # Last contraction 

298 if (cnum - len(path)) == -1: 

299 idx_result = output_subscript 

300 else: 

301 # use tensordot order to minimize transpositions 

302 all_input_inds = "".join(tmp_inputs) 

303 idx_result = "".join(sorted(out_inds, key=all_input_inds.find)) 

304 

305 shp_result = parser.find_output_shape(tmp_inputs, tmp_shapes, idx_result) 

306 

307 input_list.append(idx_result) 

308 input_shps.append(shp_result) 

309 

310 einsum_str = ",".join(tmp_inputs) + "->" + idx_result 

311 

312 # for large expressions saving the remaining terms at each step can 

313 # incur a large memory footprint - and also be messy to print 

314 if len(input_list) <= 20: 

315 remaining = tuple(input_list) 

316 else: 

317 remaining = None 

318 

319 contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas) 

320 contraction_list.append(contraction) 

321 

322 opt_cost = sum(cost_list) 

323 

324 if einsum_call_arg: 

325 return operands, contraction_list 

326 

327 path_print = PathInfo(contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost, 

328 opt_cost, size_list, size_dict) 

329 

330 return path, path_print 

331 

332 

333@sharing.einsum_cache_wrap 

334def _einsum(*operands, **kwargs): 

335 """Base einsum, but with pre-parse for valid characters if a string is given. 

336 """ 

337 fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy')) 

338 

339 if not isinstance(operands[0], str): 

340 return fn(*operands, **kwargs) 

341 

342 einsum_str, operands = operands[0], operands[1:] 

343 

344 # Do we need to temporarily map indices into [a-z,A-Z] range? 

345 if not parser.has_valid_einsum_chars_only(einsum_str): 

346 

347 # Explicitly find output str first so as to maintain order 

348 if '->' not in einsum_str: 

349 einsum_str += '->' + parser.find_output_str(einsum_str) 

350 

351 einsum_str = parser.convert_to_valid_einsum_chars(einsum_str) 

352 

353 return fn(einsum_str, *operands, **kwargs) 

354 

355 

356def _default_transpose(x, axes): 

357 # most libraries implement a method version 

358 return x.transpose(axes) 

359 

360 

361@sharing.transpose_cache_wrap 

362def _transpose(x, axes, backend='numpy'): 

363 """Base transpose. 

364 """ 

365 fn = backends.get_func('transpose', backend, _default_transpose) 

366 return fn(x, axes) 

367 

368 

369@sharing.tensordot_cache_wrap 

370def _tensordot(x, y, axes, backend='numpy'): 

371 """Base tensordot. 

372 """ 

373 fn = backends.get_func('tensordot', backend) 

374 return fn(x, y, axes=axes) 

375 

376 

377# Rewrite einsum to handle different cases 

378def contract(*operands, **kwargs): 

379 """ 

380 contract(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', use_blas=True, optimize=True, memory_limit=None, backend='numpy') 

381 

382 Evaluates the Einstein summation convention on the operands. A drop in 

383 replacement for NumPy's einsum function that optimizes the order of contraction 

384 to reduce overall scaling at the cost of several intermediate arrays. 

385 

386 Parameters 

387 ---------- 

388 subscripts : str 

389 Specifies the subscripts for summation. 

390 *operands : list of array_like 

391 These are the arrays for the operation. 

392 out : array_like 

393 A output array in which set the resulting output. 

394 dtype : str 

395 The dtype of the given contraction, see np.einsum. 

396 order : str 

397 The order of the resulting contraction, see np.einsum. 

398 casting : str 

399 The casting procedure for operations of different dtype, see np.einsum. 

400 use_blas : bool 

401 Do you use BLAS for valid operations, may use extra memory for more intermediates. 

402 optimize : str, list or bool, optional (default: ``auto``) 

403 Choose the type of path. 

404 

405 - if a list is given uses this as the path. 

406 - ``'optimal'`` An algorithm that explores all possible ways of 

407 contracting the listed tensors. Scales factorially with the number of 

408 terms in the contraction. 

409 - ``'dp'`` A faster (but essentially optimal) algorithm that uses 

410 dynamic programming to exhaustively search all contraction paths 

411 without outer-products. 

412 - ``'greedy'`` An cheap algorithm that heuristically chooses the best 

413 pairwise contraction at each step. Scales linearly in the number of 

414 terms in the contraction. 

415 - ``'random-greedy'`` Run a randomized version of the greedy algorithm 

416 32 times and pick the best path. 

417 - ``'random-greedy-128'`` Run a randomized version of the greedy 

418 algorithm 128 times and pick the best path. 

419 - ``'branch-all'`` An algorithm like optimal but that restricts itself 

420 to searching 'likely' paths. Still scales factorially. 

421 - ``'branch-2'`` An even more restricted version of 'branch-all' that 

422 only searches the best two options at each step. Scales exponentially 

423 with the number of terms in the contraction. 

424 - ``'auto'`` Choose the best of the above algorithms whilst aiming to 

425 keep the path finding time below 1ms. 

426 - ``'auto-hq'`` Aim for a high quality contraction, choosing the best 

427 of the above algorithms whilst aiming to keep the path finding time 

428 below 1sec. 

429 

430 memory_limit : {None, int, 'max_input'} (default: None) 

431 Give the upper bound of the largest intermediate tensor contract will build. 

432 

433 - None or -1 means there is no limit 

434 - 'max_input' means the limit is set as largest input tensor 

435 - a positive integer is taken as an explicit limit on the number of elements 

436 

437 The default is None. Note that imposing a limit can make contractions 

438 exponentially slower to perform. 

439 backend : str, optional (default: ``auto``) 

440 Which library to use to perform the required ``tensordot``, ``transpose`` 

441 and ``einsum`` calls. Should match the types of arrays supplied, See 

442 :func:`contract_expression` for generating expressions which convert 

443 numpy arrays to and from the backend library automatically. 

444 

445 Returns 

446 ------- 

447 out : array_like 

448 The result of the einsum expression. 

449 

450 Notes 

451 ----- 

452 This function should produce a result identical to that of NumPy's einsum 

453 function. The primary difference is ``contract`` will attempt to form 

454 intermediates which reduce the overall scaling of the given einsum contraction. 

455 By default the worst intermediate formed will be equal to that of the largest 

456 input array. For large einsum expressions with many input arrays this can 

457 provide arbitrarily large (1000 fold+) speed improvements. 

458 

459 For contractions with just two tensors this function will attempt to use 

460 NumPy's built-in BLAS functionality to ensure that the given operation is 

461 preformed optimally. When NumPy is linked to a threaded BLAS, potential 

462 speedups are on the order of 20-100 for a six core machine. 

463 

464 Examples 

465 -------- 

466 

467 See :func:`opt_einsum.contract_path` or :func:`numpy.einsum` 

468 

469 """ 

470 optimize_arg = kwargs.pop('optimize', True) 

471 if optimize_arg is True: 

472 optimize_arg = 'auto' 

473 

474 valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting'] 

475 einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs} 

476 

477 # If no optimization, run pure einsum 

478 if optimize_arg is False: 

479 return _einsum(*operands, **einsum_kwargs) 

480 

481 # Grab non-einsum kwargs 

482 use_blas = kwargs.pop('use_blas', True) 

483 memory_limit = kwargs.pop('memory_limit', None) 

484 backend = kwargs.pop('backend', 'auto') 

485 gen_expression = kwargs.pop('_gen_expression', False) 

486 constants_dict = kwargs.pop('_constants_dict', {}) 

487 

488 # Make sure remaining keywords are valid for einsum 

489 unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs] 

490 if len(unknown_kwargs): 

491 raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs)) 

492 

493 if gen_expression: 

494 full_str = operands[0] 

495 

496 # Build the contraction list and operand 

497 operands, contraction_list = contract_path(*operands, 

498 optimize=optimize_arg, 

499 memory_limit=memory_limit, 

500 einsum_call=True, 

501 use_blas=use_blas) 

502 

503 # check if performing contraction or just building expression 

504 if gen_expression: 

505 return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs) 

506 

507 return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs) 

508 

509 

510def infer_backend(x): 

511 return x.__class__.__module__.split('.')[0] 

512 

513 

514def parse_backend(arrays, backend): 

515 """Find out what backend we should use, dipatching based on the first 

516 array if ``backend='auto'`` is specified. 

517 """ 

518 if backend != 'auto': 

519 return backend 

520 backend = infer_backend(arrays[0]) 

521 

522 # some arrays will be defined in modules that don't implement tensordot 

523 # etc. so instead default to numpy 

524 if not backends.has_tensordot(backend): 

525 return 'numpy' 

526 

527 return backend 

528 

529 

530def _core_contract(operands, contraction_list, backend='auto', evaluate_constants=False, **einsum_kwargs): 

531 """Inner loop used to perform an actual contraction given the output 

532 from a ``contract_path(..., einsum_call=True)`` call. 

533 """ 

534 

535 # Special handling if out is specified 

536 out_array = einsum_kwargs.pop('out', None) 

537 specified_out = out_array is not None 

538 backend = parse_backend(operands, backend) 

539 

540 # try and do as much as possible without einsum if not available 

541 no_einsum = not backends.has_einsum(backend) 

542 

543 # Start contraction loop 

544 for num, contraction in enumerate(contraction_list): 

545 inds, idx_rm, einsum_str, _, blas_flag = contraction 

546 

547 # check if we are performing the pre-pass of an expression with constants, 

548 # if so, break out upon finding first non-constant (None) operand 

549 if evaluate_constants and any(operands[x] is None for x in inds): 

550 return operands, contraction_list[num:] 

551 

552 tmp_operands = [operands.pop(x) for x in inds] 

553 

554 # Do we need to deal with the output? 

555 handle_out = specified_out and ((num + 1) == len(contraction_list)) 

556 

557 # Call tensordot (check if should prefer einsum, but only if available) 

558 if blas_flag and ('EINSUM' not in blas_flag or no_einsum): 

559 

560 # Checks have already been handled 

561 input_str, results_index = einsum_str.split('->') 

562 input_left, input_right = input_str.split(',') 

563 

564 tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm) 

565 

566 # Find indices to contract over 

567 left_pos, right_pos = [], [] 

568 for s in idx_rm: 

569 left_pos.append(input_left.find(s)) 

570 right_pos.append(input_right.find(s)) 

571 

572 # Contract! 

573 new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend) 

574 

575 # Build a new view if needed 

576 if (tensor_result != results_index) or handle_out: 

577 

578 transpose = tuple(map(tensor_result.index, results_index)) 

579 new_view = _transpose(new_view, axes=transpose, backend=backend) 

580 

581 if handle_out: 

582 out_array[:] = new_view 

583 

584 # Call einsum 

585 else: 

586 # If out was specified 

587 if handle_out: 

588 einsum_kwargs["out"] = out_array 

589 

590 # Do the contraction 

591 new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs) 

592 

593 # Append new items and dereference what we can 

594 operands.append(new_view) 

595 del tmp_operands, new_view 

596 

597 if specified_out: 

598 return out_array 

599 else: 

600 return operands[0] 

601 

602 

603def format_const_einsum_str(einsum_str, constants): 

604 """Add brackets to the constant terms in ``einsum_str``. For example: 

605 

606 >>> format_const_einsum_str('ab,bc,cd->ad', [0, 2]) 

607 'bc,[ab,cd]->ad' 

608 

609 No-op if there are no constants. 

610 """ 

611 if not constants: 

612 return einsum_str 

613 

614 if "->" in einsum_str: 

615 lhs, rhs = einsum_str.split('->') 

616 arrow = "->" 

617 else: 

618 lhs, rhs, arrow = einsum_str, "", "" 

619 

620 wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(','))] 

621 

622 formatted_einsum_str = "{}{}{}".format(','.join(wrapped_terms), arrow, rhs) 

623 

624 # merge adjacent constants 

625 formatted_einsum_str = formatted_einsum_str.replace("],[", ',') 

626 return formatted_einsum_str 

627 

628 

629class ContractExpression: 

630 """Helper class for storing an explicit ``contraction_list`` which can 

631 then be repeatedly called solely with the array arguments. 

632 """ 

633 def __init__(self, contraction, contraction_list, constants_dict, **einsum_kwargs): 

634 self.contraction_list = contraction_list 

635 self.einsum_kwargs = einsum_kwargs 

636 self.contraction = format_const_einsum_str(contraction, constants_dict.keys()) 

637 

638 # need to know _full_num_args to parse constants with, and num_args to call with 

639 self._full_num_args = contraction.count(',') + 1 

640 self.num_args = self._full_num_args - len(constants_dict) 

641 

642 # likewise need to know full contraction list 

643 self._full_contraction_list = contraction_list 

644 

645 self._constants_dict = constants_dict 

646 self._evaluated_constants = {} 

647 self._backend_expressions = {} 

648 

649 def evaluate_constants(self, backend='auto'): 

650 """Convert any constant operands to the correct backend form, and 

651 perform as many contractions as possible to create a new list of 

652 operands, stored in ``self._evaluated_constants[backend]``. This also 

653 makes sure ``self.contraction_list`` only contains the remaining, 

654 non-const operations. 

655 """ 

656 # prepare a list of operands, with `None` for non-consts 

657 tmp_const_ops = [self._constants_dict.get(i, None) for i in range(self._full_num_args)] 

658 backend = parse_backend(tmp_const_ops, backend) 

659 

660 # get the new list of operands with constant operations performed, and remaining contractions 

661 try: 

662 new_ops, new_contraction_list = backends.evaluate_constants(backend, tmp_const_ops, self) 

663 except KeyError: 

664 new_ops, new_contraction_list = self(*tmp_const_ops, backend=backend, evaluate_constants=True) 

665 

666 self._evaluated_constants[backend] = new_ops 

667 self.contraction_list = new_contraction_list 

668 

669 def _get_evaluated_constants(self, backend): 

670 """Retrieve or generate the cached list of constant operators (mixed 

671 in with None representing non-consts) and the remaining contraction 

672 list. 

673 """ 

674 try: 

675 return self._evaluated_constants[backend] 

676 except KeyError: 

677 self.evaluate_constants(backend) 

678 return self._evaluated_constants[backend] 

679 

680 def _get_backend_expression(self, arrays, backend): 

681 try: 

682 return self._backend_expressions[backend] 

683 except KeyError: 

684 fn = backends.build_expression(backend, arrays, self) 

685 self._backend_expressions[backend] = fn 

686 return fn 

687 

688 def _contract(self, arrays, out=None, backend='auto', evaluate_constants=False): 

689 """The normal, core contraction. 

690 """ 

691 contraction_list = self._full_contraction_list if evaluate_constants else self.contraction_list 

692 

693 return _core_contract(list(arrays), 

694 contraction_list, 

695 out=out, 

696 backend=backend, 

697 evaluate_constants=evaluate_constants, 

698 **self.einsum_kwargs) 

699 

700 def _contract_with_conversion(self, arrays, out, backend, evaluate_constants=False): 

701 """Special contraction, i.e., contraction with a different backend 

702 but converting to and from that backend. Retrieves or generates a 

703 cached expression using ``arrays`` as templates, then calls it 

704 with ``arrays``. 

705 

706 If ``evaluate_constants=True``, perform a partial contraction that 

707 prepares the constant tensors and operations with the right backend. 

708 """ 

709 # convert consts to correct type & find reduced contraction list 

710 if evaluate_constants: 

711 return backends.evaluate_constants(backend, arrays, self) 

712 

713 result = self._get_backend_expression(arrays, backend)(*arrays) 

714 

715 if out is not None: 

716 out[()] = result 

717 return out 

718 

719 return result 

720 

721 def __call__(self, *arrays, **kwargs): 

722 """Evaluate this expression with a set of arrays. 

723 

724 Parameters 

725 ---------- 

726 arrays : seq of array 

727 The arrays to supply as input to the expression. 

728 out : array, optional (default: ``None``) 

729 If specified, output the result into this array. 

730 backend : str, optional (default: ``numpy``) 

731 Perform the contraction with this backend library. If numpy arrays 

732 are supplied then try to convert them to and from the correct 

733 backend array type. 

734 """ 

735 out = kwargs.pop('out', None) 

736 backend = kwargs.pop('backend', 'auto') 

737 backend = parse_backend(arrays, backend) 

738 evaluate_constants = kwargs.pop('evaluate_constants', False) 

739 

740 if kwargs: 

741 raise ValueError("The only valid keyword arguments to a `ContractExpression` " 

742 "call are `out=` or `backend=`. Got: {}.".format(kwargs)) 

743 

744 correct_num_args = self._full_num_args if evaluate_constants else self.num_args 

745 

746 if len(arrays) != correct_num_args: 

747 raise ValueError("This `ContractExpression` takes exactly {} array arguments " 

748 "but received {}.".format(self.num_args, len(arrays))) 

749 

750 if self._constants_dict and not evaluate_constants: 

751 # fill in the missing non-constant terms with newly supplied arrays 

752 ops_var, ops_const = iter(arrays), self._get_evaluated_constants(backend) 

753 ops = [next(ops_var) if op is None else op for op in ops_const] 

754 else: 

755 ops = arrays 

756 

757 try: 

758 # Check if the backend requires special preparation / calling 

759 # but also ignore non-numpy arrays -> assume user wants same type back 

760 if backends.has_backend(backend) and all(isinstance(x, np.ndarray) for x in arrays): 

761 return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants) 

762 

763 return self._contract(ops, out, backend, evaluate_constants=evaluate_constants) 

764 

765 except ValueError as err: 

766 original_msg = str(err.args) if err.args else "" 

767 msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed" 

768 " - the number and rank of the array arguments must match the original expression. " 

769 "The internal error was: '{}'".format(original_msg), ) 

770 err.args = msg 

771 raise 

772 

773 def __repr__(self): 

774 if self._constants_dict: 

775 constants_repr = ", constants={}".format(sorted(self._constants_dict)) 

776 else: 

777 constants_repr = "" 

778 return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr) 

779 

780 def __str__(self): 

781 s = [self.__repr__()] 

782 for i, c in enumerate(self.contraction_list): 

783 s.append("\n {}. ".format(i + 1)) 

784 s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else "")) 

785 if self.einsum_kwargs: 

786 s.append("\neinsum_kwargs={}".format(self.einsum_kwargs)) 

787 return "".join(s) 

788 

789 

790Shaped = namedtuple('Shaped', ['shape']) 

791 

792 

793def shape_only(shape): 

794 """Dummy ``numpy.ndarray`` which has a shape only - for generating 

795 contract expressions. 

796 """ 

797 return Shaped(shape) 

798 

799 

800def contract_expression(subscripts, *shapes, **kwargs): 

801 """Generate a reusable expression for a given contraction with 

802 specific shapes, which can, for example, be cached. 

803 

804 Parameters 

805 ---------- 

806 subscripts : str 

807 Specifies the subscripts for summation. 

808 shapes : sequence of integer tuples 

809 Shapes of the arrays to optimize the contraction for. 

810 constants : sequence of int, optional 

811 The indices of any constant arguments in ``shapes``, in which case the 

812 actual array should be supplied at that position rather than just a 

813 shape. If these are specified, then constant parts of the contraction 

814 between calls will be reused. Additionally, if a GPU-enabled backend is 

815 used for example, then the constant tensors will be kept on the GPU, 

816 minimizing transfers. 

817 kwargs : 

818 Passed on to ``contract_path`` or ``einsum``. See ``contract``. 

819 

820 Returns 

821 ------- 

822 expr : ContractExpression 

823 Callable with signature ``expr(*arrays, out=None, backend='numpy')`` 

824 where the array's shapes should match ``shapes``. 

825 

826 Notes 

827 ----- 

828 - The `out` keyword argument should be supplied to the generated expression 

829 rather than this function. 

830 - The `backend` keyword argument should also be supplied to the generated 

831 expression. If numpy arrays are supplied, if possible they will be 

832 converted to and back from the correct backend array type. 

833 - The generated expression will work with any arrays which have 

834 the same rank (number of dimensions) as the original shapes, however, if 

835 the actual sizes are different, the expression may no longer be optimal. 

836 - Constant operations will be computed upon the first call with a particular 

837 backend, then subsequently reused. 

838 

839 Examples 

840 -------- 

841 

842 Basic usage: 

843 

844 >>> expr = contract_expression("ab,bc->ac", (3, 4), (4, 5)) 

845 >>> a, b = np.random.rand(3, 4), np.random.rand(4, 5) 

846 >>> c = expr(a, b) 

847 >>> np.allclose(c, a @ b) 

848 True 

849 

850 Supply ``a`` as a constant: 

851 

852 >>> expr = contract_expression("ab,bc->ac", a, (4, 5), constants=[0]) 

853 >>> expr 

854 <ContractExpression('[ab],bc->ac', constants=[0])> 

855 

856 >>> c = expr(b) 

857 >>> np.allclose(c, a @ b) 

858 True 

859 

860 """ 

861 if not kwargs.get('optimize', True): 

862 raise ValueError("Can only generate expressions for optimized contractions.") 

863 

864 for arg in ('out', 'backend'): 

865 if kwargs.get(arg, None) is not None: 

866 raise ValueError("'{}' should only be specified when calling a " 

867 "`ContractExpression`, not when building it.".format(arg)) 

868 

869 if not isinstance(subscripts, str): 

870 subscripts, shapes = parser.convert_interleaved_input((subscripts, ) + shapes) 

871 

872 kwargs['_gen_expression'] = True 

873 

874 # build dict of constant indices mapped to arrays 

875 constants = kwargs.pop('constants', ()) 

876 constants_dict = {i: shapes[i] for i in constants} 

877 kwargs['_constants_dict'] = constants_dict 

878 

879 # apart from constant arguments, make dummy arrays 

880 dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)] 

881 

882 return contract(subscripts, *dummy_arrays, **kwargs)