Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg_ops_impl.py: 25%

32 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Operations for linear algebra.""" 

16 

17import numpy as np 

18 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import array_ops 

22from tensorflow.python.ops import math_ops 

23from tensorflow.python.util import compat 

24 

25# Names below are lower_case. 

26# pylint: disable=invalid-name 

27 

28 

29def eye(num_rows, 

30 num_columns=None, 

31 batch_shape=None, 

32 dtype=dtypes.float32, 

33 name=None): 

34 """Construct an identity matrix, or a batch of matrices. 

35 

36 See `linalg_ops.eye`. 

37 """ 

38 with ops.name_scope( 

39 name, default_name='eye', values=[num_rows, num_columns, batch_shape]): 

40 is_square = num_columns is None 

41 batch_shape = [] if batch_shape is None else batch_shape 

42 num_columns = num_rows if num_columns is None else num_columns 

43 

44 # We cannot statically infer what the diagonal size should be: 

45 if (isinstance(num_rows, ops.Tensor) or 

46 isinstance(num_columns, ops.Tensor)): 

47 diag_size = math_ops.minimum(num_rows, num_columns) 

48 else: 

49 # We can statically infer the diagonal size, and whether it is square. 

50 if not isinstance(num_rows, compat.integral_types) or not isinstance( 

51 num_columns, compat.integral_types): 

52 raise TypeError( 

53 'Arguments `num_rows` and `num_columns` must be positive integer ' 

54 f'values. Received: num_rows={num_rows}, num_columns={num_columns}') 

55 is_square = num_rows == num_columns 

56 diag_size = np.minimum(num_rows, num_columns) 

57 

58 # We can not statically infer the shape of the tensor. 

59 if isinstance(batch_shape, ops.Tensor) or isinstance(diag_size, ops.Tensor): 

60 batch_shape = ops.convert_to_tensor( 

61 batch_shape, name='shape', dtype=dtypes.int32) 

62 diag_shape = array_ops.concat((batch_shape, [diag_size]), axis=0) 

63 if not is_square: 

64 shape = array_ops.concat((batch_shape, [num_rows, num_columns]), axis=0) 

65 # We can statically infer everything. 

66 else: 

67 batch_shape = list(batch_shape) 

68 diag_shape = batch_shape + [diag_size] 

69 if not is_square: 

70 shape = batch_shape + [num_rows, num_columns] 

71 

72 diag_ones = array_ops.ones(diag_shape, dtype=dtype) 

73 if is_square: 

74 return array_ops.matrix_diag(diag_ones) 

75 else: 

76 zero_matrix = array_ops.zeros(shape, dtype=dtype) 

77 return array_ops.matrix_set_diag(zero_matrix, diag_ones) 

78 

79# pylint: enable=invalid-name,redefined-builtin