Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/special_math_ops.py: 26%
427 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Arithmetic Operations that don't fit into math_ops due to dependencies.
17To avoid circular dependencies, some math_ops should go here.
18"""
20import collections
21import functools
22import re
23import string
25import numpy as np
26import opt_einsum
29from tensorflow.compiler.tf2xla.ops import gen_xla_ops
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.ops import array_ops
33from tensorflow.python.ops import control_flow_ops
34from tensorflow.python.ops import gen_linalg_ops
35from tensorflow.python.ops import gen_special_math_ops
36from tensorflow.python.ops import math_ops
37from tensorflow.python.platform import tf_logging as logging
38from tensorflow.python.util import deprecation
39from tensorflow.python.util import dispatch
40from tensorflow.python.util.tf_export import tf_export
43# TODO(b/27419586) Change docstring for required dtype of x once int allowed
44@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
45@dispatch.add_dispatch_support
46@deprecation.deprecated_endpoints('lbeta')
47def lbeta(x, name=None):
48 r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
50 Given one-dimensional $z = [z_1,...,z_K]$, we define
52 $$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$
54 where $\Gamma$ is the gamma function.
56 And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define
58 $$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$
60 In other words, the last dimension is treated as the $z$ vector.
62 Note that if $z = [u, v]$, then
64 $$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)}
65 = \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$
67 which defines the traditional bivariate beta function.
69 If the last dimension is empty, we follow the convention that the sum over
70 the empty set is zero, and the product is one.
72 Args:
73 x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`.
74 name: A name for the operation (optional).
76 Returns:
77 The logarithm of \\(|Beta(x)|\\) reducing along the last dimension.
78 """
79 # In the event that the last dimension has zero entries, we return -inf.
80 # This is consistent with a convention that the sum over the empty set 0, and
81 # the product is 1.
82 # This is standard. See https://en.wikipedia.org/wiki/Empty_set.
83 with ops.name_scope(name, 'lbeta', [x]):
84 x = ops.convert_to_tensor(x, name='x')
86 # Note reduce_sum([]) = 0.
87 log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1])
89 # Note lgamma(0) = infinity, so if x = []
90 # log_gamma_sum_x = lgamma(0) = infinity, and
91 # log_prod_gamma_x = lgamma(1) = 0,
92 # so result = -infinity
93 sum_x = math_ops.reduce_sum(x, axis=[-1])
94 log_gamma_sum_x = math_ops.lgamma(sum_x)
95 result = log_prod_gamma_x - log_gamma_sum_x
97 return result
100@tf_export('math.special.dawsn')
101@dispatch.register_unary_elementwise_api
102@dispatch.add_dispatch_support
103def dawsn(x, name=None):
104 """Computes Dawson's integral of `x` element-wise.
106 Dawson's integral is defined as `exp(-x**2)` times the integral of
107 `exp(t**2)` from `0` to `x`, with the domain of definition all real numbers.
109 Dawson's function is odd.
110 >>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy()
111 array([-0.5380795, -0.4244364, 0.4244364, 0.5380795], dtype=float32)
113 This implementation is based off of the Cephes math library.
115 Args:
116 x: A `Tensor` or `SparseTensor`. Must be one of the following types:
117 `float32`, `float64`.
118 name: A name for the operation (optional).
120 Returns:
121 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
123 @compatibility(scipy)
124 Equivalent to scipy.special.dawsn
125 @end_compatibility
126 """
127 with ops.name_scope(name, 'dawsn', [x]):
128 return gen_special_math_ops.dawsn(x)
131@tf_export('math.special.expint')
132@dispatch.register_unary_elementwise_api
133@dispatch.add_dispatch_support
134def expint(x, name=None):
135 """Computes the Exponential integral of `x` element-wise.
137 The Exponential integral is defined as the integral of `exp(t) / t` from
138 `-inf` to `x`, with the domain of definition all positive real numbers.
140 >>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy()
141 array([ 1.8951179, 2.1673784, 5.3332353, 21.048464], dtype=float32)
143 This implementation is based off of the Cephes math library.
145 Args:
146 x: A `Tensor` or `SparseTensor`. Must be one of the following types:
147 `float32`, `float64`.
148 name: A name for the operation (optional).
150 Returns:
151 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
153 @compatibility(scipy)
154 Equivalent to scipy.special.expi
155 @end_compatibility
156 """
157 with ops.name_scope(name, 'expint', [x]):
158 return gen_special_math_ops.expint(x)
161@tf_export('math.special.fresnel_cos')
162@dispatch.register_unary_elementwise_api
163@dispatch.add_dispatch_support
164def fresnel_cos(x, name=None):
165 """Computes Fresnel's cosine integral of `x` element-wise.
167 The Fresnel cosine integral is defined as the integral of `cos(t^2)` from
168 `0` to `x`, with the domain of definition all real numbers.
170 The Fresnel cosine integral is odd.
171 >>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy()
172 array([-0.7798934 , -0.09999753, 0.09999753, 0.7798934 ], dtype=float32)
174 This implementation is based off of the Cephes math library.
176 Args:
177 x: A `Tensor` or `SparseTensor`. Must be one of the following types:
178 `float32`, `float64`.
179 name: A name for the operation (optional).
181 Returns:
182 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
184 @compatibility(scipy)
185 Equivalent to scipy.special.fresnel second output.
186 @end_compatibility
187 """
188 with ops.name_scope(name, 'fresnel_cos', [x]):
189 return gen_special_math_ops.fresnel_cos(x)
192@tf_export('math.special.fresnel_sin')
193@dispatch.register_unary_elementwise_api
194@dispatch.add_dispatch_support
195def fresnel_sin(x, name=None):
196 """Computes Fresnel's sine integral of `x` element-wise.
198 The Fresnel sine integral is defined as the integral of `sin(t^2)` from
199 `0` to `x`, with the domain of definition all real numbers.
201 >>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy()
202 array([-0.43825912, -0.00052359, 0.00052359, 0.43825912], dtype=float32)
204 This implementation is based off of the Cephes math library.
206 Args:
207 x: A `Tensor` or `SparseTensor`. Must be one of the following types:
208 `float32`, `float64`.
209 name: A name for the operation (optional).
211 Returns:
212 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
214 @compatibility(scipy)
215 Equivalent to scipy.special.fresnel first output.
216 @end_compatibility
217 """
218 with ops.name_scope(name, 'fresnel_sin', [x]):
219 return gen_special_math_ops.fresnel_sin(x)
222@tf_export('math.special.spence')
223@dispatch.register_unary_elementwise_api
224@dispatch.add_dispatch_support
225def spence(x, name=None):
226 """Computes Spence's integral of `x` element-wise.
228 Spence's integral is defined as the integral of `log(t) / (1 - t)` from
229 `1` to `x`, with the domain of definition all non-negative real numbers.
231 >>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy()
232 array([ 0.58224034, 0. , -0.82246685, -1.4367464], dtype=float32)
234 This implementation is based off of the Cephes math library.
236 Args:
237 x: A `Tensor` or `SparseTensor`. Must be one of the following types:
238 `float32`, `float64`.
239 name: A name for the operation (optional).
241 Returns:
242 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
244 @compatibility(scipy)
245 Equivalent to scipy.special.spence
246 @end_compatibility
247 """
248 with ops.name_scope(name, 'spence', [x]):
249 return gen_special_math_ops.spence(x)
252@tf_export('math.bessel_i0', 'math.special.bessel_i0')
253@dispatch.register_unary_elementwise_api
254@dispatch.add_dispatch_support
255def bessel_i0(x, name=None):
256 """Computes the Bessel i0 function of `x` element-wise.
258 Modified Bessel function of order 0.
260 It is preferable to use the numerically stabler function `i0e(x)` instead.
262 >>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy()
263 array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32)
265 Args:
266 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
267 `float32`, `float64`.
268 name: A name for the operation (optional).
270 Returns:
271 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
273 @compatibility(scipy)
274 Equivalent to scipy.special.i0
275 @end_compatibility
276 """
277 with ops.name_scope(name, 'bessel_i0', [x]):
278 return gen_special_math_ops.bessel_i0(x)
281@tf_export('math.bessel_i0e', 'math.special.bessel_i0e')
282@dispatch.register_unary_elementwise_api
283@dispatch.add_dispatch_support
284def bessel_i0e(x, name=None):
285 """Computes the Bessel i0e function of `x` element-wise.
287 Modified Bessel function of order 0.
289 >>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy()
290 array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32)
292 Args:
293 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
294 `float32`, `float64`.
295 name: A name for the operation (optional).
297 Returns:
298 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
300 @compatibility(scipy)
301 Equivalent to scipy.special.i0e
302 @end_compatibility
303 """
304 with ops.name_scope(name, 'bessel_i0e', [x]):
305 return gen_special_math_ops.bessel_i0e(x)
308@tf_export('math.bessel_i1', 'math.special.bessel_i1')
309@dispatch.register_unary_elementwise_api
310@dispatch.add_dispatch_support
311def bessel_i1(x, name=None):
312 """Computes the Bessel i1 function of `x` element-wise.
314 Modified Bessel function of order 1.
316 It is preferable to use the numerically stabler function `i1e(x)` instead.
318 >>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy()
319 array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32)
321 Args:
322 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
323 `float32`, `float64`.
324 name: A name for the operation (optional).
326 Returns:
327 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
329 @compatibility(scipy)
330 Equivalent to scipy.special.i1
331 @end_compatibility
332 """
333 with ops.name_scope(name, 'bessel_i1', [x]):
334 return gen_special_math_ops.bessel_i1(x)
337@tf_export('math.bessel_i1e', 'math.special.bessel_i1e')
338@dispatch.register_unary_elementwise_api
339@dispatch.add_dispatch_support
340def bessel_i1e(x, name=None):
341 """Computes the Bessel i1e function of `x` element-wise.
343 Modified Bessel function of order 1.
345 >>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy()
346 array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32)
348 Args:
349 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
350 `float32`, `float64`.
351 name: A name for the operation (optional).
353 Returns:
354 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
356 @compatibility(scipy)
357 Equivalent to scipy.special.i1e
358 @end_compatibility
359 """
360 with ops.name_scope(name, 'bessel_i1e', [x]):
361 return gen_special_math_ops.bessel_i1e(x)
364@tf_export('math.special.bessel_k0')
365@dispatch.register_unary_elementwise_api
366@dispatch.add_dispatch_support
367def bessel_k0(x, name=None):
368 """Computes the Bessel k0 function of `x` element-wise.
370 Modified Bessel function of order 0.
372 It is preferable to use the numerically stabler function `k0e(x)` instead.
374 >>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy()
375 array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32)
377 Args:
378 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
379 `float32`, `float64`.
380 name: A name for the operation (optional).
382 Returns:
383 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
385 @compatibility(scipy)
386 Equivalent to scipy.special.k0
387 @end_compatibility
388 """
389 with ops.name_scope(name, 'bessel_k0', [x]):
390 return gen_special_math_ops.bessel_k0(x)
393@tf_export('math.special.bessel_k0e')
394@dispatch.register_unary_elementwise_api
395@dispatch.add_dispatch_support
396def bessel_k0e(x, name=None):
397 """Computes the Bessel k0e function of `x` element-wise.
399 Modified Bessel function of order 0.
401 >>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy()
402 array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32)
404 Args:
405 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
406 `float32`, `float64`.
407 name: A name for the operation (optional).
409 Returns:
410 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
412 @compatibility(scipy)
413 Equivalent to scipy.special.k0e
414 @end_compatibility
415 """
416 with ops.name_scope(name, 'bessel_k0e', [x]):
417 return gen_special_math_ops.bessel_k0e(x)
420@tf_export('math.special.bessel_k1')
421@dispatch.register_unary_elementwise_api
422@dispatch.add_dispatch_support
423def bessel_k1(x, name=None):
424 """Computes the Bessel k1 function of `x` element-wise.
426 Modified Bessel function of order 1.
428 It is preferable to use the numerically stabler function `k1e(x)` instead.
430 >>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy()
431 array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32)
433 Args:
434 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
435 `float32`, `float64`.
436 name: A name for the operation (optional).
438 Returns:
439 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
441 @compatibility(scipy)
442 Equivalent to scipy.special.k1
443 @end_compatibility
444 """
445 with ops.name_scope(name, 'bessel_k1', [x]):
446 return gen_special_math_ops.bessel_k1(x)
449@tf_export('math.special.bessel_k1e')
450@dispatch.register_unary_elementwise_api
451@dispatch.add_dispatch_support
452def bessel_k1e(x, name=None):
453 """Computes the Bessel k1e function of `x` element-wise.
455 Modified Bessel function of order 1.
457 >>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy()
458 array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32)
460 Args:
461 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
462 `float32`, `float64`.
463 name: A name for the operation (optional).
465 Returns:
466 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
468 @compatibility(scipy)
469 Equivalent to scipy.special.k1e
470 @end_compatibility
471 """
472 with ops.name_scope(name, 'bessel_k1e', [x]):
473 return gen_special_math_ops.bessel_k1e(x)
476@tf_export('math.special.bessel_j0')
477@dispatch.register_unary_elementwise_api
478@dispatch.add_dispatch_support
479def bessel_j0(x, name=None):
480 """Computes the Bessel j0 function of `x` element-wise.
482 Modified Bessel function of order 0.
484 >>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy()
485 array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32)
487 Args:
488 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
489 `float32`, `float64`.
490 name: A name for the operation (optional).
492 Returns:
493 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
495 @compatibility(scipy)
496 Equivalent to scipy.special.j0
497 @end_compatibility
498 """
499 with ops.name_scope(name, 'bessel_j0', [x]):
500 return gen_special_math_ops.bessel_j0(x)
503@tf_export('math.special.bessel_j1')
504@dispatch.register_unary_elementwise_api
505@dispatch.add_dispatch_support
506def bessel_j1(x, name=None):
507 """Computes the Bessel j1 function of `x` element-wise.
509 Modified Bessel function of order 1.
511 >>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy()
512 array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32)
514 Args:
515 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
516 `float32`, `float64`.
517 name: A name for the operation (optional).
519 Returns:
520 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
522 @compatibility(scipy)
523 Equivalent to scipy.special.j1
524 @end_compatibility
525 """
526 with ops.name_scope(name, 'bessel_j1', [x]):
527 return gen_special_math_ops.bessel_j1(x)
530@tf_export('math.special.bessel_y0')
531@dispatch.register_unary_elementwise_api
532@dispatch.add_dispatch_support
533def bessel_y0(x, name=None):
534 """Computes the Bessel y0 function of `x` element-wise.
536 Modified Bessel function of order 0.
538 >>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy()
539 array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32)
541 Args:
542 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
543 `float32`, `float64`.
544 name: A name for the operation (optional).
546 Returns:
547 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
549 @compatibility(scipy)
550 Equivalent to scipy.special.y0
551 @end_compatibility
552 """
553 with ops.name_scope(name, 'bessel_y0', [x]):
554 return gen_special_math_ops.bessel_y0(x)
557@tf_export('math.special.bessel_y1')
558@dispatch.register_unary_elementwise_api
559@dispatch.add_dispatch_support
560def bessel_y1(x, name=None):
561 """Computes the Bessel y1 function of `x` element-wise.
563 Modified Bessel function of order 1.
565 >>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy()
566 array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32)
568 Args:
569 x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
570 `float32`, `float64`.
571 name: A name for the operation (optional).
573 Returns:
574 A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
576 @compatibility(scipy)
577 Equivalent to scipy.special.y1
578 @end_compatibility
579 """
580 with ops.name_scope(name, 'bessel_y1', [x]):
581 return gen_special_math_ops.bessel_y1(x)
584@ops.RegisterGradient('XlaEinsum')
585def _einsum_grad(op, grad):
586 equation = op.get_attr('equation')
587 if isinstance(equation, bytes):
588 equation = equation.decode()
590 inputs, output = equation.split('->')
591 left, right = inputs.split(',')
593 return [
594 gen_xla_ops.xla_einsum(
595 grad,
596 op.inputs[1],
597 equation='{},{}->{}'.format(output, right, left),
598 name=None),
599 gen_xla_ops.xla_einsum(
600 grad,
601 op.inputs[0],
602 equation='{},{}->{}'.format(output, left, right),
603 name=None)
604 ]
607def _enclosing_tpu_context():
608 # pylint: disable=protected-access
609 context = ops.get_default_graph()._get_control_flow_context()
610 # pylint: enable=protected-access
611 while context is not None and not isinstance(
612 context, control_flow_ops.XLAControlFlowContext):
613 context = context.outer_context
614 return context
617@tf_export('einsum', 'linalg.einsum')
618@dispatch.add_dispatch_support
619def einsum(equation, *inputs, **kwargs):
620 r"""Tensor contraction over specified indices and outer product.
622 Einsum allows defining Tensors by defining their element-wise computation.
623 This computation is defined by `equation`, a shorthand form based on Einstein
624 summation. As an example, consider multiplying two matrices A and B to form a
625 matrix C. The elements of C are given by:
627 $$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$
629 or
631 ```
632 C[i,k] = sum_j A[i,j] * B[j,k]
633 ```
635 The corresponding einsum `equation` is:
637 ```
638 ij,jk->ik
639 ```
641 In general, to convert the element-wise equation into the `equation` string,
642 use the following procedure (intermediate strings for matrix multiplication
643 example provided in parentheses):
645 1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`)
646 2. replace "*" with ",", (`ik = sum_j ij , jk`)
647 3. drop summation signs, and (`ik = ij, jk`)
648 4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`)
650 Note: If the output indices are not specified repeated indices are summed.
651 So `ij,jk->ik` can be simplified to `ij,jk`.
653 Many common operations can be expressed in this way. For example:
655 **Matrix multiplication**
657 >>> m0 = tf.random.normal(shape=[2, 3])
658 >>> m1 = tf.random.normal(shape=[3, 5])
659 >>> e = tf.einsum('ij,jk->ik', m0, m1)
660 >>> # output[i,k] = sum_j m0[i,j] * m1[j, k]
661 >>> print(e.shape)
662 (2, 5)
664 Repeated indices are summed if the output indices are not specified.
666 >>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k]
667 >>> print(e.shape)
668 (2, 5)
671 **Dot product**
673 >>> u = tf.random.normal(shape=[5])
674 >>> v = tf.random.normal(shape=[5])
675 >>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i]
676 >>> print(e.shape)
677 ()
679 **Outer product**
681 >>> u = tf.random.normal(shape=[3])
682 >>> v = tf.random.normal(shape=[5])
683 >>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j]
684 >>> print(e.shape)
685 (3, 5)
687 **Transpose**
689 >>> m = tf.ones(2,3)
690 >>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j]
691 >>> print(e.shape)
692 (3, 2)
694 **Diag**
696 >>> m = tf.reshape(tf.range(9), [3,3])
697 >>> diag = tf.einsum('ii->i', m)
698 >>> print(diag.shape)
699 (3,)
701 **Trace**
703 >>> # Repeated indices are summed.
704 >>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i]
705 >>> assert trace == sum(diag)
706 >>> print(trace.shape)
707 ()
709 **Batch matrix multiplication**
711 >>> s = tf.random.normal(shape=[7,5,3])
712 >>> t = tf.random.normal(shape=[7,3,2])
713 >>> e = tf.einsum('bij,bjk->bik', s, t)
714 >>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
715 >>> print(e.shape)
716 (7, 5, 2)
718 This method does not support broadcasting on named-axes. All axes with
719 matching labels should have the same length. If you have length-1 axes,
720 use `tf.squeeze` or `tf.reshape` to eliminate them.
722 To write code that is agnostic to the number of indices in the input
723 use an ellipsis. The ellipsis is a placeholder for "whatever other indices
724 fit here".
726 For example, to perform a NumPy-style broadcasting-batch-matrix multiplication
727 where the matrix multiply acts on the last two axes of the input, use:
729 >>> s = tf.random.normal(shape=[11, 7, 5, 3])
730 >>> t = tf.random.normal(shape=[11, 7, 3, 2])
731 >>> e = tf.einsum('...ij,...jk->...ik', s, t)
732 >>> print(e.shape)
733 (11, 7, 5, 2)
735 Einsum **will** broadcast over axes covered by the ellipsis.
737 >>> s = tf.random.normal(shape=[11, 1, 5, 3])
738 >>> t = tf.random.normal(shape=[1, 7, 3, 2])
739 >>> e = tf.einsum('...ij,...jk->...ik', s, t)
740 >>> print(e.shape)
741 (11, 7, 5, 2)
743 Args:
744 equation: a `str` describing the contraction, in the same format as
745 `numpy.einsum`.
746 *inputs: the inputs to contract (each one a `Tensor`), whose shapes should
747 be consistent with `equation`.
748 **kwargs:
749 - optimize: Optimization strategy to use to find contraction path using
750 opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or
751 'auto'. (optional, default: 'greedy').
752 - name: A name for the operation (optional).
754 Returns:
755 The contracted `Tensor`, with shape determined by `equation`.
757 Raises:
758 ValueError: If
759 - the format of `equation` is incorrect,
760 - number of inputs or their shapes are inconsistent with `equation`.
761 """
762 return _einsum_v2(equation, *inputs, **kwargs)
765def _einsum_v1(equation, *inputs, **kwargs):
766 """Legacy implementation of einsum without using EinsumOp."""
767 name = kwargs.pop('name', None)
768 if kwargs:
769 raise TypeError(
770 f'Invalid keyword arguments for this function: '
771 f'{", ".join([format(key) for key in sorted(list(kwargs.keys()))])}.'
772 f' Expected: name.')
773 with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
774 inputs = list(inputs)
775 input_shapes = [x.shape for x in inputs]
776 input_axis_labels, output_axis_labels = (
777 _einsum_v1_parse_and_resolve_equation(equation, input_shapes))
779 axis_labels = set(''.join(input_axis_labels) + output_axis_labels)
781 for a in axis_labels:
782 for input_labels in input_axis_labels:
783 if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
784 input_labels == input_labels[::-1] and '->' not in equation):
785 return math_ops.trace(inputs[0])
786 if input_labels.count(a) > 1:
787 raise ValueError(
788 f'Subscript not supported: the axis {a} appears more than once'
789 f' in {input_labels}.')
790 for a in axis_labels:
791 input_count = sum(1 for s in input_axis_labels if a in s)
792 if input_count > 2 and a not in output_axis_labels:
793 logging.warn(
794 f'Falling back to exponential-space implementation of einsum()'
795 f' because index {a} is summed over more than two inputs.')
796 return _exponential_space_einsum_v1(equation, *inputs)
798 # Use xla_einsum if executing on TPU and if the operation is a 2 input
799 # einsum supported by XlaEinsumOp.
800 if _enclosing_tpu_context() is not None and len(inputs) == 2:
801 return gen_xla_ops.xla_einsum(
802 inputs[0], inputs[1], input_axis_labels[0] + ',' +
803 input_axis_labels[1] + '->' + output_axis_labels)
804 temp = inputs[0]
805 temp_axis_labels = input_axis_labels[0]
806 for i in range(len(inputs) - 1):
807 axes_to_sum = (
808 set(temp_axis_labels) &
809 set(input_axis_labels[i + 1]) - set(output_axis_labels))
810 temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels,
811 inputs[i + 1],
812 input_axis_labels[i + 1],
813 axes_to_sum)
815 missing_indices = set(temp_axis_labels) - set(output_axis_labels)
816 if missing_indices:
817 axis = [
818 i for i, a in enumerate(temp_axis_labels)
819 if a not in output_axis_labels
820 ]
821 temp = math_ops.reduce_sum(temp, axis=axis)
822 temp_axis_labels = ''.join(
823 a for a in temp_axis_labels if a in output_axis_labels)
824 if sorted(temp_axis_labels) != sorted(output_axis_labels):
825 raise ValueError(
826 f'Invalid equation: {equation}. The computed and specified output '
827 f'labels do not match: {temp_axis_labels} vs {output_axis_labels}.')
829 perm = [temp_axis_labels.index(a) for a in output_axis_labels]
830 return _transpose_if_necessary(temp, perm)
833def _einsum_v1_parse_and_resolve_equation(equation, input_shapes):
834 """Helper for einsum() that splits/resolves inputs & outputs.
836 Args:
837 equation: Equation string given as argument to einsum().
838 input_shapes: List of the shapes of all inputs given to einsum()
840 Returns:
841 input_axis_labels, output_axis_labels where:
842 input_axis_labels: List of length len(input_shapes) of strings
843 representing the character label for each dimension of each given input,
844 resolving any broadcast (...) axes,
845 output_axis_labels: A string of character labels for each axes of output
846 tensor, filling in missing output subscripts and broadcast axes.
848 Raises:
849 ValueError: If equation is in the uncorrect format, incorrect number of
850 inputs given or broadcast axes "..." or output axes could not be resolved.
851 """
852 equation = equation.replace(' ', '')
853 match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation)
854 if not match:
855 raise ValueError(f'Indices have incorrect format. Received: {equation}.')
857 input_axis_labels = match.group(1).split(',')
858 output_axis_labels = match.group(2)[2:] if match.group(2) else None
860 if len(input_shapes) != len(input_axis_labels):
861 raise ValueError(
862 f'Got {len(input_shapes)} arguments for equation "{equation}", '
863 f'expecting {len(input_axis_labels)}.')
865 # Resolve Ellipsis
866 # Assign axes labels for unspecified dimensions in inputs. Labels taken
867 # from unused labels. Follow numpy einsum broadcasting conventions for
868 # tensors of different length and unlabeled output.
869 ellipsis_axes = ''
870 if '...' in equation:
871 unused = ''.join(
872 c for c in string.ascii_letters if c not in ''.join(input_axis_labels))
873 for i, ax in enumerate(input_axis_labels):
874 if '...' in ax:
875 parts = ax.split('...')
876 if len(parts) != 2:
877 raise ValueError(f'Unable to resolve ellipsis. '
878 f'Excess number found: {len(parts)-1} vs 1.')
879 if input_shapes[i].ndims is None:
880 raise ValueError('Unable to statically infer ellipsis axes. The '
881 'input shapes has a dynamic dimensionality.')
882 n = input_shapes[i].ndims - len(''.join(parts))
883 if n < 0:
884 raise ValueError('Ellipses lengths do not match.')
885 if len(unused) < n:
886 raise ValueError(
887 'Unable to resolve ellipsis, too many distinct labels.')
888 replace_axes = unused[-n:] if n > 0 else ''
889 input_axis_labels[i] = input_axis_labels[i].replace('...',
890 replace_axes)
891 if len(replace_axes) > len(ellipsis_axes):
892 ellipsis_axes = replace_axes
894 if any('.' in ax for ax in input_axis_labels):
895 raise ValueError(
896 f'Period "." found outside of ellipsis in input {input_axis_labels}.')
898 if output_axis_labels is not None:
899 output_axis_labels = output_axis_labels.replace('...', ellipsis_axes)
900 if '.' in output_axis_labels:
901 raise ValueError(f'Period "." found outside of ellipsis in output '
902 f'{output_axis_labels}.')
904 if output_axis_labels is None:
905 # infer the output subscripts if not given, assume alphabetical order,
906 # but always place ellipsis axes before given.
907 axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes)
908 indices = ''.join(sorted(axis_labels))
909 counts = {ax: 0 for ax in indices}
910 for axes_ in input_axis_labels:
911 for ax in axes_:
912 if ax not in ellipsis_axes:
913 counts[ax] += 1
915 output_axis_labels = ellipsis_axes + ''.join(
916 sorted(ax for ax in axis_labels if counts[ax] == 1))
918 return input_axis_labels, output_axis_labels
921def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
922 """Helper for einsum() that computes the result of a two-argument einsum().
924 Args:
925 t0: a `Tensor`
926 t0_axis_labels: a string of axis labels. This string's length must equal
927 the rank of t0.
928 t1: a `Tensor`
929 t1_axis_labels: a string to axis labels. This string's length must equal
930 the rank of t1.
931 axes_to_sum: set of labels of axes to be summed over
933 Returns:
934 A `Tensor` whose elements are obtained by summing, over all axes in
935 `axes_to_sum`, the corresponding elements of `t0` and `t1`.
937 For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and
938 axes_to_sum == {j,k}, this will return a tensor x where
940 out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l]
942 Raises:
943 ValueError: if the rank of `t0` does not match the length of
944 `t0_axis_labels`, or that of `t1` does not match the length of
945 `t1_axis_labels`.
946 """
947 if len(t0_axis_labels) != len(t0.shape):
948 raise ValueError(
949 f'Tensor `t0` of rank {len(t0.shape)} does not match einsum reduction '
950 f'of length {len(t0_axis_labels)}.')
951 if len(t1_axis_labels) != len(t1.shape):
952 raise ValueError(
953 f'Tensor `t1` of rank {len(t1.shape)} does not match einsum reduction '
954 f'of length {len(t1_axis_labels)}')
956 # This function computes the result of a two-argument einsum() using batch
957 # matrix multiplication. This involves
958 # 1. transposing t0 and t1 so that axes are in the correct order for
959 # batch matrix multiplication, and
960 # 2. reshaping t0 and t1 so that they are both of rank 3.
962 # First, we divide axes into three groups:
963 # * "preserved" axes are present in both inputs and the output
964 # * "summed" axes are present in both inputs but not the output
965 # * "broadcast" axes are present in exactly one input and the output
966 #
967 # As an example, if the einsum is abijk,acjkl->abcil, then "a" is a
968 # preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are
969 # summed axes.
970 assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum)
971 preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum
972 broadcast_axes = {}
973 for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]):
974 broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum
976 # Reorder the axes so that:
977 # 1. preserved axes come first in both inputs
978 # 2. in input 0, broadcast axes come next, followed by summed axes
979 # 3. in input 1, summed axes come next, followed by broadcast axes
980 def sort_key(input_index, a):
981 if a in preserved_axes:
982 return (-1, a)
983 elif ((input_index == 0 and a in broadcast_axes[0]) or
984 (input_index == 1 and a in axes_to_sum)):
985 return (0, a)
986 else:
987 return (1, a)
989 axis_labels = [t0_axis_labels, t1_axis_labels]
990 sorted_axes = [
991 sorted(sym_list, key=lambda a: sort_key(i, a))
992 for i, sym_list in enumerate(axis_labels)
993 ]
994 inputs = [t0, t1]
995 for i, axes_str in enumerate(axis_labels):
996 perm = [axes_str.find(a) for a in sorted_axes[i]]
997 inputs[i] = _transpose_if_necessary(inputs[i], perm)
998 t0, t1 = inputs
1000 if not axes_to_sum:
1001 # In the special case where there are no axes to sum over, reduce to mul()
1002 # rather than to batch matrix multiplication.
1003 for _ in broadcast_axes[1]:
1004 t0 = array_ops.expand_dims(t0, -1)
1005 for _ in broadcast_axes[0]:
1006 t1 = array_ops.expand_dims(t1, len(preserved_axes))
1007 product = math_ops.multiply(t0, t1)
1008 product_axes = sorted_axes[0] + sorted_axes[1][len(preserved_axes):]
1009 return product, ''.join(product_axes)
1010 else:
1011 # Reduce to matmul().
1013 # Reshape both inputs so as to combine multiple broadcast axes
1014 # into a single axis, and combine multiple summed axes into a
1015 # single axis.
1017 t0_shape = _get_shape(t0)
1018 num_broadcast_elements_t0 = _total_size(
1019 t0_shape[len(preserved_axes):-len(axes_to_sum)])
1020 num_summed_elements = _total_size(t0_shape[-len(axes_to_sum):])
1021 new_shape = (
1022 t0_shape[:len(preserved_axes)] +
1023 [num_broadcast_elements_t0, num_summed_elements])
1024 t0 = _reshape_if_necessary(t0, new_shape)
1026 t1_shape = _get_shape(t1)
1027 num_broadcast_elements_t1 = _total_size(
1028 t1_shape[len(preserved_axes) + len(axes_to_sum):])
1029 new_shape = (
1030 t1_shape[:len(preserved_axes)] +
1031 [num_summed_elements, num_broadcast_elements_t1])
1032 t1 = _reshape_if_necessary(t1, new_shape)
1034 product = math_ops.matmul(t0, t1)
1036 # Undo compaction of broadcast axes
1037 uncompacted_shape = (
1038 t0_shape[:len(preserved_axes) + len(broadcast_axes[0])] +
1039 t1_shape[len(t1_shape) - len(broadcast_axes[1]):])
1040 product = _reshape_if_necessary(product, uncompacted_shape)
1042 product_axes = (
1043 sorted_axes[0][:len(preserved_axes) + len(broadcast_axes[0])] +
1044 sorted_axes[1][len(sorted_axes[1]) - len(broadcast_axes[1]):])
1046 return product, ''.join(product_axes)
1049def _transpose_if_necessary(tensor, perm):
1050 """Like transpose(), but avoids creating a new tensor if possible."""
1051 if perm != list(range(len(perm))):
1052 return array_ops.transpose(tensor, perm=perm)
1053 else:
1054 return tensor
1057def _reshape_if_necessary(tensor, new_shape):
1058 """Like reshape(), but avoids creating a new tensor if possible."""
1059 # Accept None as an alias for -1 in new_shape.
1060 new_shape = tuple(-1 if x is None else x for x in new_shape)
1061 cur_shape = tuple(x.value for x in tensor.shape.dims)
1062 if (len(new_shape) == len(cur_shape) and
1063 all(not isinstance(d1, ops.Tensor) and (d0 == d1 or d1 == -1)
1064 for d0, d1 in zip(cur_shape, new_shape))):
1065 return tensor
1066 else:
1067 return array_ops.reshape(tensor, new_shape)
1070def _get_shape(tensor):
1071 """Like get_shape().as_list(), but explicitly queries the shape of a tensor
1072 if necessary to ensure that the returned value contains no unknown value."""
1074 shape = tensor.shape.as_list()
1075 none_indices = [i for i, d in enumerate(shape) if d is None]
1076 if none_indices:
1077 # Query the shape if shape contains None values
1078 shape_tensor = array_ops.shape(tensor)
1079 for i in none_indices:
1080 shape[i] = shape_tensor[i]
1081 return shape
1084def _total_size(shape_values):
1085 """Given list of tensor shape values, returns total size.
1086 If shape_values contains tensor values (which are results of
1087 array_ops.shape), then it returns a scalar tensor.
1088 If not, it returns an integer."""
1090 result = 1
1091 for val in shape_values:
1092 result *= val
1093 return result
1096def _exponential_space_einsum_v1(equation, *inputs):
1097 """Fallback implementation that supports summing an index over > 2 inputs."""
1098 inputs = list(inputs)
1099 input_shapes = [x.shape for x in inputs]
1100 idx_in, idx_out = _einsum_v1_parse_and_resolve_equation(
1101 equation, input_shapes)
1103 idx_all = set(''.join(idx_in) + idx_out)
1104 indices = ''.join(sorted(idx_all))
1106 missing_idx = set(idx_out).difference(idx_all)
1107 if missing_idx:
1108 raise ValueError(f'Unknown output axes: {missing_idx}.')
1110 axis_order = {}
1111 for ax in indices:
1112 if ax not in idx_out:
1113 axis_order[ax] = len(axis_order)
1114 for ax in idx_out:
1115 axis_order[ax] = len(axis_order)
1117 # transpose inputs so axes are in order
1118 for i, (input_, axes_) in enumerate(zip(inputs, idx_in)):
1119 if input_.shape.ndims != len(axes_):
1120 raise ValueError(
1121 f'Input {i} with axes {axes_} has incorrect number of dimensions '
1122 f'(expected {len(axes_)}, got {input_.shape.ndims}).')
1124 sorted_idx = sorted(axes_, key=axis_order.get)
1126 if len(set(axes_)) != len(axes_):
1127 raise ValueError(
1128 f'Subscript not supported: an axis appears more than once: {axes_}.')
1130 if list(axes_) != sorted_idx:
1131 permuted = [axes_.find(ax) for ax in sorted_idx]
1132 inputs[i] = array_ops.transpose(input_, permuted)
1133 idx_in[i] = sorted_idx
1135 reduction_idx = []
1136 shapes = [[dim if dim else -1
1137 for dim in tensor.shape.as_list()]
1138 for tensor in inputs]
1140 # validate shapes for broadcasting
1141 for j, ax in enumerate(sorted(idx_all, key=axis_order.get)):
1142 dims = []
1143 for i, idx in enumerate(idx_in):
1144 if ax not in idx:
1145 shapes[i].insert(j, 1)
1146 else:
1147 dim = shapes[i][j]
1148 if isinstance(dim, int) and dim > 1:
1149 dims.append(dim)
1151 if len(set(dims)) > 1:
1152 raise ValueError(f'Dimension mismatch on axis: {ax}. '
1153 f'Found {len(set(dims))}, expected 1.')
1155 if ax not in idx_out:
1156 reduction_idx.append(j)
1158 # reshape, multiply
1159 expanded_inputs = [
1160 array_ops.reshape(input_, shape) for input_, shape in zip(inputs, shapes)
1161 ]
1162 expanded_output = 1
1163 for input_ in expanded_inputs:
1164 expanded_output *= input_
1166 # contract
1167 return math_ops.reduce_sum(expanded_output, reduction_idx)
1170def _einsum_v2(equation, *inputs, **kwargs):
1171 """Implementation of einsum utilizing opt_einsum and EinsumOp."""
1172 name = kwargs.pop('name', None)
1173 optimize = kwargs.pop('optimize', 'greedy')
1174 if kwargs:
1175 raise TypeError(
1176 f'Invalid keyword arguments for einsum: {", ".join(kwargs)}. '
1177 f'Valid arguments: name, optimize, greedy.')
1179 with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
1180 inputs = list(inputs)
1181 input_shapes = []
1182 for operand in inputs:
1183 if isinstance(operand.shape, tensor_shape.TensorShape):
1184 input_shapes.append(operand.shape.as_list() if operand.shape else None)
1185 else:
1186 input_shapes.append(list(operand.shape))
1187 # Validate and sanitize the equation and resolve static input shapes, as
1188 # opt_einsum requires that all shapes be a tuple of positive integers.
1189 # Also remove ellipsis from the equation as opt_einsum will replace them
1190 # with named labels. Then broadcasting between different shapes or ranks
1191 # wouldn't work. (E.g. [1, 1, 2] wouldn't broadcast with [3, 1]).
1192 resolved_equation, resolved_input_shapes, ellipsis_label = (
1193 _einsum_v2_parse_and_resolve_equation(equation, input_shapes))
1195 if len(inputs) <= 2: # No need to call opt_einsum.
1196 # Replace back ellipses that were removed for opt_einsum.
1197 if ellipsis_label:
1198 resolved_equation = resolved_equation.replace(ellipsis_label, '...')
1199 return gen_linalg_ops.einsum(inputs, resolved_equation)
1201 # Send fully specified shapes to opt_einsum, since it cannot handle unknown
1202 # dimensions. For unknown dimensions, we guess that the dimension equals 1.
1203 # Instead of creating Tensors or NumPy arrays with the specified shape,
1204 # create a dummy `shaped` object with a `shape` property.
1205 shaped = collections.namedtuple('shaped', ['shape'])
1206 shaped_inputs = tuple(
1207 [shaped(tuple(shape)) for shape in resolved_input_shapes])
1208 # opt_einsum breaks down an n-ary einsum operation into n-1 binary einsums.
1209 # Obtain the sequence of equations and the indices of operands involved in
1210 # each einsum operation.
1211 indices_and_equations = _get_opt_einsum_contract_path(
1212 resolved_equation, shaped_inputs, optimize)
1213 for operand_indices, binary_equation in indices_and_equations:
1214 if ellipsis_label:
1215 # Replace back ellipses that were removed for opt_einsum.
1216 binary_equation = binary_equation.replace(ellipsis_label, '...')
1217 operands = list(map(inputs.pop, operand_indices))
1218 inputs.append(gen_linalg_ops.einsum(operands, binary_equation))
1219 return inputs[0]
1222def _get_opt_einsum_contract_path(equation, shaped_inputs_tuple, optimize):
1223 """Returns the (memoized) result of opt_einsum.contract_path."""
1224 # Note: We use einsum_call=True, which is an internal api for opt_einsum,
1225 # to get the contraction path without having opt_einsum perform the actual
1226 # contractions.
1227 _, contractions = opt_einsum.contract_path(
1228 equation,
1229 *shaped_inputs_tuple,
1230 optimize=optimize,
1231 einsum_call=True,
1232 use_blas=True)
1233 # Return a tuple so that the cached value is not mutable.
1234 indices_and_equations = tuple([(expr[0], expr[2]) for expr in contractions])
1235 return indices_and_equations
1238# Cache the possibly expensive opt_einsum.contract_path call using lru_cache
1239# from the Python3+ standard library.
1240_get_opt_einsum_contract_path = functools.lru_cache(maxsize=128)(
1241 _get_opt_einsum_contract_path)
1244def _einsum_v2_parse_and_resolve_equation(equation, input_shapes):
1245 """Helper which validates einsum equation and resolves input shapes."""
1246 resolved_equation = equation.replace(' ', '')
1247 ellipsis_label = None
1248 if '...' in equation:
1249 # Replace ellipsis ('...') with '0' for (a) ease of parsing and (b) to
1250 # prevent opt_einsum from resolving them into named labels; as it doesn't
1251 # support broadcasting.
1252 ellipsis_label = '0'
1253 if ellipsis_label in resolved_equation:
1254 raise ValueError(
1255 f'Invalid character "{ellipsis_label}" in equation: {equation}.')
1256 resolved_equation = resolved_equation.replace('...', ellipsis_label)
1258 # Ensure there are no non-alphanumeric characters in the equation, including
1259 # periods (`.`) outside of ellipses, in the equation. This is not a hard
1260 # requirement; except we use a special character '0' for ellipsis.
1261 allowed_labels = 'a-zA-Z'
1262 if ellipsis_label:
1263 allowed_labels += ellipsis_label
1264 match = re.match('^([{0},]*)(->[{0}]*)?$'.format(allowed_labels),
1265 resolved_equation)
1266 if not match:
1267 raise ValueError(
1268 'Subscripts have incorrect format: {}'.format(resolved_equation))
1269 input_labels = match.group(1).split(',')
1270 output_labels = match.group(2)[2:] if match.group(2) else None
1272 if len(input_shapes) != len(input_labels):
1273 raise ValueError('Got {} inputs for equation "{}", expecting {}'.format(
1274 len(input_shapes), equation, len(input_labels)))
1276 # Special case: if there are no '->', then we create output subscripts from
1277 # labels appearing only once.
1278 if '->' not in resolved_equation:
1279 label_counts = collections.Counter(match.group(1))
1280 output_labels = ''.join([
1281 x for x in sorted(list(label_counts))
1282 if x != ',' and label_counts[x] == 1
1283 ])
1284 resolved_equation += '->' + output_labels
1285 # Validate output_labels.
1286 if output_labels and len(set(output_labels)) != len(output_labels):
1287 raise ValueError(
1288 'Output subscripts contain a label appearing more than once: {}'.format(
1289 equation))
1290 input_label_set = set(match.group(1))
1291 for label in output_labels:
1292 if label != ellipsis_label and label not in input_label_set:
1293 raise ValueError('Output subscripts contain the label {} not present '
1294 'in the input subscripts.'.format(label))
1295 if ellipsis_label and output_labels:
1296 num_output_ellipses = output_labels.count(ellipsis_label)
1297 if num_output_ellipses > 1:
1298 raise ValueError(
1299 'Output subscripts contain multiple ellipsis: {}'.format(equation))
1301 # Early return if <= 2 inputs. Resolved shapes are not needed.
1302 if len(input_shapes) <= 2:
1303 return resolved_equation, None, ellipsis_label
1305 # Create a map from axis labels to known dimensions. This is used to infer
1306 # unknown dimensions if a known dimension also has the same label.
1307 label_to_dim = collections.defaultdict(lambda: 1)
1308 for i, (labels, shape) in enumerate(zip(input_labels, input_shapes)):
1309 if shape is None:
1310 continue
1311 ellipsis_start = labels.find(ellipsis_label) if ellipsis_label else -1
1312 if ellipsis_start != -1: # This input contains an ellipsis.
1313 if ellipsis_start != labels.rfind(ellipsis_label):
1314 raise ValueError(f'Too many ellipses in input label '
1315 f'{labels.replace(ellipsis_label, "...")}.')
1316 if len(labels) > len(shape) + 1:
1317 raise ValueError('Too many named labels in {}th subscript string of'
1318 ' equation {} for input shape {} '.format(
1319 i, equation, shape))
1320 ellipsis_end = ellipsis_start + len(shape) + 1 - len(labels)
1321 shape[ellipsis_start:ellipsis_end] = ([
1322 np.prod(
1323 list(filter(None, shape[ellipsis_start:ellipsis_end])),
1324 dtype=np.int64)
1325 ])
1326 else:
1327 # This input does not contain an ellipsis.
1328 if len(labels) != len(shape):
1329 raise ValueError(
1330 'Number of named labels in input #{} of equation {} '
1331 'must be equal to the number of dimensions in shape {}'.format(
1332 i, equation, shape))
1333 for dim, label in zip(shape, labels):
1334 if dim is not None:
1335 label_to_dim[label] = max(label_to_dim[label], dim)
1337 resolved_shapes = []
1338 for labels in input_labels:
1339 resolved_shapes.append([label_to_dim[label] for label in labels])
1340 return resolved_equation, resolved_shapes, ellipsis_label