Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/importer.py: 17%
210 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""A utility function for importing TensorFlow graphs."""
16import contextlib
18from tensorflow.core.framework import graph_pb2
19from tensorflow.python import tf2
20from tensorflow.python.client import pywrap_tf_session as c_api
21from tensorflow.python.framework import c_api_util
22from tensorflow.python.framework import device as pydev
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import function
25from tensorflow.python.framework import op_def_registry
26from tensorflow.python.framework import ops
27from tensorflow.python.ops import control_flow_util
28from tensorflow.python.util import compat
29from tensorflow.python.util.deprecation import deprecated_args
30from tensorflow.python.util.tf_export import tf_export
33def _IsControlInput(input_name):
34 # Expected format: '^operation_name' (control input).
35 return input_name.startswith('^')
38def _ParseTensorName(tensor_name):
39 """Parses a tensor name into an operation name and output index.
41 This function will canonicalize tensor names as follows:
43 * "foo:0" -> ("foo", 0)
44 * "foo:7" -> ("foo", 7)
45 * "foo" -> ("foo", 0)
46 * "foo:bar:baz" -> ValueError
48 Args:
49 tensor_name: The name of a tensor.
51 Returns:
52 A tuple containing the operation name, and the output index.
54 Raises:
55 ValueError: If `tensor_name' cannot be interpreted as the name of a tensor.
56 """
57 components = tensor_name.split(':')
58 if len(components) == 2:
59 # Expected format: 'operation_name:output_index'.
60 try:
61 output_index = int(components[1])
62 except ValueError:
63 raise ValueError(f'Cannot convert {tensor_name!r} to a tensor name. '
64 'Second component of the name following the `:` should '
65 f'be an int. Got {components[1]}.')
66 return components[0], output_index
67 elif len(components) == 1:
68 # Expected format: 'operation_name' (implicit 0th output).
69 return components[0], 0
70 else:
71 raise ValueError(f"Cannot convert '{tensor_name}' to a tensor name. Tensor "
72 'names should not contain more than 1 `:`. Obtained '
73 f'{len(components) - 1}')
76@contextlib.contextmanager
77def _MaybeDevice(device):
78 """Applies the given device only if device is not None or empty."""
79 if device:
80 with ops.device(device):
81 yield
82 else:
83 yield
86def _ProcessGraphDefParam(graph_def):
87 """Type-checks and possibly canonicalizes `graph_def`."""
88 if not isinstance(graph_def, graph_pb2.GraphDef):
89 # `graph_def` could be a dynamically-created message, so try a duck-typed
90 # approach
91 try:
92 old_graph_def = graph_def
93 graph_def = graph_pb2.GraphDef()
94 graph_def.MergeFrom(old_graph_def)
95 except TypeError:
96 raise TypeError('Argument `graph_def` must be a GraphDef proto.')
97 else:
98 # If we're using the graph_def provided by the caller, modify graph_def
99 # in-place to add attr defaults to the NodeDefs (this is visible to the
100 # caller).
101 # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py
102 # depends on. It might make sense to move this to meta_graph.py and have
103 # import_graph_def not modify the graph_def argument (we'd have to make sure
104 # this doesn't break anything else.)
105 for node in graph_def.node:
106 op_def = op_def_registry.get(node.op)
107 if op_def is None:
108 # Assume unrecognized ops are functions for now. TF_ImportGraphDef will
109 # report an error if the op is actually missing.
110 continue
111 _SetDefaultAttrValues(node, op_def)
113 return graph_def
116def _ProcessInputMapParam(input_map):
117 """Type-checks and possibly canonicalizes `input_map`."""
118 if input_map is None:
119 input_map = {}
120 else:
121 if not isinstance(input_map, dict):
122 raise TypeError('Argument `input_map` must be a dictionary. Obtained '
123 f'{type(input_map).__name__}')
124 if not all(
125 isinstance(k, compat.bytes_or_text_types) for k in input_map.keys()):
126 raise TypeError('All keys for argument `input_map` must be strings. '
127 f'Obtained keys: {list(input_map.keys())}')
128 return input_map
131def _ProcessReturnElementsParam(return_elements):
132 """Type-checks and possibly canonicalizes `return_elements`."""
133 if return_elements is None:
134 return None
135 if not all(
136 isinstance(x, compat.bytes_or_text_types) for x in return_elements):
137 raise TypeError('Argument `return_elements` must be a list of strings. '
138 f'Obtained {return_elements}.')
139 return tuple(compat.as_str(x) for x in return_elements)
142def _FindAttrInOpDef(attr_name, op_def):
143 for attr_def in op_def.attr:
144 if attr_name == attr_def.name:
145 return attr_def
146 return None
149def _RemoveDefaultAttrs(producer_op_list, graph_def):
150 """Removes unknown default attrs according to `producer_op_list`.
152 Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in
153 registered OpDefs) that have a default value in `producer_op_list`.
155 Args:
156 producer_op_list: OpList proto.
157 graph_def: GraphDef proto
158 """
159 producer_op_dict = {op.name: op for op in producer_op_list.op}
160 for node in graph_def.node:
161 # Remove any default attr values that aren't in op_def.
162 if node.op in producer_op_dict:
163 op_def = op_def_registry.get(node.op)
164 if op_def is None:
165 # Some custom op registrations won't show up here. That's OK, attribute
166 # stripping just won't be available.
167 continue
168 producer_op_def = producer_op_dict[node.op]
169 # We make a copy of node.attr to iterate through since we may modify
170 # node.attr inside the loop.
171 for key in list(node.attr):
172 if _FindAttrInOpDef(key, op_def) is None:
173 # No attr_def in consumer, look in producer.
174 attr_def = _FindAttrInOpDef(key, producer_op_def)
175 if (attr_def and attr_def.HasField('default_value') and
176 node.attr[key] == attr_def.default_value):
177 # Unknown attr had default value in producer, delete it so it can be
178 # understood by consumer.
179 del node.attr[key]
182def _ConvertInputMapValues(name, input_map):
183 """Ensures all input map values are tensors.
185 This should be called from inside the import name scope.
187 Args:
188 name: the `name` argument passed to import_graph_def
189 input_map: the `input_map` argument passed to import_graph_def.
191 Returns:
192 An possibly-updated version of `input_map`.
194 Raises:
195 ValueError: if input map values cannot be converted due to empty name scope.
196 """
197 if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
198 if name == '': # pylint: disable=g-explicit-bool-comparison
199 raise ValueError(
200 'tf.import_graph_def() requires a non-empty `name` if `input_map` '
201 'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
202 '`input_map` values before calling tf.import_graph_def().')
203 with ops.name_scope('_inputs'):
204 input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
205 return input_map
208def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
209 return_elements,
210 validate_colocation_constraints,
211 propagate_device_spec=False):
212 """Populates the TF_ImportGraphDefOptions `options`."""
213 c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
214 c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True)
215 c_api.TF_ImportGraphDefOptionsSetPropagateDeviceSpec(options,
216 propagate_device_spec)
218 for input_src, input_dst in input_map.items():
219 input_src = compat.as_str(input_src)
220 if input_src.startswith('^'):
221 src_name = compat.as_str(input_src[1:])
222 dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access
223 c_api.TF_ImportGraphDefOptionsRemapControlDependency(
224 options, src_name, dst_op)
225 else:
226 src_name, src_idx = _ParseTensorName(input_src)
227 src_name = compat.as_str(src_name)
228 dst_output = input_dst._as_tf_output() # pylint: disable=protected-access
229 c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx,
230 dst_output)
231 for name in return_elements or []:
232 if ':' in name:
233 op_name, index = _ParseTensorName(name)
234 op_name = compat.as_str(op_name)
235 c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
236 else:
237 c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
238 compat.as_str(name))
240 c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints(
241 options, validate_colocation_constraints)
244def _ProcessNewOps(graph):
245 """Processes the newly-added TF_Operations in `graph`."""
246 # Maps from a node to the names of the ops it's colocated with, if colocation
247 # is specified in the attributes.
248 colocation_pairs = {}
250 for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access
251 original_device = new_op.device
252 new_op._set_device('') # pylint: disable=protected-access
253 colocation_names = _GetColocationNames(new_op)
254 if colocation_names:
255 colocation_pairs[new_op] = colocation_names
256 # Don't set a device for this op, since colocation constraints override
257 # device functions and the original device. Note that this op's device may
258 # still be set by the loop below.
259 # TODO(skyewm): why does it override the original device?
260 else:
261 with _MaybeDevice(original_device):
262 graph._apply_device_functions(new_op) # pylint: disable=protected-access
264 # The following loop populates the device field of ops that are colocated
265 # with another op. This is implied by the colocation attribute, but we
266 # propagate the device field for completeness.
267 for op, coloc_op_list in colocation_pairs.items():
268 coloc_device = None
269 # Find any device in the list of colocated ops that have a device, if it
270 # exists. We assume that if multiple ops have devices, they refer to the
271 # same device. Otherwise, a runtime error will occur since the colocation
272 # property cannot be guaranteed. Note in TF2 colocations have been removed
273 # from the public API and will be considered a hint, so there is no runtime
274 # error.
275 #
276 # One possible improvement is to try to check for compatibility of all
277 # devices in this list at import time here, which would require
278 # implementing a compatibility function for device specs in python.
279 for coloc_op_name in coloc_op_list:
280 try:
281 coloc_op = graph._get_operation_by_name(coloc_op_name) # pylint: disable=protected-access
282 except KeyError:
283 # Do not error in TF2 if the colocation cannot be guaranteed
284 if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph):
285 continue
287 raise ValueError(f'Specified colocation to an op: {coloc_op_name} that '
288 f'does not exist during import for op: {op.name}')
289 if coloc_op.device:
290 coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
291 break
292 if coloc_device:
293 op._set_device(coloc_device) # pylint: disable=protected-access
296def _GetColocationNames(op):
297 """Returns names of the ops that `op` should be colocated with."""
298 colocation_names = []
299 try:
300 class_values = op.get_attr('_class')
301 except ValueError:
302 # No _class attr
303 return
304 for val in class_values:
305 val = compat.as_str(val)
306 if val.startswith('loc:@'):
307 colocation_node_name = val[len('loc:@'):]
308 if colocation_node_name != op.name:
309 colocation_names.append(colocation_node_name)
310 return colocation_names
313def _GatherReturnElements(requested_return_elements, graph, results):
314 """Returns the requested return elements from results.
316 Args:
317 requested_return_elements: list of strings of operation and tensor names
318 graph: Graph
319 results: wrapped TF_ImportGraphDefResults
321 Returns:
322 list of `Operation` and/or `Tensor` objects
323 """
324 return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results)
325 return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results)
327 combined_return_elements = []
328 outputs_idx = 0
329 opers_idx = 0
330 for name in requested_return_elements:
331 if ':' in name:
332 combined_return_elements.append(
333 graph._get_tensor_by_tf_output(return_outputs[outputs_idx])) # pylint: disable=protected-access
334 outputs_idx += 1
335 else:
336 combined_return_elements.append(
337 graph._get_operation_by_tf_operation(return_opers[opers_idx])) # pylint: disable=protected-access
338 opers_idx += 1
339 return combined_return_elements
342def _SetDefaultAttrValues(node_def, op_def):
343 """Set any default attr values in `node_def` that aren't present."""
344 assert node_def.op == op_def.name
345 for attr_def in op_def.attr:
346 key = attr_def.name
347 if attr_def.HasField('default_value'):
348 value = node_def.attr[key]
349 if value is None or value.WhichOneof('value') is None:
350 node_def.attr[key].CopyFrom(attr_def.default_value)
353@tf_export('graph_util.import_graph_def', 'import_graph_def')
354@deprecated_args(None, 'Please file an issue at '
355 'https://github.com/tensorflow/tensorflow/issues if you depend'
356 ' on this feature.', 'op_dict')
357def import_graph_def(graph_def,
358 input_map=None,
359 return_elements=None,
360 name=None,
361 op_dict=None,
362 producer_op_list=None):
363 """Imports the graph from `graph_def` into the current default `Graph`.
365 This function provides a way to import a serialized TensorFlow
366 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
367 protocol buffer, and extract individual objects in the `GraphDef` as
368 `tf.Tensor` and `tf.Operation` objects. Once extracted,
369 these objects are placed into the current default `Graph`. See
370 `tf.Graph.as_graph_def` for a way to create a `GraphDef`
371 proto.
373 Args:
374 graph_def: A `GraphDef` proto containing operations to be imported into
375 the default graph.
376 input_map: A dictionary mapping input names (as strings) in `graph_def`
377 to `Tensor` objects. The values of the named input tensors in the
378 imported graph will be re-mapped to the respective `Tensor` values.
379 return_elements: A list of strings containing operation names in
380 `graph_def` that will be returned as `Operation` objects; and/or
381 tensor names in `graph_def` that will be returned as `Tensor` objects.
382 name: (Optional.) A prefix that will be prepended to the names in
383 `graph_def`. Note that this does not apply to imported function names.
384 Defaults to `"import"`.
385 op_dict: (Optional.) Deprecated, do not use.
386 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
387 list of `OpDef`s used by the producer of the graph. If provided,
388 unrecognized attrs for ops in `graph_def` that have their default value
389 according to `producer_op_list` will be removed. This will allow some more
390 `GraphDef`s produced by later binaries to be accepted by earlier binaries.
392 Returns:
393 A list of `Operation` and/or `Tensor` objects from the imported graph,
394 corresponding to the names in `return_elements`,
395 and None if `returns_elements` is None.
397 Raises:
398 TypeError: If `graph_def` is not a `GraphDef` proto,
399 `input_map` is not a dictionary mapping strings to `Tensor` objects,
400 or `return_elements` is not a list of strings.
401 ValueError: If `input_map`, or `return_elements` contains names that
402 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
403 it refers to an unknown tensor).
404 """
405 del op_dict
406 return _import_graph_def_internal(
407 graph_def,
408 input_map=input_map,
409 return_elements=return_elements,
410 name=name,
411 producer_op_list=producer_op_list)
414def import_graph_def_for_function( # pylint: disable=invalid-name
415 graph_def, name=None, propagate_device_spec=False):
416 """Like import_graph_def but does not validate colocation constraints."""
417 return _import_graph_def_internal(
418 graph_def,
419 validate_colocation_constraints=False,
420 name=name,
421 propagate_device_spec=propagate_device_spec)
424def _import_graph_def_internal( # pylint: disable=invalid-name
425 graph_def,
426 input_map=None,
427 return_elements=None,
428 validate_colocation_constraints=True,
429 name=None,
430 producer_op_list=None,
431 propagate_device_spec=False):
432 """Imports the graph from `graph_def` into the current default `Graph`.
434 This function provides a way to import a serialized TensorFlow
435 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto)
436 protocol buffer, and extract individual objects in the `GraphDef` as
437 `tf.Tensor` and `tf.Operation` objects. Once extracted,
438 these objects are placed into the current default `Graph`. See
439 `tf.Graph.as_graph_def` for a way to create a `GraphDef`
440 proto.
442 Args:
443 graph_def: A `GraphDef` proto containing operations to be imported into the
444 default graph.
445 input_map: A dictionary mapping input names (as strings) in `graph_def` to
446 `Tensor` objects. The values of the named input tensors in the imported
447 graph will be re-mapped to the respective `Tensor` values.
448 return_elements: A list of strings containing operation names in `graph_def`
449 that will be returned as `Operation` objects; and/or tensor names in
450 `graph_def` that will be returned as `Tensor` objects.
451 validate_colocation_constraints: Whether to validate colocation constraints.
452 name: (Optional.) A prefix that will be prepended to the names in
453 `graph_def`. Note that this does not apply to imported function names.
454 Defaults to `"import"`.
455 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped)
456 list of `OpDef`s used by the producer of the graph. If provided,
457 unrecognized attrs for ops in `graph_def` that have their default value
458 according to `producer_op_list` will be removed. This will allow some more
459 `GraphDef`s produced by later binaries to be accepted by earlier binaries.
460 propagate_device_spec: Whether to propagate assigned device information
461 when importing a graph from a GraphDef into the current default `Graph`.
463 Returns:
464 A list of `Operation` and/or `Tensor` objects from the imported graph,
465 corresponding to the names in `return_elements`,
466 and None if `returns_elements` is None.
468 Raises:
469 TypeError: If `graph_def` is not a `GraphDef` proto,
470 `input_map` is not a dictionary mapping strings to `Tensor` objects,
471 or `return_elements` is not a list of strings.
472 ValueError: If `input_map`, or `return_elements` contains names that
473 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g.
474 it refers to an unknown tensor).
475 """
476 graph_def = _ProcessGraphDefParam(graph_def)
477 input_map = _ProcessInputMapParam(input_map)
478 return_elements = _ProcessReturnElementsParam(return_elements)
480 if producer_op_list is not None:
481 # TODO(skyewm): make a copy of graph_def so we're not mutating the argument?
482 _RemoveDefaultAttrs(producer_op_list, graph_def)
484 graph = ops.get_default_graph()
485 with ops.name_scope(name, 'import', input_map.values()) as scope:
486 # Save unique prefix generated by name_scope
487 if scope:
488 assert scope.endswith('/')
489 prefix = scope[:-1]
490 else:
491 prefix = ''
493 # Generate any input map tensors inside name scope
494 input_map = _ConvertInputMapValues(name, input_map)
496 scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
497 options = scoped_options.options
498 _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements,
499 validate_colocation_constraints,
500 propagate_device_spec)
502 # _ProcessNewOps mutates the new operations. _mutation_lock ensures a
503 # Session.run call cannot occur between creating the TF_Operations in the
504 # TF_GraphImportGraphDefWithResults call and mutating the them in
505 # _ProcessNewOps.
506 with graph._mutation_lock(): # pylint: disable=protected-access
507 with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
508 try:
509 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access
510 results = c_api.TF_GraphImportGraphDefWithResults(
511 c_graph, serialized, options)
512 results = c_api_util.ScopedTFImportGraphDefResults(results)
513 except errors.InvalidArgumentError as e:
514 # Convert to ValueError for backwards compatibility.
515 raise ValueError(str(e))
517 # Create _DefinedFunctions for any imported functions.
518 #
519 # We do this by creating _DefinedFunctions directly from `graph_def`, and
520 # adding them to `graph`. Adding an existing function to a TF_Graph is a
521 # no-op, so this only has the effect of updating the Python state (usually
522 # _DefinedFunction.add_to_graph also adds the function to the TF_Graph).
523 #
524 # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph
525 # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph
527 _ProcessNewOps(graph)
529 if graph_def.library and graph_def.library.function:
530 functions = function.from_library(graph_def.library)
531 for f in functions:
532 f.add_to_graph(graph)
534 # Treat input mappings that don't appear in the graph as an error, because
535 # they are likely to be due to a typo.
536 missing_unused_input_keys = (
537 c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper(
538 results.results))
539 if missing_unused_input_keys:
540 missing_unused_input_keys = [
541 compat.as_str(s) for s in missing_unused_input_keys
542 ]
543 missing_keys = ', '.join(missing_unused_input_keys)
544 raise ValueError(
545 'Attempted to map inputs that were not found in graph_def: '
546 f'[{missing_keys}]')
548 if return_elements is None:
549 return None
550 else:
551 return _GatherReturnElements(return_elements, graph, results.results)