Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/linalg/slicing.py: 18%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for slicing in to a `LinearOperator`."""
17import collections
18import functools
19import numpy as np
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import tensor_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.util import nest
27__all__ = ['batch_slice']
30def _prefer_static_where(condition, x, y):
31 args = [condition, x, y]
32 constant_args = [tensor_util.constant_value(a) for a in args]
33 # Do this statically.
34 if all(arg is not None for arg in constant_args):
35 condition_, x_, y_ = constant_args
36 return np.where(condition_, x_, y_)
37 return array_ops.where(condition, x, y)
40def _broadcast_parameter_with_batch_shape(
41 param, param_ndims_to_matrix_ndims, batch_shape):
42 """Broadcasts `param` with the given batch shape, recursively."""
43 if hasattr(param, 'batch_shape_tensor'):
44 # Recursively broadcast every parameter inside the operator.
45 override_dict = {}
46 for name, ndims in param._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
47 sub_param = getattr(param, name)
48 override_dict[name] = nest.map_structure_up_to(
49 sub_param, functools.partial(
50 _broadcast_parameter_with_batch_shape,
51 batch_shape=batch_shape), sub_param, ndims)
52 parameters = dict(param.parameters, **override_dict)
53 return type(param)(**parameters)
55 base_shape = array_ops.concat(
56 [batch_shape, array_ops.ones(
57 [param_ndims_to_matrix_ndims], dtype=dtypes.int32)], axis=0)
58 return array_ops.broadcast_to(
59 param,
60 array_ops.broadcast_dynamic_shape(base_shape, array_ops.shape(param)))
63def _sanitize_slices(slices, intended_shape, deficient_shape):
64 """Restricts slices to avoid overflowing size-1 (broadcast) dimensions.
66 Args:
67 slices: iterable of slices received by `__getitem__`.
68 intended_shape: int `Tensor` shape for which the slices were intended.
69 deficient_shape: int `Tensor` shape to which the slices will be applied.
70 Must have the same rank as `intended_shape`.
71 Returns:
72 sanitized_slices: Python `list` of slice objects.
73 """
74 sanitized_slices = []
75 idx = 0
76 for slc in slices:
77 if slc is Ellipsis: # Switch over to negative indexing.
78 if idx < 0:
79 raise ValueError('Found multiple `...` in slices {}'.format(slices))
80 num_remaining_non_newaxis_slices = sum(
81 s is not array_ops.newaxis for s in slices[
82 slices.index(Ellipsis) + 1:])
83 idx = -num_remaining_non_newaxis_slices
84 elif slc is array_ops.newaxis:
85 pass
86 else:
87 is_broadcast = intended_shape[idx] > deficient_shape[idx]
88 if isinstance(slc, slice):
89 # Slices are denoted by start:stop:step.
90 start, stop, step = slc.start, slc.stop, slc.step
91 if start is not None:
92 start = _prefer_static_where(is_broadcast, 0, start)
93 if stop is not None:
94 stop = _prefer_static_where(is_broadcast, 1, stop)
95 if step is not None:
96 step = _prefer_static_where(is_broadcast, 1, step)
97 slc = slice(start, stop, step)
98 else: # int, or int Tensor, e.g. d[d.batch_shape_tensor()[0] // 2]
99 slc = _prefer_static_where(is_broadcast, 0, slc)
100 idx += 1
101 sanitized_slices.append(slc)
102 return sanitized_slices
105def _slice_single_param(
106 param, param_ndims_to_matrix_ndims, slices, batch_shape):
107 """Slices into the batch shape of a single parameter.
109 Args:
110 param: The original parameter to slice; either a `Tensor` or an object
111 with batch shape (LinearOperator).
112 param_ndims_to_matrix_ndims: `int` number of right-most dimensions used for
113 inferring matrix shape of the `LinearOperator`. For non-Tensor
114 parameters, this is the number of this param's batch dimensions used by
115 the matrix shape of the parent object.
116 slices: iterable of slices received by `__getitem__`.
117 batch_shape: The parameterized object's batch shape `Tensor`.
119 Returns:
120 new_param: Instance of the same type as `param`, batch-sliced according to
121 `slices`.
122 """
123 # Broadcast the parammeter to have full batch rank.
124 param = _broadcast_parameter_with_batch_shape(
125 param, param_ndims_to_matrix_ndims, array_ops.ones_like(batch_shape))
127 if hasattr(param, 'batch_shape_tensor'):
128 param_batch_shape = param.batch_shape_tensor()
129 else:
130 param_batch_shape = array_ops.shape(param)
131 # Truncate by param_ndims_to_matrix_ndims
132 param_batch_rank = array_ops.size(param_batch_shape)
133 param_batch_shape = param_batch_shape[
134 :(param_batch_rank - param_ndims_to_matrix_ndims)]
136 # At this point the param should have full batch rank, *unless* it's an
137 # atomic object like `tfb.Identity()` incapable of having any batch rank.
138 if (tensor_util.constant_value(array_ops.size(batch_shape)) != 0 and
139 tensor_util.constant_value(array_ops.size(param_batch_shape)) == 0):
140 return param
141 param_slices = _sanitize_slices(
142 slices, intended_shape=batch_shape, deficient_shape=param_batch_shape)
144 # Extend `param_slices` (which represents slicing into the
145 # parameter's batch shape) with the parameter's event ndims. For example, if
146 # `params_ndims == 1`, then `[i, ..., j]` would become `[i, ..., j, :]`.
147 if param_ndims_to_matrix_ndims > 0:
148 if Ellipsis not in [
149 slc for slc in slices if not tensor_util.is_tensor(slc)]:
150 param_slices.append(Ellipsis)
151 param_slices += [slice(None)] * param_ndims_to_matrix_ndims
152 return param.__getitem__(tuple(param_slices))
155def batch_slice(linop, params_overrides, slices):
156 """Slices `linop` along its batch dimensions.
158 Args:
159 linop: A `LinearOperator` instance.
160 params_overrides: A `dict` of parameter overrides.
161 slices: A `slice` or `int` or `int` `Tensor` or `tf.newaxis` or `tuple`
162 thereof. (e.g. the argument of a `__getitem__` method).
164 Returns:
165 new_linop: A batch-sliced `LinearOperator`.
166 """
167 if not isinstance(slices, collections.abc.Sequence):
168 slices = (slices,)
169 if len(slices) == 1 and slices[0] is Ellipsis:
170 override_dict = {}
171 else:
172 batch_shape = linop.batch_shape_tensor()
173 override_dict = {}
174 for param_name, param_ndims_to_matrix_ndims in linop._experimental_parameter_ndims_to_matrix_ndims.items(): # pylint:disable=protected-access,line-too-long
175 param = getattr(linop, param_name)
176 # These represent optional `Tensor` parameters.
177 if param is not None:
178 override_dict[param_name] = nest.map_structure_up_to(
179 param, functools.partial(
180 _slice_single_param, slices=slices, batch_shape=batch_shape),
181 param, param_ndims_to_matrix_ndims)
182 override_dict.update(params_overrides)
183 parameters = dict(linop.parameters, **override_dict)
184 return type(linop)(**parameters)