Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/slicing.py: 18%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for slicing in to a `LinearOperator`.""" 

16 

17import collections 

18import functools 

19import numpy as np 

20 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import tensor_util 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.util import nest 

25 

26 

27__all__ = ['batch_slice'] 

28 

29 

30def _prefer_static_where(condition, x, y): 

31 args = [condition, x, y] 

32 constant_args = [tensor_util.constant_value(a) for a in args] 

33 # Do this statically. 

34 if all(arg is not None for arg in constant_args): 

35 condition_, x_, y_ = constant_args 

36 return np.where(condition_, x_, y_) 

37 return array_ops.where(condition, x, y) 

38 

39 

40def _broadcast_parameter_with_batch_shape( 

41 param, param_ndims_to_matrix_ndims, batch_shape): 

42 """Broadcasts `param` with the given batch shape, recursively.""" 

43 if hasattr(param, 'batch_shape_tensor'): 

44 # Recursively broadcast every parameter inside the operator. 

45 override_dict = {} 

46 for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long 

47 sub_param = getattr(param, name) 

48 override_dict[name] = nest.map_structure_up_to( 

49 sub_param, functools.partial( 

50 _broadcast_parameter_with_batch_shape, 

51 batch_shape=batch_shape), sub_param, ndims) 

52 parameters = dict(param.parameters, **override_dict) 

53 return type(param)(**parameters) 

54 

55 base_shape = array_ops.concat( 

56 [batch_shape, array_ops.ones( 

57 [param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0) 

58 return array_ops.broadcast_to( 

59 param, 

60 array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param))) 

61 

62 

63def _sanitize_slices(slices, intended_shape, deficient_shape): 

64 """Restricts slices to avoid overflowing size-1 (broadcast) dimensions. 

65 

66 Args: 

67 slices: iterable of slices received by `__getitem__`. 

68 intended_shape: int `Tensor` shape for which the slices were intended. 

69 deficient_shape: int `Tensor` shape to which the slices will be applied. 

70 Must have the same rank as `intended_shape`. 

71 Returns: 

72 sanitized_slices: Python `list` of slice objects. 

73 """ 

74 sanitized_slices = [] 

75 idx = 0 

76 for slc in slices: 

77 if slc is Ellipsis: # Switch over to negative indexing. 

78 if idx < 0: 

79 raise ValueError('Found multiple `...` in slices {}'.format(slices)) 

80 num_remaining_non_newaxis_slices = sum( 

81 s is not array_ops.newaxis for s in slices[ 

82 slices.index(Ellipsis) + 1:]) 

83 idx = -num_remaining_non_newaxis_slices 

84 elif slc is array_ops.newaxis: 

85 pass 

86 else: 

87 is_broadcast = intended_shape[idx] > deficient_shape[idx] 

88 if isinstance(slc, slice): 

89 # Slices are denoted by start:stop:step. 

90 start, stop, step = slc.start, slc.stop, slc.step 

91 if start is not None: 

92 start = _prefer_static_where(is_broadcast, 0, start) 

93 if stop is not None: 

94 stop = _prefer_static_where(is_broadcast, 1, stop) 

95 if step is not None: 

96 step = _prefer_static_where(is_broadcast, 1, step) 

97 slc = slice(start, stop, step) 

98 else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2] 

99 slc = _prefer_static_where(is_broadcast, 0, slc) 

100 idx += 1 

101 sanitized_slices.append(slc) 

102 return sanitized_slices 

103 

104 

105def _slice_single_param( 

106 param, param_ndims_to_matrix_ndims, slices, batch_shape): 

107 """Slices into the batch shape of a single parameter. 

108 

109 Args: 

110 param: The original parameter to slice; either a `Tensor` or an object 

111 with batch shape (LinearOperator). 

112 param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for 

113 inferring matrix shape of the `LinearOperator`. For non-Tensor 

114 parameters, this is the number of this param's batch dimensions used by 

115 the matrix shape of the parent object. 

116 slices: iterable of slices received by `__getitem__`. 

117 batch_shape: The parameterized object's batch shape `Tensor`. 

118 

119 Returns: 

120 new_param: Instance of the same type as `param`, batch-sliced according to 

121 `slices`. 

122 """ 

123 # Broadcast the parammeter to have full batch rank. 

124 param = _broadcast_parameter_with_batch_shape( 

125 param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape)) 

126 

127 if hasattr(param, 'batch_shape_tensor'): 

128 param_batch_shape = param.batch_shape_tensor() 

129 else: 

130 param_batch_shape = array_ops.shape(param) 

131 # Truncate by param_ndims_to_matrix_ndims 

132 param_batch_rank = array_ops.size(param_batch_shape) 

133 param_batch_shape = param_batch_shape[ 

134 :(param_batch_rank - param_ndims_to_matrix_ndims)] 

135 

136 # At this point the param should have full batch rank, *unless* it's an 

137 # atomic object like `tfb.Identity()` incapable of having any batch rank. 

138 if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and 

139 tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0): 

140 return param 

141 param_slices = _sanitize_slices( 

142 slices, intended_shape=batch_shape, deficient_shape=param_batch_shape) 

143 

144 # Extend `param_slices` (which represents slicing into the 

145 # parameter's batch shape) with the parameter's event ndims. For example, if 

146 # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`. 

147 if param_ndims_to_matrix_ndims > 0: 

148 if Ellipsis not in [ 

149 slc for slc in slices if not tensor_util.is_tensor(slc)]: 

150 param_slices.append(Ellipsis) 

151 param_slices += [slice(None)] * param_ndims_to_matrix_ndims 

152 return param.__getitem__(tuple(param_slices)) 

153 

154 

155def batch_slice(linop, params_overrides, slices): 

156 """Slices `linop` along its batch dimensions. 

157 

158 Args: 

159 linop: A `LinearOperator` instance. 

160 params_overrides: A `dict` of parameter overrides. 

161 slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple` 

162 thereof. (e.g. the argument of a `__getitem__` method). 

163 

164 Returns: 

165 new_linop: A batch-sliced `LinearOperator`. 

166 """ 

167 if not isinstance(slices, collections.abc.Sequence): 

168 slices = (slices,) 

169 if len(slices) == 1 and slices[0] is Ellipsis: 

170 override_dict = {} 

171 else: 

172 batch_shape = linop.batch_shape_tensor() 

173 override_dict = {} 

174 for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long 

175 param = getattr(linop, param_name) 

176 # These represent optional `Tensor` parameters. 

177 if param is not None: 

178 override_dict[param_name] = nest.map_structure_up_to( 

179 param, functools.partial( 

180 _slice_single_param, slices=slices, batch_shape=batch_shape), 

181 param, param_ndims_to_matrix_ndims) 

182 override_dict.update(params_overrides) 

183 parameters = dict(linop.parameters, **override_dict) 

184 return type(linop)(**parameters)