Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/monomorphic_function.py: 19%

731 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""Implementation for Monomorphic Functions (including Differentiable ones).""" 

17 

18import collections 

19import pprint 

20 

21from tensorflow.core.framework import attr_value_pb2 

22from tensorflow.core.function.polymorphism import function_type as function_type_lib 

23from tensorflow.python import pywrap_tfe 

24from tensorflow.python.eager import backprop_util 

25from tensorflow.python.eager import context 

26from tensorflow.python.eager import forwardprop_util 

27from tensorflow.python.eager import record 

28from tensorflow.python.eager.graph_only_ops import graph_placeholder 

29from tensorflow.python.eager.polymorphic_function import atomic_function 

30from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib 

31from tensorflow.python.eager.polymorphic_function import function_spec 

32from tensorflow.python.eager.polymorphic_function import saved_model_exported_concrete 

33from tensorflow.python.framework import composite_tensor 

34from tensorflow.python.framework import dtypes 

35from tensorflow.python.framework import errors 

36from tensorflow.python.framework import func_graph as func_graph_module 

37from tensorflow.python.framework import indexed_slices 

38from tensorflow.python.framework import ops 

39from tensorflow.python.framework import tensor_shape 

40from tensorflow.python.framework import tensor_spec 

41from tensorflow.python.framework import type_spec 

42from tensorflow.python.ops import array_ops 

43from tensorflow.python.ops import default_gradient 

44from tensorflow.python.ops import gradients_util 

45from tensorflow.python.ops import handle_data_util 

46from tensorflow.python.ops import resource_variable_ops 

47from tensorflow.python.platform import tf_logging as logging 

48from tensorflow.python.profiler import trace 

49from tensorflow.python.trackable import base as trackable 

50from tensorflow.python.types import core 

51from tensorflow.python.util import _pywrap_utils 

52from tensorflow.python.util import compat 

53from tensorflow.python.util import nest 

54from tensorflow.python.util import object_identity 

55 

56 

57def _is_type_subset(a, b): 

58 """Returns true if `b` is a subset of type `a` (or if a is not a TypeSpec.)""" 

59 if isinstance(a, type_spec.TypeSpec): 

60 return a.most_specific_compatible_type(b) == a 

61 return True 

62 

63 

64def _parse_func_attrs(attributes): 

65 """Convert the keyword arguments into function_def attributes. 

66 

67 Currently only support primitive types: bool, int, float and string. 

68 

69 Args: 

70 attributes: the dictionary of attributes. 

71 Returns: 

72 A dict of attributes where the key is the name of attribute and the value 

73 is the AttrValue proto. 

74 Raises: 

75 ValueError: If the kwargs contains unallowlisted name or unsupported value 

76 types. 

77 """ 

78 attrs = {} 

79 for key, value in attributes.items(): 

80 if key not in attributes_lib.MONOMORPHIC_FUNCTION_ALLOWLIST: 

81 raise ValueError( 

82 f"ConcreteFunction does not support `{key}` as an attribute.") 

83 if isinstance(value, attr_value_pb2.AttrValue): 

84 attrs[key] = value 

85 # bool type check has to happen before int since bool is a subclass of int. 

86 elif isinstance(value, bool): 

87 attrs[key] = attr_value_pb2.AttrValue(b=value) 

88 elif isinstance(value, int): 

89 attrs[key] = attr_value_pb2.AttrValue(i=value) 

90 elif isinstance(value, float): 

91 attrs[key] = attr_value_pb2.AttrValue(f=value) 

92 elif isinstance(value, (str, bytes)): 

93 attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value)) 

94 else: 

95 raise ValueError(f"Attribute {key} must be bool, int, float, string, or " 

96 f"AttrValue. Got {type(value)}.") 

97 return attrs 

98 

99_FORWARD_PREFIX = "__forward_" 

100_BACKWARD_PREFIX = "__backward_" 

101_INFERENCE_PREFIX = "__inference_" 

102 

103 

104def _forward_name(n): 

105 """The name of a generated forward defun named n.""" 

106 return "%s%s_%s" % (_FORWARD_PREFIX, n, ops.uid()) 

107 

108 

109def _backward_name(n): 

110 """The name of a generated backward defun named n.""" 

111 return "%s%s_%s" % (_BACKWARD_PREFIX, n, ops.uid()) 

112 

113 

114def _inference_name(n): 

115 """The name of a forward-but-no-gradient defun named n.""" 

116 return "%s%s_%s" % (_INFERENCE_PREFIX, n, ops.uid()) 

117 

118 

119def _create_forward_backward_with_graph(attrs, forward_graph, backwards_graph): 

120 """Creates forward and backward functions from the function graphs.""" 

121 forward_function_name = _forward_name(forward_graph.name) 

122 common_attributes = dict(attrs) 

123 # NB: forward and backward function need to drop "_implements". 

124 # attribute, because their signature contains all the intermediate tensors 

125 # that they compute. Thus they don't have a stable signature which can 

126 # be directly optimized downstream. 

127 # See for more details: 

128 # https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md#appendix-future-support-for-optimizing-gradient-functions 

129 common_attributes.pop(attributes_lib.IMPLEMENTS, None) 

130 backward_function_attr = _parse_func_attrs( 

131 {attributes_lib.FORWARD_FUNCTION: forward_function_name}) 

132 backward_function_attr.update(common_attributes) 

133 backward_function = ConcreteFunction( 

134 backwards_graph, attrs=backward_function_attr) 

135 forward_function_attr = _parse_func_attrs({ 

136 attributes_lib.BACKWARD_FUNCTION: 

137 backward_function.name}) 

138 forward_function_attr.update(common_attributes) 

139 forward_function = atomic_function.from_func_graph( 

140 forward_function_name, forward_graph, forward_graph.inputs, 

141 forward_graph.outputs, forward_function_attr) 

142 return forward_function, backward_function 

143 

144 

145class _DelayedRewriteGradientFunctions(object): 

146 """Caches forward/backward functions with a delayed forward rewrite.""" 

147 

148 def __init__(self, func_graph, attrs, func_graph_deleter): 

149 """Construct an inference function and initialize caches.""" 

150 # A map from the number of forward function outputs with accepted gradients 

151 # to forward and backward functions, used to cache non-tape backward 

152 # function generation. 

153 self._cached_function_pairs = {} 

154 self._func_graph = func_graph 

155 self._inference_function = atomic_function.from_func_graph( 

156 _inference_name(self._func_graph.name), self._func_graph, 

157 self._func_graph.inputs, self._func_graph.outputs, attrs) 

158 self._attrs = attrs 

159 self._gradient_name = None 

160 # Note that the FuncGraph is mutated later, so we need to inspect it now to 

161 # figure out the user-specified outputs of the inference function. 

162 self._num_inference_outputs = len(self._func_graph.outputs) 

163 self._func_graph_deleter = func_graph_deleter 

164 

165 def forward_backward(self, num_doutputs=None): 

166 """A possibly-cached pair of forward and backward functions.""" 

167 if num_doutputs is None: 

168 num_doutputs = self._num_inference_outputs 

169 forward_backward = self._cached_function_pairs.get(num_doutputs) 

170 if forward_backward is not None: 

171 return forward_backward 

172 forward, backward = self._construct_forward_backward(num_doutputs) 

173 self._cached_function_pairs[num_doutputs] = (forward, backward) 

174 return forward, backward 

175 

176 def _construct_forward_backward(self, num_doutputs): 

177 """Constructs a pair of forward and backward functions. 

178 

179 Args: 

180 num_doutputs: The constructed backprop function will take output gradients 

181 for the first `num_doutputs` outputs of the forward function. Defaults 

182 to the number of outputs for the inference function, but when 

183 higher-order gradients are computed this will increase to include side 

184 outputs. 

185 

186 Returns: 

187 A pair of (forward_function, backward_function): 

188 forward_function: A re-generated inference function (an 

189 AtomicFunction) to account for new side outputs, if any extra 

190 were required when building the backward pass. 

191 backward_function: A ConcreteFunction that Takes `num_doutputs` 

192 arguments and returns gradients with respect to inputs of the forward 

193 function. 

194 """ 

195 trainable_outputs = [ 

196 output for output in self._func_graph.outputs[:num_doutputs] 

197 if backprop_util.IsTrainable(output)] 

198 

199 signature = [] 

200 for t in trainable_outputs: 

201 signature.append( 

202 tensor_spec.TensorSpec(*default_gradient.shape_and_dtype(t))) 

203 

204 def _backprop_function(*grad_ys): 

205 with ops.device(None): 

206 return gradients_util._GradientsHelper( # pylint: disable=protected-access 

207 trainable_outputs, 

208 self._func_graph.inputs, 

209 grad_ys=grad_ys, 

210 src_graph=self._func_graph) 

211 

212 with self._func_graph.as_default(): 

213 backwards_graph = func_graph_module.FuncGraph( 

214 _backward_name(self._func_graph.name)) 

215 func_graph_module.func_graph_from_py_func( 

216 name=backwards_graph.name, 

217 python_func=_backprop_function, 

218 args=[], kwargs={}, 

219 signature=signature, 

220 func_graph=backwards_graph) 

221 backwards_graph_captures = backwards_graph.external_captures 

222 captures_from_forward = [ 

223 c for c in backwards_graph_captures if 

224 not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph] 

225 

226 existing_outputs = object_identity.ObjectIdentitySet( 

227 self._func_graph.outputs) 

228 for capture in captures_from_forward: 

229 if capture not in existing_outputs: 

230 existing_outputs.add(capture) 

231 self._func_graph.outputs.append(capture) 

232 

233 forward_function, backward_function = _create_forward_backward_with_graph( 

234 self._attrs, self._func_graph, backwards_graph) 

235 return forward_function, backward_function 

236 

237 def _rewrite_forward_and_call_backward(self, op, *doutputs): 

238 """Add outputs to the forward call and feed them to the grad function.""" 

239 forward_function, backwards_function = self.forward_backward(len(doutputs)) 

240 if not backwards_function.outputs: 

241 return backwards_function.structured_outputs 

242 

243 op.graph._add_function_recursive(forward_function) # pylint: disable=protected-access 

244 

245 # pylint: disable=protected-access 

246 # Rewrite an inference call op to be a forward call op 

247 op._set_func_attr("f", forward_function.name) 

248 op._set_type_list_attr( 

249 "Tout", 

250 [ 

251 o.dtype.as_datatype_enum 

252 for o in forward_function.function_type.flat_outputs 

253 ], 

254 ) 

255 truncated_outputs = forward_function.function_type.flat_outputs[ 

256 len(op.outputs) : 

257 ] 

258 op._add_outputs( 

259 [o.dtype.as_datatype_enum for o in truncated_outputs], 

260 [o.shape for o in truncated_outputs], 

261 ) 

262 for i in range(len(op.outputs)): 

263 output_type = forward_function.function_type.flat_outputs[i] 

264 handle_data = output_type.dtype._handle_data 

265 if handle_data: 

266 handle_data_util.set_handle_data(op.outputs[i], handle_data) 

267 # pylint: enable=protected-access 

268 

269 capture_mapping = dict( 

270 zip((ops.tensor_id(t) for t in self._func_graph.outputs), op.outputs)) 

271 remapped_captures = [ 

272 capture_mapping.get(ops.tensor_id(capture), capture) 

273 for capture in backwards_function.captured_inputs 

274 ] 

275 

276 # Replace Nones with zeros since we're calling a graph function which 

277 # expects numeric inputs. 

278 cleaned_doutputs = [] 

279 for doutput, placeholder in zip(doutputs, self._func_graph.outputs): 

280 if backprop_util.IsTrainable(placeholder): 

281 if isinstance(doutput, indexed_slices.IndexedSlices): 

282 # Gradient passed to a backward ConcreteFunction must be tf.Tensor, 

283 # so we convert tf.IndexedSlices to tf.Tensor. 

284 cleaned_doutputs.append(ops.convert_to_tensor(doutput)) 

285 elif doutput is not None: 

286 cleaned_doutputs.append(doutput) 

287 else: 

288 cleaned_doutputs.append(default_gradient.zeros_like(placeholder)) 

289 

290 # Compute the gradients using the side outputs 

291 return backwards_function._call_flat( # pylint: disable=protected-access 

292 cleaned_doutputs, remapped_captures) 

293 

294 def get_gradient_function(self): 

295 """Returns gradient function. 

296 

297 The gradient rewrites an inference call op to a forward call op, but does 

298 not modify a pre-existing forward call op. It then computes the gradient 

299 from the output's gradients and the side outputs of the forward op. 

300 """ 

301 return self._rewrite_forward_and_call_backward 

302 

303 def forward(self, inference_args=None, input_tangents=None): 

304 """A forward function with only user-specified outputs. 

305 

306 The call operation for the returned inference function can be rewritten into 

307 a forward function. This only happens if the backward function (from the 

308 `backward` method) ends up being used to compute gradients. 

309 

310 This approach avoids constructing unnecessary graphs, but it only works if 

311 we are calling this function when not executing eagerly. 

312 

313 Args: 

314 inference_args: A flat list of Tensors, arguments to the inference 

315 function. Unused, but taken for compatibility with 

316 _TapeGradientFunctions. 

317 input_tangents: A flat list of Tensors, jvps associated with 

318 `inference_args`. Unused; if required, tape functions must be used 

319 instead. 

320 

321 Returns: 

322 An atomic_function.AtomicFunction. 

323 """ 

324 del inference_args # unused 

325 if input_tangents: 

326 # This class does not support special-cased forwardprop. The arguments are 

327 # here for compatibility with _TapeGradientFunctions. 

328 raise errors.InternalError("unexpectedly got forwardprop information in " 

329 "a class that does not support forwardprop.") 

330 return self._inference_function 

331 

332 def _backward(self, outputs): 

333 """Fetch a backward function for `outputs` from the forward function.""" 

334 def _backward_function(*args): 

335 call_op = outputs[0].op 

336 return self._rewrite_forward_and_call_backward(call_op, *args) 

337 return _backward_function, outputs 

338 

339 def record(self, flat_outputs, inference_args, input_tangents): 

340 """Record the function call operation. 

341 

342 _DelayedRewriteGradientFunctions supports only first-order backprop tape 

343 gradients (and then only when graph building). It does not work with 

344 higher-order tape gradients or forward autodiff, but does work with 

345 higher-order symbolic gradients (tf.gradients). 

346 

347 Args: 

348 flat_outputs: The result of running `forward`. 

349 inference_args: A flat list of Tensors with inference inputs to the 

350 operation. 

351 input_tangents: A flat list of Tensors with input tangents consumed by the 

352 operation. 

353 """ 

354 backward_function, to_record = self._backward(flat_outputs) 

355 record.record_operation( 

356 self._inference_function.cached_definition.signature.name, 

357 to_record, 

358 inference_args + input_tangents, 

359 backward_function, 

360 ) 

361 

362 

363# Contains information about a forward function wrapped to compute jvps. 

364_ForwardWrapper = collections.namedtuple( 

365 "_ForwardWrapper", ( 

366 # The wrapper Graph. 

367 "graph", 

368 # A flat list of non-tangent Tensor outputs from the wrapped forward 

369 # function. 

370 "outputs", 

371 # Indices for output tangents, same format as 

372 # forwardprop_util.pack_tangents. 

373 "output_indices", 

374 # A flat list of tangents for `outputs`. 

375 "output_tangents")) 

376 

377 

378class _TapeGradientFunctions(object): 

379 """Caches forward and backward functions compatible with eager gradients. 

380 

381 In contrast to the delayed-rewrite approach in 

382 `_DelayedRewriteGradientFunctions` which only works with delayed execution, 

383 the forward function generated by this class has a fixed set of outputs which 

384 may be preserved by a tape in order to compute gradients later. 

385 

386 This class is abstract; its child classes differ in how many side outputs of 

387 the forward function their backward function accepts gradients for, which 

388 determines whether higher-order tape gradients are possible. 

389 """ 

390 

391 def __init__(self, func_graph, attrs, func_graph_deleter, 

392 forwardprop_input_indices, delayed_rewrite_functions, 

393 need_gradients_for_jvps): 

394 self._func_graph = func_graph 

395 self._forward_graph = None 

396 self._attrs = attrs 

397 self._forward = None 

398 self._backward = None 

399 self._num_outputs = len(func_graph.outputs) 

400 self._func_graph_deleter = func_graph_deleter 

401 self._forwardprop_input_indices = forwardprop_input_indices 

402 self._forwardprop_output_indices = None 

403 self._num_forwardprop_outputs = 0 

404 self._num_inference_outputs = len(func_graph.outputs) 

405 self._num_trainable_inference_outputs = len( 

406 [t for t in func_graph.outputs if backprop_util.IsTrainable(t)]) 

407 self._delayed_rewrite_functions = delayed_rewrite_functions 

408 self._need_gradients_for_jvps = need_gradients_for_jvps 

409 

410 def _build_functions_for_outputs( 

411 self, outputs, inference_args, input_tangents): 

412 """Forward+backward functions where the backward function sees `outputs`.""" 

413 # First figure out which of `outputs` are trainable. We'll accept gradients 

414 # for each of these in the backward function. 

415 trainable_outputs = [] 

416 trainable_indices = [] 

417 for index, output in enumerate(outputs): 

418 

419 if backprop_util.IsTrainable(output): 

420 trainable_outputs.append(output) 

421 trainable_indices.append(index) 

422 

423 backwards_graph = func_graph_module.FuncGraph( 

424 _backward_name(self._func_graph.name)) 

425 with backwards_graph.as_default(): 

426 gradients_wrt_outputs = [] 

427 for output in trainable_outputs: 

428 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 

429 output) 

430 gradient_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 

431 handle_data_util.copy_handle_data(output, gradient_placeholder) 

432 gradients_wrt_outputs.append(gradient_placeholder) 

433 with ops.device(None): 

434 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 

435 trainable_outputs, 

436 self._func_graph.inputs, 

437 grad_ys=gradients_wrt_outputs, 

438 src_graph=self._func_graph) 

439 

440 if input_tangents: 

441 # Convert IndexedSlices to dense tensors (as we do elsewhere for 

442 # function gradients). Our C++ bindings don't know how to handle them 

443 # currently. 

444 gradients_wrt_inputs = nest.map_structure( 

445 lambda x: ops.convert_to_tensor(x) if x is not None else None, 

446 gradients_wrt_inputs) 

447 captures_from_forward = [ 

448 c for c in backwards_graph.external_captures 

449 if not isinstance(c, ops.EagerTensor) and c.graph is self._func_graph 

450 ] 

451 existing_outputs = object_identity.ObjectIdentitySet( 

452 self._func_graph.outputs) 

453 for capture in captures_from_forward: 

454 if capture not in existing_outputs: 

455 existing_outputs.add(capture) 

456 self._func_graph.outputs.append(capture) 

457 

458 # The ordering of `backwards_graph.inputs` is important: inputs of 

459 # `backward_function` correspond to outputs (including 

460 # side outputs) of `self._tape_forward_function`. 

461 backwards_graph.inputs = ( 

462 gradients_wrt_outputs + backwards_graph.internal_captures) 

463 backwards_graph.outputs.extend( 

464 grad 

465 for grad in nest.flatten(gradients_wrt_inputs, expand_composites=True) 

466 if grad is not None) 

467 backwards_graph.structured_outputs = gradients_wrt_inputs 

468 

469 forward_function, backward_function = _create_forward_backward_with_graph( 

470 self._attrs, self._func_graph, backwards_graph) 

471 

472 if not input_tangents: 

473 # There is no need to special-case forwardprop, so we can return the 

474 # forward+backward pair we've created without further wrapping. 

475 return (forward_function, self._func_graph, backward_function, 

476 # No forwardprop outputs. 

477 None, 0) 

478 forward_wrapper = self._wrap_forward_function_with_jvps( 

479 forward_function, backward_function, inference_args, input_tangents) 

480 (wrapped_backwards_graph, 

481 forward_wrapper) = self._wrap_backward_function_with_jvp_backprop( 

482 backward_function, gradients_wrt_outputs, forward_wrapper) 

483 # Now that we've added new captures, we need to make sure forward outputs 

484 # are in the same order the backward function expects them to be in: 

485 # [inference outputs] + [jvps] + [side outputs] + [captures]. 

486 forward_wrapper = self._shuffle_forward_outputs(forward_wrapper) 

487 (wrapped_forward_function, 

488 wrapped_backward_function) = _create_forward_backward_with_graph( 

489 self._attrs, forward_wrapper.graph, wrapped_backwards_graph) 

490 if (len(inference_args) + len(input_tangents) 

491 != len(forward_wrapper.graph.inputs)): 

492 raise errors.InternalError( 

493 f"The forward graph had {len(forward_wrapper.graph.inputs)} inputs, " 

494 f"but we expected {len(inference_args) + len(input_tangents)} " 

495 f"({len(inference_args)} inference inputs and " 

496 f"{len(input_tangents)} input tangents).") 

497 return (wrapped_forward_function, forward_wrapper.graph, 

498 wrapped_backward_function, forward_wrapper.output_indices, 

499 len(forward_wrapper.output_tangents)) 

500 

501 def _wrap_forward_function_with_jvps( 

502 self, forward_function, backward_function, 

503 inference_args, input_tangents): 

504 """Adds inline JVP computation to a forward function.""" 

505 forward_wrapper_graph = func_graph_module.FuncGraph( 

506 _forward_name(self._func_graph.name)) 

507 with forward_wrapper_graph.as_default(): 

508 # Tell forward accumulators to free up space for new JVP computations, 

509 # since one may be in the process of computing a JVP (if that computation 

510 # triggered this function building). 

511 # 

512 # We'll make symbolic versions of input JVPs, run the forward function 

513 # under forward accumulators to get symbolic output JVPs, then set those 

514 # as outputs of the new wrapped forward function. 

515 with forwardprop_util.push_forwardprop_state(): 

516 forward_captures = { 

517 ops.tensor_id(internal): external 

518 for external, internal in self._func_graph.captures} 

519 for input_index, real_input in enumerate(self._func_graph.inputs): 

520 # This loop is more or less equivalent to running tf.identity on each 

521 # of self._func_graph.inputs. However, doing that also captures jvps 

522 # for resource handles, which confuses the jvp capturing code below 

523 # (since primal inputs are interwoven with jvp inputs). 

524 input_placeholder = array_ops.placeholder( 

525 dtype=real_input.dtype, 

526 shape=real_input.shape) 

527 capture = forward_captures.get(ops.tensor_id(real_input)) 

528 if capture is not None: 

529 forward_wrapper_graph.add_capture(capture, input_placeholder) 

530 if capture.dtype == dtypes.resource: 

531 handle_data_util.copy_handle_data(capture, input_placeholder) 

532 else: 

533 forward_wrapper_graph.inputs.append(input_placeholder) 

534 for inp, arg in zip(forward_wrapper_graph.inputs, inference_args): 

535 record.record_operation( 

536 "captured_value", [inp], [arg], 

537 backward_function=lambda x: [x], 

538 forward_function=lambda x: [x]) 

539 num_inference_inputs = len(inference_args) 

540 for tape_indices in self._forwardprop_input_indices: 

541 for input_index, jvp_index in tape_indices: 

542 input_placeholder = forward_wrapper_graph.inputs[input_index] 

543 if len(forward_wrapper_graph.inputs) != jvp_index: 

544 raise errors.InternalError( 

545 f"Expected {jvp_index} forward graph inputs, " 

546 f"got {len(forward_wrapper_graph.inputs)}.") 

547 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 

548 input_placeholder) 

549 jvp_placeholder = graph_placeholder(gradient_dtype, gradient_shape) 

550 external_jvp = input_tangents[jvp_index - num_inference_inputs] 

551 forward_wrapper_graph.add_capture(external_jvp, jvp_placeholder) 

552 tensor_shape.TensorShape( 

553 external_jvp.shape).assert_is_compatible_with( 

554 jvp_placeholder.shape) 

555 record.record_operation( 

556 "captured_value", 

557 [jvp_placeholder], 

558 [external_jvp], 

559 backward_function=lambda x: [x], 

560 forward_function=lambda x: [x]) 

561 forward_inputs = forward_wrapper_graph.inputs[:num_inference_inputs] 

562 gradient_function = ( 

563 self._delayed_rewrite_functions._rewrite_forward_and_call_backward) # pylint: disable=protected-access 

564 with ops.get_default_graph()._override_gradient_function( # pylint: disable=protected-access 

565 {"PartitionedCall": gradient_function, 

566 "StatefulPartitionedCall": gradient_function}): 

567 forward_outputs = forward_function(*forward_inputs) 

568 if isinstance(forward_outputs, ops.Operation): 

569 # _wrapped_backward_function expects a list, but if the function has 

570 # no outputs its call() returns an Operation. We need to undo that 

571 # so we don't cause problems later. 

572 forward_outputs = [] 

573 py_backward, _ = self._wrap_backward_function( 

574 self._func_graph, backward_function, forward_outputs) 

575 # We will never request backward tape gradients for this operation 

576 # directly since we're wrapping the call; forwardprop will call the 

577 # backward function (and nested forward accumulators may build 

578 # higher-order gradients), but any watching GradientTapes should ignore 

579 # it. 

580 # 

581 # TODO(allenl): It might be better to explicitly stop backward recording 

582 # so we don't use the second-order tape cases unnecessarily. 

583 record.record_operation_forwardprop_only( 

584 forward_function.cached_definition.signature.name, 

585 forward_outputs, forward_inputs, py_backward, None) 

586 output_indices, output_tangents = ( 

587 pywrap_tfe.TFE_Py_PackJVPs(forward_outputs)) 

588 output_tangents = [forward_wrapper_graph.capture(t) 

589 for t in output_tangents] 

590 return _ForwardWrapper( 

591 graph=forward_wrapper_graph, outputs=forward_outputs, 

592 output_indices=output_indices, output_tangents=output_tangents) 

593 

594 def _wrap_backward_function_with_jvp_backprop( 

595 self, backward_function, gradients_wrt_outputs, forward_wrapper): 

596 """Wraps `backward_function` to include gradients for JVPs.""" 

597 wrapped_backwards_graph = func_graph_module.FuncGraph( 

598 _backward_name(self._func_graph.name)) 

599 with wrapped_backwards_graph.as_default(): 

600 py_backward, recorded_outputs = self._wrap_backward_function( 

601 self._func_graph, backward_function, forward_wrapper.outputs) 

602 trainable_index = 0 

603 forward_doutputs = [] 

604 doutput_args = [] 

605 for output in recorded_outputs: 

606 if backprop_util.IsTrainable(output): 

607 doutput = gradients_wrt_outputs[trainable_index] 

608 doutput_placeholder = graph_placeholder(doutput.dtype, doutput.shape) 

609 doutput_args.append(doutput_placeholder) 

610 forward_doutputs.append(doutput_placeholder) 

611 trainable_index += 1 

612 else: 

613 doutput_args.append(None) 

614 

615 dinputs = py_backward(*doutput_args) 

616 existing_outputs = object_identity.ObjectIdentitySet( 

617 forward_wrapper.outputs + forward_wrapper.output_tangents) 

618 num_processed_output_tangents = 0 

619 gradients_wrt_output_tangents = [] 

620 tangent_doutputs = [] 

621 output_tangents = forward_wrapper.output_tangents 

622 output_indices = forward_wrapper.output_indices 

623 if self._need_gradients_for_jvps: 

624 # TODO(allenl): Consider using a throwaway graph to avoid extra gradient 

625 # evaluations; gradients for jvps may have common subgraphs. 

626 while num_processed_output_tangents != len(output_tangents): 

627 for output in output_tangents[num_processed_output_tangents:]: 

628 gradient_shape, gradient_dtype = default_gradient.shape_and_dtype( 

629 output) 

630 placeholder = graph_placeholder(gradient_dtype, gradient_shape) 

631 gradients_wrt_output_tangents.append(placeholder) 

632 tangent_doutputs.append(placeholder) 

633 num_processed_output_tangents = len(output_tangents) 

634 with ops.device(None): 

635 gradients_wrt_inputs = gradients_util._GradientsHelper( # pylint: disable=protected-access 

636 output_tangents, 

637 forward_wrapper.graph.inputs, 

638 grad_ys=gradients_wrt_output_tangents, 

639 src_graph=forward_wrapper.graph) 

640 dinputs = [ 

641 backprop_util.AggregateIndexedSlicesGradients((existing, new)) 

642 for existing, new in zip(dinputs, gradients_wrt_inputs) 

643 if existing is not None or new is not None] 

644 dinputs.extend(gradients_wrt_inputs[len(dinputs):]) 

645 captures_from_forward = [ 

646 c for c in wrapped_backwards_graph.external_captures 

647 if (not isinstance(c, ops.EagerTensor) 

648 and c.graph is forward_wrapper.graph)] 

649 for capture in captures_from_forward: 

650 if capture not in existing_outputs: 

651 existing_outputs.add(capture) 

652 forward_wrapper.outputs.append(capture) 

653 output_indices, output_tangents = ( 

654 forwardprop_util.pack_tangents(forward_wrapper.outputs)) 

655 output_tangents = [forward_wrapper.graph.capture(t) 

656 for t in output_tangents] 

657 for t in output_tangents: 

658 existing_outputs.add(t) 

659 wrapped_backwards_graph.inputs = ( 

660 forward_doutputs[:self._num_trainable_inference_outputs] 

661 + tangent_doutputs 

662 + forward_doutputs[self._num_trainable_inference_outputs:] 

663 + wrapped_backwards_graph.internal_captures) 

664 wrapped_backwards_graph.structured_outputs = dinputs 

665 wrapped_backwards_graph.outputs = [t for t in dinputs if t is not None] 

666 return (wrapped_backwards_graph, 

667 forward_wrapper._replace(output_indices=output_indices, 

668 output_tangents=output_tangents)) 

669 

670 def _shuffle_forward_outputs(self, forward_wrapper): 

671 """Reorders function outputs so captures are last.""" 

672 def _index_map(original): 

673 if original < self._num_inference_outputs: 

674 return original 

675 if original >= len(forward_wrapper.outputs): 

676 return (original - len(forward_wrapper.outputs) 

677 + self._num_inference_outputs) 

678 return original + len(forward_wrapper.output_tangents) 

679 output_indices = nest.map_structure( 

680 _index_map, forward_wrapper.output_indices) 

681 forward_wrapper.graph.outputs = ( 

682 forward_wrapper.outputs[:self._num_inference_outputs] 

683 + forward_wrapper.output_tangents 

684 + forward_wrapper.outputs[self._num_inference_outputs:]) 

685 return forward_wrapper._replace(output_indices=output_indices) 

686 

687 def forward(self, inference_args, input_tangents): 

688 """Construct or fetch a forward function with side-outputs. 

689 

690 When graph building without a tape active, symbolic gradients rely on 

691 regenerating the backward function for higher-order gradients (to account 

692 for new side outputs of the rewritten forward function call). Thus there is 

693 no fixed backward function for this case. However, when a tape is active 

694 (eager or graph building), we generate fixed backward and forward functions 

695 at forward function call time. 

696 

697 This difference between the tape and non-tape cases is to avoid building 

698 unneeded backward functions while graph building (where we may or may not 

699 eventually need gradients). 

700 

701 Args: 

702 inference_args: A flat list of Tensors, arguments to the inference 

703 function. 

704 input_tangents: A flat list of Tensors, jvps associated with 

705 `inference_args`. 

706 

707 Returns: 

708 A forward atomic_function.AtomicFunction. 

709 """ 

710 if self._forward is None: 

711 (self._forward, self._forward_graph, self._backward, 

712 self._forwardprop_output_indices, self._num_forwardprop_outputs) = ( 

713 self._forward_and_backward_functions(inference_args, input_tangents)) 

714 return self._forward 

715 

716 def _wrap_backward_function(self, forward_graph, backward, outputs): 

717 """Create a backward function given `outputs` from the forward function.""" 

718 capture_mapping = dict( 

719 zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs)) 

720 captured_inputs = backward.captured_inputs 

721 remapped_captures = [ 

722 capture_mapping.get(ops.tensor_id(capture), capture) 

723 for capture in captured_inputs 

724 ] 

725 if any(t.graph is forward_graph for t in remapped_captures 

726 if not isinstance(t, ops.EagerTensor)): 

727 incorrect_mapping = [t for t in remapped_captures 

728 if (not isinstance(t, ops.EagerTensor) and 

729 t.graph is not forward_graph)] 

730 raise errors.InternalError("Failed to map all backward graph captures to " 

731 "the forward graph. Incorrectly mapped: " 

732 f"{incorrect_mapping}.") 

733 # We may need to use zeros_like to get a zero for variant Tensors with 

734 # unconnected gradients. We do that in advance so we don't have to hold on 

735 # to the outputs themselves, which may not be needed otherwise. 

736 variant_zeros_like = {} 

737 backward_function_inputs = (len(backward.inputs) - len(captured_inputs)) 

738 recorded_outputs = [] 

739 trainable_recorded_outputs = 0 

740 skip_positions = [] 

741 if self._num_forwardprop_outputs and not self._need_gradients_for_jvps: 

742 relevant_outputs = ( 

743 outputs[:self._num_inference_outputs] 

744 + outputs[self._num_inference_outputs 

745 + self._num_forwardprop_outputs:]) 

746 else: 

747 relevant_outputs = outputs 

748 for output_index, output in enumerate(relevant_outputs): 

749 if trainable_recorded_outputs < backward_function_inputs: 

750 recorded_outputs.append(output) 

751 if backprop_util.IsTrainable(output): 

752 trainable_recorded_outputs += 1 

753 else: 

754 skip_positions.append(output_index) 

755 if output.dtype == dtypes.variant: 

756 variant_zeros_like[output_index] = default_gradient.zeros_like(output) 

757 

758 def _backward_function_wrapper(*args): 

759 """Process output gradients and call the backward function.""" 

760 if not backward.outputs: 

761 return backward.structured_outputs 

762 

763 processed_args = [] 

764 input_index = 0 

765 for output_index, arg in enumerate(args): 

766 # Convert IndexedSlices to dense tensors. The IndexedSlices optimization 

767 # is only really effective when doing tf.gather(variable) as the 

768 # adjoint functions for most operations are unlikely to preserve the 

769 # sparsity in IndexedSlices. 

770 if isinstance(arg, indexed_slices.IndexedSlices): 

771 arg = ops.convert_to_tensor(arg) 

772 if output_index in skip_positions: 

773 continue 

774 if arg is None: 

775 # We're calling a (non-polymorphic) ConcreteFunction, so we need to 

776 # have a Tensor value for each Tensor we thought would be trainable 

777 # based on its dtype, even if it ended up being unconnected. 

778 input_placeholder = backward.inputs[ 

779 input_index] 

780 if input_placeholder.dtype == dtypes.variant: 

781 arg = variant_zeros_like[output_index] 

782 else: 

783 arg = array_ops.zeros( 

784 *default_gradient.shape_and_dtype(input_placeholder)) 

785 processed_args.append(arg) 

786 input_index += 1 

787 if input_index >= backward_function_inputs: 

788 break 

789 return backward._call_flat( # pylint: disable=protected-access 

790 processed_args, remapped_captures) 

791 

792 return _backward_function_wrapper, recorded_outputs 

793 

794 def record(self, flat_outputs, inference_args, input_tangents): 

795 """Record the function call operation. 

796 

797 For backprop, indicates the backward function to use and which new Tensors 

798 must be watched. For forwardprop from eager, the function call itself will 

799 have produced tangents which need to be recorded. 

800 

801 Args: 

802 flat_outputs: The result of running `forward`. 

803 inference_args: A flat list of Tensors with inference inputs to the 

804 operation. 

805 input_tangents: A flat list of Tensors with input tangents consumed by the 

806 operation. 

807 """ 

808 backward_function, to_record = self._wrap_backward_function( 

809 self._forward_graph, self._backward, flat_outputs) 

810 if self._forwardprop_output_indices: 

811 record.record_operation_backprop_only( 

812 self._forward.cached_definition.signature.name, 

813 to_record, inference_args, 

814 backward_function) 

815 record.record_operation_forwardprop_only( 

816 self._forward.cached_definition.signature.name, 

817 flat_outputs, inference_args + input_tangents, 

818 backward_function, 

819 self._forwardprop_output_indices) 

820 else: 

821 record.record_operation(self._forward.cached_definition.signature.name, 

822 to_record, inference_args + input_tangents, 

823 backward_function) 

824 

825 

826class _FirstOrderTapeGradientFunctions(_TapeGradientFunctions): 

827 """Caches tape-friendly functions for first-order gradients.""" 

828 

829 def __init__(self, func_graph, attrs, func_graph_deleter, 

830 forwardprop_input_indices, delayed_rewrite_functions, 

831 need_gradients_for_jvps): 

832 super().__init__(func_graph, attrs, func_graph_deleter, 

833 forwardprop_input_indices, delayed_rewrite_functions, 

834 need_gradients_for_jvps) 

835 self._func_graph_deleter = func_graph_deleter 

836 self._forwardprop_input_indices = forwardprop_input_indices 

837 

838 def _forward_and_backward_functions(self, inference_args, input_tangents): 

839 """Shortcut for when only first-order gradients are required. 

840 

841 The returned backward function does not accept gradients with respect to 

842 side output of forward_function. This is fine as long as the user can't 

843 possibly request second order tape gradients, as when they've used a single 

844 non-persistent GradientTape. Since we don't need the backward function to 

845 take gradients with respect to side outputs, we can skip some potentially 

846 slow graph building. 

847 

848 Args: 

849 inference_args: A flat list of Tensors, arguments to the inference 

850 function. 

851 input_tangents: A flat list of Tensors, jvps associated with 

852 `inference_args`. 

853 

854 Returns: 

855 A tuple of (forward_function, backward_function): 

856 forward_function: Takes the same inputs as the inference function, but 

857 returns side outputs used by backward_function in addition to the 

858 inference function's outputs. 

859 backward_function: Takes side outputs from forward_function and 

860 gradients with respect to the "real" outputs of forward_function and 

861 returns gradients with respect to the inputs. 

862 """ 

863 outputs = self._func_graph.outputs[:self._num_inference_outputs] 

864 return self._build_functions_for_outputs( 

865 outputs, inference_args, input_tangents) 

866 

867 

868class _HigherOrderTapeGradientFunctions(_TapeGradientFunctions): 

869 """Caches tape-friendly functions for higher-order gradients.""" 

870 

871 # TODO(b/136189779): Cond/while under a tape may need similar logic. Consider 

872 # generalizing if so. 

873 def _forward_and_backward_functions(self, inference_args, input_tangents): 

874 """Forward and backward functions suitable for higher-order gradients. 

875 

876 Unlike in `_FirstOrderTapeGradientFunctions`, the backward function built by 

877 this method accepts gradients for all of the outputs of the returned forward 

878 function, including side outputs. 

879 

880 Args: 

881 inference_args: A flat list of Tensors, arguments to the inference 

882 function. 

883 input_tangents: A flat list of Tensors, jvps associated with 

884 `inference_args`. 

885 

886 Returns: 

887 A tuple of (forward_function, backward_function): 

888 forward_function: Takes the same inputs as the inference function, but 

889 returns side outputs used by backward_function in addition to the 

890 inference function's outputs. 

891 backward_function: Takes side outputs from forward_function and 

892 gradients with respect to all of its outputs, real and side. Returns 

893 gradients with respect to the inputs. 

894 """ 

895 outputs = [] 

896 iteration_count = 0 

897 # First we need to figure out how many side outputs from the forward pass 

898 # will be required. We do this in a temporary graph to avoid actually 

899 # running multiple copies of the backward pass (one per _GradientsHelper 

900 # call). 

901 # 

902 # While computing gradients, the backward function captures Tensors from 

903 # the forward function. We add these as side outputs of the original 

904 # function. However, we then need to accept output gradients with respect 

905 # to these side outputs for higher order gradients to work. Thus we loop 

906 # until the number of outputs of the function stabilizes. Note that this 

907 # is only required for tape gradients, where we need to declare in advance 

908 # all of the forward op's outputs: symbolic gradients with tf.gradients 

909 # instead rely on regenerating backward functions when higher-order 

910 # gradients are requested. 

911 while (len(outputs) < len(self._func_graph.outputs) 

912 # It's possible for gradient generation to add new ops to the forward 

913 # pass. If all of the new outputs are non-trainable, there's no 

914 # reason to continue. 

915 and any(backprop_util.IsTrainable(output) 

916 for output in self._func_graph.outputs[len(outputs):])): 

917 iteration_count += 1 

918 if iteration_count >= 20 and iteration_count % 5 == 0: 

919 new_op_with_trainable_output = None 

920 num_new_trainable_outputs = 0 

921 for output in self._func_graph.outputs[len(outputs):]: 

922 if backprop_util.IsTrainable(output): 

923 num_new_trainable_outputs += 1 

924 new_op_with_trainable_output = output.op 

925 logging.warning( 

926 ("Determining side outputs for the function '{}' is taking longer " 

927 "than expected ({} iterations, typically this converges in 5 or " 

928 "so). This could indicate that a gradient registration is adding " 

929 "new ops to the forward pass every time gradients are generated. " 

930 "{} new trainable output(s) were added this iteration, one from " 

931 "the following op:\n {}\nThis may indicate a TensorFlow bug, or " 

932 "an issue in a tf.custom_gradient.") 

933 .format( 

934 self._func_graph.name, iteration_count, 

935 num_new_trainable_outputs, new_op_with_trainable_output)) 

936 outputs = list(self._func_graph.outputs) 

937 self._build_functions_for_outputs( 

938 outputs, inference_args, input_tangents) 

939 

940 (forward_function, forward_graph, 

941 backward_function, output_indices, num_output_tangents) = ( 

942 self._build_functions_for_outputs( 

943 outputs, inference_args, input_tangents)) 

944 if (len(self._func_graph.outputs) > len(outputs) 

945 and any(backprop_util.IsTrainable(output) 

946 for output in self._func_graph.outputs[len(outputs):])): 

947 raise errors.InternalError( 

948 "Unexpectedly added new outputs to the forward function when " 

949 "building the backward function: " 

950 f"{self._func_graph.outputs[len(outputs):]}.") 

951 return (forward_function, forward_graph, backward_function, output_indices, 

952 num_output_tangents) 

953 

954 

955class _ForwardBackwardCall(object): 

956 """Holds the state of a function call between execution and recording.""" 

957 

958 __slots__ = [ 

959 "_functions", "_inference_args", "_input_tangents", "_tape_watching" 

960 ] 

961 

962 def __init__(self, functions, inference_args, input_tangents, tape_watching): 

963 """Collects information about the function call. 

964 

965 Args: 

966 functions: An object which produces forward and backward functions, either 

967 a _DelayedRewriteGradientFunctions or a _TapeGradientFunctions object. 

968 inference_args: A flat list of Tensors, arguments to the inference 

969 function. 

970 input_tangents: A flat list of Tensors, jvps associated with 

971 `inference_args`. 

972 tape_watching: Boolean, with True indicating that recording is necessary. 

973 """ 

974 self._functions = functions 

975 self._inference_args = inference_args 

976 self._input_tangents = input_tangents 

977 self._tape_watching = tape_watching 

978 

979 def forward(self): 

980 """Builds or retrieves a forward function for this call.""" 

981 forward_function = self._functions.forward( 

982 self._inference_args, self._input_tangents) 

983 return forward_function, self._inference_args + self._input_tangents 

984 

985 def record(self, flat_outputs): 

986 """Given outputs from the execution of `forward`, records the operation.""" 

987 if (self._tape_watching 

988 and not isinstance(flat_outputs, ops.Operation) 

989 and flat_outputs is not None): 

990 # We only record function calls which have outputs, and then only when a 

991 # tape is watching. 

992 self._functions.record( 

993 flat_outputs, self._inference_args, self._input_tangents) 

994 

995 

996class ConcreteFunction(core.ConcreteFunction, trackable.Trackable): 

997 """A `tf.types.experimental.ConcreteFunction` created from `tf.function`.""" 

998 

999 def __init__(self, func_graph, attrs=None, shared_func_graph=True, spec=None): 

1000 """Initialize a `ConcreteFunction`. 

1001 

1002 Args: 

1003 func_graph: An instance of FuncGraph: the function body to wrap. 

1004 attrs: (optional) dict mapping names of attributes to their AttrValue 

1005 values. Attributes in `attrs` will be included in this function's 

1006 definition. 

1007 shared_func_graph: If False, the ConcreteFunction takes ownership of 

1008 `func_graph` and will break reference cycles when it is deleted. This 

1009 makes the FuncGraph inoperable. 

1010 spec: FunctionSpec for the original function. If not specified, then this 

1011 ConcreteFunction may only be called using the flat signature. 

1012 

1013 Raises: 

1014 ValueError: If number of input_placeholders is not equal to the number 

1015 of function inputs. 

1016 """ 

1017 # _arg_keywords and _num_positional_args define the flat signature. They 

1018 # are assigned after construction. 

1019 self._arg_keywords = None 

1020 self._num_positional_args = None 

1021 

1022 self._func_graph = func_graph 

1023 self._captured_inputs = self._func_graph.external_captures + self._func_graph.deferred_external_captures 

1024 

1025 # spec defines the structured signature. 

1026 self._set_function_spec(spec) 

1027 

1028 if attrs and attributes_lib.IMPLEMENTS in attrs: 

1029 # The alternative is to silently drop "implements" tag 

1030 # but it seems likely it would lead to hard to catch bugs. 

1031 # Another alternative is to make func_body to preserve the order 

1032 # of arguments if variables are present. Yet another option 

1033 # is to automatically replace variables as arguments to functions 

1034 # to v.read_value() whenever "implements" tag is present 

1035 # Anytime we annotate existing function we probably want to wrap 

1036 # it with safe read_value for backward compatibility. 

1037 has_resource_vars = any( 

1038 inp.dtype == dtypes.resource for inp in self.inputs) 

1039 

1040 assert not any((has_resource_vars, self._captured_inputs)), ( 

1041 'Function {name} has "{attr}={value}" attribute and thus can not ' 

1042 "depend on any tensors outside of its signature or modify variables. " 

1043 "\n\nNote: variables are always captured and cause function " 

1044 "re-tracing for every variable called.\n" 

1045 " inputs: {inputs}\n captures: {captured}\n\n" 

1046 "To pass a variable to such function use " 

1047 "use variable.read_value().".format( 

1048 name=func_graph.name, 

1049 attr=attributes_lib.IMPLEMENTS, 

1050 value=attrs[attributes_lib.IMPLEMENTS], 

1051 inputs=self.inputs, 

1052 captured=self._captured_inputs)) 

1053 self._output_shapes = tuple( 

1054 output.shape for output in self._func_graph.outputs) 

1055 self._attrs = _parse_func_attrs(attrs or {}) 

1056 

1057 if shared_func_graph: 

1058 self._garbage_collector = None 

1059 else: 

1060 self._garbage_collector = ConcreteFunctionGarbageCollector(func_graph) 

1061 

1062 # Pairs of forward and backward functions used for computing gradients. 

1063 # 

1064 # These each get a reference to the FuncGraph deleter since they use the 

1065 # FuncGraph directly. 

1066 self._delayed_rewrite_functions = _DelayedRewriteGradientFunctions( 

1067 func_graph, self._attrs, self._garbage_collector) 

1068 self._first_order_tape_functions = {} 

1069 self._higher_order_tape_functions = {} 

1070 # Cache the inference function to avoid a (Python) function call when not 

1071 # building gradients. 

1072 self._inference_function = self._delayed_rewrite_functions.forward() 

1073 

1074 def _set_function_spec(self, spec): 

1075 """Enables the structured signature by supplying a spec.""" 

1076 self._function_spec = None 

1077 self._pre_initialized_function_spec = spec 

1078 self._initialize_function_spec() 

1079 

1080 def _initialize_function_spec(self): 

1081 """Updates `self._function_spec` to include varargs and bound variables. 

1082 

1083 Adds new positional arguments for any varargs (i.e., for args that are 

1084 in `structured_input_signature`, but not in the original fullargspec.args). 

1085 

1086 Replaces `defaults` and `kwonlydefaults` with the `BOUND_VALUE`, for 

1087 all args and kwargs in `structured_input_signature`. 

1088 

1089 Sets `varkw` and `varargs` to None. 

1090 """ 

1091 if self._pre_initialized_function_spec is None: 

1092 return # e.g., SavedBareConcreteFunction doesn't have function_spec yet. 

1093 assert not self._function_spec, "already initialized" 

1094 spec = self._pre_initialized_function_spec 

1095 unconstrainted_poly_type = function_type_lib.FunctionType( 

1096 [ 

1097 function_type_lib.Parameter(p.name, p.kind, p.optional, None) 

1098 for p in spec.function_type.parameters.values() 

1099 ] 

1100 ) 

1101 arg_specs, kwarg_specs = self.structured_input_signature 

1102 

1103 _, func_type, _ = function_type_lib.canonicalize_to_monomorphic( 

1104 arg_specs, 

1105 { 

1106 function_type_lib.sanitize_arg_name(k): v 

1107 for k, v in kwarg_specs.items() 

1108 }, 

1109 self._pre_initialized_function_spec.default_values, 

1110 {}, 

1111 unconstrainted_poly_type, 

1112 ) 

1113 

1114 self._function_spec = function_spec.FunctionSpec( 

1115 func_type, 

1116 {d: function_spec.BOUND_VALUE for d in spec.default_values}, 

1117 spec.is_pure, 

1118 name=self._func_graph.name, 

1119 ) 

1120 

1121 @property 

1122 def variables(self): 

1123 """Sequence of variables for this function.""" 

1124 return tuple(self._func_graph.variables) 

1125 

1126 def set_variables(self, variables): 

1127 self._func_graph.variables = variables 

1128 

1129 @property 

1130 def trainable_variables(self): 

1131 """Sequence of trainable variables for this function.""" 

1132 return tuple(self._func_graph.trainable_variables) 

1133 

1134 def __call__(self, *args, **kwargs): 

1135 """Executes the wrapped function. 

1136 

1137 ConcreteFunctions have two signatures: 

1138 

1139 * The signature of the original function wrapped by this ConcreteFunction. 

1140 * A flat signature, where each argument accepts a single Tensor. 

1141 

1142 The original function signature is generally preferred, but the flat input 

1143 signature is supported for backward compatibility. 

1144 

1145 ### Original Function Signature 

1146 

1147 When calling a ConcreteFunction with the signature of the original function, 

1148 each argument must match the type or value that was used when the 

1149 ConcreteFunction's graph was traced. In particular: 

1150 

1151 * Tensor arguments (including CompositeTensors, such as RaggedTensor) must 

1152 have matching `TypeSpec`s. 

1153 * Non-Tensor arguments (such as booleans or ints) must have equal values. 

1154 * Nested arguments (such as lists, tuples, or dictionaries) must have the 

1155 same nesting structure; and each nested value must have a matching type 

1156 or value. 

1157 

1158 The default value for any arguments that were traced with non-Tensor values 

1159 is the value that was used in the trace. Arguments that were traced with 

1160 tensor arguments do not have a default value (even if the original function 

1161 had a default value for that argument). 

1162 

1163 ### Flat Signature 

1164 

1165 When calling a ConcreteFunction with the flat signature, the arguments 

1166 correspond to the flattened component tensors of the arguments that were 

1167 used to construct the ConcreteFunction. Parameter names are assigned based 

1168 on `TensorSpec.name` (when specified) or the original argument names (with 

1169 suffixes automatically added for nested arguments or composite tensors with 

1170 multiple components). 

1171 

1172 Args: 

1173 *args: Positional arguments to the concrete function. 

1174 **kwargs: Keyword arguments to the concrete function. 

1175 

1176 Returns: 

1177 The result of applying the TF function on the given Tensors. 

1178 

1179 Raises: 

1180 AssertionError: If this `ConcreteFunction` was not created through 

1181 `get_concrete_function`. 

1182 TypeError: If the arguments do not match the function's signature. 

1183 """ 

1184 return self._call_impl(args, kwargs) 

1185 

1186 def _call_impl(self, args, kwargs): 

1187 """See `__call__` for details.""" 

1188 with trace.Trace(self._func_graph.name, tf_function_call="concrete"): 

1189 # Construct the list of input tensors: check if the structured signature 

1190 # applies first; and if not, then use the flat signature. 

1191 if self._function_spec is not None: 

1192 try: 

1193 return self._call_with_structured_signature(args, kwargs) 

1194 except TypeError as structured_err: 

1195 try: 

1196 return self._call_with_flat_signature(args, kwargs) 

1197 except TypeError: 

1198 raise structured_err 

1199 

1200 return self._call_with_flat_signature(args, kwargs) 

1201 

1202 def _call_with_flat_signature(self, args, kwargs): 

1203 """Executes the wrapped function with the flat signature. 

1204 

1205 Args: 

1206 args: Positional arguments to the concrete function. 

1207 kwargs: Keyword arguments to the concrete function. 

1208 

1209 Returns: 

1210 The result of applying the function on the Tensors/Variables contained in 

1211 `args` and `kwargs`. 

1212 Raises: 

1213 TypeError: if `args` and `kwargs` do not match the flat signature of this 

1214 `ConcreteFunction`. 

1215 """ 

1216 if len(args) > self._num_positional_args: 

1217 raise TypeError( 

1218 f"{self._flat_signature_summary()} takes {self._num_positional_args} " 

1219 f"positional arguments, got {len(args)}.") 

1220 args = list(args) 

1221 kwargs = dict(kwargs) 

1222 kwargs = { 

1223 function_type_lib.sanitize_arg_name(k): v for k, v in kwargs.items() 

1224 } 

1225 for keyword in self._arg_keywords[len(args):]: 

1226 try: 

1227 args.append( 

1228 kwargs.pop( 

1229 function_type_lib.sanitize_arg_name(compat.as_str(keyword)))) 

1230 except KeyError: 

1231 specified_keywords = ( 

1232 list(self._arg_keywords[:len(args)]) + list(kwargs.keys())) 

1233 missing_required_args = sorted( 

1234 set(self._arg_keywords) - set(specified_keywords)) 

1235 raise TypeError(f"{self._flat_signature_summary()} missing required " 

1236 f"arguments: {', '.join(missing_required_args)}.") 

1237 if kwargs: 

1238 positional_arg_keywords = set(self._arg_keywords[:len(args)]) 

1239 for unused_key in kwargs: 

1240 if unused_key in positional_arg_keywords: 

1241 raise TypeError(f"{self._flat_signature_summary()} got two values " 

1242 f"for '{unused_key}'.") 

1243 raise TypeError(f"{self._flat_signature_summary()} got unexpected " 

1244 f"keyword arguments: {', '.join(sorted(kwargs))}.") 

1245 

1246 for i, arg in enumerate(args): 

1247 if not isinstance( 

1248 arg, (ops.Tensor, resource_variable_ops.BaseResourceVariable)): 

1249 raise TypeError(f"{self._flat_signature_summary()}: expected argument " 

1250 f"#{i}(zero-based) to be a Tensor; " 

1251 f"got {type(arg).__name__} ({arg}).") 

1252 return self._call_flat(args, self.captured_inputs) 

1253 

1254 def _call_with_structured_signature(self, args, kwargs): 

1255 """Executes the wrapped function with the structured signature. 

1256 

1257 Args: 

1258 args: Positional arguments to the concrete function. 

1259 kwargs: Keyword arguments to the concrete function. 

1260 

1261 Returns: 

1262 The result of applying the function on the Tensors/Variables contained in 

1263 `args` and `kwargs`. 

1264 Raises: 

1265 TypeError: if `args` and `kwargs` do not match the structured signature 

1266 of this `ConcreteFunction`. 

1267 """ 

1268 args, kwargs, filtered_flat_args = ( 

1269 self._function_spec.canonicalize_function_inputs(args, kwargs)) 

1270 return self._call_flat( 

1271 filtered_flat_args, 

1272 captured_inputs=self.captured_inputs) 

1273 

1274 def _call_flat(self, args, captured_inputs): 

1275 """Executes the wrapped function. 

1276 

1277 Args: 

1278 args: a list of Tensors or Variables. Arguments from the Python function 

1279 should be filtered before calling this method: objects aside from 

1280 Tensors, CompositeTensors, and Variables are ignored. Any 

1281 CompositeTensors other than ResourceVariables should be expanded before 

1282 calling this method. 

1283 captured_inputs: the captured inputs that are also part of the input args 

1284 to the actual execution. By default, it should be self._captured_inputs. 

1285 Returns: 

1286 The result of applying the TF function to `args`. 

1287 

1288 Raises: 

1289 ValueError: If `args` contains anything other than Tensors or Variables. 

1290 """ 

1291 ctx = context.context() 

1292 executing_eagerly = ctx.executing_eagerly() 

1293 

1294 # Copy saveable status of function's graph to current FuncGraph. 

1295 default_graph = ops.get_default_graph() 

1296 if default_graph.building_function and not self._func_graph.saveable: 

1297 default_graph.mark_as_unsaveable(self._func_graph.saving_errors) 

1298 

1299 if (record.could_possibly_record() or 

1300 hasattr(default_graph, "watch_variable")): 

1301 for v in self._func_graph.variables: 

1302 resource_variable_ops.variable_accessed(v) 

1303 

1304 tensor_inputs = [] 

1305 variables_used = set([]) 

1306 for i, arg in enumerate(args): 

1307 if isinstance(arg, resource_variable_ops.BaseResourceVariable): 

1308 # We can pass a variable more than once, and in this case we need to 

1309 # pass its handle only once. 

1310 if id(arg.handle) in variables_used: 

1311 continue 

1312 resource_variable_ops.variable_accessed(arg) 

1313 tensor_inputs.append(arg.handle) 

1314 variables_used.add(id(arg.handle)) 

1315 elif isinstance(arg, ops.Tensor): 

1316 tensor_inputs.append(arg) 

1317 else: 

1318 raise ValueError(f"{i:d}-th input {arg} must be a Tensor, got " 

1319 f"{type(arg)} when calling {self._func_graph.name}.") 

1320 

1321 if not executing_eagerly: 

1322 for i, tensor_input in enumerate(tensor_inputs): 

1323 # Can not compare shapes in these cases 

1324 # TODO(b/216506654): Consider moving this check elsewhere and making it 

1325 # work for all types (e.g. by including shape for Variables). 

1326 if (tensor_input.dtype == dtypes.resource or 

1327 tensor_input.dtype == dtypes.variant): 

1328 continue 

1329 

1330 # If we're graph building, shape inference is on. We check for input 

1331 # compatibility up front to avoid hard to debug incompatibilities 

1332 # later. 

1333 graph_input_shape = tensor_shape.TensorShape( 

1334 self._func_graph.inputs[i].shape) 

1335 if not graph_input_shape.is_compatible_with(tensor_input.shape): 

1336 raise ValueError( 

1337 f"Tensor {tensor_input} is not compatible with the shape this " 

1338 f"function was traced with. Expected shape " 

1339 f"{self._func_graph.inputs[i].shape}, but got shape " 

1340 f"{tensor_input.shape}.\n\nIf you called get_concrete_function, " 

1341 f"you may need to pass a tf.TensorSpec(..., shape=...) with a " 

1342 f"less specific shape, having None on axes which can vary.") 

1343 

1344 args = tensor_inputs + captured_inputs 

1345 possible_gradient_type = gradients_util.PossibleTapeGradientTypes(args) 

1346 if (possible_gradient_type == gradients_util.POSSIBLE_GRADIENT_TYPES_NONE 

1347 and executing_eagerly): 

1348 # No tape is watching; skip to running the function. 

1349 return self._build_call_outputs(self._inference_function(*args)) 

1350 forward_backward = self._select_forward_and_backward_functions( 

1351 args, 

1352 possible_gradient_type, 

1353 executing_eagerly) 

1354 forward_function, args_with_tangents = forward_backward.forward() 

1355 if executing_eagerly: 

1356 flat_outputs = forward_function(*args_with_tangents) 

1357 else: 

1358 with default_graph._override_gradient_function( # pylint: disable=protected-access 

1359 {"PartitionedCall": self._get_gradient_function(), 

1360 "StatefulPartitionedCall": self._get_gradient_function()}): 

1361 flat_outputs = forward_function(*args_with_tangents) 

1362 forward_backward.record(flat_outputs) 

1363 return self._build_call_outputs(flat_outputs) 

1364 

1365 @property 

1366 def name(self): 

1367 """`ConcreteFunction` name.""" 

1368 return self._delayed_rewrite_functions.forward().name 

1369 

1370 @property 

1371 def graph(self): 

1372 """Returns the graph from which this function was constructed.""" 

1373 return self._func_graph 

1374 

1375 @property 

1376 def inputs(self): 

1377 """Returns tensors in `self.graph` corresponding to arguments.""" 

1378 return self._func_graph.inputs 

1379 

1380 @property 

1381 def structured_input_signature(self): 

1382 """Returns structured signature for this concrete function. 

1383 

1384 Returns: 

1385 A tuple `(args, kwargs)`, where: 

1386 

1387 * `args` is a tuple that specifies the expected type or value each for 

1388 positional argument. 

1389 * `kwargs` is a dictionary that specifies the expected type or value 

1390 for each keyword-only argument. 

1391 

1392 The type or value for each argument is specified using one of the 

1393 following: 

1394 

1395 * A `tf.TypeSpec`, indicating that a Tensor or other TensorFlow-native 

1396 value is expected. 

1397 * A Python value, such as an integer, indicating that an equal value 

1398 is expected. 

1399 * A nested structure of `tf.TypeSpec`s and Python values, indicating 

1400 that a corresponding nested structure is expected. 

1401 """ 

1402 return self._func_graph.structured_input_signature 

1403 

1404 @property 

1405 def outputs(self): 

1406 """Returns tensors in `self.graph` corresponding to returned tensors.""" 

1407 return self._func_graph.outputs 

1408 

1409 @property 

1410 def structured_outputs(self): 

1411 """Returns outputs in `self.graph` as returned by the original function.""" 

1412 return self._func_graph.structured_outputs 

1413 

1414 def set_external_captures(self, captures): 

1415 """Updates the function capture values. 

1416 

1417 The new values must have tensor types and shapes consistent with the 

1418 original captures of the concrete function, but it is allowed to change a 

1419 value captured with a deferred one and vice-versa. 

1420 

1421 Args: 

1422 captures: A list of tensors or closures. Tensors are value captures, and 

1423 closures are call-time (deferred captures). 

1424 """ 

1425 # TODO(wxinyi): 1. verify that the new captures' type spec is compatible 

1426 # with the original's. However, doing so requires MirroredVariable captures 

1427 # initialized. 2. replace the original/new captures/deferred 

1428 # captures in the wrapped graph. Doing such for a capture-to-deferred 

1429 # capture replacement requires more arguments than the deferred capture 

1430 # itself, e.g. default value, spec. 

1431 self._captured_inputs = captures 

1432 

1433 def replace_capture_with_deferred_capture(self, 

1434 tensor, 

1435 closure, 

1436 spec, 

1437 placeholder=None, 

1438 default_value=None): 

1439 """Replaces existing capture `tensor` with a deferred capture `closure`. 

1440 

1441 This API replaces the capture `tensor` from the concrete function's captured 

1442 inputs list, and places the deferred capture `closure` in 

1443 its spot so the order of captured inputs is preserved. This is important 

1444 because the old `tensor` and the new `closure` will have the same internal 

1445 placeholder, which can be passed through the `placeholder` argument, or 

1446 skipped, in which case we find the placeholder from internal inputs by 

1447 indexing `tensor` in the external captured inputs list. Thus, it is 

1448 important that the new deferred capture has output spec (specified by the 

1449 `spec` argument) compatible with the internal placeholder (`placeholder`) 

1450 and the original capture (`tensor`). 

1451 

1452 For example, 

1453 

1454 ```python 

1455 bool_captured_tensor = tf.constant(True) 

1456 float_captured_tensor = tf.constant([3.], dtype=tf.float32) 

1457 value = tf.constant([2.], dtype=tf.float32) 

1458 

1459 @tf.function 

1460 def fn(): 

1461 deferred_tensor = ops.get_default_graph().capture_call_time_value( 

1462 lambda: value, 

1463 tf.TensorSpec(shape=(1,), dtype=tf.float32)) 

1464 if bool_captured_tensor: 

1465 return deferred_tensor 

1466 else: 

1467 return deferred_tensor + float_captured_tensor 

1468 

1469 concrete_fn = fn.get_concrete_function() 

1470 print(concrete_fn()) # tf.Tensor([2.], shape=(1,), dtype=float32) 

1471 

1472 new_bool_captured_tensor = constant_op.constant(False) 

1473 def bool_closure(): 

1474 return new_bool_captured_tensor 

1475 

1476 concrete_fn.replace_capture_with_deferred_capture( 

1477 bool_captured_tensor, 

1478 bool_closure, 

1479 spec=tensor_spec.TensorSpec(shape=(), dtype=dtypes.bool)) 

1480 

1481 print(concrete_fn()) # tf.Tensor([5.], shape=(1,), dtype=float32) 

1482 ``` 

1483 

1484 Args: 

1485 tensor: Tensor already captured. This `tensor` should be listed in 

1486 concrete_function.captured_inputs except when it's empty such as when 

1487 the concrete function is restored from SavedModel. 

1488 closure: function which takes no arguments, to be evaluated at function 

1489 call time, returning a nest of tensors compatible with `spec`. 

1490 spec: nest of TypeSpec for the value to capture. 

1491 placeholder: optional. The internal placeholder corresponding to the 

1492 captured `tensor` and the new `closure`. 

1493 default_value: optional value to use in environments that cannot safely 

1494 evaluate closure. 

1495 """ 

1496 capture_index = None 

1497 for i, capture in enumerate(self._captured_inputs): 

1498 if id(tensor) == id(capture): 

1499 capture_index = i 

1500 break 

1501 

1502 if placeholder is None: 

1503 if capture_index is None: 

1504 raise ValueError( 

1505 f"Did not find `tensor` argument {tensor} in the ConcreteFunction's" 

1506 " captured inputs list, and did not receive a placeholder argument." 

1507 " Thus we're unable to infer the internal placeholder. ") 

1508 

1509 placeholder = self.inputs[-len(self._captured_inputs) + capture_index] 

1510 

1511 if not (spec.is_compatible_with(tensor) or 

1512 spec.is_compatible_with(placeholder)): 

1513 raise ValueError( 

1514 f"Attempting to substitute closure with spec {spec} that's " 

1515 f"incompatible with the original capture {tensor} or the internal " 

1516 f"placeholder {placeholder}.") 

1517 

1518 self._func_graph.replace_capture_with_deferred_capture( 

1519 tensor=tensor, 

1520 closure=closure, 

1521 spec=spec, 

1522 placeholder=placeholder, 

1523 default_value=default_value) 

1524 

1525 if capture_index is not None: 

1526 self._captured_inputs[capture_index] = closure 

1527 

1528 @property 

1529 def captured_inputs(self): 

1530 """Returns external Tensors captured by this function. 

1531 

1532 self.__call__(*args) passes `args + self.captured_inputs` to the function. 

1533 """ 

1534 return nest.flatten( 

1535 [x() if callable(x) else x for x in self._captured_inputs], 

1536 expand_composites=True) 

1537 

1538 @property 

1539 def function_def(self): 

1540 """Returns a `FunctionDef` object representing this function.""" 

1541 return self._delayed_rewrite_functions.forward().cached_definition 

1542 

1543 @property 

1544 def output_shapes(self): 

1545 """The function's output shapes.""" 

1546 return nest.map_structure( 

1547 lambda x: getattr(x, "shape", tensor_shape.TensorShape(None)), 

1548 composite_tensor.replace_composites_with_components( 

1549 self._func_graph.structured_outputs), 

1550 expand_composites=False) 

1551 

1552 @property 

1553 def output_dtypes(self): 

1554 # TODO(akshayka): Consider removing this. 

1555 return nest.map_structure( 

1556 lambda x: x.dtype if x is not None else None, 

1557 composite_tensor.replace_composites_with_components( 

1558 self._func_graph.structured_outputs), 

1559 expand_composites=False) 

1560 

1561 def add_to_graph(self, g=None, overwrite=False): 

1562 """Registers the function, adds it to the graph g or default graph. 

1563 

1564 Args: 

1565 g: If specified, registers the function with this graph. Defaults to the 

1566 current context (either the default graph or the eager context). 

1567 overwrite: A bool. If True, its forward function will overwrite 

1568 any existing function of the same signature name in the graph `g`. 

1569 """ 

1570 # If we are not executing eagerly, adds the function to default graph if no 

1571 # graph is specified. 

1572 # In case of eager execution, function definition gets added to context 

1573 # during construction itself. 

1574 

1575 if not context.executing_eagerly() and not g: 

1576 g = ops.get_default_graph() 

1577 

1578 if g is not None: 

1579 g._add_function_recursive(self._delayed_rewrite_functions.forward()) # pylint: disable=protected-access 

1580 

1581 def add_gradient_functions_to_graph(self, g=None): 

1582 """Add forward/backward functions to graph `g` or the current context.""" 

1583 if not context.executing_eagerly() and not g: 

1584 g = ops.get_default_graph() 

1585 g._add_function_recursive(self._delayed_rewrite_functions.forward()) # pylint: disable=protected-access 

1586 forward_function, backward_function = ( 

1587 self._delayed_rewrite_functions.forward_backward()) 

1588 g._add_function_recursive(forward_function) # pylint: disable=protected-access 

1589 backward_function.add_to_graph(g) 

1590 

1591 def _get_gradient_function(self): 

1592 """Returns gradient function. It will be lazily created at first call.""" 

1593 return self._delayed_rewrite_functions._rewrite_forward_and_call_backward # pylint: disable=protected-access 

1594 

1595 def _select_forward_and_backward_functions( 

1596 self, args, possible_gradient_type, executing_eagerly): 

1597 """Selects forward and backward functions based on the calling context. 

1598 

1599 The forward function computes the "real" function outputs, `self._outputs`, 

1600 and any extra values needed by the corresponding backward function. 

1601 

1602 Args: 

1603 args: A flat list of Tensors with all of the inputs to the forward 

1604 function (including user-specified and captured inputs). 

1605 possible_gradient_type: One of gradients_util.POSSIBLE_GRADIENT_TYPES_*. 

1606 executing_eagerly: Boolean, the value of context.executing_eagerly(). 

1607 

1608 Returns: 

1609 An object with a `forward` method returning a tuple of (forward_function : 

1610 AtomicFunction, augmented_arguments : List), and a corresponding 

1611 `record` method which takes outputs from the forward function and records 

1612 the operation. forward_function should be called with augmented_arguments. 

1613 """ 

1614 if executing_eagerly: 

1615 input_tangents = forwardprop_util.pack_tangents(args) 

1616 else: 

1617 input_tangents = forwardprop_util.TangentInfo() 

1618 need_gradients_for_jvps = record.should_record_backprop( 

1619 input_tangents.tangents) 

1620 # Allows re-use of forward and backward function pairs depending on the 

1621 # tapes and forward accumulators watching its inputs. 

1622 cache_key = (need_gradients_for_jvps, input_tangents.indices) 

1623 if (possible_gradient_type 

1624 == gradients_util.POSSIBLE_GRADIENT_TYPES_FIRST_ORDER): 

1625 if input_tangents.indices or executing_eagerly: 

1626 # There is a single non-persistent tape active, so the user can only 

1627 # request first-order gradients from a tape. We can spend less time 

1628 # graph building since we know this. 

1629 # 

1630 # We may still end up computing higher-order gradients, but that'd be 

1631 # through `tf.gradients`, which can re-write the forward pass and so 

1632 # needs no preparation here. 

1633 functions = self._first_order_tape_functions.get(cache_key, None) 

1634 if functions is None: 

1635 functions = _FirstOrderTapeGradientFunctions( 

1636 self._func_graph, self._attrs, self._garbage_collector, 

1637 forwardprop_input_indices=input_tangents.indices, 

1638 delayed_rewrite_functions=self._delayed_rewrite_functions, 

1639 need_gradients_for_jvps=need_gradients_for_jvps) 

1640 self._first_order_tape_functions[cache_key] = functions 

1641 return _ForwardBackwardCall( 

1642 functions, args, input_tangents.tangents, tape_watching=True) 

1643 else: 

1644 # We can avoid computing second-order gradients in some cases by doing a 

1645 # delayed rewrite when graph building. Since we know we'll only compute 

1646 # first-order tape gradients, the delayed rewrite is safe: we won't need 

1647 # to tell the tape about side outputs. 

1648 # 

1649 # TODO(allenl): This case is really dirty. It would be better if we 

1650 # could temporarily pop all of the current tapes to avoid 

1651 # accidentally taking second-order gradients. 

1652 return _ForwardBackwardCall( 

1653 self._delayed_rewrite_functions, args, input_tangents.tangents, 

1654 tape_watching=True) 

1655 elif (possible_gradient_type 

1656 == gradients_util.POSSIBLE_GRADIENT_TYPES_HIGHER_ORDER): 

1657 # Either there's a persistent tape watching, or there are multiple nested 

1658 # tapes. Either way, the user may request higher-order gradients. We'll 

1659 # spend a bit more time and make sure higher-order gradients are correct. 

1660 functions = self._higher_order_tape_functions.get( 

1661 cache_key, None) 

1662 if functions is None: 

1663 functions = _HigherOrderTapeGradientFunctions( 

1664 self._func_graph, self._attrs, self._garbage_collector, 

1665 forwardprop_input_indices=input_tangents.indices, 

1666 delayed_rewrite_functions=self._delayed_rewrite_functions, 

1667 need_gradients_for_jvps=need_gradients_for_jvps) 

1668 self._higher_order_tape_functions[cache_key] = functions 

1669 return _ForwardBackwardCall(functions, args, input_tangents.tangents, 

1670 tape_watching=True) 

1671 # else possible_gradient_type == POSSIBLE_GRADIENT_TYPES_NONE, meaning no 

1672 # tape is recording. 

1673 return _ForwardBackwardCall( 

1674 self._delayed_rewrite_functions, args, input_tangents.tangents, 

1675 tape_watching=False) 

1676 

1677 def _build_call_outputs(self, result): 

1678 """Maps the fdef output list to actual output structure. 

1679 

1680 Args: 

1681 result: Output lists defined by FunctionDef. 

1682 Returns: 

1683 The actual call output. 

1684 """ 

1685 # TODO(jlchu): call C++ version in function.cc when speed is improved 

1686 if self._func_graph.structured_outputs is None: 

1687 return result 

1688 

1689 # Replace outputs with results, skipping over any 'None' values. 

1690 outputs_list = nest.flatten( 

1691 self._func_graph.structured_outputs, expand_composites=True) 

1692 j = 0 

1693 for i, o in enumerate(outputs_list): 

1694 if o is not None: 

1695 handle_data_util.copy_handle_data(self.outputs[j], result[j]) 

1696 outputs_list[i] = result[j] 

1697 j += 1 

1698 ret = nest.pack_sequence_as(self._func_graph.structured_outputs, 

1699 outputs_list, expand_composites=True) 

1700 return ret 

1701 

1702 @property 

1703 def _as_name_attr_list(self): 

1704 """Returns a `NameAttrList` representing this function.""" 

1705 ret = attr_value_pb2.NameAttrList(name=self.name) 

1706 for name, value in self._attrs.items(): 

1707 ret.attr[name].CopyFrom(value) 

1708 return ret 

1709 

1710 def _structured_signature_summary(self, default_values=False): 

1711 """Returns a string summarizing this function's structured signature. 

1712 

1713 Args: 

1714 default_values: If true, then include default values in the signature. 

1715 

1716 Returns: 

1717 A `string`. 

1718 """ 

1719 # Note: we can't just use self._funcion_spec.signature_summary(), because 

1720 # that would show "BOUND_VALUE" as the default value for all arguments. 

1721 assert self._function_spec is not None 

1722 arg_specs, kwarg_specs = self.structured_input_signature 

1723 arg_names = list(self._function_spec.arg_names) 

1724 

1725 # If an explicit input_signature is provided to @tf.function, then any 

1726 # arguments with defaults that are not covered by that explicit signature 

1727 # are simply dropped from the signature. 

1728 # TODO(b/159639913) Look into whether dropping arguments with default values 

1729 # from the signature is the right thing to do. 

1730 arg_names = arg_names[:len(arg_specs)] 

1731 

1732 if default_values: 

1733 for i in range(len(arg_names)): 

1734 if not _contains_type_spec(arg_specs[i]): 

1735 arg_names[i] += "={}".format(arg_specs[i]) 

1736 if kwarg_specs: 

1737 arg_names.append("*") 

1738 for name, spec in kwarg_specs.items(): 

1739 arg_names.append(name) 

1740 if default_values and not _contains_type_spec(spec): 

1741 arg_names[-1] += "={}".format(spec) 

1742 signature = f"{self._func_graph.name}({', '.join(arg_names)})" 

1743 

1744 return signature 

1745 

1746 def _flat_signature_summary(self): 

1747 """Returns a string summarizing this function's flat signature.""" 

1748 assert self._arg_keywords is not None 

1749 assert self._num_positional_args is not None 

1750 arg_names = self._arg_keywords 

1751 if self._num_positional_args > len(arg_names): 

1752 arg_names.extend( 

1753 "<arg{}>".format(i + 1) 

1754 for i in range(len(arg_names), self._num_positional_args)) 

1755 return f"{self._func_graph.name}({', '.join(arg_names)})" 

1756 

1757 def pretty_printed_signature(self, verbose=True): 

1758 """Returns a string summarizing the signature of this concrete function.""" 

1759 if not verbose: 

1760 return self._structured_signature_summary(default_values=True) 

1761 

1762 def pretty_print_spec(spec): 

1763 """Returns a string describing the spec for a single argument.""" 

1764 if isinstance(spec, tensor_spec.TensorSpec): 

1765 return "{} Tensor, shape={}".format(spec.dtype.name, spec.shape) 

1766 elif nest.is_nested(spec): 

1767 pieces = nest.flatten(spec, expand_composites=False) 

1768 markers = [_Marker("<{}>".format(i + 1)) for i in range(len(pieces))] 

1769 structure = nest.pack_sequence_as(spec, markers) 

1770 # Ensure dictionaries are sorted by key (for determinism) 

1771 result = pprint.pformat(structure, width=10000) 

1772 for (marker, piece) in zip(markers, pieces): 

1773 result += "\n {}: {}".format(marker, pretty_print_spec(piece)) 

1774 return result 

1775 else: 

1776 return repr(spec) 

1777 

1778 lines = [self._structured_signature_summary(default_values=True)] 

1779 arg_specs, kwarg_specs = self.structured_input_signature 

1780 names = list(self._function_spec.arg_names) 

1781 

1782 # If an explicit input_signature is provided to @tf.function, then any 

1783 # arguments with defaults that are not covered by that explicit signature 

1784 # are simply dropped from the signature. 

1785 # TODO(b/159639913) Look into whether dropping arguments with default values 

1786 # from the signature is the right thing to do. 

1787 

1788 # Note: we can skip bound args, since we already displayed their bound 

1789 # value in the signature summary. 

1790 arg_details = [] 

1791 for (name, spec) in zip(names[:len(arg_specs)], list(arg_specs)): 

1792 if _contains_type_spec(spec): 

1793 arg_details.append(" {}: {}".format(name, pretty_print_spec(spec))) 

1794 

1795 if kwarg_specs: 

1796 for kwarg in sorted(kwarg_specs): 

1797 spec = kwarg_specs[kwarg] 

1798 if _contains_type_spec(spec): 

1799 arg_details.append(" {}: {}".format( 

1800 kwarg, pretty_print_spec(spec))) 

1801 

1802 if arg_details: 

1803 lines.append(" Args:") 

1804 lines.extend(arg_details) 

1805 lines.append(" Returns:") 

1806 

1807 def spec_from_value(value): 

1808 # For loaded function, structured_outputs are already specs. 

1809 if isinstance(value, type_spec.TypeSpec): 

1810 return value 

1811 return type_spec.type_spec_from_value(value) 

1812 

1813 lines.append(" {}".format( 

1814 pretty_print_spec( 

1815 nest.map_structure(spec_from_value, self.structured_outputs)))) 

1816 

1817 return "\n".join(lines) 

1818 

1819 def __repr__(self): 

1820 if self._function_spec is not None: 

1821 return "<ConcreteFunction {} at 0x{:X}>".format( 

1822 self.pretty_printed_signature(verbose=False), id(self)) 

1823 elif not (self._num_positional_args is None or self._arg_keywords is None): 

1824 return "<ConcreteFunction {} at 0x{:X}>".format( 

1825 self._flat_signature_summary(), id(self)) 

1826 else: 

1827 return object.__repr__(self) 

1828 

1829 def __str__(self): 

1830 if self._function_spec is not None: 

1831 return "ConcreteFunction {}".format(self.pretty_printed_signature()) 

1832 else: 

1833 return self.__repr__() 

1834 

1835 def _trackable_children(self, save_type="checkpoint", **kwargs): 

1836 """Implements `Trackable`.""" 

1837 if save_type == "checkpoint": 

1838 # Checkpoint dependencies do not include functions at all. Users 

1839 # expect the checkpointed variables to be saved using the model 

1840 # architecture, e.g. `model.layers[1].kernel` or `model.variables`. 

1841 return {} 

1842 

1843 captured_trackables = {} 

1844 for n, (capture, _) in enumerate(self.graph.captures): 

1845 if (capture.dtype not in (dtypes.variant, dtypes.resource) and 

1846 not resource_variable_ops.is_resource_variable(capture)): 

1847 # Variant/resource type tensors are skipped since we have no way of 

1848 # getting the `Trackable` wrapper for these tensors. The wrappers are 

1849 # expected to be elsewhere in the saved object graph. 

1850 # TODO(b/223866972): Directly encode/decode tensor captures. 

1851 

1852 # Resource variable captures are also skipped at this time, to maintain 

1853 # existing behavior. 

1854 # TODO(b/217979389): Return the non-constant captures as children. 

1855 

1856 captured_trackables[f"capture_{n}"] = capture 

1857 

1858 return captured_trackables 

1859 

1860 def _deserialization_dependencies(self, children): 

1861 return children 

1862 

1863 def _export_to_saved_model_graph(self, object_map, tensor_map, 

1864 **unused_kwargs): 

1865 if not self.graph.saveable: 

1866 raise ValueError( 

1867 (f"Unable to save function {self.name} for the following reason(s):\n" 

1868 + "\n".join(self.graph.saving_errors))) 

1869 self.add_to_graph() 

1870 object_map[self] = saved_model_exported_concrete.ExportedConcreteFunction( 

1871 self, tensor_map) 

1872 return [] 

1873 

1874 

1875_pywrap_utils.RegisterType("Tensor", ops.Tensor) 

1876_pywrap_utils.RegisterType("EagerTensor", ops.EagerTensor) 

1877_pywrap_utils.RegisterType("IndexedSlices", indexed_slices.IndexedSlices) 

1878 

1879 

1880class ConcreteFunctionGarbageCollector: 

1881 """Cleans up reference cycles when a `ConcreteFunction` goes out of scope.""" 

1882 

1883 __slots__ = ["_func_graph"] 

1884 

1885 def __init__(self, func_graph): 

1886 self._func_graph = func_graph 

1887 

1888 def release(self): 

1889 """Call off the FuncGraph deletion.""" 

1890 self._func_graph = None 

1891 

1892 def __del__(self): 

1893 if func_graph_module is None or self._func_graph is None: 

1894 return 

1895 try: 

1896 func_graph_module.dismantle_func_graph(self._func_graph) 

1897 except: # pylint: disable=bare-except 

1898 pass 

1899 

1900 

1901class _Marker(object): 

1902 """Markers used to pretty-print nested args in function signatures.""" 

1903 

1904 __slots__ = ["_s"] 

1905 

1906 def __init__(self, s): 

1907 self._s = s 

1908 

1909 def __repr__(self): 

1910 return str(self._s) 

1911 

1912 

1913def _contains_type_spec(value): 

1914 return any(isinstance(x, type_spec.TypeSpec) for x in nest.flatten(value))