Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg_ops_impl.py: 25%
32 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Operations for linear algebra."""
17import numpy as np
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import math_ops
23from tensorflow.python.util import compat
25# Names below are lower_case.
26# pylint: disable=invalid-name
29def eye(num_rows,
30 num_columns=None,
31 batch_shape=None,
32 dtype=dtypes.float32,
33 name=None):
34 """Construct an identity matrix, or a batch of matrices.
36 See `linalg_ops.eye`.
37 """
38 with ops.name_scope(
39 name, default_name='eye', values=[num_rows, num_columns, batch_shape]):
40 is_square = num_columns is None
41 batch_shape = [] if batch_shape is None else batch_shape
42 num_columns = num_rows if num_columns is None else num_columns
44 # We cannot statically infer what the diagonal size should be:
45 if (isinstance(num_rows, ops.Tensor) or
46 isinstance(num_columns, ops.Tensor)):
47 diag_size = math_ops.minimum(num_rows, num_columns)
48 else:
49 # We can statically infer the diagonal size, and whether it is square.
50 if not isinstance(num_rows, compat.integral_types) or not isinstance(
51 num_columns, compat.integral_types):
52 raise TypeError(
53 'Arguments `num_rows` and `num_columns` must be positive integer '
54 f'values. Received: num_rows={num_rows}, num_columns={num_columns}')
55 is_square = num_rows == num_columns
56 diag_size = np.minimum(num_rows, num_columns)
58 # We can not statically infer the shape of the tensor.
59 if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor):
60 batch_shape = ops.convert_to_tensor(
61 batch_shape, name='shape', dtype=dtypes.int32)
62 diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0)
63 if not is_square:
64 shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0)
65 # We can statically infer everything.
66 else:
67 batch_shape = list(batch_shape)
68 diag_shape = batch_shape + [diag_size]
69 if not is_square:
70 shape = batch_shape + [num_rows, num_columns]
72 diag_ones = array_ops.ones(diag_shape, dtype=dtype)
73 if is_square:
74 return array_ops.matrix_diag(diag_ones)
75 else:
76 zero_matrix = array_ops.zeros(shape, dtype=dtype)
77 return array_ops.matrix_set_diag(zero_matrix, diag_ones)
79# pylint: enable=invalid-name,redefined-builtin