Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/lift_to_graph.py: 14%
133 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=unidiomatic-typecheck
16"""Utility to lift subgraphs."""
18import collections
20from tensorflow.python.framework import func_graph
21from tensorflow.python.framework import ops
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import op_selector
24from tensorflow.python.ops import resource_variable_ops
25from tensorflow.python.util import compat
26from tensorflow.python.util import object_identity
27from tensorflow.python.util.tf_export import tf_export
30UnliftableError = op_selector.UnliftableError
33def _as_operation(op_or_tensor):
34 if isinstance(op_or_tensor, ops.Tensor):
35 return op_or_tensor.op
36 return op_or_tensor
39def _constant_inputs(op_or_tensor):
40 return all(_as_operation(i).type == u"Const"
41 and not _as_operation(i).control_inputs
42 for i in op_selector.graph_inputs(_as_operation(op_or_tensor)))
45# Represents an input to `copied_op` which must be updated once
46# `old_graph_tensor` has been copied.
47_InputMutation = collections.namedtuple(
48 "_InputMutation",
49 ["copied_op", "input_index", "old_graph_tensor"])
52# Represents a control input to `copied_op` which must be added once
53# `old_graph_op` has been copied.
54_ControlMutation = collections.namedtuple(
55 "_ControlMutation",
56 ["copied_op", "old_graph_op"])
59def _copy_non_source(op, graph, op_map, base_graph):
60 """Copy an op directly to a given graph.
62 Generally `op`'s inputs should already have been copied. If this is not the
63 case, for example with v1 while_loops, then `_copy_non_source` inserts
64 placeholders for the unavailable Tensors and returns a list of required
65 mutations.
67 Args:
68 op: The op to be copied.
69 graph: The destination graph.
70 op_map: A dict mapping ops and tensors in the old graph to the new one.
71 base_graph: The graph we're copying from, for any necessary functions.
72 Returns:
73 A tuple of (required_inputs, required_control_inputs):
74 required_inputs:
75 A list of `_InputMutation` tuples containing inputs to `copied_op` which
76 must be updated once `old_graph_tensor` has been copied.
77 required_control_inputs:
78 A list of `_ControlMutation` tuples containing control inputs to
79 `copied_op` which must be added once `old_graph_op` has been copied.
80 """
81 input_mutations = []
82 control_mutations = []
83 copied_inputs = []
84 for input_index, original_input in enumerate(op.inputs):
85 copied_input = op_map.get(original_input, None)
86 if copied_input is None:
87 # An input for this op is missing due to a loop in the graph. We'll insert
88 # a placeholder for now and return information about the required post-hoc
89 # mutation.
90 copied_input = array_ops.placeholder(
91 name="unused_control_flow_input",
92 shape=original_input.shape,
93 dtype=original_input.dtype)
94 input_mutations.append(
95 # `copied_op` is filled in below, after we've created it.
96 _InputMutation(copied_op=None,
97 input_index=input_index,
98 old_graph_tensor=original_input))
99 copied_inputs.append(copied_input)
101 copied_control_inputs = []
102 for original_control_input in op.control_inputs:
103 copied_control_input = op_map.get(original_control_input, None)
104 if copied_control_input is None:
105 control_mutations.append(
106 _ControlMutation(copied_op=None,
107 old_graph_op=original_control_input))
108 else:
109 copied_control_inputs.append(copied_control_input)
111 # Don't copy over nodes with _tpu_replicate attribute. This attributed is used
112 # to signal that the op was built inside a tpu_replicate context; if we're
113 # lifting it to another graph we're similarly lifting it into another context.
114 with ops.control_dependencies(copied_control_inputs), ops.device(op.device):
115 # pylint: disable=protected-access
116 f = base_graph._functions.get(op.type, None)
117 if f is not None and compat.as_str(f.name) not in graph._functions:
118 f.add_to_graph(graph)
119 # pylint: enable=protected-access
121 # Create a new op in the destination graph if it doesn't exist before.
122 copied_op = graph.create_op(
123 op_type=op.type,
124 inputs=copied_inputs,
125 dtypes=[x.dtype for x in op.outputs],
126 attrs={
127 key: value for key, value in op.node_def.attr.items()
128 if not key.startswith("_class") and
129 not key.startswith("_tpu_replicate")
130 }, # b/128981532.
131 name=op.name)
132 op_map[op] = copied_op
133 for i, o in enumerate(op.outputs):
134 op_map[o] = copied_op.outputs[i]
136 return ([mutation._replace(copied_op=copied_op)
137 for mutation in input_mutations],
138 [mutation._replace(copied_op=copied_op)
139 for mutation in control_mutations])
142def _copy_source(s, graph, op_map, handle_captures, inverse_captures,
143 base_graph):
144 """Create a source in a graph based on a Tensor from a different graph.
146 This function creates a placeholder analog of `s` in a graph with the
147 following behavior:
149 1) If s is a captured Tensor or Variable and handle_captures is set to True,
150 simply capture it in the new graph as well.
152 2) If s is a PlaceholderWithDefault whose default is a constant, preserve
153 said default in the new graph.
155 3) When applicable, copy resource variable metadata from `s` to the newly
156 created placeholder.
158 Args:
159 s: The source of interest.
160 graph: The destination graph.
161 op_map: A dict mapping ops and tensors in the old graph to the new one.
162 handle_captures: A boolean indicating whether to re-capture s in the new
163 graph or simply create a vanilla placeholder.
164 inverse_captures: A dict mapping s back to the Tensor or Variable that it
165 captures.
166 base_graph: The graph being copied from.
167 """
168 if handle_captures and s in inverse_captures:
169 copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name)
170 elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s):
171 # Copy the default value to the graph.
172 default_value = s.op.inputs[0]
173 unavailable_inputs, unavailable_control_inputs = _copy_non_source(
174 op=default_value.op, graph=graph, op_map=op_map,
175 base_graph=base_graph)
176 if unavailable_inputs or unavailable_control_inputs:
177 raise AssertionError(
178 "Could not copy source node {} because it has inputs."
179 .format(default_value))
181 with ops.device(s.op.device):
182 copied_placeholder = array_ops.placeholder_with_default(
183 input=op_map[default_value], shape=s.shape, name=s.op.name)
184 else:
185 with ops.device(s.op.device):
186 copied_placeholder = array_ops.placeholder(
187 dtype=s.dtype, shape=s.shape, name=s.op.name)
189 base_handle = resource_variable_ops.get_resource_handle_data(s)
190 if base_handle.shape_and_type:
191 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access
192 copied_placeholder,
193 base_handle,
194 graph_mode=True)
196 op_map[s] = copied_placeholder
197 # Add an entry for the op of the source tensor so that if there are any nodes
198 # depending on that op via control dependencies it can work correctly.
199 op_map[s.op] = copied_placeholder.op
202@tf_export("__internal__.lift_to_graph", v1=[])
203def lift_to_graph(tensors,
204 graph,
205 sources=None,
206 disallowed_placeholders=None,
207 add_sources=False,
208 handle_captures=False,
209 base_graph=None,
210 op_map=None):
211 """Copies the tensor and all its inputs recursively to the outer graph.
213 Args:
214 tensors: The Tensors to lift.
215 graph: The graph to lift to.
216 sources: Optional sequence of nodes to start from. If omitted the whole
217 subgraph which feeds into `init_tensor` is lifted.
218 disallowed_placeholders: An optional set of ops which may not appear in the
219 lifted graph. Defaults to all placeholders.
220 add_sources: A boolean indicating whether placeholders which are not in
221 sources should be allowed.
222 handle_captures: A boolean indicating whether to re-capture s in the new
223 graph or simply create a vanilla placeholder.
224 base_graph: The graph from which to lift ops. This will be inferred if not
225 specified.
226 op_map: A map contains all the existing nodes that have been lifted to the
227 destination graph, so they won't be lifted and copied again.
229 Returns:
230 A mapping from ops in the current default graph to ops in `graph`.
232 Raises:
233 UnliftableError: If a placeholder blocks lifting.
234 """
235 variable_init_tensors = []
236 init_tensors = []
237 for tensor in tensors:
238 if isinstance(tensor, resource_variable_ops.ResourceVariable):
239 variable_init_tensors.append(tensor)
240 else:
241 init_tensors.append(tensor)
242 base_graph = base_graph or init_tensors[0].graph
243 op_map = op_map or object_identity.ObjectIdentityDictionary()
245 # Check that the initializer does not depend on any placeholders.
246 sources = object_identity.ObjectIdentitySet(sources or [])
247 visited_ops = set(x.op for x in sources)
248 op_outputs = collections.defaultdict(set)
250 # First we extract the subgraph between init_tensors and sources.
251 for init_tensor in init_tensors:
252 sources.update(op_selector.map_subgraph(
253 init_tensor=init_tensor,
254 sources=sources,
255 disallowed_placeholders=disallowed_placeholders,
256 visited_ops=visited_ops,
257 op_outputs=op_outputs,
258 add_sources=add_sources))
260 # Try to topologically sort the nodes we've extracted. Now we know how many of
261 # their outputs are part of this subgraph.
262 ops_to_copy = []
263 marked_ops = set([])
264 ops_to_visit = [_as_operation(t) for t in init_tensors
265 if not op_outputs[_as_operation(t)]]
266 unvisited_ops = set(ops_to_visit)
267 while unvisited_ops:
268 while ops_to_visit:
269 op = ops_to_visit.pop()
270 if op in marked_ops:
271 continue
272 marked_ops.add(op)
273 ops_to_copy.append(op)
274 for inp in op_selector.graph_inputs(op):
275 # Don't lift the TPUReplicateMetadata nodes out of the function, because
276 # it has no registered kernels.
277 if inp.type == "TPUReplicateMetadata":
278 continue
279 unvisited_ops.add(inp)
280 if (all(x in marked_ops for x in op_outputs[inp]) and
281 inp not in sources):
282 ops_to_visit.append(inp)
283 unvisited_ops.difference_update(marked_ops)
284 if unvisited_ops:
285 # `unvisited_ops` should only have elements if the graph has a loop. In
286 # this case we want to keep copying and there's no topological ordering;
287 # we'll do ugly post-hoc mutations instead.
288 ops_to_visit.append(next(iter(unvisited_ops)))
290 # When the topological sort fails due to loops, it can result in exceptions
291 # later when copying a node which inputs haven't been copied yet. We can
292 # improve that pseudo-topological order slightly by putting the ops without
293 # inputs, such as constants, at the start of the topological order (i.e at
294 # the end of ops_to_copy).
295 ops_to_copy.sort(key=(lambda op: len(op_selector.graph_inputs(op)) == 0))
297 # When lifting from one FuncGraph to another, we will need to capture the
298 # relevant tensors as well.
299 captures = []
300 inverse_captures = object_identity.ObjectIdentityDictionary()
301 internal_captures = []
302 if (isinstance(base_graph, func_graph.FuncGraph) and
303 isinstance(graph, func_graph.FuncGraph)):
304 captures = base_graph.captures
305 for external_capture, internal_capture in captures:
306 inverse_captures[internal_capture] = external_capture
307 internal_captures = base_graph.internal_captures
309 # ops_to_copy now holds a reverse topologically sorted list of ops which
310 # ends in the initializer. We copy those to the outermost graph and
311 # build the initialization op there.
312 with graph.as_default():
313 for i in variable_init_tensors:
314 op_map[i] = i
315 source_ops = set()
316 # Add the sources in the same order as the original graph.
317 for s in internal_captures:
318 if s in sources:
319 sources.remove(s)
320 source_ops.add(s.op)
321 _copy_source(
322 s=s,
323 graph=graph,
324 op_map=op_map,
325 handle_captures=handle_captures,
326 inverse_captures=inverse_captures,
327 base_graph=base_graph)
328 for s in sources:
329 source_ops.add(s.op)
330 _copy_source(
331 s=s,
332 graph=graph,
333 op_map=op_map,
334 handle_captures=handle_captures,
335 inverse_captures=inverse_captures,
336 base_graph=base_graph)
338 input_mutations = []
339 control_mutations = []
340 for op in reversed(ops_to_copy):
341 if op in source_ops or op in op_map:
342 continue
343 new_input_mutations, new_control_mutations = _copy_non_source(
344 op=op, graph=graph, op_map=op_map, base_graph=base_graph)
345 input_mutations.extend(new_input_mutations)
346 control_mutations.extend(new_control_mutations)
348 # Mutate the new graph to insert any loops which existed in the source
349 # graph due to v1 while_loops.
350 #
351 # pylint: disable=protected-access
352 with graph._mutation_lock():
353 for mutation in input_mutations:
354 mutation.copied_op._update_input(
355 mutation.input_index, op_map[mutation.old_graph_tensor])
356 for mutation in control_mutations:
357 # Don't lift the TPUReplicateMetadata nodes out of the function, because
358 # it has no registered kernels.
359 if mutation.old_graph_op.type == "TPUReplicateMetadata":
360 continue
361 mutation.copied_op._add_control_input(op_map[mutation.old_graph_op])
362 # pylint: enable=protected-access
364 return op_map