Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py: 22%

441 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""FuncGraph and related functionality.""" 

16 

17import traceback 

18from typing import Any, Callable, Hashable 

19import weakref 

20 

21from tensorflow.core.function import trace_type 

22from tensorflow.core.function.capture import capture_container 

23from tensorflow.python.eager import context 

24from tensorflow.python.eager import execute 

25from tensorflow.python.eager.polymorphic_function import composite_tensor_utils 

26from tensorflow.python.framework import auto_control_deps 

27from tensorflow.python.framework import composite_tensor 

28from tensorflow.python.framework import constant_op 

29from tensorflow.python.framework import dtypes 

30from tensorflow.python.framework import errors 

31from tensorflow.python.framework import indexed_slices 

32from tensorflow.python.framework import ops 

33from tensorflow.python.framework import tensor_spec 

34from tensorflow.python.framework import type_spec 

35from tensorflow.python.ops import array_ops 

36from tensorflow.python.ops import resource_variable_ops 

37from tensorflow.python.ops import tensor_array_ops 

38from tensorflow.python.ops import variable_scope 

39from tensorflow.python.saved_model import save_context 

40from tensorflow.python.types import core 

41from tensorflow.python.util import compat 

42from tensorflow.python.util import nest 

43from tensorflow.python.util import object_identity 

44from tensorflow.python.util import tf_contextlib 

45from tensorflow.python.util import tf_decorator 

46from tensorflow.python.util import tf_inspect 

47from tensorflow.python.util import variable_utils 

48from tensorflow.python.util.tf_export import tf_export 

49 

50 

51ALLOWLIST_COLLECTIONS = [ 

52 ops.GraphKeys.GLOBAL_VARIABLES, 

53 ops.GraphKeys.LOCAL_VARIABLES, 

54 ops.GraphKeys.TRAINABLE_VARIABLES, 

55 variable_scope._VARSTORE_KEY, # pylint: disable=protected-access 

56 variable_scope._VARSCOPESTORE_KEY # pylint: disable=protected-access 

57] 

58 

59 

60class UnknownArgument(object): 

61 """Signifies an argument which is not currently handled.""" 

62 

63 

64def convert_structure_to_signature(structure, arg_names=None, 

65 signature_context=None): 

66 """Convert a potentially nested structure to a signature. 

67 

68 Args: 

69 structure: Structure to convert, where top level collection is a list or a 

70 tuple. 

71 arg_names: Optional list of arguments that has equal number of elements as 

72 `structure` and is used for naming corresponding TensorSpecs. 

73 signature_context: TraceType InternalTracingContext to generate alias_ids 

74 for mutable objects, like ResourceVariables. 

75 

76 Returns: 

77 Identical structure that has TensorSpec objects instead of Tensors and 

78 UnknownArgument instead of any unsupported types. 

79 """ 

80 

81 def encode_arg(arg, path): 

82 """A representation for this argument, for converting into signatures.""" 

83 if isinstance(arg, ops.Tensor): 

84 user_specified_name = None 

85 try: 

86 user_specified_name = compat.as_str( 

87 arg.op.get_attr("_user_specified_name")) 

88 except (ValueError, AttributeError): 

89 pass 

90 

91 if path and user_specified_name and user_specified_name != path[0]: 

92 # The user has explicitly named the argument differently than the name 

93 # of the function argument. 

94 name = user_specified_name 

95 else: 

96 name = tensor_spec.sanitize_spec_name("_".join(str(p) for p in path)) 

97 return tensor_spec.TensorSpec(arg.shape, arg.dtype, name) 

98 if isinstance(arg, resource_variable_ops.ResourceVariable): 

99 return trace_type.from_value(arg, signature_context) 

100 if isinstance(arg, composite_tensor.CompositeTensor): 

101 # TODO(b/133606651) Do we need to inject arg_name? 

102 return arg._type_spec # pylint: disable=protected-access 

103 if isinstance(arg, ( 

104 int, 

105 float, 

106 bool, 

107 str, 

108 type(None), 

109 dtypes.DType, 

110 tensor_spec.TensorSpec, 

111 type_spec.TypeSpec, 

112 )): 

113 return arg 

114 return UnknownArgument() 

115 

116 # We are using the flattened paths to name the TensorSpecs. We need an 

117 # explicit name for them downstream. 

118 flattened = nest.flatten_with_tuple_paths(structure) 

119 if arg_names: 

120 if len(arg_names) != len(structure): 

121 raise ValueError( 

122 "Passed in arg_names don't match actual signature (%s)." % arg_names) 

123 # Replace all top-level names with their actual arg_names. If a path before 

124 # was "(2,'a',1)", it will become "(arg_names[2],'a',1)". 

125 flattened = [ 

126 ((arg_names[path[0]],) + path[1:], arg) for path, arg in flattened 

127 ] 

128 

129 mapped = [encode_arg(arg, path) for path, arg in flattened] 

130 return nest.pack_sequence_as(structure, mapped) 

131 

132 

133@tf_export("__internal__.FuncGraph", v1=[]) 

134class FuncGraph(ops.Graph): 

135 """Graph representing a function body. 

136 

137 Attributes: 

138 name: The name of the function. 

139 inputs: Placeholder tensors representing the inputs to this function. The 

140 tensors are in this FuncGraph. This represents "regular" inputs as well as 

141 captured inputs (i.e. the values of self.captures), with the regular 

142 inputs coming first. 

143 outputs: Tensors that will be returned by this function. The tensors are in 

144 this FuncGraph. 

145 control_outputs: Operations that must be executed before the function 

146 represented by this graph can be said to have been executed. 

147 structured_input_signature: A tuple of (args, kwargs), which are both 

148 possibly-nested python objects that were received by this function. Note 

149 that these structures might contain Python `None`s. 

150 structured_outputs: A possibly-nested python object which will be returned 

151 by this function. The Tensors in this structure are the same as those of 

152 self.outputs. Note that this structure might contain Python `None`s. 

153 variables: Variables that should be watched during function execution. 

154 outer_graph: The graph this function is defined in. May be another FuncGraph 

155 or the global default Graph. 

156 captures: Maps external tensor -> internal tensor (i.e. input placeholder). 

157 The entries are in the order they were captured. 

158 seed: The graph-level random seed. 

159 capture_by_value: If True, the func graph will capture Variables by value 

160 instead of reference. 

161 """ 

162 

163 def __init__(self, 

164 name, 

165 collections=None, 

166 capture_by_value=None, 

167 structured_input_signature=None, 

168 structured_outputs=None): 

169 """Construct a new FuncGraph. 

170 

171 The graph will inherit its graph key, collections, seed, and distribution 

172 strategy stack from the current context or graph. 

173 

174 Args: 

175 name: the name of the function. 

176 collections: a dictionary of collections this FuncGraph should start with. 

177 If not specified (None), the FuncGraph will read (but not write to) the 

178 outer graph's collections that are not allowlisted, and both read and 

179 write to the outer graph's collections that are allowlisted. The current 

180 allowlisted collections are the global variables, the local variables, 

181 and the trainable variables. Defaults to None. 

182 capture_by_value: An optional boolean. If True, the func graph will 

183 capture Variables by value instead of reference. By default inherit from 

184 outer graphs, and failing that will default to False. 

185 structured_input_signature: Optional. The structured input signature to 

186 use for initializing the FuncGraph. See the docstring for FuncGraph for 

187 more information. 

188 structured_outputs: Optional. The structured outputs to use for 

189 initializing the FuncGraph. See the docstring for FuncGraph for more 

190 information. 

191 """ 

192 super().__init__() 

193 self.name = name 

194 # TODO(panzf): Separate captures from non-captures inputs in self.inputs 

195 self.inputs = [] 

196 self.outputs = [] 

197 self.control_outputs = [] 

198 self.structured_input_signature = structured_input_signature 

199 self.structured_outputs = structured_outputs 

200 self._resource_tensor_inputs = object_identity.ObjectIdentitySet() 

201 self._weak_variables = [] 

202 self._watched_variables = object_identity.ObjectIdentityWeakSet() 

203 self.is_control_flow_graph = False 

204 

205 self._function_captures = capture_container.FunctionCaptures() 

206 outer_graph = ops.get_default_graph() 

207 self._weak_outer_graph = weakref.ref(outer_graph) 

208 while outer_graph.building_function: 

209 outer_graph = outer_graph.outer_graph 

210 # If self._weak_outer_graph is deleted, we revert to the outermost Graph 

211 # active when the FuncGraph was traced. This will not be a FuncGraph. 

212 self._fallback_outer_graph = outer_graph 

213 # If not None, records the names of output args of this function. Used to 

214 # preserve the output names in the signature of a serialized+deserialized 

215 # function. Private at the moment mostly because it's often out of date. 

216 self._output_names = None 

217 # Inherit capture-by-value from outer graph. 

218 if capture_by_value is not None: 

219 self.capture_by_value = capture_by_value 

220 elif self.outer_graph is not None and isinstance(self.outer_graph, 

221 FuncGraph): 

222 self.capture_by_value = self.outer_graph.capture_by_value 

223 else: 

224 self.capture_by_value = False 

225 

226 self._building_function = True 

227 

228 graph = self.outer_graph 

229 

230 if context.executing_eagerly(): 

231 self.seed = context.global_seed() 

232 # [for tf-data user migration from TF1.0 to 2.0] seed_used keep track of 

233 # any None op_seed for random_op in the function, in which case we end up 

234 # using function seed, which could be unintended behavior for the op. 

235 self._seed_used = False 

236 else: 

237 self.seed = graph.seed 

238 self._seed_used = False 

239 # TODO(allenl): Figure out if we can remove colocation stack 

240 # specialization (currently used in cond_v2), here and in the cache key. 

241 self._colocation_stack = graph._colocation_stack.copy() # pylint: disable=protected-access 

242 

243 if collections is None: 

244 for collection_name in graph.get_all_collection_keys(): 

245 if collection_name not in ALLOWLIST_COLLECTIONS: 

246 self._collections[collection_name] = graph.get_collection( 

247 collection_name) 

248 for collection_name in ALLOWLIST_COLLECTIONS: 

249 self._collections[collection_name] = graph.get_collection_ref( 

250 collection_name) 

251 else: 

252 self._collections = collections 

253 

254 # Keep track of whether this FuncGraph is exportable to SavedModel. Use 

255 # `graph.mark_as_unsaveable(reason)` to mark this FuncGraph and any 

256 # dependent functions as unsaveable. 

257 self._saveable = True 

258 self._saving_errors = set() 

259 

260 # Keep track of callbacks to run when this graph exits default scope 

261 self._scope_exit_callbacks = None 

262 

263 def __str__(self): 

264 return "FuncGraph(name=%s, id=%s)" % (self.name, id(self)) 

265 

266 def watch_variable(self, v): 

267 """Marks the variable v as accessed while building this graph.""" 

268 # Don't watch `v` if it is one of ResourceVariable input arguments. 

269 if (isinstance(v, resource_variable_ops.ResourceVariable) and 

270 v.handle in self._resource_tensor_inputs): 

271 return 

272 

273 while self is not None and isinstance(self, FuncGraph): 

274 self._watched_variables.add(v) 

275 self = self.outer_graph 

276 

277 def capture_call_time_value(self, 

278 closure, 

279 spec, 

280 key=None, 

281 default_value=None, 

282 placeholder=None): 

283 """Returns a placeholder which at call time has the value closure(). 

284 

285 The `tf.function` supports the notion of captures, that is, it allows Python 

286 functions to have closure variables, which bind over some value outside the 

287 function. However, this name binding is "early binding" performed before the 

288 program is run, i.e., 

289 ``` 

290 @tf.function 

291 def f(): 

292 return x 

293 

294 x = tf.constant(1) 

295 f() # returns 1 

296 

297 x = tf.constant(2) 

298 f() # still returns 1! 

299 ``` 

300 while in Python, name binding is performed as the program is running. 

301 ``` 

302 def f(): 

303 return x 

304 

305 x = 1 

306 f() # returns 1 

307 

308 x = 2 

309 f() # returns 2 

310 ``` 

311 `capture_call_time_value` allows tf.function to mimic late binding as a 

312 Python function does, by passing in a `closure` callable argument to be 

313 executed when the tf.function is invoked eagerly. E.g. 

314 ``` 

315 @tf.function 

316 def f(): 

317 return ops.get_default_graph.capture_call_time_value(lambda: x) 

318 

319 x = tf.constant(1) 

320 f() # returns 1 

321 

322 x = tf.constant(2) 

323 f() # returns 2 

324 ``` 

325 Note that a `capture_call_time_value` function itself does not work well in 

326 the saving process (since the tf.function in which it's called is not 

327 invoked eagerly) unless passed a `default_value` argument. At saving time, 

328 the `default_value` argument is returned instead. 

329 

330 Args: 

331 closure: function which takes no arguments, to be evaluated at function 

332 call time, returning a nest of tensors compatible with `spec`. 

333 spec: nest of TypeSpec for the value to capture. 

334 key: optional. If not None, multiple calls to lazy_capture with the same 

335 key in the same graph will return the same placeholder, and the first 

336 closure will be used at function call time. 

337 default_value: optional value to return in environments that cannot safely 

338 evaluate closure. 

339 placeholder: optional. If not None, the graph will take the passed-in 

340 `placeholder` as the internal capture instead of creating a new one. 

341 This is useful when loading from a SavedModel. 

342 

343 Returns: 

344 Nest of placeholders which, at function call time, will be fed with the 

345 result of calling closure(). 

346 

347 Raises: 

348 ValueError: at function call time, if the return value of closure() is 

349 not compatible with `spec`. 

350 """ 

351 if key is None: 

352 key = object() 

353 if key not in self._function_captures.by_ref_internal: 

354 trace_ctx = trace_type.InternalTracingContext(True) 

355 spec = trace_type.from_value(spec, trace_ctx) 

356 

357 if placeholder is None: 

358 placeholder_ctx = trace_type.InternalPlaceholderContext(self) 

359 placeholder = spec.placeholder_value(placeholder_ctx) 

360 

361 def wrapped_closure(): 

362 

363 # One major case requiring returning a `default_value` is when passing a 

364 # concrete function to `save`, i.e. 

365 # serving_fn = serve_fn.get_concrete_function(...) 

366 # model.save(save_dir, signatures={"serving_default": serving_fn}) 

367 # `serving_fn` has deferred captures added through 

368 # `capture_call_time_value`. It can't be saved correctly since 

369 # `wrapped_closure` will end up executing under a default Graph instead 

370 # of FuncGraph. The user of `capture_call_time_value` also cannot 

371 # conditionally avoid this call since presence of `save_context` when 

372 # executing `wrapped_closure` is not known at tracing time of 

373 # `serving_fn`. 

374 if save_context.in_save_context() and default_value is not None: 

375 return default_value 

376 # TODO(wxinyi): raise an error if in save context but no default value. 

377 

378 if not context.executing_eagerly(): 

379 graph = ops.get_default_graph() 

380 assert isinstance( 

381 graph, 

382 FuncGraph), "This API should only be used in TF2 enviroment." 

383 

384 with graph.as_default(): 

385 ret_nest = graph.capture_call_time_value( 

386 closure, spec, key=key, default_value=default_value) 

387 else: 

388 ret_nest = closure() 

389 

390 ret_nest = spec._cast(ret_nest, trace_type.InternalCastContext) # pylint: disable=protected-access 

391 return spec._to_tensors(ret_nest) # pylint: disable=protected-access 

392 

393 wrapped_closure.output_spec = spec 

394 self._function_captures.add_or_replace( 

395 key=key, 

396 external=wrapped_closure, 

397 internal=placeholder, 

398 tracetype=spec, 

399 is_by_ref=True) 

400 return self._function_captures.by_ref_internal[key] 

401 

402 def control_dependencies(self, control_inputs): 

403 """Handles control dependencies. 

404 

405 FuncGraph wraps Graph's control_dependencies logic by first filtering out 

406 any external tensors / operations and storing them in the graph's 

407 control_captures member. Any consumers of this function graph must then 

408 decide how to handle the control captures. 

409 

410 Args: 

411 control_inputs: A list of `Operation` or `Tensor` objects which must be 

412 executed or computed before running the operations defined in the 

413 context. Can also be `None` to clear the control dependencies. 

414 

415 Returns: 

416 A context manager that specifies control dependencies for all 

417 operations constructed within the context. 

418 

419 Raises: 

420 TypeError: If `control_inputs` is not a list of `Operation` or 

421 `Tensor` objects. 

422 """ 

423 if control_inputs is None: 

424 return super().control_dependencies(control_inputs) 

425 

426 filtered_control_inputs = [] 

427 for c in control_inputs: 

428 # Check for _UnreadVariable 

429 if (isinstance(c, indexed_slices.IndexedSlices) or 

430 (hasattr(c, "_handle") and hasattr(c, "op"))): 

431 c = c.op 

432 graph_element = ops._as_graph_element(c) # pylint: disable=protected-access 

433 if graph_element is None: 

434 graph_element = c 

435 if graph_element is not None and getattr(graph_element, "graph", 

436 None) is not self: 

437 self._function_captures.control.add(graph_element) 

438 else: 

439 filtered_control_inputs.append(graph_element) 

440 return super().control_dependencies(filtered_control_inputs) 

441 

442 def as_default(self): 

443 outer_cm = super().as_default() 

444 

445 @tf_contextlib.contextmanager 

446 def inner_cm(): 

447 """Context manager for copying distribute.Strategy scope information.""" 

448 # pylint: disable=protected-access 

449 # TODO(b/112906995, nareshmodi): distribution strategy depends on 

450 # inheriting this stack from the default graph even in eager mode. Maybe 

451 # it should be part of the eager context? This would also allow us to 

452 # remove a get_default_graph() call from the function cache lookup. 

453 graph = ops.get_default_graph() 

454 old_strategy_stack = self._distribution_strategy_stack 

455 self._distribution_strategy_stack = list( 

456 graph._distribution_strategy_stack) 

457 

458 # We ignore device placements from any outer scopes while tracing the 

459 # function when possible, to avoid hard-coding them in the function 

460 # graph. "Default" placements come from the PartitionedCallOp's placement, 

461 # so that the same trace of the Python function may be placed on several 

462 # different devices and saved functions may be placed on new devices when 

463 # restored. 

464 # However, we need to preserve the outer device stack in the following 

465 # cases in non eager context: 

466 # 1. device stack is callable 

467 # 2. When using distribution strategy with legacy graph mode. 

468 old_device_stack = self._device_function_stack 

469 if (not context.executing_eagerly() and 

470 (device_stack_has_callable(graph._device_function_stack) or 

471 (self._distribution_strategy_stack and 

472 not ops.executing_eagerly_outside_functions()))): 

473 # Hard-code devices from device functions in the function body 

474 self._device_function_stack = graph._device_function_stack.copy() 

475 

476 old_creator_stack = self._variable_creator_stack 

477 self._variable_creator_stack = graph._variable_creator_stack 

478 # Inherit the graph key, since this is used for matching variables in 

479 # optimizers. 

480 old_graph_key = self._graph_key 

481 self._graph_key = graph._graph_key 

482 # pylint: enable=protected-access 

483 

484 old_scope_exit_callbacks = self._scope_exit_callbacks 

485 self._scope_exit_callbacks = [] 

486 

487 with outer_cm as g: 

488 try: 

489 yield g 

490 finally: 

491 try: 

492 for fn in self._scope_exit_callbacks: 

493 fn() 

494 finally: 

495 self._scope_exit_callbacks = old_scope_exit_callbacks 

496 self._distribution_strategy_stack = old_strategy_stack 

497 self._device_function_stack = old_device_stack 

498 self._variable_creator_stack = old_creator_stack 

499 self._graph_key = old_graph_key 

500 

501 return inner_cm() 

502 

503 @property 

504 def outer_graph(self): 

505 """The Graph this FuncGraph is nested in. 

506 

507 Functions may capture Tensors from graphs they are nested in (transitive). 

508 

509 Returns: 

510 A Graph object. Initially set to the current default graph when the 

511 FuncGraph was created. If the previous `outer_graph` was deleted because 

512 the function that owns it was deleted, `outer_graph` is reset to the 

513 outermost default graph active when the FuncGraph was created. This 

514 FuncGraph won't have captured anything from the new `outer_graph` (and 

515 likely not from the previous setting, since that would have created a 

516 strong reference), but it is returned so that FuncGraphs always have a 

517 parent. 

518 """ 

519 current = self._weak_outer_graph() 

520 if current is None: 

521 return self._fallback_outer_graph 

522 return current 

523 

524 @outer_graph.setter 

525 def outer_graph(self, new_outer_graph): 

526 """Sets `outer_graph` to `new_outer_graph`.""" 

527 self._weak_outer_graph = weakref.ref(new_outer_graph) 

528 

529 @property 

530 def output_types(self): 

531 return [t.dtype for t in self.outputs] 

532 

533 @property 

534 def output_shapes(self): 

535 return [t.shape for t in self.outputs] 

536 

537 @property 

538 def trainable_variables(self): 

539 """A sequence of trainable variables accessed by this FuncGraph. 

540 

541 Note that functions keep only weak references to variables. Calling the 

542 function after a variable it accesses has been deleted is an error. 

543 

544 Returns: 

545 Sequence of trainable variables for this func graph. 

546 """ 

547 return tuple(v for v in self.variables if v.trainable) 

548 

549 @property 

550 def variables(self): 

551 """A sequence of variables accessed by this FuncGraph. 

552 

553 Note that functions keep only weak references to variables. Calling the 

554 function after a variable it accesses has been deleted is an error. 

555 

556 Returns: 

557 Sequence of variables for this func graph. 

558 """ 

559 

560 def deref(weak_v): 

561 v = weak_v() 

562 if v is None: 

563 raise AssertionError( 

564 "Called a function referencing variables which have been deleted. " 

565 "This likely means that function-local variables were created and " 

566 "not referenced elsewhere in the program. This is generally a " 

567 "mistake; consider storing variables in an object attribute on " 

568 "first call.") 

569 return v 

570 

571 return tuple(deref(v) for v in self._weak_variables) 

572 

573 @variables.setter 

574 def variables(self, var_list): 

575 self._weak_variables = [weakref.ref(v) for v in var_list] 

576 

577 def _capture_by_value( 

578 self, 

579 op_type, 

580 inputs, 

581 dtypes, # pylint: disable=redefined-outer-name 

582 input_types=None, 

583 name=None, 

584 attrs=None, 

585 op_def=None, 

586 compute_device=True): 

587 # When capturing by value, do the read outside 

588 reverse_captures = dict((id(v), k) for k, v in self.captures) 

589 uncaptured_inputs = [reverse_captures.get(id(t), t) for t in inputs] 

590 with ops.init_scope(): 

591 if context.executing_eagerly(): 

592 attr_list = ("dtype", int(attrs["dtype"].type)) 

593 value, = execute.execute( 

594 compat.as_bytes(op_type), 1, uncaptured_inputs, attr_list, 

595 context.context()) 

596 else: 

597 op = ops.get_default_graph()._create_op_internal( # pylint: disable=protected-access 

598 op_type, uncaptured_inputs, dtypes, input_types, name, attrs, 

599 op_def, compute_device) 

600 value = op.outputs[0] 

601 captured_value = self.capture(value) 

602 return captured_value.op 

603 

604 def _create_op_internal( 

605 self, 

606 op_type, 

607 inputs, 

608 dtypes=None, # pylint: disable=redefined-outer-name 

609 input_types=None, 

610 name=None, 

611 attrs=None, 

612 op_def=None, 

613 compute_device=True): 

614 """Like Graph.create_op, except handles external input tensors. 

615 

616 This overload adds functionality to create_op to "capture" any external 

617 input tensors, i.e. tensors from the eager context or outer function graphs 

618 if this is a nested function. See `capture` for more information. 

619 

620 Args: 

621 op_type: The `Operation` type to create. This corresponds to the 

622 `OpDef.name` field for the proto that defines the operation. 

623 inputs: A list of `Tensor` objects that will be inputs to the `Operation`. 

624 dtypes: (Optional) A list of `DType` objects that will be the types of the 

625 tensors that the operation produces. 

626 input_types: (Optional.) A list of `DType`s that will be the types of the 

627 tensors that the operation consumes. By default, uses the base `DType` 

628 of each input in `inputs`. Operations that expect reference-typed inputs 

629 must specify `input_types` explicitly. 

630 name: (Optional.) A string name for the operation. If not specified, a 

631 name is generated based on `op_type`. 

632 attrs: (Optional.) A dictionary where the key is the attribute name (a 

633 string) and the value is the respective `attr` attribute of the 

634 `NodeDef` proto that will represent the operation (an `AttrValue` 

635 proto). 

636 op_def: (Optional.) The `OpDef` proto that describes the `op_type` that 

637 the operation will have. 

638 compute_device: (Optional.) If True, device functions will be executed to 

639 compute the device property of the Operation. 

640 

641 Returns: 

642 An `Operation` object. 

643 """ 

644 if self.capture_by_value and op_type in [ 

645 "ReadVariableOp", "ResourceGather" 

646 ]: 

647 return self._capture_by_value(op_type, inputs, dtypes, input_types, name, 

648 attrs, op_def, compute_device) 

649 

650 # This capturing logic interacts poorly with control flow contexts which 

651 # want to replace inputs of ops far too late in the process. This can lead 

652 # the context to get confused and try to create an Enter for an Enter. We 

653 # can detect this here and skip the additional Enter which can confuse loop 

654 # validation logic. 

655 if op_type == "Enter" and inputs[0].op.type == "Enter": 

656 if inputs[0].op.get_attr("frame_name") == attrs["frame_name"].s: 

657 return inputs[0].op 

658 # Calling AddValue on the control flow contexts to force creation of the 

659 # backward accumulators in the original graph before we create placeholders 

660 # to capture the inputs. 

661 ctxt = ops.get_default_graph()._control_flow_context # pylint: disable=protected-access 

662 # Use a different list to avoid modifying the original inputs list. 

663 captured_inputs = [] 

664 for inp in inputs: 

665 # TPU Estimator defines a control flow context with no AddValue method. 

666 if ctxt is not None and hasattr(ctxt, "AddValue"): 

667 inp = ctxt.AddValue(inp) 

668 inp = self.capture(inp) 

669 captured_inputs.append(inp) 

670 return super()._create_op_internal( # pylint: disable=protected-access 

671 op_type, captured_inputs, dtypes, input_types, name, attrs, op_def, 

672 compute_device) 

673 

674 def capture(self, tensor, name=None, shape=None): 

675 return self._function_captures.capture_by_value(self, tensor, name) 

676 

677 def _validate_in_scope(self, tensor): 

678 inner_graph = tensor.graph 

679 while inner_graph is not None and isinstance(inner_graph, FuncGraph): 

680 if inner_graph is self: 

681 try: 

682 tb = tensor.op.traceback 

683 except AttributeError: 

684 tensor_traceback = "<unknown>" 

685 else: 

686 tensor_traceback_list = [] 

687 for frame in traceback.format_list(tb.get_user_frames()): 

688 tensor_traceback_list.extend( 

689 [f" {line}" for line in frame.split("\n") if line.strip()]) 

690 tensor_traceback = "\n".join(tensor_traceback_list) 

691 # Keep in sync with tfe_wrapper.cc. 

692 # TODO(b/200991648): Unify those two paths. 

693 raise errors.InaccessibleTensorError( 

694 f"{tensor!r} is out of scope and cannot be used here. Use return " 

695 "values, explicit Python locals or TensorFlow collections to " 

696 "access it.\n" 

697 "Please see https://www.tensorflow.org/guide/function#all_outputs_of_a_tffunction_must_be_return_values " # pylint: disable=line-too-long 

698 "for more information.\n\n" 

699 f"{tensor!r} was defined here:\n{tensor_traceback}\n\n" 

700 f"The tensor {tensor!r} cannot be accessed from {self}, because " 

701 f"it was defined in {tensor.graph}, which is out of scope.") 

702 inner_graph = inner_graph.outer_graph 

703 

704 # TODO(panzf): Rename this method along with usages in cond/while graph. 

705 def _capture_helper(self, tensor, name): 

706 return self._function_captures._create_placeholder_helper( # pylint: disable=protected-access 

707 self, tensor, name) 

708 

709 def _experimental_capture_side_input_by_ref(self, identifier: Hashable, 

710 func: Callable[[], Any]) ->...: 

711 """Implement capturing side input by reference for tf.function. 

712 

713 Note that this API will only register the capture in the func_graph where 

714 it is called. In the case of nested graph, like nested tf.function or 

715 tf.while, the outer graph is not aware of this capture in the inner graph. 

716 Thus, the outer tf.function will not retrace when the by-ref capture 

717 changes. It's the user's responsibility to call this API in the outer 

718 func_graph as well if proper retracing is needed. 

719 

720 For example: 

721 

722 ``` 

723 x = 1 

724 

725 # Correct usage 

726 @tf.function 

727 def f_1(): 

728 graph = tf.compat.v1.get_default_graph() 

729 # Capture the same x for the outer tf.function 

730 graph._experimental_capture_side_input_by_ref("x", lambda: x) 

731 

732 @tf.function 

733 def g(): 

734 graph = tf.compat.v1.get_default_graph() 

735 cap_x = graph._experimental_capture_side_input_by_ref("x", lambda: x) 

736 return cap_x + 1 

737 

738 return g() 

739 

740 # Incorrect usage 

741 @tf.function 

742 def f_2(): 

743 

744 @tf.function 

745 def g(): 

746 graph = tf.compat.v1.get_default_graph() 

747 cap_x = graph._experimental_capture_side_input_by_ref("x", lambda: x) 

748 return cap_x + 1 

749 

750 return g() 

751 

752 assert f_1() == 2 

753 assert f_2() == 2 

754 x = 2 

755 assert f_1() == 3 

756 assert f_2() == 2 # This is incorrect 

757 ``` 

758 

759 Args: 

760 identifier: A hashable object as the key for the capture. 

761 func: A Python function that takes no arguments and returns the value of 

762 side input. The function is evaluated at function call time. 

763 

764 Returns: 

765 A nested structure with the same structure as the side input. Tensors 

766 are replaced with placehoders, and non-tensors remain the same. 

767 

768 """ 

769 if context.executing_eagerly(): 

770 return func() 

771 

772 def maybe_convert_to_tensor(): 

773 value = func() 

774 if not (isinstance(value, core.Value) or isinstance(value, core.Symbol)): 

775 value = constant_op.constant(value) 

776 return value 

777 

778 placeholder = self._function_captures._capture_by_ref( # pylint: disable=protected-access 

779 self, maybe_convert_to_tensor, identifier) 

780 return placeholder 

781 

782 @property 

783 def captures(self): 

784 """Order list of tuples containing external and internal captures.""" 

785 return self._function_captures.by_val_capture_tuples 

786 

787 def add_capture(self, tensor, placeholder): 

788 """Capture a specific tensor and utilize the provided placeholder. 

789 

790 Args: 

791 tensor: Tensor to captures. 

792 placeholder: Provided placeholder for the tensor. 

793 """ 

794 self._function_captures.add_or_replace( 

795 key=id(tensor), 

796 external=tensor, 

797 internal=placeholder, 

798 is_by_ref=False) 

799 self.inputs.append(placeholder) 

800 

801 def replace_capture(self, tensor, placeholder): 

802 """Replace already existing capture.""" 

803 self._function_captures.add_or_replace( 

804 key=id(tensor), 

805 external=tensor, 

806 internal=placeholder, 

807 is_by_ref=False) 

808 

809 def replace_capture_with_deferred_capture(self, 

810 tensor, 

811 closure, 

812 spec, 

813 placeholder, 

814 default_value=None): 

815 """Replaces existing capture `tensor` with a deferred capture `closure`. 

816 

817 Caution: It is the caller's responsibility to make sure that, after calling 

818 this function, the TypeSpec of the `inputs` (i.e. internal placeholders) and 

819 the `_captured_inputs` (i.e. external captures) of a concrete function that 

820 wraps this function graph are still compatible. Thus user should pairing 

821 usage of this function with `ConcreteFunction.set_external_captures` to make 

822 sure the order still matches. For example, 

823 ``` 

824 # concrete_fn._captured_inputs == [tensor1, tensor2, tensor3] 

825 # concrete_fn.inputs == [placeholder1, placeholder2, placeholder3] 

826 # replace external capture `tensor2` with a deferred_capture, i.e., a 

827 # closure, `closure2` 

828 concrete_fn.graph.replace_capture_with_deferred_capture(tensor2, 

829 closure2, 

830 placeholder2, 

831 some_spec, 

832 some_default) 

833 concrete_fn.set_external_captures([tensor1, closure2, tensor3]) 

834 ``` 

835 

836 Args: 

837 tensor: Tensor already captured. 

838 closure: function which takes no arguments, to be evaluated at function 

839 call time, returning a nest of tensors compatible with `spec`. 

840 spec: nest of TypeSpec for the value to capture. 

841 placeholder: the internal placeholder corresponding to the captured 

842 `tensor`. 

843 default_value: optional value to use in environments that cannot safely 

844 evaluate closure. 

845 """ 

846 self._function_captures.pop(id(tensor), is_by_ref=False) 

847 self.capture_call_time_value( 

848 closure, 

849 spec, 

850 key=id(tensor), 

851 default_value=default_value, 

852 placeholder=placeholder) 

853 

854 @property 

855 def external_captures(self): 

856 """External tensors captured by this function.""" 

857 return list(self._function_captures.by_val_external.values()) 

858 

859 @property 

860 def internal_captures(self): 

861 """Placeholders in this function corresponding captured tensors.""" 

862 return list(self._function_captures.by_val_internal.values()) 

863 

864 @property 

865 def deferred_external_captures(self): 

866 """Ordered nest of tensors whose placeholders will be fed at call time.""" 

867 return list(self._function_captures.by_ref_external.values()) 

868 

869 @property 

870 def deferred_internal_captures(self): 

871 """List of nest of placeholders which at call time will be fed.""" 

872 return list(self._function_captures.by_ref_internal.values()) 

873 

874 @property 

875 def variable_captures(self): 

876 """Map of python object ids of variables to variables which are captured.""" 

877 return self.variables 

878 

879 @property 

880 def function_captures(self): 

881 return self._function_captures 

882 

883 def mark_as_unsaveable(self, error_message): 

884 """Marks this FuncGraph as unsaveable. 

885 

886 Any attempts to export this FuncGraph will raise an error with the specified 

887 message. 

888 

889 Args: 

890 error_message: List or string containing the error message to be raised 

891 when saving this FuncGraph to SavedModel. 

892 """ 

893 self._saveable = False 

894 if isinstance(error_message, str): 

895 error_message = [error_message] 

896 self._saving_errors.update(error_message) 

897 

898 @property 

899 def saveable(self): 

900 """Returns whether this FuncGraph is saveable.""" 

901 return self._saveable 

902 

903 @property 

904 def saving_errors(self): 

905 """Returns set of errors preventing this FuncGraph from being saved.""" 

906 return self._saving_errors 

907 

908 def _add_scope_exit_callback(self, fn): 

909 """Add a function to call when this graph exits the default scope.""" 

910 if not callable(fn): 

911 raise TypeError("fn is not callable: {}".format(fn)) 

912 if self._scope_exit_callbacks is None: 

913 raise RuntimeError( 

914 "Attempting to add a scope exit callback, but the default graph is " 

915 "not the context scope graph. Did you forget to call " 

916 "'with graph.as_default(): ...'?") 

917 self._scope_exit_callbacks.append(fn) 

918 

919 

920def func_graph_from_py_func(name, 

921 python_func, 

922 args, 

923 kwargs, 

924 signature=None, 

925 func_graph=None, 

926 add_control_dependencies=True, 

927 arg_names=None, 

928 op_return_value=None, 

929 collections=None, 

930 capture_by_value=None, 

931 create_placeholders=True): 

932 """Returns a `FuncGraph` generated from `python_func`. 

933 

934 Args: 

935 name: an identifier for the function. 

936 python_func: the Python function to trace. 

937 args: the positional args with which the Python function should be called; 

938 ignored if a signature is provided. 

939 kwargs: the keyword args with which the Python function should be called; 

940 ignored if a signature is provided. 

941 signature: a possibly nested sequence of `TensorSpecs` specifying the shapes 

942 and dtypes of the arguments. When a signature is provided, `args` and 

943 `kwargs` are ignored, and `python_func` is traced with Tensors conforming 

944 to `signature`. If `None`, the shapes and dtypes are inferred from the 

945 inputs. 

946 func_graph: Optional. An instance of FuncGraph. If provided, we will use 

947 this graph else a new one is built and returned. 

948 add_control_dependencies: If True, automatically adds control dependencies 

949 to ensure program order matches execution order and stateful ops always 

950 execute. 

951 arg_names: Optional list of argument names, used to give input placeholders 

952 recognizable names. 

953 op_return_value: Optional. A Tensor. If set and `python_func` returns 

954 Operations, those return values will be replaced with this value. If not 

955 set, returning an Operation triggers an error. 

956 collections: a dictionary of collections this FuncGraph should start with. 

957 If not specified (None), the FuncGraph will read (but not write to) the 

958 outer graph's collections that are not allowlisted, and both read and 

959 write to the outer graph's collections that are allowlisted. The current 

960 allowlisted collections are the global variables, the local variables, and 

961 the trainable variables. Defaults to None. 

962 capture_by_value: An optional boolean. If True, the func graph will capture 

963 Variables by value instead of reference. By default inherit from outer 

964 graphs, and failing that will default to False. 

965 create_placeholders: An optional boolean. If True, then func graph will 

966 create placeholders for the inputs as graph ops. If False, the input args 

967 and kwargs will be treated as the input placeholders. 

968 

969 Returns: 

970 A FuncGraph. 

971 

972 Raises: 

973 TypeError: If any of `python_func`'s return values is neither `None`, a 

974 `Tensor` or a `tf.experimental.ExtensionType`. 

975 """ 

976 if op_return_value is not None: 

977 assert isinstance(op_return_value, ops.Tensor), op_return_value 

978 if func_graph is None: 

979 func_graph = FuncGraph( 

980 name, collections=collections, capture_by_value=capture_by_value) 

981 assert isinstance(func_graph, FuncGraph) 

982 if add_control_dependencies: 

983 deps_control_manager = auto_control_deps.AutomaticControlDependencies() 

984 else: 

985 deps_control_manager = ops.NullContextmanager() 

986 

987 with func_graph.as_default(), deps_control_manager as deps_ctx: 

988 current_scope = variable_scope.get_variable_scope() 

989 default_use_resource = current_scope.use_resource 

990 current_scope.set_use_resource(True) 

991 

992 if signature is not None: 

993 args = signature 

994 kwargs = {} 

995 

996 if create_placeholders: 

997 func_args, func_kwargs = _create_placeholders(args, kwargs, arg_names) 

998 else: 

999 func_args, func_kwargs = args, kwargs 

1000 

1001 input_trace_types = trace_type.from_value([func_args, func_kwargs]) 

1002 func_graph.inputs = input_trace_types._to_tensors([func_args, func_kwargs]) # pylint: disable=protected-access 

1003 for arg in func_graph.inputs: 

1004 if arg.dtype == dtypes.resource: 

1005 func_graph._resource_tensor_inputs.add(arg) # pylint:disable=protected-access 

1006 

1007 signature_context = trace_type.InternalTracingContext() 

1008 # Convert all Tensors into TensorSpecs before saving the structured inputs. 

1009 # If storing pure concrete functions that are not called through polymorphic 

1010 # functions, we don't have access to FunctionSpec, so we need to call the 

1011 # TensorSpecs by their `arg_names` for later binding. 

1012 func_graph.structured_input_signature = ( 

1013 convert_structure_to_signature( 

1014 func_args, arg_names, signature_context=signature_context), 

1015 convert_structure_to_signature( 

1016 func_kwargs, signature_context=signature_context)) 

1017 

1018 # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`. 

1019 # Variables to help check whether mutation happens in calling the function 

1020 # Copy the recursive list, tuple and map structure, but not base objects 

1021 func_args_before = nest.pack_sequence_as( 

1022 func_args, 

1023 nest.flatten(func_args, expand_composites=True), 

1024 expand_composites=True) 

1025 func_kwargs_before = nest.pack_sequence_as( 

1026 func_kwargs, 

1027 nest.flatten(func_kwargs, expand_composites=True), 

1028 expand_composites=True) 

1029 

1030 def convert(x): 

1031 """Converts a function output to a Tensor.""" 

1032 if x is None: 

1033 return None 

1034 if op_return_value is not None and isinstance(x, ops.Operation): 

1035 # TODO(b/79881896): we currently can't capture external control deps, so 

1036 # this won't work if x needs to be captured (i.e. if python_func returns 

1037 # captured Operations). 

1038 with ops.control_dependencies([x]): 

1039 x = array_ops.identity(op_return_value) 

1040 elif not isinstance(x, tensor_array_ops.TensorArray): 

1041 try: 

1042 x = ops.convert_to_tensor_or_composite(x) 

1043 except (ValueError, TypeError): 

1044 raise TypeError( 

1045 "To be compatible with tf.function, Python functions " 

1046 "must return zero or more Tensors or ExtensionTypes or None " 

1047 f"values; in compilation of {str(python_func)}, found return " 

1048 f"value of type {type(x).__name__}, which is not a Tensor or " 

1049 "ExtensionType.") 

1050 if add_control_dependencies: 

1051 x = deps_ctx.mark_as_return(x) 

1052 return x 

1053 

1054 _, original_func = tf_decorator.unwrap(python_func) 

1055 func_outputs = python_func(*func_args, **func_kwargs) 

1056 

1057 # invariant: `func_outputs` contains only Tensors, CompositeTensors, 

1058 # TensorArrays and `None`s. 

1059 func_outputs = variable_utils.convert_variables_to_tensors(func_outputs) 

1060 func_outputs = nest.map_structure( 

1061 convert, func_outputs, expand_composites=True) 

1062 

1063 # flatten and unflatten func_args and func_kwargs to maintain parity 

1064 # from flattening which sorts by key 

1065 func_args = nest.pack_sequence_as( 

1066 func_args, 

1067 nest.flatten(func_args, expand_composites=True), 

1068 expand_composites=True) 

1069 func_kwargs = nest.pack_sequence_as( 

1070 func_kwargs, 

1071 nest.flatten(func_kwargs, expand_composites=True), 

1072 expand_composites=True) 

1073 check_func_mutation(func_args_before, func_kwargs_before, func_args, 

1074 func_kwargs, original_func) 

1075 current_scope.set_use_resource(default_use_resource) 

1076 

1077 inputs = [] 

1078 for arg in composite_tensor_utils.flatten_with_variables([func_args, 

1079 func_kwargs]): 

1080 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 

1081 # Even if an argument variable was not used in the function, we've 

1082 # already manually captured the resource Tensor when creating argument 

1083 # placeholders. 

1084 capture = func_graph._function_captures.pop(id(arg.handle), False) # pylint: disable=protected-access 

1085 assert len(capture) >= 2 

1086 resource_placeholder = capture[1] 

1087 if resource_placeholder is None: 

1088 continue 

1089 inputs.append(resource_placeholder) 

1090 elif isinstance(arg, ops.Tensor): 

1091 inputs.append(arg) 

1092 func_graph.inputs = ( 

1093 inputs + func_graph.internal_captures + nest.flatten( 

1094 func_graph.deferred_internal_captures, expand_composites=True)) 

1095 func_graph.structured_outputs = func_outputs 

1096 # Returning a closed-over tensor does not trigger convert_to_tensor. 

1097 func_graph.outputs.extend( 

1098 func_graph.capture(x) 

1099 for x in flatten(func_graph.structured_outputs) 

1100 if x is not None) 

1101 

1102 func_graph.variables = func_graph._watched_variables # pylint: disable=protected-access 

1103 

1104 if add_control_dependencies: 

1105 func_graph.control_outputs.extend(deps_control_manager.ops_which_must_run) 

1106 func_graph.collective_manager_ids_used = ( 

1107 deps_control_manager.collective_manager_ids_used) 

1108 

1109 return func_graph 

1110 

1111 

1112def maybe_captured(tensor): 

1113 """If t is a captured value placeholder, returns the original captured value. 

1114 

1115 Args: 

1116 tensor: Tensor. 

1117 

1118 Returns: 

1119 A tensor, potentially from a different Graph/FuncGraph. 

1120 """ 

1121 if (not isinstance(tensor, ops.EagerTensor) and 

1122 tensor.op.graph.building_function and tensor.op.type == "Placeholder"): 

1123 for input_t, placeholder_t in tensor.op.graph.captures: 

1124 if tensor == placeholder_t: 

1125 return maybe_captured(input_t) 

1126 # pylint: enable=protected-access 

1127 return tensor 

1128 

1129 

1130def device_stack_has_callable(device_stack): 

1131 """Checks whether a device stack contains a callable.""" 

1132 return any( 

1133 callable(spec._device_name_or_function) # pylint: disable=protected-access 

1134 for spec in device_stack.peek_objs()) 

1135 

1136 

1137def has_mutation(n1, n2): 

1138 """Returns true if n1 and n2 are different (using `is` to compare leaves).""" 

1139 try: 

1140 nest.assert_same_structure(n1, n2, expand_composites=True) 

1141 except ValueError: 

1142 return True 

1143 

1144 for arg1, arg2 in zip( 

1145 nest.flatten(n1, expand_composites=True), 

1146 nest.flatten(n2, expand_composites=True)): 

1147 if arg1 is not arg2: 

1148 return True 

1149 

1150 return False 

1151 

1152 

1153def check_func_mutation(old_args, old_kwargs, new_args, new_kwargs, func): 

1154 """Checks that the arguments to a function are not modified.""" 

1155 if not has_mutation((old_args, old_kwargs), (new_args, new_kwargs)): 

1156 return 

1157 

1158 # Mutation detected; construct a useful error message. 

1159 func_name = getattr(func, "__qualname__", getattr(func, "__name__", func)) 

1160 signature = tf_inspect.signature(func) 

1161 try: 

1162 old_bound = signature.bind(*old_args, **old_kwargs).arguments 

1163 new_bound = signature.bind(*new_args, **new_kwargs).arguments 

1164 except TypeError as e: 

1165 # This occurs when the function is called with the (deprecated) 

1166 # "flat signature". See ConcreteFunction._call_with_flat_signature. In 

1167 # this case, we can't report which arguments were modified. 

1168 raise ValueError( 

1169 f"{func_name}{signature} should not modify its Python input " 

1170 f"arguments. Check if it modifies any lists or dicts passed as " 

1171 f"arguments. Modifying a copy is allowed.") from e 

1172 

1173 assert set(old_bound) == set(new_bound) 

1174 modified_args = [ 

1175 arg_name for arg_name in new_bound 

1176 if has_mutation(old_bound[arg_name], new_bound[arg_name]) 

1177 ] 

1178 changes = ", ".join(modified_args) 

1179 raise ValueError(f"{func_name}{signature} should not modify its Python " 

1180 f"input arguments. Modifying a copy is allowed. The " 

1181 f"following parameter(s) were modified: {changes}") 

1182 

1183 

1184# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 

1185def flatten(sequence): 

1186 """Like nest.flatten w/ expand_composites, but returns flow for TensorArrays. 

1187 

1188 Args: 

1189 sequence: A nested structure of Tensors, CompositeTensors, and TensorArrays. 

1190 

1191 Returns: 

1192 A list of tensors. 

1193 """ 

1194 flat_sequence = nest.flatten(sequence, expand_composites=True) 

1195 return [ 

1196 item.flow if isinstance(item, tensor_array_ops.TensorArray) else item 

1197 for item in flat_sequence 

1198 ] 

1199 

1200 

1201# TODO(edloper): If TensorArray becomes a CompositeTensor, then delete this. 

1202def pack_sequence_as(structure, flat_sequence): 

1203 """Like `nest.pack_sequence_as` but also builds TensorArrays from flows. 

1204 

1205 Args: 

1206 structure: The structure to pack into. May contain Tensors, 

1207 CompositeTensors, or TensorArrays. 

1208 flat_sequence: An iterable containing tensors. 

1209 

1210 Returns: 

1211 A nested structure. 

1212 

1213 Raises: 

1214 AssertionError if `structure` and `flat_sequence` are not compatible. 

1215 """ 

1216 flat_sequence = list(flat_sequence) 

1217 flattened_structure = nest.flatten(structure, expand_composites=True) 

1218 if len(flattened_structure) != len(flat_sequence): 

1219 raise ValueError("Mismatch in element count") 

1220 for i in range(len(flat_sequence)): 

1221 if isinstance(flattened_structure[i], tensor_array_ops.TensorArray): 

1222 flat_sequence[i] = tensor_array_ops.build_ta_with_new_flow( 

1223 old_ta=flattened_structure[i], flow=flat_sequence[i]) 

1224 return nest.pack_sequence_as(structure, flat_sequence, expand_composites=True) 

1225 

1226 

1227def _create_placeholders(args, kwargs, arg_names=None): 

1228 """Create placeholders given positional args and keyword args.""" 

1229 signature_context = trace_type.InternalTracingContext( 

1230 is_legacy_signature=True) 

1231 arg_trace_types = trace_type.from_value(tuple(args), signature_context) 

1232 kwarg_trace_types = trace_type.from_value(kwargs, signature_context) 

1233 

1234 placeholder_mapping = signature_context.get_placeholder_mapping() 

1235 placeholder_context = trace_type.InternalPlaceholderContext( 

1236 ops.get_default_graph(), placeholder_mapping) 

1237 

1238 if arg_names is None: 

1239 arg_names = [None] * len(arg_trace_types.components) 

1240 

1241 # Create placeholders for trace type args and trace type kwargs 

1242 func_args = [] 

1243 for name, trace_type_arg in zip(arg_names, arg_trace_types.components): 

1244 placeholder_context.update_naming_scope(name) 

1245 placeholder = trace_type_arg.placeholder_value(placeholder_context) 

1246 func_args.append(placeholder) 

1247 

1248 func_kwargs = {} 

1249 for name, trace_type_kwarg in zip(*sorted(kwarg_trace_types.mapping.items())): 

1250 placeholder_context.update_naming_scope(name) 

1251 placeholder = trace_type_kwarg.placeholder_value(placeholder_context) 

1252 func_kwargs[name] = placeholder 

1253 

1254 return tuple(func_args), func_kwargs 

1255 

1256 

1257def dismantle_func_graph(func_graph): 

1258 """Removes reference cycles in `func_graph` FuncGraph. 

1259 

1260 Helpful for making sure the garbage collector doesn't need to run when 

1261 the FuncGraph goes out of scope, e.g. in tests using defun with 

1262 @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True). 

1263 

1264 Args: 

1265 func_graph: A `FuncGraph` object to destroy. `func_graph` is unusable after 

1266 this function. 

1267 """ 

1268 func_graph._function_captures.clear() # pylint: disable=protected-access 

1269 ops.dismantle_graph(func_graph) 

1270 

1271 

1272def override_func_graph_name_scope(func_graph, name_scope): 

1273 func_graph._name_stack = name_scope # pylint: disable=protected-access