Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py: 26%

427 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Arithmetic Operations that don't fit into math_ops due to dependencies. 

16 

17To avoid circular dependencies, some math_ops should go here. 

18""" 

19 

20import collections 

21import functools 

22import re 

23import string 

24 

25import numpy as np 

26import opt_einsum 

27 

28 

29from tensorflow.compiler.tf2xla.ops import gen_xla_ops 

30from tensorflow.python.framework import ops 

31from tensorflow.python.framework import tensor_shape 

32from tensorflow.python.ops import array_ops 

33from tensorflow.python.ops import control_flow_ops 

34from tensorflow.python.ops import gen_linalg_ops 

35from tensorflow.python.ops import gen_special_math_ops 

36from tensorflow.python.ops import math_ops 

37from tensorflow.python.platform import tf_logging as logging 

38from tensorflow.python.util import deprecation 

39from tensorflow.python.util import dispatch 

40from tensorflow.python.util.tf_export import tf_export 

41 

42 

43# TODO(b/27419586) Change docstring for required dtype of x once int allowed 

44@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta']) 

45@dispatch.add_dispatch_support 

46@deprecation.deprecated_endpoints('lbeta') 

47def lbeta(x, name=None): 

48 r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension. 

49 

50 Given one-dimensional $z = [z_1,...,z_K]$, we define 

51 

52 $$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$ 

53 

54 where $\Gamma$ is the gamma function. 

55 

56 And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define 

57 

58 $$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$ 

59 

60 In other words, the last dimension is treated as the $z$ vector. 

61 

62 Note that if $z = [u, v]$, then 

63 

64 $$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)} 

65 = \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$ 

66 

67 which defines the traditional bivariate beta function. 

68 

69 If the last dimension is empty, we follow the convention that the sum over 

70 the empty set is zero, and the product is one. 

71 

72 Args: 

73 x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`. 

74 name: A name for the operation (optional). 

75 

76 Returns: 

77 The logarithm of \\(|Beta(x)|\\) reducing along the last dimension. 

78 """ 

79 # In the event that the last dimension has zero entries, we return -inf. 

80 # This is consistent with a convention that the sum over the empty set 0, and 

81 # the product is 1. 

82 # This is standard. See https://en.wikipedia.org/wiki/Empty_set. 

83 with ops.name_scope(name, 'lbeta', [x]): 

84 x = ops.convert_to_tensor(x, name='x') 

85 

86 # Note reduce_sum([]) = 0. 

87 log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1]) 

88 

89 # Note lgamma(0) = infinity, so if x = [] 

90 # log_gamma_sum_x = lgamma(0) = infinity, and 

91 # log_prod_gamma_x = lgamma(1) = 0, 

92 # so result = -infinity 

93 sum_x = math_ops.reduce_sum(x, axis=[-1]) 

94 log_gamma_sum_x = math_ops.lgamma(sum_x) 

95 result = log_prod_gamma_x - log_gamma_sum_x 

96 

97 return result 

98 

99 

100@tf_export('math.special.dawsn') 

101@dispatch.register_unary_elementwise_api 

102@dispatch.add_dispatch_support 

103def dawsn(x, name=None): 

104 """Computes Dawson's integral of `x` element-wise. 

105 

106 Dawson's integral is defined as `exp(-x**2)` times the integral of 

107 `exp(t**2)` from `0` to `x`, with the domain of definition all real numbers. 

108 

109 Dawson's function is odd. 

110 >>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy() 

111 array([-0.5380795, -0.4244364, 0.4244364, 0.5380795], dtype=float32) 

112 

113 This implementation is based off of the Cephes math library. 

114 

115 Args: 

116 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 

117 `float32`, `float64`. 

118 name: A name for the operation (optional). 

119 

120 Returns: 

121 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

122 

123 @compatibility(scipy) 

124 Equivalent to scipy.special.dawsn 

125 @end_compatibility 

126 """ 

127 with ops.name_scope(name, 'dawsn', [x]): 

128 return gen_special_math_ops.dawsn(x) 

129 

130 

131@tf_export('math.special.expint') 

132@dispatch.register_unary_elementwise_api 

133@dispatch.add_dispatch_support 

134def expint(x, name=None): 

135 """Computes the Exponential integral of `x` element-wise. 

136 

137 The Exponential integral is defined as the integral of `exp(t) / t` from 

138 `-inf` to `x`, with the domain of definition all positive real numbers. 

139 

140 >>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy() 

141 array([ 1.8951179, 2.1673784, 5.3332353, 21.048464], dtype=float32) 

142 

143 This implementation is based off of the Cephes math library. 

144 

145 Args: 

146 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 

147 `float32`, `float64`. 

148 name: A name for the operation (optional). 

149 

150 Returns: 

151 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

152 

153 @compatibility(scipy) 

154 Equivalent to scipy.special.expi 

155 @end_compatibility 

156 """ 

157 with ops.name_scope(name, 'expint', [x]): 

158 return gen_special_math_ops.expint(x) 

159 

160 

161@tf_export('math.special.fresnel_cos') 

162@dispatch.register_unary_elementwise_api 

163@dispatch.add_dispatch_support 

164def fresnel_cos(x, name=None): 

165 """Computes Fresnel's cosine integral of `x` element-wise. 

166 

167 The Fresnel cosine integral is defined as the integral of `cos(t^2)` from 

168 `0` to `x`, with the domain of definition all real numbers. 

169 

170 The Fresnel cosine integral is odd. 

171 >>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy() 

172 array([-0.7798934 , -0.09999753, 0.09999753, 0.7798934 ], dtype=float32) 

173 

174 This implementation is based off of the Cephes math library. 

175 

176 Args: 

177 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 

178 `float32`, `float64`. 

179 name: A name for the operation (optional). 

180 

181 Returns: 

182 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

183 

184 @compatibility(scipy) 

185 Equivalent to scipy.special.fresnel second output. 

186 @end_compatibility 

187 """ 

188 with ops.name_scope(name, 'fresnel_cos', [x]): 

189 return gen_special_math_ops.fresnel_cos(x) 

190 

191 

192@tf_export('math.special.fresnel_sin') 

193@dispatch.register_unary_elementwise_api 

194@dispatch.add_dispatch_support 

195def fresnel_sin(x, name=None): 

196 """Computes Fresnel's sine integral of `x` element-wise. 

197 

198 The Fresnel sine integral is defined as the integral of `sin(t^2)` from 

199 `0` to `x`, with the domain of definition all real numbers. 

200 

201 >>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy() 

202 array([-0.43825912, -0.00052359, 0.00052359, 0.43825912], dtype=float32) 

203 

204 This implementation is based off of the Cephes math library. 

205 

206 Args: 

207 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 

208 `float32`, `float64`. 

209 name: A name for the operation (optional). 

210 

211 Returns: 

212 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

213 

214 @compatibility(scipy) 

215 Equivalent to scipy.special.fresnel first output. 

216 @end_compatibility 

217 """ 

218 with ops.name_scope(name, 'fresnel_sin', [x]): 

219 return gen_special_math_ops.fresnel_sin(x) 

220 

221 

222@tf_export('math.special.spence') 

223@dispatch.register_unary_elementwise_api 

224@dispatch.add_dispatch_support 

225def spence(x, name=None): 

226 """Computes Spence's integral of `x` element-wise. 

227 

228 Spence's integral is defined as the integral of `log(t) / (1 - t)` from 

229 `1` to `x`, with the domain of definition all non-negative real numbers. 

230 

231 >>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy() 

232 array([ 0.58224034, 0. , -0.82246685, -1.4367464], dtype=float32) 

233 

234 This implementation is based off of the Cephes math library. 

235 

236 Args: 

237 x: A `Tensor` or `SparseTensor`. Must be one of the following types: 

238 `float32`, `float64`. 

239 name: A name for the operation (optional). 

240 

241 Returns: 

242 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

243 

244 @compatibility(scipy) 

245 Equivalent to scipy.special.spence 

246 @end_compatibility 

247 """ 

248 with ops.name_scope(name, 'spence', [x]): 

249 return gen_special_math_ops.spence(x) 

250 

251 

252@tf_export('math.bessel_i0', 'math.special.bessel_i0') 

253@dispatch.register_unary_elementwise_api 

254@dispatch.add_dispatch_support 

255def bessel_i0(x, name=None): 

256 """Computes the Bessel i0 function of `x` element-wise. 

257 

258 Modified Bessel function of order 0. 

259 

260 It is preferable to use the numerically stabler function `i0e(x)` instead. 

261 

262 >>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy() 

263 array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32) 

264 

265 Args: 

266 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

267 `float32`, `float64`. 

268 name: A name for the operation (optional). 

269 

270 Returns: 

271 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

272 

273 @compatibility(scipy) 

274 Equivalent to scipy.special.i0 

275 @end_compatibility 

276 """ 

277 with ops.name_scope(name, 'bessel_i0', [x]): 

278 return gen_special_math_ops.bessel_i0(x) 

279 

280 

281@tf_export('math.bessel_i0e', 'math.special.bessel_i0e') 

282@dispatch.register_unary_elementwise_api 

283@dispatch.add_dispatch_support 

284def bessel_i0e(x, name=None): 

285 """Computes the Bessel i0e function of `x` element-wise. 

286 

287 Modified Bessel function of order 0. 

288 

289 >>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy() 

290 array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32) 

291 

292 Args: 

293 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

294 `float32`, `float64`. 

295 name: A name for the operation (optional). 

296 

297 Returns: 

298 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

299 

300 @compatibility(scipy) 

301 Equivalent to scipy.special.i0e 

302 @end_compatibility 

303 """ 

304 with ops.name_scope(name, 'bessel_i0e', [x]): 

305 return gen_special_math_ops.bessel_i0e(x) 

306 

307 

308@tf_export('math.bessel_i1', 'math.special.bessel_i1') 

309@dispatch.register_unary_elementwise_api 

310@dispatch.add_dispatch_support 

311def bessel_i1(x, name=None): 

312 """Computes the Bessel i1 function of `x` element-wise. 

313 

314 Modified Bessel function of order 1. 

315 

316 It is preferable to use the numerically stabler function `i1e(x)` instead. 

317 

318 >>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy() 

319 array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32) 

320 

321 Args: 

322 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

323 `float32`, `float64`. 

324 name: A name for the operation (optional). 

325 

326 Returns: 

327 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

328 

329 @compatibility(scipy) 

330 Equivalent to scipy.special.i1 

331 @end_compatibility 

332 """ 

333 with ops.name_scope(name, 'bessel_i1', [x]): 

334 return gen_special_math_ops.bessel_i1(x) 

335 

336 

337@tf_export('math.bessel_i1e', 'math.special.bessel_i1e') 

338@dispatch.register_unary_elementwise_api 

339@dispatch.add_dispatch_support 

340def bessel_i1e(x, name=None): 

341 """Computes the Bessel i1e function of `x` element-wise. 

342 

343 Modified Bessel function of order 1. 

344 

345 >>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy() 

346 array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32) 

347 

348 Args: 

349 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

350 `float32`, `float64`. 

351 name: A name for the operation (optional). 

352 

353 Returns: 

354 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

355 

356 @compatibility(scipy) 

357 Equivalent to scipy.special.i1e 

358 @end_compatibility 

359 """ 

360 with ops.name_scope(name, 'bessel_i1e', [x]): 

361 return gen_special_math_ops.bessel_i1e(x) 

362 

363 

364@tf_export('math.special.bessel_k0') 

365@dispatch.register_unary_elementwise_api 

366@dispatch.add_dispatch_support 

367def bessel_k0(x, name=None): 

368 """Computes the Bessel k0 function of `x` element-wise. 

369 

370 Modified Bessel function of order 0. 

371 

372 It is preferable to use the numerically stabler function `k0e(x)` instead. 

373 

374 >>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy() 

375 array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32) 

376 

377 Args: 

378 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

379 `float32`, `float64`. 

380 name: A name for the operation (optional). 

381 

382 Returns: 

383 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

384 

385 @compatibility(scipy) 

386 Equivalent to scipy.special.k0 

387 @end_compatibility 

388 """ 

389 with ops.name_scope(name, 'bessel_k0', [x]): 

390 return gen_special_math_ops.bessel_k0(x) 

391 

392 

393@tf_export('math.special.bessel_k0e') 

394@dispatch.register_unary_elementwise_api 

395@dispatch.add_dispatch_support 

396def bessel_k0e(x, name=None): 

397 """Computes the Bessel k0e function of `x` element-wise. 

398 

399 Modified Bessel function of order 0. 

400 

401 >>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy() 

402 array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32) 

403 

404 Args: 

405 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

406 `float32`, `float64`. 

407 name: A name for the operation (optional). 

408 

409 Returns: 

410 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

411 

412 @compatibility(scipy) 

413 Equivalent to scipy.special.k0e 

414 @end_compatibility 

415 """ 

416 with ops.name_scope(name, 'bessel_k0e', [x]): 

417 return gen_special_math_ops.bessel_k0e(x) 

418 

419 

420@tf_export('math.special.bessel_k1') 

421@dispatch.register_unary_elementwise_api 

422@dispatch.add_dispatch_support 

423def bessel_k1(x, name=None): 

424 """Computes the Bessel k1 function of `x` element-wise. 

425 

426 Modified Bessel function of order 1. 

427 

428 It is preferable to use the numerically stabler function `k1e(x)` instead. 

429 

430 >>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy() 

431 array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32) 

432 

433 Args: 

434 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

435 `float32`, `float64`. 

436 name: A name for the operation (optional). 

437 

438 Returns: 

439 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

440 

441 @compatibility(scipy) 

442 Equivalent to scipy.special.k1 

443 @end_compatibility 

444 """ 

445 with ops.name_scope(name, 'bessel_k1', [x]): 

446 return gen_special_math_ops.bessel_k1(x) 

447 

448 

449@tf_export('math.special.bessel_k1e') 

450@dispatch.register_unary_elementwise_api 

451@dispatch.add_dispatch_support 

452def bessel_k1e(x, name=None): 

453 """Computes the Bessel k1e function of `x` element-wise. 

454 

455 Modified Bessel function of order 1. 

456 

457 >>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy() 

458 array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32) 

459 

460 Args: 

461 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

462 `float32`, `float64`. 

463 name: A name for the operation (optional). 

464 

465 Returns: 

466 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

467 

468 @compatibility(scipy) 

469 Equivalent to scipy.special.k1e 

470 @end_compatibility 

471 """ 

472 with ops.name_scope(name, 'bessel_k1e', [x]): 

473 return gen_special_math_ops.bessel_k1e(x) 

474 

475 

476@tf_export('math.special.bessel_j0') 

477@dispatch.register_unary_elementwise_api 

478@dispatch.add_dispatch_support 

479def bessel_j0(x, name=None): 

480 """Computes the Bessel j0 function of `x` element-wise. 

481 

482 Modified Bessel function of order 0. 

483 

484 >>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy() 

485 array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32) 

486 

487 Args: 

488 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

489 `float32`, `float64`. 

490 name: A name for the operation (optional). 

491 

492 Returns: 

493 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

494 

495 @compatibility(scipy) 

496 Equivalent to scipy.special.j0 

497 @end_compatibility 

498 """ 

499 with ops.name_scope(name, 'bessel_j0', [x]): 

500 return gen_special_math_ops.bessel_j0(x) 

501 

502 

503@tf_export('math.special.bessel_j1') 

504@dispatch.register_unary_elementwise_api 

505@dispatch.add_dispatch_support 

506def bessel_j1(x, name=None): 

507 """Computes the Bessel j1 function of `x` element-wise. 

508 

509 Modified Bessel function of order 1. 

510 

511 >>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy() 

512 array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32) 

513 

514 Args: 

515 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

516 `float32`, `float64`. 

517 name: A name for the operation (optional). 

518 

519 Returns: 

520 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

521 

522 @compatibility(scipy) 

523 Equivalent to scipy.special.j1 

524 @end_compatibility 

525 """ 

526 with ops.name_scope(name, 'bessel_j1', [x]): 

527 return gen_special_math_ops.bessel_j1(x) 

528 

529 

530@tf_export('math.special.bessel_y0') 

531@dispatch.register_unary_elementwise_api 

532@dispatch.add_dispatch_support 

533def bessel_y0(x, name=None): 

534 """Computes the Bessel y0 function of `x` element-wise. 

535 

536 Modified Bessel function of order 0. 

537 

538 >>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy() 

539 array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32) 

540 

541 Args: 

542 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

543 `float32`, `float64`. 

544 name: A name for the operation (optional). 

545 

546 Returns: 

547 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

548 

549 @compatibility(scipy) 

550 Equivalent to scipy.special.y0 

551 @end_compatibility 

552 """ 

553 with ops.name_scope(name, 'bessel_y0', [x]): 

554 return gen_special_math_ops.bessel_y0(x) 

555 

556 

557@tf_export('math.special.bessel_y1') 

558@dispatch.register_unary_elementwise_api 

559@dispatch.add_dispatch_support 

560def bessel_y1(x, name=None): 

561 """Computes the Bessel y1 function of `x` element-wise. 

562 

563 Modified Bessel function of order 1. 

564 

565 >>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy() 

566 array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32) 

567 

568 Args: 

569 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`, 

570 `float32`, `float64`. 

571 name: A name for the operation (optional). 

572 

573 Returns: 

574 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`. 

575 

576 @compatibility(scipy) 

577 Equivalent to scipy.special.y1 

578 @end_compatibility 

579 """ 

580 with ops.name_scope(name, 'bessel_y1', [x]): 

581 return gen_special_math_ops.bessel_y1(x) 

582 

583 

584@ops.RegisterGradient('XlaEinsum') 

585def _einsum_grad(op, grad): 

586 equation = op.get_attr('equation') 

587 if isinstance(equation, bytes): 

588 equation = equation.decode() 

589 

590 inputs, output = equation.split('->') 

591 left, right = inputs.split(',') 

592 

593 return [ 

594 gen_xla_ops.xla_einsum( 

595 grad, 

596 op.inputs[1], 

597 equation='{},{}->{}'.format(output, right, left), 

598 name=None), 

599 gen_xla_ops.xla_einsum( 

600 grad, 

601 op.inputs[0], 

602 equation='{},{}->{}'.format(output, left, right), 

603 name=None) 

604 ] 

605 

606 

607def _enclosing_tpu_context(): 

608 # pylint: disable=protected-access 

609 context = ops.get_default_graph()._get_control_flow_context() 

610 # pylint: enable=protected-access 

611 while context is not None and not isinstance( 

612 context, control_flow_ops.XLAControlFlowContext): 

613 context = context.outer_context 

614 return context 

615 

616 

617@tf_export('einsum', 'linalg.einsum') 

618@dispatch.add_dispatch_support 

619def einsum(equation, *inputs, **kwargs): 

620 r"""Tensor contraction over specified indices and outer product. 

621 

622 Einsum allows defining Tensors by defining their element-wise computation. 

623 This computation is defined by `equation`, a shorthand form based on Einstein 

624 summation. As an example, consider multiplying two matrices A and B to form a 

625 matrix C. The elements of C are given by: 

626 

627 $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$ 

628 

629 or 

630 

631 ``` 

632 C[i,k] = sum_j A[i,j] * B[j,k] 

633 ``` 

634 

635 The corresponding einsum `equation` is: 

636 

637 ``` 

638 ij,jk->ik 

639 ``` 

640 

641 In general, to convert the element-wise equation into the `equation` string, 

642 use the following procedure (intermediate strings for matrix multiplication 

643 example provided in parentheses): 

644 

645 1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`) 

646 2. replace "*" with ",", (`ik = sum_j ij , jk`) 

647 3. drop summation signs, and (`ik = ij, jk`) 

648 4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`) 

649 

650 Note: If the output indices are not specified repeated indices are summed. 

651 So `ij,jk->ik` can be simplified to `ij,jk`. 

652 

653 Many common operations can be expressed in this way. For example: 

654 

655 **Matrix multiplication** 

656 

657 >>> m0 = tf.random.normal(shape=[2, 3]) 

658 >>> m1 = tf.random.normal(shape=[3, 5]) 

659 >>> e = tf.einsum('ij,jk->ik', m0, m1) 

660 >>> # output[i,k] = sum_j m0[i,j] * m1[j, k] 

661 >>> print(e.shape) 

662 (2, 5) 

663 

664 Repeated indices are summed if the output indices are not specified. 

665 

666 >>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k] 

667 >>> print(e.shape) 

668 (2, 5) 

669 

670 

671 **Dot product** 

672 

673 >>> u = tf.random.normal(shape=[5]) 

674 >>> v = tf.random.normal(shape=[5]) 

675 >>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i] 

676 >>> print(e.shape) 

677 () 

678 

679 **Outer product** 

680 

681 >>> u = tf.random.normal(shape=[3]) 

682 >>> v = tf.random.normal(shape=[5]) 

683 >>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j] 

684 >>> print(e.shape) 

685 (3, 5) 

686 

687 **Transpose** 

688 

689 >>> m = tf.ones(2,3) 

690 >>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j] 

691 >>> print(e.shape) 

692 (3, 2) 

693 

694 **Diag** 

695 

696 >>> m = tf.reshape(tf.range(9), [3,3]) 

697 >>> diag = tf.einsum('ii->i', m) 

698 >>> print(diag.shape) 

699 (3,) 

700 

701 **Trace** 

702 

703 >>> # Repeated indices are summed. 

704 >>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i] 

705 >>> assert trace == sum(diag) 

706 >>> print(trace.shape) 

707 () 

708 

709 **Batch matrix multiplication** 

710 

711 >>> s = tf.random.normal(shape=[7,5,3]) 

712 >>> t = tf.random.normal(shape=[7,3,2]) 

713 >>> e = tf.einsum('bij,bjk->bik', s, t) 

714 >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k] 

715 >>> print(e.shape) 

716 (7, 5, 2) 

717 

718 This method does not support broadcasting on named-axes. All axes with 

719 matching labels should have the same length. If you have length-1 axes, 

720 use `tf.squeeze` or `tf.reshape` to eliminate them. 

721 

722 To write code that is agnostic to the number of indices in the input 

723 use an ellipsis. The ellipsis is a placeholder for "whatever other indices 

724 fit here". 

725 

726 For example, to perform a NumPy-style broadcasting-batch-matrix multiplication 

727 where the matrix multiply acts on the last two axes of the input, use: 

728 

729 >>> s = tf.random.normal(shape=[11, 7, 5, 3]) 

730 >>> t = tf.random.normal(shape=[11, 7, 3, 2]) 

731 >>> e = tf.einsum('...ij,...jk->...ik', s, t) 

732 >>> print(e.shape) 

733 (11, 7, 5, 2) 

734 

735 Einsum **will** broadcast over axes covered by the ellipsis. 

736 

737 >>> s = tf.random.normal(shape=[11, 1, 5, 3]) 

738 >>> t = tf.random.normal(shape=[1, 7, 3, 2]) 

739 >>> e = tf.einsum('...ij,...jk->...ik', s, t) 

740 >>> print(e.shape) 

741 (11, 7, 5, 2) 

742 

743 Args: 

744 equation: a `str` describing the contraction, in the same format as 

745 `numpy.einsum`. 

746 *inputs: the inputs to contract (each one a `Tensor`), whose shapes should 

747 be consistent with `equation`. 

748 **kwargs: 

749 - optimize: Optimization strategy to use to find contraction path using 

750 opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or 

751 'auto'. (optional, default: 'greedy'). 

752 - name: A name for the operation (optional). 

753 

754 Returns: 

755 The contracted `Tensor`, with shape determined by `equation`. 

756 

757 Raises: 

758 ValueError: If 

759 - the format of `equation` is incorrect, 

760 - number of inputs or their shapes are inconsistent with `equation`. 

761 """ 

762 return _einsum_v2(equation, *inputs, **kwargs) 

763 

764 

765def _einsum_v1(equation, *inputs, **kwargs): 

766 """Legacy implementation of einsum without using EinsumOp.""" 

767 name = kwargs.pop('name', None) 

768 if kwargs: 

769 raise TypeError( 

770 f'Invalid keyword arguments for this function: ' 

771 f'{", ".join([format(key) for key in sorted(list(kwargs.keys()))])}.' 

772 f' Expected: name.') 

773 with ops.name_scope(name, 'einsum', [equation, inputs]) as name: 

774 inputs = list(inputs) 

775 input_shapes = [x.shape for x in inputs] 

776 input_axis_labels, output_axis_labels = ( 

777 _einsum_v1_parse_and_resolve_equation(equation, input_shapes)) 

778 

779 axis_labels = set(''.join(input_axis_labels) + output_axis_labels) 

780 

781 for a in axis_labels: 

782 for input_labels in input_axis_labels: 

783 if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and 

784 input_labels == input_labels[::-1] and '->' not in equation): 

785 return math_ops.trace(inputs[0]) 

786 if input_labels.count(a) > 1: 

787 raise ValueError( 

788 f'Subscript not supported: the axis {a} appears more than once' 

789 f' in {input_labels}.') 

790 for a in axis_labels: 

791 input_count = sum(1 for s in input_axis_labels if a in s) 

792 if input_count > 2 and a not in output_axis_labels: 

793 logging.warn( 

794 f'Falling back to exponential-space implementation of einsum()' 

795 f' because index {a} is summed over more than two inputs.') 

796 return _exponential_space_einsum_v1(equation, *inputs) 

797 

798 # Use xla_einsum if executing on TPU and if the operation is a 2 input 

799 # einsum supported by XlaEinsumOp. 

800 if _enclosing_tpu_context() is not None and len(inputs) == 2: 

801 return gen_xla_ops.xla_einsum( 

802 inputs[0], inputs[1], input_axis_labels[0] + ',' + 

803 input_axis_labels[1] + '->' + output_axis_labels) 

804 temp = inputs[0] 

805 temp_axis_labels = input_axis_labels[0] 

806 for i in range(len(inputs) - 1): 

807 axes_to_sum = ( 

808 set(temp_axis_labels) & 

809 set(input_axis_labels[i + 1]) - set(output_axis_labels)) 

810 temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels, 

811 inputs[i + 1], 

812 input_axis_labels[i + 1], 

813 axes_to_sum) 

814 

815 missing_indices = set(temp_axis_labels) - set(output_axis_labels) 

816 if missing_indices: 

817 axis = [ 

818 i for i, a in enumerate(temp_axis_labels) 

819 if a not in output_axis_labels 

820 ] 

821 temp = math_ops.reduce_sum(temp, axis=axis) 

822 temp_axis_labels = ''.join( 

823 a for a in temp_axis_labels if a in output_axis_labels) 

824 if sorted(temp_axis_labels) != sorted(output_axis_labels): 

825 raise ValueError( 

826 f'Invalid equation: {equation}. The computed and specified output ' 

827 f'labels do not match: {temp_axis_labels} vs {output_axis_labels}.') 

828 

829 perm = [temp_axis_labels.index(a) for a in output_axis_labels] 

830 return _transpose_if_necessary(temp, perm) 

831 

832 

833def _einsum_v1_parse_and_resolve_equation(equation, input_shapes): 

834 """Helper for einsum() that splits/resolves inputs & outputs. 

835 

836 Args: 

837 equation: Equation string given as argument to einsum(). 

838 input_shapes: List of the shapes of all inputs given to einsum() 

839 

840 Returns: 

841 input_axis_labels, output_axis_labels where: 

842 input_axis_labels: List of length len(input_shapes) of strings 

843 representing the character label for each dimension of each given input, 

844 resolving any broadcast (...) axes, 

845 output_axis_labels: A string of character labels for each axes of output 

846 tensor, filling in missing output subscripts and broadcast axes. 

847 

848 Raises: 

849 ValueError: If equation is in the uncorrect format, incorrect number of 

850 inputs given or broadcast axes "..." or output axes could not be resolved. 

851 """ 

852 equation = equation.replace(' ', '') 

853 match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation) 

854 if not match: 

855 raise ValueError(f'Indices have incorrect format. Received: {equation}.') 

856 

857 input_axis_labels = match.group(1).split(',') 

858 output_axis_labels = match.group(2)[2:] if match.group(2) else None 

859 

860 if len(input_shapes) != len(input_axis_labels): 

861 raise ValueError( 

862 f'Got {len(input_shapes)} arguments for equation "{equation}", ' 

863 f'expecting {len(input_axis_labels)}.') 

864 

865 # Resolve Ellipsis 

866 # Assign axes labels for unspecified dimensions in inputs. Labels taken 

867 # from unused labels. Follow numpy einsum broadcasting conventions for 

868 # tensors of different length and unlabeled output. 

869 ellipsis_axes = '' 

870 if '...' in equation: 

871 unused = ''.join( 

872 c for c in string.ascii_letters if c not in ''.join(input_axis_labels)) 

873 for i, ax in enumerate(input_axis_labels): 

874 if '...' in ax: 

875 parts = ax.split('...') 

876 if len(parts) != 2: 

877 raise ValueError(f'Unable to resolve ellipsis. ' 

878 f'Excess number found: {len(parts)-1} vs 1.') 

879 if input_shapes[i].ndims is None: 

880 raise ValueError('Unable to statically infer ellipsis axes. The ' 

881 'input shapes has a dynamic dimensionality.') 

882 n = input_shapes[i].ndims - len(''.join(parts)) 

883 if n < 0: 

884 raise ValueError('Ellipses lengths do not match.') 

885 if len(unused) < n: 

886 raise ValueError( 

887 'Unable to resolve ellipsis, too many distinct labels.') 

888 replace_axes = unused[-n:] if n > 0 else '' 

889 input_axis_labels[i] = input_axis_labels[i].replace('...', 

890 replace_axes) 

891 if len(replace_axes) > len(ellipsis_axes): 

892 ellipsis_axes = replace_axes 

893 

894 if any('.' in ax for ax in input_axis_labels): 

895 raise ValueError( 

896 f'Period "." found outside of ellipsis in input {input_axis_labels}.') 

897 

898 if output_axis_labels is not None: 

899 output_axis_labels = output_axis_labels.replace('...', ellipsis_axes) 

900 if '.' in output_axis_labels: 

901 raise ValueError(f'Period "." found outside of ellipsis in output ' 

902 f'{output_axis_labels}.') 

903 

904 if output_axis_labels is None: 

905 # infer the output subscripts if not given, assume alphabetical order, 

906 # but always place ellipsis axes before given. 

907 axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes) 

908 indices = ''.join(sorted(axis_labels)) 

909 counts = {ax: 0 for ax in indices} 

910 for axes_ in input_axis_labels: 

911 for ax in axes_: 

912 if ax not in ellipsis_axes: 

913 counts[ax] += 1 

914 

915 output_axis_labels = ellipsis_axes + ''.join( 

916 sorted(ax for ax in axis_labels if counts[ax] == 1)) 

917 

918 return input_axis_labels, output_axis_labels 

919 

920 

921def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum): 

922 """Helper for einsum() that computes the result of a two-argument einsum(). 

923 

924 Args: 

925 t0: a `Tensor` 

926 t0_axis_labels: a string of axis labels. This string's length must equal 

927 the rank of t0. 

928 t1: a `Tensor` 

929 t1_axis_labels: a string to axis labels. This string's length must equal 

930 the rank of t1. 

931 axes_to_sum: set of labels of axes to be summed over 

932 

933 Returns: 

934 A `Tensor` whose elements are obtained by summing, over all axes in 

935 `axes_to_sum`, the corresponding elements of `t0` and `t1`. 

936 

937 For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and 

938 axes_to_sum == {j,k}, this will return a tensor x where 

939 

940 out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l] 

941 

942 Raises: 

943 ValueError: if the rank of `t0` does not match the length of 

944 `t0_axis_labels`, or that of `t1` does not match the length of 

945 `t1_axis_labels`. 

946 """ 

947 if len(t0_axis_labels) != len(t0.shape): 

948 raise ValueError( 

949 f'Tensor `t0` of rank {len(t0.shape)} does not match einsum reduction ' 

950 f'of length {len(t0_axis_labels)}.') 

951 if len(t1_axis_labels) != len(t1.shape): 

952 raise ValueError( 

953 f'Tensor `t1` of rank {len(t1.shape)} does not match einsum reduction ' 

954 f'of length {len(t1_axis_labels)}') 

955 

956 # This function computes the result of a two-argument einsum() using batch 

957 # matrix multiplication. This involves 

958 # 1. transposing t0 and t1 so that axes are in the correct order for 

959 # batch matrix multiplication, and 

960 # 2. reshaping t0 and t1 so that they are both of rank 3. 

961 

962 # First, we divide axes into three groups: 

963 # * "preserved" axes are present in both inputs and the output 

964 # * "summed" axes are present in both inputs but not the output 

965 # * "broadcast" axes are present in exactly one input and the output 

966 # 

967 # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a 

968 # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are 

969 # summed axes. 

970 assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum) 

971 preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum 

972 broadcast_axes = {} 

973 for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]): 

974 broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum 

975 

976 # Reorder the axes so that: 

977 # 1. preserved axes come first in both inputs 

978 # 2. in input 0, broadcast axes come next, followed by summed axes 

979 # 3. in input 1, summed axes come next, followed by broadcast axes 

980 def sort_key(input_index, a): 

981 if a in preserved_axes: 

982 return (-1, a) 

983 elif ((input_index == 0 and a in broadcast_axes[0]) or 

984 (input_index == 1 and a in axes_to_sum)): 

985 return (0, a) 

986 else: 

987 return (1, a) 

988 

989 axis_labels = [t0_axis_labels, t1_axis_labels] 

990 sorted_axes = [ 

991 sorted(sym_list, key=lambda a: sort_key(i, a)) 

992 for i, sym_list in enumerate(axis_labels) 

993 ] 

994 inputs = [t0, t1] 

995 for i, axes_str in enumerate(axis_labels): 

996 perm = [axes_str.find(a) for a in sorted_axes[i]] 

997 inputs[i] = _transpose_if_necessary(inputs[i], perm) 

998 t0, t1 = inputs 

999 

1000 if not axes_to_sum: 

1001 # In the special case where there are no axes to sum over, reduce to mul() 

1002 # rather than to batch matrix multiplication. 

1003 for _ in broadcast_axes[1]: 

1004 t0 = array_ops.expand_dims(t0, -1) 

1005 for _ in broadcast_axes[0]: 

1006 t1 = array_ops.expand_dims(t1, len(preserved_axes)) 

1007 product = math_ops.multiply(t0, t1) 

1008 product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):] 

1009 return product, ''.join(product_axes) 

1010 else: 

1011 # Reduce to matmul(). 

1012 

1013 # Reshape both inputs so as to combine multiple broadcast axes 

1014 # into a single axis, and combine multiple summed axes into a 

1015 # single axis. 

1016 

1017 t0_shape = _get_shape(t0) 

1018 num_broadcast_elements_t0 = _total_size( 

1019 t0_shape[len(preserved_axes):-len(axes_to_sum)]) 

1020 num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):]) 

1021 new_shape = ( 

1022 t0_shape[:len(preserved_axes)] + 

1023 [num_broadcast_elements_t0, num_summed_elements]) 

1024 t0 = _reshape_if_necessary(t0, new_shape) 

1025 

1026 t1_shape = _get_shape(t1) 

1027 num_broadcast_elements_t1 = _total_size( 

1028 t1_shape[len(preserved_axes) + len(axes_to_sum):]) 

1029 new_shape = ( 

1030 t1_shape[:len(preserved_axes)] + 

1031 [num_summed_elements, num_broadcast_elements_t1]) 

1032 t1 = _reshape_if_necessary(t1, new_shape) 

1033 

1034 product = math_ops.matmul(t0, t1) 

1035 

1036 # Undo compaction of broadcast axes 

1037 uncompacted_shape = ( 

1038 t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] + 

1039 t1_shape[len(t1_shape) - len(broadcast_axes[1]):]) 

1040 product = _reshape_if_necessary(product, uncompacted_shape) 

1041 

1042 product_axes = ( 

1043 sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] + 

1044 sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):]) 

1045 

1046 return product, ''.join(product_axes) 

1047 

1048 

1049def _transpose_if_necessary(tensor, perm): 

1050 """Like transpose(), but avoids creating a new tensor if possible.""" 

1051 if perm != list(range(len(perm))): 

1052 return array_ops.transpose(tensor, perm=perm) 

1053 else: 

1054 return tensor 

1055 

1056 

1057def _reshape_if_necessary(tensor, new_shape): 

1058 """Like reshape(), but avoids creating a new tensor if possible.""" 

1059 # Accept None as an alias for -1 in new_shape. 

1060 new_shape = tuple(-1 if x is None else x for x in new_shape) 

1061 cur_shape = tuple(x.value for x in tensor.shape.dims) 

1062 if (len(new_shape) == len(cur_shape) and 

1063 all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1) 

1064 for d0, d1 in zip(cur_shape, new_shape))): 

1065 return tensor 

1066 else: 

1067 return array_ops.reshape(tensor, new_shape) 

1068 

1069 

1070def _get_shape(tensor): 

1071 """Like get_shape().as_list(), but explicitly queries the shape of a tensor 

1072 if necessary to ensure that the returned value contains no unknown value.""" 

1073 

1074 shape = tensor.shape.as_list() 

1075 none_indices = [i for i, d in enumerate(shape) if d is None] 

1076 if none_indices: 

1077 # Query the shape if shape contains None values 

1078 shape_tensor = array_ops.shape(tensor) 

1079 for i in none_indices: 

1080 shape[i] = shape_tensor[i] 

1081 return shape 

1082 

1083 

1084def _total_size(shape_values): 

1085 """Given list of tensor shape values, returns total size. 

1086 If shape_values contains tensor values (which are results of 

1087 array_ops.shape), then it returns a scalar tensor. 

1088 If not, it returns an integer.""" 

1089 

1090 result = 1 

1091 for val in shape_values: 

1092 result *= val 

1093 return result 

1094 

1095 

1096def _exponential_space_einsum_v1(equation, *inputs): 

1097 """Fallback implementation that supports summing an index over > 2 inputs.""" 

1098 inputs = list(inputs) 

1099 input_shapes = [x.shape for x in inputs] 

1100 idx_in, idx_out = _einsum_v1_parse_and_resolve_equation( 

1101 equation, input_shapes) 

1102 

1103 idx_all = set(''.join(idx_in) + idx_out) 

1104 indices = ''.join(sorted(idx_all)) 

1105 

1106 missing_idx = set(idx_out).difference(idx_all) 

1107 if missing_idx: 

1108 raise ValueError(f'Unknown output axes: {missing_idx}.') 

1109 

1110 axis_order = {} 

1111 for ax in indices: 

1112 if ax not in idx_out: 

1113 axis_order[ax] = len(axis_order) 

1114 for ax in idx_out: 

1115 axis_order[ax] = len(axis_order) 

1116 

1117 # transpose inputs so axes are in order 

1118 for i, (input_, axes_) in enumerate(zip(inputs, idx_in)): 

1119 if input_.shape.ndims != len(axes_): 

1120 raise ValueError( 

1121 f'Input {i} with axes {axes_} has incorrect number of dimensions ' 

1122 f'(expected {len(axes_)}, got {input_.shape.ndims}).') 

1123 

1124 sorted_idx = sorted(axes_, key=axis_order.get) 

1125 

1126 if len(set(axes_)) != len(axes_): 

1127 raise ValueError( 

1128 f'Subscript not supported: an axis appears more than once: {axes_}.') 

1129 

1130 if list(axes_) != sorted_idx: 

1131 permuted = [axes_.find(ax) for ax in sorted_idx] 

1132 inputs[i] = array_ops.transpose(input_, permuted) 

1133 idx_in[i] = sorted_idx 

1134 

1135 reduction_idx = [] 

1136 shapes = [[dim if dim else -1 

1137 for dim in tensor.shape.as_list()] 

1138 for tensor in inputs] 

1139 

1140 # validate shapes for broadcasting 

1141 for j, ax in enumerate(sorted(idx_all, key=axis_order.get)): 

1142 dims = [] 

1143 for i, idx in enumerate(idx_in): 

1144 if ax not in idx: 

1145 shapes[i].insert(j, 1) 

1146 else: 

1147 dim = shapes[i][j] 

1148 if isinstance(dim, int) and dim > 1: 

1149 dims.append(dim) 

1150 

1151 if len(set(dims)) > 1: 

1152 raise ValueError(f'Dimension mismatch on axis: {ax}. ' 

1153 f'Found {len(set(dims))}, expected 1.') 

1154 

1155 if ax not in idx_out: 

1156 reduction_idx.append(j) 

1157 

1158 # reshape, multiply 

1159 expanded_inputs = [ 

1160 array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes) 

1161 ] 

1162 expanded_output = 1 

1163 for input_ in expanded_inputs: 

1164 expanded_output *= input_ 

1165 

1166 # contract 

1167 return math_ops.reduce_sum(expanded_output, reduction_idx) 

1168 

1169 

1170def _einsum_v2(equation, *inputs, **kwargs): 

1171 """Implementation of einsum utilizing opt_einsum and EinsumOp.""" 

1172 name = kwargs.pop('name', None) 

1173 optimize = kwargs.pop('optimize', 'greedy') 

1174 if kwargs: 

1175 raise TypeError( 

1176 f'Invalid keyword arguments for einsum: {", ".join(kwargs)}. ' 

1177 f'Valid arguments: name, optimize, greedy.') 

1178 

1179 with ops.name_scope(name, 'einsum', [equation, inputs]) as name: 

1180 inputs = list(inputs) 

1181 input_shapes = [] 

1182 for operand in inputs: 

1183 if isinstance(operand.shape, tensor_shape.TensorShape): 

1184 input_shapes.append(operand.shape.as_list() if operand.shape else None) 

1185 else: 

1186 input_shapes.append(list(operand.shape)) 

1187 # Validate and sanitize the equation and resolve static input shapes, as 

1188 # opt_einsum requires that all shapes be a tuple of positive integers. 

1189 # Also remove ellipsis from the equation as opt_einsum will replace them 

1190 # with named labels. Then broadcasting between different shapes or ranks 

1191 # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]). 

1192 resolved_equation, resolved_input_shapes, ellipsis_label = ( 

1193 _einsum_v2_parse_and_resolve_equation(equation, input_shapes)) 

1194 

1195 if len(inputs) <= 2: # No need to call opt_einsum. 

1196 # Replace back ellipses that were removed for opt_einsum. 

1197 if ellipsis_label: 

1198 resolved_equation = resolved_equation.replace(ellipsis_label, '...') 

1199 return gen_linalg_ops.einsum(inputs, resolved_equation) 

1200 

1201 # Send fully specified shapes to opt_einsum, since it cannot handle unknown 

1202 # dimensions. For unknown dimensions, we guess that the dimension equals 1. 

1203 # Instead of creating Tensors or NumPy arrays with the specified shape, 

1204 # create a dummy `shaped` object with a `shape` property. 

1205 shaped = collections.namedtuple('shaped', ['shape']) 

1206 shaped_inputs = tuple( 

1207 [shaped(tuple(shape)) for shape in resolved_input_shapes]) 

1208 # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums. 

1209 # Obtain the sequence of equations and the indices of operands involved in 

1210 # each einsum operation. 

1211 indices_and_equations = _get_opt_einsum_contract_path( 

1212 resolved_equation, shaped_inputs, optimize) 

1213 for operand_indices, binary_equation in indices_and_equations: 

1214 if ellipsis_label: 

1215 # Replace back ellipses that were removed for opt_einsum. 

1216 binary_equation = binary_equation.replace(ellipsis_label, '...') 

1217 operands = list(map(inputs.pop, operand_indices)) 

1218 inputs.append(gen_linalg_ops.einsum(operands, binary_equation)) 

1219 return inputs[0] 

1220 

1221 

1222def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize): 

1223 """Returns the (memoized) result of opt_einsum.contract_path.""" 

1224 # Note: We use einsum_call=True, which is an internal api for opt_einsum, 

1225 # to get the contraction path without having opt_einsum perform the actual 

1226 # contractions. 

1227 _, contractions = opt_einsum.contract_path( 

1228 equation, 

1229 *shaped_inputs_tuple, 

1230 optimize=optimize, 

1231 einsum_call=True, 

1232 use_blas=True) 

1233 # Return a tuple so that the cached value is not mutable. 

1234 indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions]) 

1235 return indices_and_equations 

1236 

1237 

1238# Cache the possibly expensive opt_einsum.contract_path call using lru_cache 

1239# from the Python3+ standard library. 

1240_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)( 

1241 _get_opt_einsum_contract_path) 

1242 

1243 

1244def _einsum_v2_parse_and_resolve_equation(equation, input_shapes): 

1245 """Helper which validates einsum equation and resolves input shapes.""" 

1246 resolved_equation = equation.replace(' ', '') 

1247 ellipsis_label = None 

1248 if '...' in equation: 

1249 # Replace ellipsis ('...') with '0' for (a) ease of parsing and (b) to 

1250 # prevent opt_einsum from resolving them into named labels; as it doesn't 

1251 # support broadcasting. 

1252 ellipsis_label = '0' 

1253 if ellipsis_label in resolved_equation: 

1254 raise ValueError( 

1255 f'Invalid character "{ellipsis_label}" in equation: {equation}.') 

1256 resolved_equation = resolved_equation.replace('...', ellipsis_label) 

1257 

1258 # Ensure there are no non-alphanumeric characters in the equation, including 

1259 # periods (`.`) outside of ellipses, in the equation. This is not a hard 

1260 # requirement; except we use a special character '0' for ellipsis. 

1261 allowed_labels = 'a-zA-Z' 

1262 if ellipsis_label: 

1263 allowed_labels += ellipsis_label 

1264 match = re.match('^([{0},]*)(->[{0}]*)?$'.format(allowed_labels), 

1265 resolved_equation) 

1266 if not match: 

1267 raise ValueError( 

1268 'Subscripts have incorrect format: {}'.format(resolved_equation)) 

1269 input_labels = match.group(1).split(',') 

1270 output_labels = match.group(2)[2:] if match.group(2) else None 

1271 

1272 if len(input_shapes) != len(input_labels): 

1273 raise ValueError('Got {} inputs for equation "{}", expecting {}'.format( 

1274 len(input_shapes), equation, len(input_labels))) 

1275 

1276 # Special case: if there are no '->', then we create output subscripts from 

1277 # labels appearing only once. 

1278 if '->' not in resolved_equation: 

1279 label_counts = collections.Counter(match.group(1)) 

1280 output_labels = ''.join([ 

1281 x for x in sorted(list(label_counts)) 

1282 if x != ',' and label_counts[x] == 1 

1283 ]) 

1284 resolved_equation += '->' + output_labels 

1285 # Validate output_labels. 

1286 if output_labels and len(set(output_labels)) != len(output_labels): 

1287 raise ValueError( 

1288 'Output subscripts contain a label appearing more than once: {}'.format( 

1289 equation)) 

1290 input_label_set = set(match.group(1)) 

1291 for label in output_labels: 

1292 if label != ellipsis_label and label not in input_label_set: 

1293 raise ValueError('Output subscripts contain the label {} not present ' 

1294 'in the input subscripts.'.format(label)) 

1295 if ellipsis_label and output_labels: 

1296 num_output_ellipses = output_labels.count(ellipsis_label) 

1297 if num_output_ellipses > 1: 

1298 raise ValueError( 

1299 'Output subscripts contain multiple ellipsis: {}'.format(equation)) 

1300 

1301 # Early return if <= 2 inputs. Resolved shapes are not needed. 

1302 if len(input_shapes) <= 2: 

1303 return resolved_equation, None, ellipsis_label 

1304 

1305 # Create a map from axis labels to known dimensions. This is used to infer 

1306 # unknown dimensions if a known dimension also has the same label. 

1307 label_to_dim = collections.defaultdict(lambda: 1) 

1308 for i, (labels, shape) in enumerate(zip(input_labels, input_shapes)): 

1309 if shape is None: 

1310 continue 

1311 ellipsis_start = labels.find(ellipsis_label) if ellipsis_label else -1 

1312 if ellipsis_start != -1: # This input contains an ellipsis. 

1313 if ellipsis_start != labels.rfind(ellipsis_label): 

1314 raise ValueError(f'Too many ellipses in input label ' 

1315 f'{labels.replace(ellipsis_label, "...")}.') 

1316 if len(labels) > len(shape) + 1: 

1317 raise ValueError('Too many named labels in {}th subscript string of' 

1318 ' equation {} for input shape {} '.format( 

1319 i, equation, shape)) 

1320 ellipsis_end = ellipsis_start + len(shape) + 1 - len(labels) 

1321 shape[ellipsis_start:ellipsis_end] = ([ 

1322 np.prod( 

1323 list(filter(None, shape[ellipsis_start:ellipsis_end])), 

1324 dtype=np.int64) 

1325 ]) 

1326 else: 

1327 # This input does not contain an ellipsis. 

1328 if len(labels) != len(shape): 

1329 raise ValueError( 

1330 'Number of named labels in input #{} of equation {} ' 

1331 'must be equal to the number of dimensions in shape {}'.format( 

1332 i, equation, shape)) 

1333 for dim, label in zip(shape, labels): 

1334 if dim is not None: 

1335 label_to_dim[label] = max(label_to_dim[label], dim) 

1336 

1337 resolved_shapes = [] 

1338 for labels in input_labels: 

1339 resolved_shapes.append([label_to_dim[label] for label in labels]) 

1340 return resolved_equation, resolved_shapes, ellipsis_label