Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/importer.py: 17%

210 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""A utility function for importing TensorFlow graphs.""" 

16import contextlib 

17 

18from tensorflow.core.framework import graph_pb2 

19from tensorflow.python import tf2 

20from tensorflow.python.client import pywrap_tf_session as c_api 

21from tensorflow.python.framework import c_api_util 

22from tensorflow.python.framework import device as pydev 

23from tensorflow.python.framework import errors 

24from tensorflow.python.framework import function 

25from tensorflow.python.framework import op_def_registry 

26from tensorflow.python.framework import ops 

27from tensorflow.python.ops import control_flow_util 

28from tensorflow.python.util import compat 

29from tensorflow.python.util.deprecation import deprecated_args 

30from tensorflow.python.util.tf_export import tf_export 

31 

32 

33def _IsControlInput(input_name): 

34 # Expected format: '^operation_name' (control input). 

35 return input_name.startswith('^') 

36 

37 

38def _ParseTensorName(tensor_name): 

39 """Parses a tensor name into an operation name and output index. 

40 

41 This function will canonicalize tensor names as follows: 

42 

43 * "foo:0" -> ("foo", 0) 

44 * "foo:7" -> ("foo", 7) 

45 * "foo" -> ("foo", 0) 

46 * "foo:bar:baz" -> ValueError 

47 

48 Args: 

49 tensor_name: The name of a tensor. 

50 

51 Returns: 

52 A tuple containing the operation name, and the output index. 

53 

54 Raises: 

55 ValueError: If `tensor_name' cannot be interpreted as the name of a tensor. 

56 """ 

57 components = tensor_name.split(':') 

58 if len(components) == 2: 

59 # Expected format: 'operation_name:output_index'. 

60 try: 

61 output_index = int(components[1]) 

62 except ValueError: 

63 raise ValueError(f'Cannot convert {tensor_name!r} to a tensor name. ' 

64 'Second component of the name following the `:` should ' 

65 f'be an int. Got {components[1]}.') 

66 return components[0], output_index 

67 elif len(components) == 1: 

68 # Expected format: 'operation_name' (implicit 0th output). 

69 return components[0], 0 

70 else: 

71 raise ValueError(f"Cannot convert '{tensor_name}' to a tensor name. Tensor " 

72 'names should not contain more than 1 `:`. Obtained ' 

73 f'{len(components) - 1}') 

74 

75 

76@contextlib.contextmanager 

77def _MaybeDevice(device): 

78 """Applies the given device only if device is not None or empty.""" 

79 if device: 

80 with ops.device(device): 

81 yield 

82 else: 

83 yield 

84 

85 

86def _ProcessGraphDefParam(graph_def): 

87 """Type-checks and possibly canonicalizes `graph_def`.""" 

88 if not isinstance(graph_def, graph_pb2.GraphDef): 

89 # `graph_def` could be a dynamically-created message, so try a duck-typed 

90 # approach 

91 try: 

92 old_graph_def = graph_def 

93 graph_def = graph_pb2.GraphDef() 

94 graph_def.MergeFrom(old_graph_def) 

95 except TypeError: 

96 raise TypeError('Argument `graph_def` must be a GraphDef proto.') 

97 else: 

98 # If we're using the graph_def provided by the caller, modify graph_def 

99 # in-place to add attr defaults to the NodeDefs (this is visible to the 

100 # caller). 

101 # NOTE(skyewm): this is undocumented behavior that at least meta_graph.py 

102 # depends on. It might make sense to move this to meta_graph.py and have 

103 # import_graph_def not modify the graph_def argument (we'd have to make sure 

104 # this doesn't break anything else.) 

105 for node in graph_def.node: 

106 op_def = op_def_registry.get(node.op) 

107 if op_def is None: 

108 # Assume unrecognized ops are functions for now. TF_ImportGraphDef will 

109 # report an error if the op is actually missing. 

110 continue 

111 _SetDefaultAttrValues(node, op_def) 

112 

113 return graph_def 

114 

115 

116def _ProcessInputMapParam(input_map): 

117 """Type-checks and possibly canonicalizes `input_map`.""" 

118 if input_map is None: 

119 input_map = {} 

120 else: 

121 if not isinstance(input_map, dict): 

122 raise TypeError('Argument `input_map` must be a dictionary. Obtained ' 

123 f'{type(input_map).__name__}') 

124 if not all( 

125 isinstance(k, compat.bytes_or_text_types) for k in input_map.keys()): 

126 raise TypeError('All keys for argument `input_map` must be strings. ' 

127 f'Obtained keys: {list(input_map.keys())}') 

128 return input_map 

129 

130 

131def _ProcessReturnElementsParam(return_elements): 

132 """Type-checks and possibly canonicalizes `return_elements`.""" 

133 if return_elements is None: 

134 return None 

135 if not all( 

136 isinstance(x, compat.bytes_or_text_types) for x in return_elements): 

137 raise TypeError('Argument `return_elements` must be a list of strings. ' 

138 f'Obtained {return_elements}.') 

139 return tuple(compat.as_str(x) for x in return_elements) 

140 

141 

142def _FindAttrInOpDef(attr_name, op_def): 

143 for attr_def in op_def.attr: 

144 if attr_name == attr_def.name: 

145 return attr_def 

146 return None 

147 

148 

149def _RemoveDefaultAttrs(producer_op_list, graph_def): 

150 """Removes unknown default attrs according to `producer_op_list`. 

151 

152 Removes any unknown attrs in `graph_def` (i.e. attrs that do not appear in 

153 registered OpDefs) that have a default value in `producer_op_list`. 

154 

155 Args: 

156 producer_op_list: OpList proto. 

157 graph_def: GraphDef proto 

158 """ 

159 producer_op_dict = {op.name: op for op in producer_op_list.op} 

160 for node in graph_def.node: 

161 # Remove any default attr values that aren't in op_def. 

162 if node.op in producer_op_dict: 

163 op_def = op_def_registry.get(node.op) 

164 if op_def is None: 

165 # Some custom op registrations won't show up here. That's OK, attribute 

166 # stripping just won't be available. 

167 continue 

168 producer_op_def = producer_op_dict[node.op] 

169 # We make a copy of node.attr to iterate through since we may modify 

170 # node.attr inside the loop. 

171 for key in list(node.attr): 

172 if _FindAttrInOpDef(key, op_def) is None: 

173 # No attr_def in consumer, look in producer. 

174 attr_def = _FindAttrInOpDef(key, producer_op_def) 

175 if (attr_def and attr_def.HasField('default_value') and 

176 node.attr[key] == attr_def.default_value): 

177 # Unknown attr had default value in producer, delete it so it can be 

178 # understood by consumer. 

179 del node.attr[key] 

180 

181 

182def _ConvertInputMapValues(name, input_map): 

183 """Ensures all input map values are tensors. 

184 

185 This should be called from inside the import name scope. 

186 

187 Args: 

188 name: the `name` argument passed to import_graph_def 

189 input_map: the `input_map` argument passed to import_graph_def. 

190 

191 Returns: 

192 An possibly-updated version of `input_map`. 

193 

194 Raises: 

195 ValueError: if input map values cannot be converted due to empty name scope. 

196 """ 

197 if not all(isinstance(v, ops.Tensor) for v in input_map.values()): 

198 if name == '': # pylint: disable=g-explicit-bool-comparison 

199 raise ValueError( 

200 'tf.import_graph_def() requires a non-empty `name` if `input_map` ' 

201 'contains non-Tensor values. Try calling tf.convert_to_tensor() on ' 

202 '`input_map` values before calling tf.import_graph_def().') 

203 with ops.name_scope('_inputs'): 

204 input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()} 

205 return input_map 

206 

207 

208def _PopulateTFImportGraphDefOptions(options, prefix, input_map, 

209 return_elements, 

210 validate_colocation_constraints, 

211 propagate_device_spec=False): 

212 """Populates the TF_ImportGraphDefOptions `options`.""" 

213 c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix) 

214 c_api.TF_ImportGraphDefOptionsSetUniquifyNames(options, True) 

215 c_api.TF_ImportGraphDefOptionsSetPropagateDeviceSpec(options, 

216 propagate_device_spec) 

217 

218 for input_src, input_dst in input_map.items(): 

219 input_src = compat.as_str(input_src) 

220 if input_src.startswith('^'): 

221 src_name = compat.as_str(input_src[1:]) 

222 dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access 

223 c_api.TF_ImportGraphDefOptionsRemapControlDependency( 

224 options, src_name, dst_op) 

225 else: 

226 src_name, src_idx = _ParseTensorName(input_src) 

227 src_name = compat.as_str(src_name) 

228 dst_output = input_dst._as_tf_output() # pylint: disable=protected-access 

229 c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name, src_idx, 

230 dst_output) 

231 for name in return_elements or []: 

232 if ':' in name: 

233 op_name, index = _ParseTensorName(name) 

234 op_name = compat.as_str(op_name) 

235 c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index) 

236 else: 

237 c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, 

238 compat.as_str(name)) 

239 

240 c_api.TF_ImportGraphDefOptionsSetValidateColocationConstraints( 

241 options, validate_colocation_constraints) 

242 

243 

244def _ProcessNewOps(graph): 

245 """Processes the newly-added TF_Operations in `graph`.""" 

246 # Maps from a node to the names of the ops it's colocated with, if colocation 

247 # is specified in the attributes. 

248 colocation_pairs = {} 

249 

250 for new_op in graph._add_new_tf_operations(compute_devices=False): # pylint: disable=protected-access 

251 original_device = new_op.device 

252 new_op._set_device('') # pylint: disable=protected-access 

253 colocation_names = _GetColocationNames(new_op) 

254 if colocation_names: 

255 colocation_pairs[new_op] = colocation_names 

256 # Don't set a device for this op, since colocation constraints override 

257 # device functions and the original device. Note that this op's device may 

258 # still be set by the loop below. 

259 # TODO(skyewm): why does it override the original device? 

260 else: 

261 with _MaybeDevice(original_device): 

262 graph._apply_device_functions(new_op) # pylint: disable=protected-access 

263 

264 # The following loop populates the device field of ops that are colocated 

265 # with another op. This is implied by the colocation attribute, but we 

266 # propagate the device field for completeness. 

267 for op, coloc_op_list in colocation_pairs.items(): 

268 coloc_device = None 

269 # Find any device in the list of colocated ops that have a device, if it 

270 # exists. We assume that if multiple ops have devices, they refer to the 

271 # same device. Otherwise, a runtime error will occur since the colocation 

272 # property cannot be guaranteed. Note in TF2 colocations have been removed 

273 # from the public API and will be considered a hint, so there is no runtime 

274 # error. 

275 # 

276 # One possible improvement is to try to check for compatibility of all 

277 # devices in this list at import time here, which would require 

278 # implementing a compatibility function for device specs in python. 

279 for coloc_op_name in coloc_op_list: 

280 try: 

281 coloc_op = graph._get_operation_by_name(coloc_op_name) # pylint: disable=protected-access 

282 except KeyError: 

283 # Do not error in TF2 if the colocation cannot be guaranteed 

284 if tf2.enabled() or control_flow_util.EnableControlFlowV2(graph): 

285 continue 

286 

287 raise ValueError(f'Specified colocation to an op: {coloc_op_name} that ' 

288 f'does not exist during import for op: {op.name}') 

289 if coloc_op.device: 

290 coloc_device = pydev.DeviceSpec.from_string(coloc_op.device) 

291 break 

292 if coloc_device: 

293 op._set_device(coloc_device) # pylint: disable=protected-access 

294 

295 

296def _GetColocationNames(op): 

297 """Returns names of the ops that `op` should be colocated with.""" 

298 colocation_names = [] 

299 try: 

300 class_values = op.get_attr('_class') 

301 except ValueError: 

302 # No _class attr 

303 return 

304 for val in class_values: 

305 val = compat.as_str(val) 

306 if val.startswith('loc:@'): 

307 colocation_node_name = val[len('loc:@'):] 

308 if colocation_node_name != op.name: 

309 colocation_names.append(colocation_node_name) 

310 return colocation_names 

311 

312 

313def _GatherReturnElements(requested_return_elements, graph, results): 

314 """Returns the requested return elements from results. 

315 

316 Args: 

317 requested_return_elements: list of strings of operation and tensor names 

318 graph: Graph 

319 results: wrapped TF_ImportGraphDefResults 

320 

321 Returns: 

322 list of `Operation` and/or `Tensor` objects 

323 """ 

324 return_outputs = c_api.TF_ImportGraphDefResultsReturnOutputs(results) 

325 return_opers = c_api.TF_ImportGraphDefResultsReturnOperations(results) 

326 

327 combined_return_elements = [] 

328 outputs_idx = 0 

329 opers_idx = 0 

330 for name in requested_return_elements: 

331 if ':' in name: 

332 combined_return_elements.append( 

333 graph._get_tensor_by_tf_output(return_outputs[outputs_idx])) # pylint: disable=protected-access 

334 outputs_idx += 1 

335 else: 

336 combined_return_elements.append( 

337 graph._get_operation_by_tf_operation(return_opers[opers_idx])) # pylint: disable=protected-access 

338 opers_idx += 1 

339 return combined_return_elements 

340 

341 

342def _SetDefaultAttrValues(node_def, op_def): 

343 """Set any default attr values in `node_def` that aren't present.""" 

344 assert node_def.op == op_def.name 

345 for attr_def in op_def.attr: 

346 key = attr_def.name 

347 if attr_def.HasField('default_value'): 

348 value = node_def.attr[key] 

349 if value is None or value.WhichOneof('value') is None: 

350 node_def.attr[key].CopyFrom(attr_def.default_value) 

351 

352 

353@tf_export('graph_util.import_graph_def', 'import_graph_def') 

354@deprecated_args(None, 'Please file an issue at ' 

355 'https://github.com/tensorflow/tensorflow/issues if you depend' 

356 ' on this feature.', 'op_dict') 

357def import_graph_def(graph_def, 

358 input_map=None, 

359 return_elements=None, 

360 name=None, 

361 op_dict=None, 

362 producer_op_list=None): 

363 """Imports the graph from `graph_def` into the current default `Graph`. 

364 

365 This function provides a way to import a serialized TensorFlow 

366 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 

367 protocol buffer, and extract individual objects in the `GraphDef` as 

368 `tf.Tensor` and `tf.Operation` objects. Once extracted, 

369 these objects are placed into the current default `Graph`. See 

370 `tf.Graph.as_graph_def` for a way to create a `GraphDef` 

371 proto. 

372 

373 Args: 

374 graph_def: A `GraphDef` proto containing operations to be imported into 

375 the default graph. 

376 input_map: A dictionary mapping input names (as strings) in `graph_def` 

377 to `Tensor` objects. The values of the named input tensors in the 

378 imported graph will be re-mapped to the respective `Tensor` values. 

379 return_elements: A list of strings containing operation names in 

380 `graph_def` that will be returned as `Operation` objects; and/or 

381 tensor names in `graph_def` that will be returned as `Tensor` objects. 

382 name: (Optional.) A prefix that will be prepended to the names in 

383 `graph_def`. Note that this does not apply to imported function names. 

384 Defaults to `"import"`. 

385 op_dict: (Optional.) Deprecated, do not use. 

386 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) 

387 list of `OpDef`s used by the producer of the graph. If provided, 

388 unrecognized attrs for ops in `graph_def` that have their default value 

389 according to `producer_op_list` will be removed. This will allow some more 

390 `GraphDef`s produced by later binaries to be accepted by earlier binaries. 

391 

392 Returns: 

393 A list of `Operation` and/or `Tensor` objects from the imported graph, 

394 corresponding to the names in `return_elements`, 

395 and None if `returns_elements` is None. 

396 

397 Raises: 

398 TypeError: If `graph_def` is not a `GraphDef` proto, 

399 `input_map` is not a dictionary mapping strings to `Tensor` objects, 

400 or `return_elements` is not a list of strings. 

401 ValueError: If `input_map`, or `return_elements` contains names that 

402 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. 

403 it refers to an unknown tensor). 

404 """ 

405 del op_dict 

406 return _import_graph_def_internal( 

407 graph_def, 

408 input_map=input_map, 

409 return_elements=return_elements, 

410 name=name, 

411 producer_op_list=producer_op_list) 

412 

413 

414def import_graph_def_for_function( # pylint: disable=invalid-name 

415 graph_def, name=None, propagate_device_spec=False): 

416 """Like import_graph_def but does not validate colocation constraints.""" 

417 return _import_graph_def_internal( 

418 graph_def, 

419 validate_colocation_constraints=False, 

420 name=name, 

421 propagate_device_spec=propagate_device_spec) 

422 

423 

424def _import_graph_def_internal( # pylint: disable=invalid-name 

425 graph_def, 

426 input_map=None, 

427 return_elements=None, 

428 validate_colocation_constraints=True, 

429 name=None, 

430 producer_op_list=None, 

431 propagate_device_spec=False): 

432 """Imports the graph from `graph_def` into the current default `Graph`. 

433 

434 This function provides a way to import a serialized TensorFlow 

435 [`GraphDef`](https://www.tensorflow.org/code/tensorflow/core/framework/graph.proto) 

436 protocol buffer, and extract individual objects in the `GraphDef` as 

437 `tf.Tensor` and `tf.Operation` objects. Once extracted, 

438 these objects are placed into the current default `Graph`. See 

439 `tf.Graph.as_graph_def` for a way to create a `GraphDef` 

440 proto. 

441 

442 Args: 

443 graph_def: A `GraphDef` proto containing operations to be imported into the 

444 default graph. 

445 input_map: A dictionary mapping input names (as strings) in `graph_def` to 

446 `Tensor` objects. The values of the named input tensors in the imported 

447 graph will be re-mapped to the respective `Tensor` values. 

448 return_elements: A list of strings containing operation names in `graph_def` 

449 that will be returned as `Operation` objects; and/or tensor names in 

450 `graph_def` that will be returned as `Tensor` objects. 

451 validate_colocation_constraints: Whether to validate colocation constraints. 

452 name: (Optional.) A prefix that will be prepended to the names in 

453 `graph_def`. Note that this does not apply to imported function names. 

454 Defaults to `"import"`. 

455 producer_op_list: (Optional.) An `OpList` proto with the (possibly stripped) 

456 list of `OpDef`s used by the producer of the graph. If provided, 

457 unrecognized attrs for ops in `graph_def` that have their default value 

458 according to `producer_op_list` will be removed. This will allow some more 

459 `GraphDef`s produced by later binaries to be accepted by earlier binaries. 

460 propagate_device_spec: Whether to propagate assigned device information 

461 when importing a graph from a GraphDef into the current default `Graph`. 

462 

463 Returns: 

464 A list of `Operation` and/or `Tensor` objects from the imported graph, 

465 corresponding to the names in `return_elements`, 

466 and None if `returns_elements` is None. 

467 

468 Raises: 

469 TypeError: If `graph_def` is not a `GraphDef` proto, 

470 `input_map` is not a dictionary mapping strings to `Tensor` objects, 

471 or `return_elements` is not a list of strings. 

472 ValueError: If `input_map`, or `return_elements` contains names that 

473 do not appear in `graph_def`, or `graph_def` is not well-formed (e.g. 

474 it refers to an unknown tensor). 

475 """ 

476 graph_def = _ProcessGraphDefParam(graph_def) 

477 input_map = _ProcessInputMapParam(input_map) 

478 return_elements = _ProcessReturnElementsParam(return_elements) 

479 

480 if producer_op_list is not None: 

481 # TODO(skyewm): make a copy of graph_def so we're not mutating the argument? 

482 _RemoveDefaultAttrs(producer_op_list, graph_def) 

483 

484 graph = ops.get_default_graph() 

485 with ops.name_scope(name, 'import', input_map.values()) as scope: 

486 # Save unique prefix generated by name_scope 

487 if scope: 

488 assert scope.endswith('/') 

489 prefix = scope[:-1] 

490 else: 

491 prefix = '' 

492 

493 # Generate any input map tensors inside name scope 

494 input_map = _ConvertInputMapValues(name, input_map) 

495 

496 scoped_options = c_api_util.ScopedTFImportGraphDefOptions() 

497 options = scoped_options.options 

498 _PopulateTFImportGraphDefOptions(options, prefix, input_map, return_elements, 

499 validate_colocation_constraints, 

500 propagate_device_spec) 

501 

502 # _ProcessNewOps mutates the new operations. _mutation_lock ensures a 

503 # Session.run call cannot occur between creating the TF_Operations in the 

504 # TF_GraphImportGraphDefWithResults call and mutating the them in 

505 # _ProcessNewOps. 

506 with graph._mutation_lock(): # pylint: disable=protected-access 

507 with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized: 

508 try: 

509 with graph._c_graph.get() as c_graph: # pylint: disable=protected-access 

510 results = c_api.TF_GraphImportGraphDefWithResults( 

511 c_graph, serialized, options) 

512 results = c_api_util.ScopedTFImportGraphDefResults(results) 

513 except errors.InvalidArgumentError as e: 

514 # Convert to ValueError for backwards compatibility. 

515 raise ValueError(str(e)) 

516 

517 # Create _DefinedFunctions for any imported functions. 

518 # 

519 # We do this by creating _DefinedFunctions directly from `graph_def`, and 

520 # adding them to `graph`. Adding an existing function to a TF_Graph is a 

521 # no-op, so this only has the effect of updating the Python state (usually 

522 # _DefinedFunction.add_to_graph also adds the function to the TF_Graph). 

523 # 

524 # TODO(skyewm): fetch the TF_Functions directly from the TF_Graph 

525 # TODO(skyewm): avoid sending serialized FunctionDefs back to the TF_Graph 

526 

527 _ProcessNewOps(graph) 

528 

529 if graph_def.library and graph_def.library.function: 

530 functions = function.from_library(graph_def.library) 

531 for f in functions: 

532 f.add_to_graph(graph) 

533 

534 # Treat input mappings that don't appear in the graph as an error, because 

535 # they are likely to be due to a typo. 

536 missing_unused_input_keys = ( 

537 c_api.TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( 

538 results.results)) 

539 if missing_unused_input_keys: 

540 missing_unused_input_keys = [ 

541 compat.as_str(s) for s in missing_unused_input_keys 

542 ] 

543 missing_keys = ', '.join(missing_unused_input_keys) 

544 raise ValueError( 

545 'Attempted to map inputs that were not found in graph_def: ' 

546 f'[{missing_keys}]') 

547 

548 if return_elements is None: 

549 return None 

550 else: 

551 return _GatherReturnElements(return_elements, graph, results.results)