Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/op_selector.py: 11%
194 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for selecting ops in a graph."""
17from tensorflow.python.framework import ops
18from tensorflow.python.util import object_identity
21def is_differentiable(op):
22 try:
23 return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access
24 except LookupError:
25 return False
28def is_iterable(obj):
29 """Return true if the object is iterable."""
30 if isinstance(obj, ops.Tensor):
31 return False
32 try:
33 _ = iter(obj)
34 except Exception: # pylint: disable=broad-except
35 return False
36 return True
39def concatenate_unique(la, lb):
40 """Add all the elements of `lb` to `la` if they are not there already.
42 The elements added to `la` maintain ordering with respect to `lb`.
44 Args:
45 la: List of Python objects.
46 lb: List of Python objects.
47 Returns:
48 `la`: The list `la` with missing elements from `lb`.
49 """
50 la_set = set(la)
51 for l in lb:
52 if l not in la_set:
53 la.append(l)
54 la_set.add(l)
55 return la
58def get_tensors(graph):
59 """get all the tensors which are input or output of an op in the graph.
61 Args:
62 graph: a `tf.Graph`.
63 Returns:
64 A list of `tf.Tensor`.
65 Raises:
66 TypeError: if graph is not a `tf.Graph`.
67 """
68 if not isinstance(graph, ops.Graph):
69 raise TypeError("Expected a graph, got: {}".format(type(graph)))
70 ts = []
71 for op in graph.get_operations():
72 ts += op.outputs
73 return ts
76def get_unique_graph(tops, check_types=None, none_if_empty=False):
77 """Return the unique graph used by the all the elements in tops.
79 Args:
80 tops: iterable of elements to check (usually a list of tf.Operation and/or
81 tf.Tensor). Or a tf.Graph.
82 check_types: check that the element in tops are of given type(s). If None,
83 the types (tf.Operation, tf.Tensor) are used.
84 none_if_empty: don't raise an error if tops is an empty list, just return
85 None.
86 Returns:
87 The unique graph used by all the tops.
88 Raises:
89 TypeError: if tops is not a iterable of tf.Operation.
90 ValueError: if the graph is not unique.
91 """
92 if isinstance(tops, ops.Graph):
93 return tops
94 if not is_iterable(tops):
95 raise TypeError("{} is not iterable".format(type(tops)))
96 if check_types is None:
97 check_types = (ops.Operation, ops.Tensor)
98 elif not is_iterable(check_types):
99 check_types = (check_types,)
100 g = None
101 for op in tops:
102 if not isinstance(op, check_types):
103 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
104 t) for t in check_types]), type(op)))
105 if g is None:
106 g = op.graph
107 elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access
108 raise ValueError("Operation {} does not belong to given graph".format(op))
109 if g is None and not none_if_empty:
110 raise ValueError("Can't find the unique graph of an empty list")
111 return g
114def check_graphs(*args):
115 """Check that all the element in args belong to the same graph.
117 Args:
118 *args: a list of object with a obj.graph property.
119 Raises:
120 ValueError: if all the elements do not belong to the same graph.
121 """
122 graph = None
123 for i, sgv in enumerate(args):
124 if graph is None and sgv.graph is not None:
125 graph = sgv.graph
126 elif sgv.graph is not None and sgv.graph is not graph:
127 raise ValueError(f"args[{i}] does not belong to the same graph as "
128 "other arguments.")
131def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
132 """Convert ts to a list of `tf.Tensor`.
134 Args:
135 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
136 check_graph: if `True` check if all the tensors belong to the same graph.
137 allow_graph: if `False` a `tf.Graph` cannot be converted.
138 ignore_ops: if `True`, silently ignore `tf.Operation`.
139 Returns:
140 A newly created list of `tf.Tensor`.
141 Raises:
142 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
143 if `check_graph` is `True`, if all the ops do not belong to the same graph.
144 """
145 if isinstance(ts, ops.Graph):
146 if allow_graph:
147 return get_tensors(ts)
148 else:
149 raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
150 else:
151 if not is_iterable(ts):
152 ts = [ts]
153 if not ts:
154 return []
155 if check_graph:
156 check_types = None if ignore_ops else ops.Tensor
157 get_unique_graph(ts, check_types=check_types)
158 return [t for t in ts if isinstance(t, ops.Tensor)]
161def get_generating_ops(ts):
162 """Return all the generating ops of the tensors in `ts`.
164 Args:
165 ts: a list of `tf.Tensor`
166 Returns:
167 A list of all the generating `tf.Operation` of the tensors in `ts`.
168 Raises:
169 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
170 """
171 ts = make_list_of_t(ts, allow_graph=False)
172 return [t.op for t in ts]
175def get_consuming_ops(ts):
176 """Return all the consuming ops of the tensors in ts.
178 Args:
179 ts: a list of `tf.Tensor`
180 Returns:
181 A list of all the consuming `tf.Operation` of the tensors in `ts`.
182 Raises:
183 TypeError: if ts cannot be converted to a list of `tf.Tensor`.
184 """
185 ts = make_list_of_t(ts, allow_graph=False)
186 tops = []
187 for t in ts:
188 for op in t.consumers():
189 if op not in tops:
190 tops.append(op)
191 return tops
194def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
195 """Convert ops to a list of `tf.Operation`.
197 Args:
198 tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
199 operation.
200 check_graph: if `True` check if all the operations belong to the same graph.
201 allow_graph: if `False` a `tf.Graph` cannot be converted.
202 ignore_ts: if True, silently ignore `tf.Tensor`.
203 Returns:
204 A newly created list of `tf.Operation`.
205 Raises:
206 TypeError: if tops cannot be converted to a list of `tf.Operation` or,
207 if `check_graph` is `True`, if all the ops do not belong to the
208 same graph.
209 """
210 if isinstance(tops, ops.Graph):
211 if allow_graph:
212 return tops.get_operations()
213 else:
214 raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
215 else:
216 if not is_iterable(tops):
217 tops = [tops]
218 if not tops:
219 return []
220 if check_graph:
221 check_types = None if ignore_ts else ops.Operation
222 get_unique_graph(tops, check_types=check_types)
223 return [op for op in tops if isinstance(op, ops.Operation)]
226def _get_inputs(op, only_differentiable):
227 op_inputs = op.inputs
228 if only_differentiable:
229 return op_inputs if is_differentiable(op) else []
230 else:
231 return op_inputs
234def get_backward_walk_ops(seed_ops,
235 inclusive=True,
236 within_ops=None,
237 within_ops_fn=None,
238 stop_at_ts=(),
239 control_inputs=False,
240 only_differentiable=False):
241 """Do a backward graph walk and return all the visited ops.
243 Args:
244 seed_ops: an iterable of operations from which the backward graph
245 walk starts. If a list of tensors is given instead, the seed_ops are set
246 to be the generators of those tensors.
247 inclusive: if True the given seed_ops are also part of the resulting set.
248 within_ops: an iterable of `tf.Operation` within which the search is
249 restricted. If `within_ops` is `None`, the search is performed within
250 the whole graph.
251 within_ops_fn: if provided, a function on ops that should return True iff
252 the op is within the graph traversal. This can be used along within_ops,
253 in which case an op is within if it is also in within_ops.
254 stop_at_ts: an iterable of tensors at which the graph walk stops.
255 control_inputs: if True, control inputs will be used while moving backward.
256 only_differentiable: if True, only traverse ops which are differentiable.
257 This includes natively differentiable ops, or ops with custom gradients.
258 Returns:
259 A Python set of all the `tf.Operation` behind `seed_ops`.
260 Raises:
261 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
262 `tf.Operation`.
263 """
264 control_inputs = control_inputs and (not only_differentiable)
266 if not is_iterable(seed_ops):
267 seed_ops = [seed_ops]
269 try:
270 first_seed_op = next(iter(seed_ops))
271 except StopIteration:
272 # Empty iterable.
273 return []
275 if isinstance(first_seed_op, ops.Tensor):
276 ts = make_list_of_t(seed_ops, allow_graph=False)
277 seed_ops = get_generating_ops(ts)
278 else:
279 seed_ops = make_list_of_op(seed_ops, allow_graph=False)
281 stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts))
282 seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops))
283 if within_ops:
284 within_ops = make_list_of_op(within_ops, allow_graph=False)
285 within_ops = object_identity.ObjectIdentitySet(within_ops)
286 seed_ops &= within_ops
288 def is_within(op):
289 return (within_ops is None or op in within_ops) and (
290 within_ops_fn is None or within_ops_fn(op))
292 result = list(seed_ops)
293 wave = set(seed_ops)
294 while wave:
295 new_wave = set()
296 for op in wave:
297 for new_t in _get_inputs(op, only_differentiable=only_differentiable):
298 if new_t in stop_at_ts:
299 continue
300 if new_t.op not in result and is_within(new_t.op):
301 new_wave.add(new_t.op)
302 if control_inputs:
303 for new_op in op.control_inputs:
304 if new_op not in result and is_within(new_op):
305 new_wave.add(new_op)
306 concatenate_unique(result, new_wave)
307 wave = new_wave
308 if not inclusive:
309 result = [op for op in result if op not in seed_ops]
310 return result
313class UnliftableError(Exception):
314 """Raised if a Tensor cannot be lifted from the graph."""
316 # Prevent autograph from rewriting this error.
317 ag_pass_through = True
320def _as_operation(op_or_tensor):
321 if isinstance(op_or_tensor, ops.Tensor):
322 return op_or_tensor.op
323 return op_or_tensor
326def graph_inputs(op):
327 return [x.op for x in op.inputs] + list(op.control_inputs)
330def show_path(from_op, tensors, sources):
331 """Find one path from `from_op` to any of `tensors`, ignoring `sources`.
333 Args:
334 from_op: A `tf.Operation`.
335 tensors: A `tf.Operation`, a `tf.Tensor`, or a list thereof.
336 sources: A list of `tf.Tensor`.
338 Returns:
339 A python string containing the path, or "??" if none is found.
340 """
341 if isinstance(from_op, ops.Tensor):
342 from_op = from_op.op
344 if not isinstance(tensors, list):
345 tensors = [tensors]
347 final_ops = [_as_operation(tensor) for tensor in tensors]
349 visited_ops = set(x.op for x in sources)
350 ops_to_visit = list(final_ops)
351 some_op_output = {}
352 while ops_to_visit:
353 op = ops_to_visit.pop()
354 if op in visited_ops:
355 continue
356 visited_ops.add(op)
357 if op == from_op:
358 path_op = op
359 path = [path_op]
360 while path_op not in final_ops:
361 path_op = some_op_output[path_op]
362 path.append(path_op)
363 return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path))
364 else:
365 for inp in graph_inputs(op):
366 if inp not in visited_ops and inp not in sources:
367 some_op_output[inp] = op
368 ops_to_visit.append(inp)
369 return "??"
372# TODO(jmenick) - there is considerable duplication of functionality between
373# this function and get_backward_walk_ops(). Need to deduplicate.
374def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
375 op_outputs, add_sources):
376 """Walk a Graph and capture the subgraph between init_tensor and sources.
378 Note: This function mutates visited_ops and op_outputs.
380 Args:
381 init_tensor: A Tensor or Operation where the subgraph terminates.
382 sources: A set of Tensors where subgraph extraction should stop.
383 disallowed_placeholders: An optional set of ops which may not appear in the
384 lifted graph. Defaults to all placeholders.
385 visited_ops: A set of operations which were visited in a prior pass.
386 op_outputs: A defaultdict containing the outputs of an op which are to be
387 copied into the new subgraph.
388 add_sources: A boolean indicating whether placeholders which are not in
389 sources should be allowed.
391 Returns:
392 The set of placeholders upon which init_tensor depends and are not in
393 sources.
395 Raises:
396 UnliftableError: if init_tensor depends on a placeholder which is not in
397 sources and add_sources is False.
398 """
399 ops_to_visit = [_as_operation(init_tensor)]
400 extra_sources = object_identity.ObjectIdentitySet()
401 while ops_to_visit:
402 op = ops_to_visit.pop()
403 if op in visited_ops:
404 continue
405 visited_ops.add(op)
407 should_raise = False
408 if disallowed_placeholders is not None and op in disallowed_placeholders:
409 should_raise = True
410 elif op.type == "Placeholder":
411 if disallowed_placeholders is None and not add_sources:
412 should_raise = True
413 extra_sources.update(op.outputs)
415 if should_raise:
416 raise UnliftableError(
417 "Unable to lift tensor %s because it depends transitively on "
418 "placeholder %s via at least one path, e.g.: %s" %
419 (repr(init_tensor), repr(op), show_path(op, init_tensor, sources)))
420 for inp in graph_inputs(op):
421 op_outputs[inp].add(op)
422 if inp not in visited_ops and inp not in (sources or extra_sources):
423 ops_to_visit.append(inp)
425 return extra_sources