Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/polymorphic_function/polymorphic_function.py: 32%

508 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""API for defining graph functions with some additional eager semantics. 

17 

18tf.function utilizes varying configurations of TracingCompiler to allow 

19initializing `tf.Variable`s with subgraphs of the function. For example: 

20 

21```python 

22class M(tf.Module): 

23 def __init__(self): 

24 self.v_opinit = None 

25 self.v_arginit = None 

26 

27 @tf.function 

28 def __call__(self, x): 

29 # Variables are only created on the first call to the function. This is a 

30 # common pattern in layer libraries. 

31 if self.v_opinit is None: 

32 # self.v_opinit will outlive the function call, but `tf.ones` is traced as 

33 # part of the function body before the `tf.Variable` object is 

34 # created. This subgraph is easy to lift out of the function. 

35 self.v_opinit = tf.Variable(tf.ones([])) 

36 

37 # If arguments feed into variable initialization, it can be very tricky to 

38 # disentangle from the rest of the function. We don't attempt it. 

39 self.v_arginit = tf.Variable(tf.ones(tf.shape(x)) * tf.constant(2.)) 

40 return self.v_opinit + self.v_arginit + x 

41``` 

42 

43These patterns with using "TracingCompiler" directly throw an error asking 

44the user to put the variable's initializer in a lambda. With tf.function they 

45work with eager semantics either by lifting the subgraph out of the function and 

46using it to initialize the variable, or by initializing variables on the first 

47call to the function (if they weren't already initialized by something else, 

48e.g. a checkpoint API). The latter requires tf.conds, and is not well supported 

49by TF-XLA, so we only do it when necessary. 

50 

51Since these patterns are relatively common in layer libraries, we expose the 

52wrapper in this file as `tf.function`. The defun concept in quarantine.py is a 

53legacy internal API. 

54 

55In order to support these variable initialization patterns, tf.function defines 

56a variable subtype (UnliftedInitializerVariable) which collects the input 

57subgraph. This type of variable replaces the regular variable type on the first 

58tf.function trace. To exclude initializers from the function body (the `tf.ones` 

59ops above and associated assignment operations), tf.function traces a second 

60time if it sees variables on the first call. 

61""" 

62 

63import functools 

64import os 

65import threading 

66import types as types_lib 

67import weakref 

68 

69from google.protobuf import text_format as _text_format 

70from google.protobuf.message import DecodeError 

71from tensorflow.core.framework import attr_value_pb2 

72from tensorflow.core.function import trace_type 

73from tensorflow.python.distribute.parallel_device import parallel_device 

74from tensorflow.python.eager import context 

75from tensorflow.python.eager import lift_to_graph 

76from tensorflow.python.eager import monitoring 

77from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib 

78from tensorflow.python.eager.polymorphic_function import autograph_util 

79from tensorflow.python.eager.polymorphic_function import compiler_ir 

80from tensorflow.python.eager.polymorphic_function import eager_function_run 

81from tensorflow.python.eager.polymorphic_function import function_spec as function_spec_lib 

82from tensorflow.python.eager.polymorphic_function import tracing_compiler 

83from tensorflow.python.framework import composite_tensor 

84from tensorflow.python.framework import errors 

85from tensorflow.python.framework import func_graph as func_graph_module 

86from tensorflow.python.framework import ops 

87from tensorflow.python.framework import tensor_spec 

88from tensorflow.python.ops import array_ops_stack 

89from tensorflow.python.ops import cond 

90from tensorflow.python.ops import control_flow_ops 

91from tensorflow.python.ops import control_flow_util 

92from tensorflow.python.ops import math_ops 

93from tensorflow.python.ops import resource_variable_ops 

94from tensorflow.python.platform import tf_logging as logging 

95from tensorflow.python.profiler import trace 

96from tensorflow.python.trackable import base as trackable 

97from tensorflow.python.types import core 

98from tensorflow.python.util import deprecation 

99from tensorflow.python.util import nest 

100from tensorflow.python.util import object_identity 

101from tensorflow.python.util import tf_decorator 

102from tensorflow.python.util import traceback_utils 

103from tensorflow.python.util.tf_export import tf_export 

104 

105FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY = 10 

106FREQUENT_TRACING_WARNING_THRESHOLD = 5 

107FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR = 2 

108ALLOW_DYNAMIC_VARIABLE_CREATION = False 

109 

110 

111def set_dynamic_variable_creation(is_allowed): 

112 global ALLOW_DYNAMIC_VARIABLE_CREATION 

113 ALLOW_DYNAMIC_VARIABLE_CREATION = is_allowed 

114 

115 

116_tf_function_counter = monitoring.Counter( 

117 "/tensorflow/core/tf_function_counter", 

118 "Counter for the number of tf.functions created when Eager execution is " 

119 "enabled.", 

120 # jit_compile is "0" or "1". 

121 "jit_compile") 

122 

123 

124class _FrequentTracingDetector(object): 

125 """Class keeping track of how many recent calls triggered tracing.""" 

126 

127 __slots__ = ["_calls_per_tracings", "_call_count", "_total_warning_count"] 

128 

129 def __init__(self): 

130 self._calls_per_tracings = [] 

131 self._total_warning_count = 0 

132 self._call_count = 0 

133 

134 def called_with_tracing(self, function_name, omit_warning): 

135 """Updates the list of most recent calls' tracing information. 

136 

137 Warns the user when recent calls caused retracing too often. 

138 

139 Args: 

140 function_name: the python function being traced. 

141 omit_warning: If 'True', this call will not warn the user even if 

142 retracing happens too often. 

143 """ 

144 self._call_count += 1 

145 self._calls_per_tracings.append(1) 

146 

147 while self._calls_per_tracings: 

148 if (self._call_count - self._calls_per_tracings[0] > 

149 FREQUENT_TRACING_WARNING_MAX_CALL_HISTORY): 

150 self._call_count -= self._calls_per_tracings.pop(0) 

151 else: 

152 break 

153 

154 if (omit_warning or self._total_warning_count >= 

155 FREQUENT_TRACING_WARNING_MAX_WARNING_PER_DETECTOR): 

156 return 

157 if len(self._calls_per_tracings) >= FREQUENT_TRACING_WARNING_THRESHOLD: 

158 self._total_warning_count += 1 

159 logging.warning( 

160 "{} out of the last {} calls to {} triggered tf.function " 

161 "retracing. Tracing is expensive and the excessive number of " 

162 "tracings could be due to (1) creating @tf.function repeatedly in " 

163 "a loop, (2) passing tensors with different shapes, (3) passing " 

164 "Python objects instead of tensors. For (1), please define your " 

165 "@tf.function outside of the loop. For (2), @tf.function has " 

166 "reduce_retracing=True option that can avoid unnecessary " 

167 "retracing. For (3), please refer to " 

168 "https://www.tensorflow.org/guide/function#controlling_retracing" 

169 " and https://www.tensorflow.org/api_docs/python/tf/function for " 

170 " more details.".format( 

171 len(self._calls_per_tracings), self._call_count, function_name)) 

172 

173 def called_without_tracing(self): 

174 # We don't count tracing when users load a concrete function directly or 

175 # call get_concrete_function, so the first call can be not a tracing call. 

176 if not self._calls_per_tracings: 

177 self._calls_per_tracings = [0] 

178 self._calls_per_tracings[-1] += 1 

179 self._call_count += 1 

180 

181 

182class _FrequentTracingDetectorManager(object): 

183 """Class for the management of all _FrequentTracingDetector objects.""" 

184 

185 __slots__ = ["_detectors", "_lock"] 

186 

187 def __init__(self): 

188 self._detectors = weakref.WeakKeyDictionary() # GUARDED_BY(self._lock) 

189 self._lock = threading.Lock() 

190 

191 def _get_detector(self, key): 

192 if key not in self._detectors: 

193 self._detectors[key] = _FrequentTracingDetector() 

194 return self._detectors[key] 

195 

196 def called_without_tracing(self, key): 

197 with self._lock: 

198 detector = self._get_detector(key) 

199 detector.called_without_tracing() 

200 

201 def called_with_tracing(self, key, function_name, omit_warning): 

202 with self._lock: 

203 detector = self._get_detector(key) 

204 detector.called_with_tracing(function_name, omit_warning) 

205 

206 

207_frequent_tracing_detector_manager = _FrequentTracingDetectorManager() 

208 

209 

210class UnliftedInitializerVariable(resource_variable_ops.UninitializedVariable): 

211 """Variable which does not lift its initializer out of function context. 

212 

213 Instances of this variable, when created, build a graph which runs their 

214 initializer inside a tf.cond(is_initialized) block. 

215 

216 This can only be created inside a TracingCompiler called from 

217 (eventually) eager mode. That is, non-function-building graphs are not 

218 supported. 

219 """ 

220 

221 def __init__( 

222 self, 

223 initial_value=None, 

224 trainable=None, 

225 caching_device=None, 

226 name=None, 

227 dtype=None, 

228 constraint=None, 

229 add_initializers_to=None, 

230 synchronization=None, 

231 aggregation=None, 

232 shape=None, 

233 **unused_kwargs, 

234 ): 

235 """Creates a variable. 

236 

237 Args: 

238 initial_value: A `Tensor`, or Python object convertible to a `Tensor`, 

239 which is the initial value for the Variable. The initial value must have 

240 a shape specified unless `validate_shape` is set to False. Can also be a 

241 callable with no argument that returns the initial value when called. 

242 (Note that initializer functions from init_ops.py must first be bound to 

243 a shape before being used here.) 

244 trainable: If `True`, GradientTapes automatically watch uses of this 

245 Variable. 

246 caching_device: Optional device string or function describing where the 

247 Variable should be cached for reading. Defaults to the Variable's 

248 device. If not `None`, caches on another device. Typical use is to 

249 cache on the device where the Ops using the Variable reside, to 

250 deduplicate copying through `Switch` and other conditional statements. 

251 name: Optional name for the variable. Defaults to `'Variable'` and gets 

252 uniquified automatically. 

253 dtype: If set, initial_value will be converted to the given type. If None, 

254 either the datatype will be kept (if initial_value is a Tensor) or 

255 float32 will be used (if it is a Python object convertible to a Tensor). 

256 constraint: An optional projection function to be applied to the variable 

257 after being updated by an `Optimizer` (e.g. used to implement norm 

258 constraints or value constraints for layer weights). The function must 

259 take as input the unprojected Tensor representing the value of the 

260 variable and return the Tensor for the projected value (which must have 

261 the same shape). Constraints are not safe to use when doing asynchronous 

262 distributed training. 

263 add_initializers_to: if not None and not in legacy graph mode, the 

264 initializer tensor will be added to this map in addition to adding the 

265 assignment to the function. 

266 synchronization: Indicates when a distributed variable will be aggregated. 

267 Accepted values are constants defined in the class 

268 `tf.VariableSynchronization`. By default the synchronization is set to 

269 `AUTO` and the current `DistributionStrategy` chooses when to 

270 synchronize. 

271 aggregation: Indicates how a distributed variable will be aggregated. 

272 Accepted values are constants defined in the class 

273 `tf.VariableAggregation`. 

274 shape: (optional) The shape of this variable. If None, the shape of 

275 `initial_value` will be used. When setting this argument to 

276 `tf.TensorShape(None)` (representing an unspecified shape), the variable 

277 can be assigned with values of different shapes. 

278 

279 Raises: 

280 ValueError: If the initial value is not specified, or does not have a 

281 shape and `validate_shape` is `True`. 

282 RuntimeError: If called outside of a function definition. 

283 """ 

284 with ops.init_scope(): 

285 self._in_graph_mode = not context.executing_eagerly() 

286 if not ops.inside_function(): 

287 # If we've been init_scope()d out of the function definition nothing to do 

288 # here; we can't really do the capturing or conditional logic. 

289 resource_variable_ops.ResourceVariable.__init__( 

290 self, initial_value=initial_value, trainable=trainable, 

291 caching_device=caching_device, name=name, dtype=dtype, 

292 constraint=constraint) 

293 return 

294 if initial_value is None: 

295 raise ValueError("`initial_value` must be a Tensor or a Python " 

296 "object convertible to a Tensor. Got None.") 

297 init_from_fn = callable(initial_value) 

298 

299 if constraint is not None and not callable(constraint): 

300 raise ValueError(f"`constraint` with type {type(constraint)} must be a " 

301 "callable.") 

302 

303 with ops.name_scope(name, "Variable", [] 

304 if init_from_fn else [initial_value]) as scope_name: 

305 with ops.name_scope("Initializer"): 

306 if init_from_fn: 

307 initial_value = initial_value() 

308 if isinstance(initial_value, trackable.CheckpointInitialValue): 

309 self._maybe_initialize_trackable() 

310 self._update_uid = initial_value.checkpoint_position.restore_uid 

311 initial_value = initial_value.wrapped_value 

312 

313 initial_value = ops.convert_to_tensor(initial_value, 

314 name="initial_value", dtype=dtype) 

315 assert initial_value is not None 

316 

317 # Don't use `shape or initial_value.shape` since TensorShape has 

318 # overridden `__bool__`. 

319 if shape is None: 

320 shape = initial_value.shape 

321 

322 # Use the constructor for UninitializedVariable to start. Outside the name 

323 # scope so we don't double up the prefix. 

324 super().__init__( 

325 trainable=trainable, 

326 caching_device=caching_device, 

327 name=name, 

328 shape=shape, 

329 dtype=initial_value.dtype, 

330 constraint=constraint, 

331 synchronization=synchronization, 

332 aggregation=aggregation, 

333 extra_handle_data=initial_value, 

334 **unused_kwargs) 

335 

336 with ops.name_scope(scope_name): 

337 if self._in_graph_mode: 

338 with ops.init_scope(): 

339 outer_graph = ops.get_default_graph() 

340 func_graph = ops.get_default_graph() 

341 function_placeholders = ( 

342 func_graph.inputs + func_graph.internal_captures) 

343 placeholder_ops = set( 

344 [tensor.op for tensor in function_placeholders]) 

345 lifted_initializer = lift_to_graph.lift_to_graph( 

346 [initial_value], outer_graph, 

347 disallowed_placeholders=placeholder_ops)[initial_value] 

348 with ops.init_scope(): 

349 self._initial_value = lifted_initializer 

350 with ops.name_scope("IsInitialized"): 

351 self._is_initialized_op = ( 

352 resource_variable_ops.var_is_initialized_op(self._handle)) 

353 if initial_value is not None: 

354 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 

355 self._initializer_op = resource_variable_ops.assign_variable_op( 

356 self._handle, lifted_initializer, name=n) 

357 elif context.executing_eagerly(): 

358 # In this case, both current scope and init scope are eager. 

359 # Assign_variable_op will be executed immediately. So we don't need to 

360 # add it to "add_initializers_to" to lift it out. 

361 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 

362 resource_variable_ops.assign_variable_op( 

363 self._handle, initial_value, name=n) 

364 else: 

365 # Init scope is eager but current scope is graph. We will lift out this 

366 # variable by addint it into "add_initializers_to". 

367 if add_initializers_to is not None: 

368 add_initializers_to.append((self, initial_value)) 

369 

370 def assign_fn(): 

371 with ops.name_scope("Assign") as n, ops.colocate_with(self._handle): 

372 resource_variable_ops.assign_variable_op( 

373 self._handle, 

374 initial_value, 

375 name=n) 

376 # Returning values to keep tf.cond happy. 

377 return ops.convert_to_tensor(1) 

378 def not_assign_fn(): 

379 return ops.convert_to_tensor(0) 

380 # Note: this cond is always guaranteed to run because we're inside a 

381 # TracingCompiler which will insert automatic control dependencies. 

382 # It will only execute assign_fn if lifting failed. 

383 graph = ops.get_default_graph() 

384 

385 # Capture the handle ahead of time in order to avoid querying the shape 

386 # of the handle which helps async execution performance 

387 graph.capture(self._handle, shape=()) 

388 cond.cond( 

389 resource_variable_ops.var_is_initialized_op(self._handle), 

390 not_assign_fn, assign_fn) 

391 

392 

393JIT_COMPILE_FUNCTIONS = ( 

394 os.getenv("TF_FUNCTION_JIT_COMPILE_DEFAULT", "false").lower() 

395 in ("true", "1")) 

396 

397 

398def _evaluate_var_is_initialized(variables): 

399 """Compute booleans indicating whether each variable is initialized.""" 

400 with ops.init_scope(): 

401 var_is_initialized = [] 

402 for v in variables: 

403 var_is_initialized.append( 

404 resource_variable_ops.var_is_initialized_op(v.handle)) 

405 try: 

406 # Stack all the var_is_initialized values into one tensor and interpret 

407 # the numpy value. This will reduce the number of RPCs between client and 

408 # worker in the remote case. 

409 return array_ops_stack.stack(var_is_initialized).numpy() 

410 except errors.UnimplementedError: 

411 # Some devices do not support implicit copy-off to host. Fall back to 

412 # variable-by-variable processing. 

413 for index, v in enumerate(variables): 

414 try: 

415 numpy_value = var_is_initialized[index].numpy() 

416 except errors.UnimplementedError: 

417 # This is a variable on a parallel device; we'll extract its value on 

418 # each replica and assert that they're identical. 

419 components = parallel_device.unpack(var_is_initialized[index]) 

420 with ops.device(None): 

421 components = array_ops_stack.stack(components) 

422 all_initialized = math_ops.reduce_all(components).numpy() 

423 any_initialized = math_ops.reduce_any(components).numpy() 

424 if all_initialized != any_initialized: 

425 raise NotImplementedError( 

426 f"Some but not all components of a parallel variable {v!r} " 

427 "were initialized between their creation in a tf.function and " 

428 "the function's trace having completed. This is not " 

429 "supported; consider initializing either all or none of the " 

430 "components, or moving initialization out of the function.") 

431 numpy_value = all_initialized 

432 var_is_initialized[index] = numpy_value 

433 return var_is_initialized 

434 

435 

436class OptionalXlaContext: 

437 """Wrapper for XLA context optionally applied under a context manager.""" 

438 

439 def __init__(self, is_compiled): 

440 wrap = is_compiled and not control_flow_util.GraphOrParentsInXlaContext( \ 

441 ops.get_default_graph()) 

442 self.xla_context = control_flow_ops.XLAControlFlowContext() \ 

443 if wrap else None 

444 

445 def __enter__(self): 

446 if self.xla_context: 

447 self.xla_context.Enter() 

448 

449 def __exit__(self, t, value, traceback): 

450 if self.xla_context: 

451 self.xla_context.Exit() 

452 

453 

454# TODO(mdan): Consider expose this type for instance type checking. 

455@tf_export("__internal__.function.Function", v1=[]) 

456class Function(core.GenericFunction, trackable.Trackable): 

457 """A `tf.types.experimental.GenericFunction` created by `tf.function`. 

458 

459 Currently, individual methods/attributes under this class are not guaranteed 

460 by the TF API contract, and are subject to future changes. 

461 """ 

462 

463 def __init__(self, 

464 python_function, 

465 name, 

466 input_signature=None, 

467 autograph=True, 

468 jit_compile=None, 

469 reduce_retracing=False, 

470 experimental_implements=None, 

471 experimental_autograph_options=None, 

472 experimental_attributes=None,): 

473 """Initializes a `Function`. 

474 

475 Args: 

476 python_function: the function to be wrapped. 

477 name: the name given to it. 

478 input_signature: See the documentation for `tf.function`. 

479 autograph: See the documentation for `tf.function`. 

480 jit_compile: See the documentation for `tf.function`. 

481 reduce_retracing: See the documentation for `tf.function`. 

482 experimental_implements: See the documentation for `tf.function`. 

483 experimental_autograph_options: See the documentation for `tf.function`. 

484 experimental_attributes: See the documentation for `tf.function`. 

485 

486 Raises: 

487 ValueError: if `input_signature` is not None and the `python_function`'s 

488 argspec has keyword arguments. 

489 """ 

490 self._lock = threading.RLock() 

491 self._python_function = python_function 

492 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature( 

493 python_function, 

494 input_signature, 

495 jit_compile=jit_compile, 

496 ) 

497 

498 self._attributes = {} 

499 if experimental_implements is not None: 

500 self._attributes = self._create_implements_attribute( 

501 experimental_implements 

502 ) 

503 

504 if experimental_attributes is not None: 

505 self._attributes.update(experimental_attributes) 

506 

507 for attribute in self._attributes: 

508 if attribute not in attributes_lib.POLYMORPHIC_FUNCTION_ALLOWLIST: 

509 raise ValueError( 

510 f"`{attribute} is not supported by tf.function as an attribute." 

511 ) 

512 

513 # If `True`, the function uses the rendezvous of the parent. This is only 

514 # needed to support code where raw send/recv operations are inserted and 

515 # when functions are run in graph mode where they may not be inlined. 

516 self._shared_rendezvous = None 

517 self._autograph = autograph 

518 self._experimental_autograph_options = experimental_autograph_options 

519 self._reduce_retracing = reduce_retracing 

520 self._jit_compile = jit_compile 

521 self._created_variables = None # GUARDED_BY(self._lock) 

522 self._variable_creation_fn = None # GUARDED_BY(self._lock) 

523 self._no_variable_creation_fn = None # GUARDED_BY(self._lock) 

524 self._descriptor_cache = weakref.WeakKeyDictionary() 

525 self._name = name 

526 self._key_for_call_stats = self._get_key_for_call_stats() 

527 self._omit_frequent_tracing_warning = False 

528 ops._tf_function_api_gauge.get_cell().set(True) # pylint: disable=protected-access 

529 

530 @property 

531 def name(self): 

532 return self._name 

533 

534 def __getstate__(self): 

535 """Custom pickling, to omit unpickleable objects.""" 

536 result = self.__dict__.copy() 

537 del result["_lock"] 

538 del result["_descriptor_cache"] 

539 del result["_key_for_call_stats"] 

540 return result 

541 

542 def __setstate__(self, state): 

543 """Restore from pickled state.""" 

544 self.__dict__ = state 

545 self._lock = threading.RLock() 

546 self._descriptor_cache = weakref.WeakKeyDictionary() 

547 self._key_for_call_stats = self._get_key_for_call_stats() 

548 

549 def _get_key_for_call_stats(self): 

550 """Returns key instance to track call stats and retracings. 

551 

552 The key instance a best-effort to preserve global consistency. 

553 """ 

554 target_function = self._python_function 

555 # `__wrapped__` is a conventional Python attribute that a higher-order 

556 # function keeps its original function's instance. We also directly use 

557 # this attribute for dealing with a class method. See 

558 # `bound_method_wrapper` in `function.py`. If we don't use `__wrapped__`, 

559 # all class methods will return the same `bound_method_wrapper` instance 

560 # from this function. 

561 while hasattr(target_function, "__wrapped__"): 

562 target_function = target_function.__wrapped__ 

563 

564 if hasattr(target_function, "__func__"): 

565 target_function = target_function.__func__ 

566 

567 if hasattr(target_function, "__code__"): 

568 return target_function.__code__ 

569 

570 return self._python_function 

571 

572 def _compiler_with_scope(self, scope): 

573 """Creates a TracingCompiler wrapped inside a variable creator scope.""" 

574 

575 weak_wrapped_fn = None 

576 compile_with_xla = self._jit_compile 

577 

578 def wrapped_fn(*args, **kwds): 

579 """Wraps `self._python_function` in a variable creator scope.""" 

580 # We register a variable creator with reduced priority. If an outer 

581 # variable creator is just modifying keyword arguments to the variable 

582 # constructor, this will work harmoniously. Since the `scope` registered 

583 # here actually creates the variable, it taking priority would otherwise 

584 # ignore the outer creator. 

585 # 

586 # If an outer variable creator calls the variable constructor manually, 

587 # for example creating a MirroredVariable, then they won't call our 

588 # creator. This means we won't be able to trace the initialization graph, 

589 # and so variable initializers can't depend on function arguments. This is 

590 # better than the alternative, tracing the initialization graph but giving 

591 # the user a variable type they didn't want. 

592 default_graph = ops.get_default_graph() 

593 with default_graph._variable_creator_scope(scope, priority=50): # pylint: disable=protected-access 

594 # __wrapped__ allows AutoGraph to swap in a converted function. We give 

595 # the function a weak reference to itself to avoid a reference cycle. 

596 with OptionalXlaContext(compile_with_xla): 

597 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 

598 return out 

599 

600 weak_wrapped_fn = weakref.ref(wrapped_fn) 

601 

602 return self._compiler(tf_decorator.make_decorator( 

603 self._python_function, 

604 wrapped_fn)) 

605 

606 def _create_implements_attribute(self, implements_arg): 

607 """Creates the attribute value corresponding to attribute_lib.IMPLEMENTS.""" 

608 attributes = {} 

609 if isinstance(implements_arg, str): 

610 # First check if the attribute_lib.IMPLEMENTS is specified as a 

611 # NameAttrList. This is used when apart from the function name being 

612 # implemented, a list of attributes is also being specified. 

613 # The attributes are specified as key-value pairs in the NameAttrList 

614 # of the corresponding AttrValue. The function name will be in the 

615 # 'name' field of the NameAttrList. Else, it is just a string 

616 # corresponding to the function name. 

617 try: 

618 attr_value = attr_value_pb2.AttrValue() 

619 nameattrlist = attr_value_pb2.NameAttrList() 

620 _text_format.Merge(implements_arg, nameattrlist) 

621 attr_value.func.CopyFrom(nameattrlist) 

622 attributes[attributes_lib.IMPLEMENTS] = attr_value 

623 except (_text_format.ParseError, DecodeError): 

624 attributes[attributes_lib.IMPLEMENTS] = implements_arg 

625 return attributes 

626 

627 def _compiler(self, fn): 

628 """Returns a TracingCompiler generated from the input function.""" 

629 attributes = self._attributes.copy() 

630 

631 share = self._shared_rendezvous 

632 if share is not None: 

633 attributes[attributes_lib.SHARED_RENDEZVOUS] = share 

634 

635 if self._jit_compile is not None: 

636 attributes[attributes_lib.XLA_COMPILE] = bool(self._jit_compile) 

637 if self._jit_compile: 

638 attributes[attributes_lib.NO_INLINE] = True 

639 

640 try: 

641 name = fn.__name__ 

642 except AttributeError: 

643 name = "function" 

644 

645 if self._autograph: 

646 fn = autograph_util.py_func_from_autograph( 

647 fn, self._experimental_autograph_options) 

648 

649 return tracing_compiler.TracingCompiler( 

650 fn, 

651 name, 

652 input_signature=self.input_signature, 

653 attributes=attributes, 

654 autograph=self._autograph, 

655 jit_compile=self._jit_compile, 

656 reduce_retracing=self._reduce_retracing, 

657 autograph_options=self._experimental_autograph_options) 

658 

659 def _initialize(self, args, kwds, add_initializers_to=None): 

660 """Initializes, on the first call. 

661 

662 Creates two `Function`s, one that will allow creation of variables 

663 and one that won't. 

664 

665 Additionally runs a trace for the `Function` that allows creation 

666 of variables. 

667 

668 Args: 

669 args: Arguments to the underlying python callable. 

670 kwds: Keyword arguments to the python callable. 

671 add_initializers_to: Where to collect variable initializers, if not None. 

672 """ 

673 created_variables = [] 

674 

675 def variable_capturing_scope(next_creator, **kwds): 

676 """Creates UnliftedInitializerVariables and saves references to them.""" 

677 enable_variable_lifting = kwds.get("experimental_enable_variable_lifting") 

678 if enable_variable_lifting is None: 

679 enable_variable_lifting = True 

680 if not enable_variable_lifting: 

681 return next_creator(**kwds) 

682 v = UnliftedInitializerVariable( 

683 add_initializers_to=add_initializers_to, **kwds 

684 ) 

685 created_variables.append(weakref.ref(v)) 

686 return v 

687 

688 self._created_variables = created_variables 

689 self._variable_creation_fn = self._compiler_with_scope( 

690 variable_capturing_scope) 

691 self._variable_creation_fn._name = self._name # pylint: disable=protected-access 

692 # Force the definition of the function for these arguments 

693 self._concrete_variable_creation_fn = ( 

694 self._variable_creation_fn # pylint: disable=protected-access 

695 ._get_concrete_function_internal_garbage_collected( 

696 *args, **kwds)) 

697 

698 def invalid_creator_scope(*unused_args, **unused_kwds): 

699 """Disables variable creation.""" 

700 raise ValueError( 

701 "tf.function only supports singleton tf.Variables created on the " 

702 "first call. Make sure the tf.Variable is only created once or " 

703 "created outside tf.function. See " 

704 "https://www.tensorflow.org/guide/function#creating_tfvariables " 

705 "for more information.") 

706 

707 self._no_variable_creation_fn = self._compiler_with_scope( 

708 invalid_creator_scope) 

709 self._no_variable_creation_fn._name = self._name # pylint: disable=protected-access 

710 

711 def _clone(self, python_function): 

712 """Clone the function with different python function.""" 

713 f = Function( 

714 python_function=(self._python_function 

715 if python_function is None else python_function), 

716 name=self._name, 

717 input_signature=self.input_signature, 

718 autograph=self._autograph, 

719 jit_compile=self._jit_compile, 

720 reduce_retracing=self._reduce_retracing, 

721 experimental_attributes=self._attributes, 

722 experimental_autograph_options=self._experimental_autograph_options) 

723 

724 if self._shared_rendezvous: 

725 f._shared_rendezvous = self._shared_rendezvous # pylint: disable=protected-access 

726 

727 return f 

728 

729 def _decorate(self, decorator): 

730 """Allows the captured Python function to be decorated in place. 

731 

732 This method is only safe to call when the Function has not been called by a 

733 user. It makes sense to use this method to push a decorator into the 

734 function rather than wrapping the function in the decorator. 

735 

736 We use this in tf.Module to allow user annotated `tf.functions` to remain as 

737 `Function` objects but still automatically enter the Module name_scope 

738 when they are evaluated like all other methods. 

739 

740 Args: 

741 decorator: A callable accepting a single argument which is the function 

742 to decorate and returning a callable result. 

743 

744 Raises: 

745 ValueError: If the function has been called a ValueError is raised. 

746 """ 

747 if self._variable_creation_fn is not None or self._no_variable_creation_fn is not None: 

748 raise ValueError( 

749 "Functions cannot be decorated after they have been traced.") 

750 

751 self._python_function = decorator(self._python_function) 

752 self._function_spec = function_spec_lib.FunctionSpec.from_function_and_signature( 

753 self._python_function, self.input_signature) 

754 

755 # TODO: Remove this private method after updating all its uses 

756 # A good moment to do this could be when the experimental label is removed 

757 def _get_tracing_count(self): 

758 return self.experimental_get_tracing_count() 

759 

760 def experimental_get_tracing_count(self): 

761 """Returns the number of times the function has been traced. 

762 

763 For more information on when a function is traced and when it is 

764 traced multiple times see https://www.tensorflow.org/guide/function. 

765 Example: 

766 

767 >>> @tf.function 

768 ... def double(a): 

769 ... return a + a 

770 >>> double(tf.constant(1)) 

771 >>> double(tf.constant(2)) 

772 >>> double.experimental_get_tracing_count() 

773 1 

774 >>> double(tf.constant("a")) 

775 >>> double.experimental_get_tracing_count() 

776 2 

777 

778 

779 The first time experimental_get_tracing_count is called 

780 it returns 1, as the function is traced the first 

781 time it is called, and the second time the same graph is used 

782 since we're calling it with a parameter of the same type. 

783 

784 The second time experimental_get_tracing_count is called 

785 it returns 2, as we called double with a 

786 different argument type, and so it was traced again. 

787 

788 """ 

789 result = self._no_variable_creation_fn.tracing_count if self._no_variable_creation_fn else 0 

790 result += self._variable_creation_fn.tracing_count if self._variable_creation_fn else 0 

791 return result 

792 

793 @property 

794 def _run_functions_eagerly(self): 

795 return eager_function_run.RUN_FUNCTIONS_EAGERLY 

796 

797 @traceback_utils.filter_traceback 

798 def __call__(self, *args, **kwds): 

799 # Implements GenericFunction.__call__. 

800 if self._run_functions_eagerly: 

801 with trace.Trace(self._name, tf_function_call="eager"): 

802 return self._python_function(*args, **kwds) 

803 

804 # Only count the statistics the first time, before initialization took 

805 # place. 

806 if self._created_variables is None: 

807 compiled = bool(self._jit_compile and 

808 not control_flow_util.GraphOrParentsInXlaContext( 

809 ops.get_default_graph())) 

810 # For nested functions, increment the counter only when a function with 

811 # jit_compile=True is called within a function with jit_compile=False. We 

812 # count this special case to correctly record that both jit_compile=True 

813 # and jit_compile=False is being used for parts of the outer function. 

814 if ops.executing_eagerly_outside_functions() and ( 

815 context.executing_eagerly() or compiled): 

816 # Labels must be strings in Python, so we convert 'compiled' to a string 

817 _tf_function_counter.get_cell(str(int(compiled))).increase_by(1) 

818 

819 tracing_count = self.experimental_get_tracing_count() 

820 with trace.Trace(self._name) as tm: 

821 # TODO(cheshire): Do not duplicate the XLAControlFlowContext annotation. 

822 compiler = "xla" if self._jit_compile else "nonXla" 

823 

824 with OptionalXlaContext(self._jit_compile): 

825 result = self._call(*args, **kwds) 

826 

827 new_tracing_count = self.experimental_get_tracing_count() 

828 without_tracing = (tracing_count == new_tracing_count) 

829 execution_mode = "notTraced" if without_tracing else "traced" 

830 tm.set_metadata(tf_function_call=execution_mode + "-" + compiler, 

831 tracing_count=new_tracing_count) 

832 

833 if context.executing_eagerly(): 

834 if without_tracing: 

835 _frequent_tracing_detector_manager.called_without_tracing( 

836 self._key_for_call_stats) 

837 else: 

838 _frequent_tracing_detector_manager.called_with_tracing( 

839 self._key_for_call_stats, self._python_function, 

840 self._omit_frequent_tracing_warning) 

841 

842 return result 

843 

844 def _call(self, *args, **kwds): 

845 """Calls the graph function.""" 

846 self._lock.acquire() 

847 if ALLOW_DYNAMIC_VARIABLE_CREATION: 

848 condition = self._created_variables and self._variable_creation_fn is None 

849 else: 

850 condition = self._created_variables 

851 if condition: 

852 # Release the lock early so that multiple threads can perform the call 

853 # in parallel. 

854 self._lock.release() 

855 # In this case we have created variables on the first call, so we run the 

856 # defunned version which is guaranteed to never create variables. 

857 return self._no_variable_creation_fn(*args, **kwds) # pylint: disable=not-callable 

858 elif self._variable_creation_fn is not None: 

859 # Release the lock early so that multiple threads can perform the call 

860 # in parallel. 

861 self._lock.release() 

862 # In this case we have not created variables on the first call. So we can 

863 # run the first trace but we should fail if variables are created. 

864 results = self._variable_creation_fn(*args, **kwds) 

865 if self._created_variables and not ALLOW_DYNAMIC_VARIABLE_CREATION: 

866 raise ValueError("Creating variables on a non-first call to a function" 

867 " decorated with tf.function.") 

868 return results 

869 

870 try: 

871 # This is the first call of __call__, so we have to initialize. 

872 initializers = [] 

873 self._initialize(args, kwds, add_initializers_to=initializers) 

874 finally: 

875 # At this point we know that the initialization is complete (or less 

876 # interestingly an exception was raised) so we no longer need a lock. 

877 self._lock.release() 

878 

879 if self._created_variables: 

880 try: 

881 # Attempt to initialize variables eagerly and without conds by lifting 

882 # out initialization graphs. This is the only initialization strategy 

883 # compatible with XLA at the moment. 

884 self._initialize_uninitialized_variables(initializers) 

885 except lift_to_graph.UnliftableError: 

886 pass # Fall through to cond-based initialization. 

887 else: 

888 # Lifting succeeded, so variables are initialized and we can run the 

889 # no_variable_creation function. 

890 return self._no_variable_creation_fn(*args, **kwds) 

891 else: 

892 _, _, filtered_flat_args = ( 

893 self._variable_creation_fn._function_spec # pylint: disable=protected-access 

894 .canonicalize_function_inputs( 

895 args, kwds)) 

896 # If we did not create any variables the trace we have is good enough. 

897 return self._concrete_variable_creation_fn._call_flat( # pylint: disable=protected-access 

898 filtered_flat_args, 

899 self._concrete_variable_creation_fn.captured_inputs) 

900 

901 def fn_with_cond(inner_args, inner_kwds, inner_filtered_flat_args): 

902 """Conditionally runs initialization if it's needed.""" 

903 condition = True 

904 for v, _ in initializers: 

905 condition = math_ops.logical_and( 

906 condition, resource_variable_ops.var_is_initialized_op( 

907 v.handle)) 

908 # We want to call no_variable_creation if possible because it avoids 

909 # recomputing potentially expensive initializers. 

910 return cond.cond( 

911 condition, 

912 lambda: self._no_variable_creation_fn(*inner_args, **inner_kwds), 

913 functools.partial( 

914 self._concrete_variable_creation_fn._call_flat, # pylint: disable=protected-access 

915 inner_filtered_flat_args, 

916 captured_inputs=self._concrete_variable_creation_fn 

917 .captured_inputs)) 

918 

919 # We've created variables and are unable to lift the initialization graphs, 

920 # so we fall back to initializing with conds while running the function. 

921 # TODO(b/216870587) Note that this path is not currently supported for XLA. 

922 if self._jit_compile: 

923 raise errors.UnimplementedError( 

924 None, None, 

925 "We failed to lift variable creations out of this tf.function, " 

926 "so this tf.function cannot be run on XLA. A possible workaround is " 

927 "to move variable creation outside of the XLA compiled function.") 

928 canon_args, canon_kwds, filtered_flat_args = ( 

929 self._variable_creation_fn._function_spec.canonicalize_function_inputs( # pylint: disable=protected-access 

930 args, kwds)) 

931 return tracing_compiler.TracingCompiler( 

932 fn_with_cond, "fn_with_cond")(canon_args, canon_kwds, 

933 filtered_flat_args) 

934 

935 def experimental_get_compiler_ir(self, *args, **kwargs): 

936 # Implements GenericFunction.experimental_get_compiler_ir 

937 context.ensure_initialized() 

938 if not self._jit_compile: 

939 raise ValueError("Compiler IR can only be returned for functions marked " 

940 "with 'jit_compile=True'") 

941 

942 is_tensor_spec = lambda x: isinstance(x, tensor_spec.TensorSpec) 

943 

944 def _check_inputs(args, kwargs): 

945 all_inputs = list(args) + list(kwargs.values()) 

946 # Emtpy input is okay. 

947 if not all_inputs: 

948 return 

949 if any(map(is_tensor_spec, all_inputs)) and any( 

950 map(lambda x: not is_tensor_spec(x), all_inputs) 

951 ): 

952 raise ValueError( 

953 "experimental_get_compiler_ir supports either " 

954 "(1) all inputs are TensorSpec or " 

955 "(2) all inputs are tf.Tensor/python variables" 

956 ) 

957 

958 _check_inputs(args, kwargs) 

959 if ( 

960 len(args) + len(kwargs.values()) > 0 

961 and all(map(is_tensor_spec, args)) 

962 and all(map(is_tensor_spec, kwargs.values())) 

963 ): 

964 # For the case inputs are not empty and input types are all tf.TensorSpec 

965 concrete_fn = self.get_concrete_function(*args, **kwargs) 

966 return compiler_ir.from_concrete_function(concrete_fn) 

967 

968 concrete_fn = self.get_concrete_function(*args, **kwargs) 

969 fn_name = concrete_fn.name 

970 

971 # pylint: disable=protected-access 

972 _, _, filtered_flat_args = ( 

973 concrete_fn._function_spec.canonicalize_function_inputs(args, kwargs)) 

974 

975 def compiler_ir_generator(stage="hlo", device_name=None): 

976 device_name = compiler_ir.maybe_get_device_name(device_name) 

977 res_bytes = context.context().get_compiler_ir( 

978 device_name=device_name, 

979 function_name=fn_name, 

980 flat_args=list(filtered_flat_args), 

981 captured_inputs=concrete_fn.captured_inputs, 

982 stage=stage, 

983 ) 

984 if stage in ("hlo_serialized", "optimized_hlo_serialized", 

985 "optimized_hlo_proto_serialized"): 

986 return res_bytes 

987 else: 

988 return res_bytes.decode("utf-8") 

989 

990 return compiler_ir_generator 

991 

992 @property 

993 def python_function(self): 

994 """The python function wrapped in this tf.function.""" 

995 return self._python_function 

996 

997 @property 

998 def input_signature(self): 

999 return self._function_spec.input_signature 

1000 

1001 @property 

1002 def function_spec(self): 

1003 return self._function_spec 

1004 

1005 def pretty_printed_concrete_signatures(self, verbose=True): 

1006 joiner = "\n\n" if verbose else "\n" 

1007 return joiner.join([ 

1008 c.pretty_printed_signature(verbose=verbose) 

1009 for c in self._list_all_concrete_functions() 

1010 ]) 

1011 

1012 def _initialize_uninitialized_variables(self, initializers): 

1013 """Make and call a `ConcreteFunction` which initializes variables.""" 

1014 

1015 if not initializers: 

1016 return 

1017 

1018 var_is_initialized = _evaluate_var_is_initialized( 

1019 [v for v, _ in initializers]) 

1020 

1021 def initialize_variables(): 

1022 op_map = object_identity.ObjectIdentityDictionary() 

1023 

1024 inits = [] 

1025 for (v, init), is_initialized in zip(initializers, var_is_initialized): 

1026 with ops.init_scope(): 

1027 if is_initialized: 

1028 continue 

1029 inits.append(init) 

1030 

1031 if inits: 

1032 op_map = lift_to_graph.lift_to_graph( 

1033 inits, ops.get_default_graph(), op_map=op_map) 

1034 for (v, init), is_initialized in zip(initializers, var_is_initialized): 

1035 with ops.init_scope(): 

1036 if is_initialized: 

1037 continue 

1038 v.assign(op_map[init], read_value=False) 

1039 

1040 with ops.init_scope(): 

1041 # Note: using TracingCompiler here avoids an infinite recursion. 

1042 # Most of the code in this function runs eagerly with init_scope, where 

1043 # autograph is not necessary. 

1044 return tracing_compiler.TracingCompiler( 

1045 initialize_variables, "initialize_variables", 

1046 autograph=False).get_concrete_function()() 

1047 

1048 def get_initialization_function(self, *args, **kwargs): 

1049 """Returns a `ConcreteFunction` which initializes this function's variables. 

1050 

1051 Requires that this function hasn't been accessed yet through either calling 

1052 it or calling get_concrete_function. Fails if we cannot build an initializer 

1053 function which does not depend on the concrete values of the inputs to this 

1054 function. 

1055 

1056 Note that running this function will overwrite any values currently assigned 

1057 to variables, for example restores from a checkpoint. 

1058 

1059 Args: 

1060 *args: arguments to the underlying python callable. 

1061 **kwargs: keyword arguments to the python callable. 

1062 

1063 Returns: 

1064 A `ConcreteFunction` object which initializes the variables of this 

1065 function. 

1066 

1067 Raises: 

1068 RuntimeError: if called after the variables have been initialized. 

1069 """ 

1070 with self._lock: 

1071 if self._variable_creation_fn is not None: 

1072 raise RuntimeError( 

1073 "get_initialization_function cannot be called after the function " 

1074 "has been used") 

1075 # Here we trace the function, collect the initializers, and attempt to 

1076 # extract them and run them eagerly. Fail only if we cannot do so. 

1077 initializers = [] 

1078 self._initialize(args, kwargs, add_initializers_to=initializers) 

1079 

1080 def initialize_variables(): 

1081 for v, init in initializers: 

1082 v.assign( 

1083 lift_to_graph.lift_to_graph([init], ops.get_default_graph())[init], 

1084 read_value=False) 

1085 

1086 # Note: using TracingCompiler here avoids an infinite recursion. 

1087 return tracing_compiler.TracingCompiler( 

1088 initialize_variables, "initialize_variables").get_concrete_function() 

1089 

1090 def _list_all_concrete_functions(self): 

1091 """Returns all concrete functions.""" 

1092 if self.input_signature is not None: 

1093 self.get_concrete_function() 

1094 concrete_functions = [] 

1095 # pylint: disable=protected-access 

1096 if self._variable_creation_fn: 

1097 concrete_functions.extend( 

1098 self._variable_creation_fn._list_all_concrete_functions()) 

1099 if self._no_variable_creation_fn: 

1100 concrete_functions.extend( 

1101 self._no_variable_creation_fn._list_all_concrete_functions()) 

1102 # pylint: enable=protected-access 

1103 return concrete_functions 

1104 

1105 def _list_all_concrete_functions_for_serialization(self): 

1106 """Returns all concrete functions for serialization. 

1107 

1108 Returns: 

1109 A list of instances of `ConcreteFunction`. 

1110 """ 

1111 seen_signatures = [] 

1112 if self.input_signature is not None: 

1113 seen_signatures.append((self.input_signature, {})) 

1114 else: 

1115 concrete_functions = self._list_all_concrete_functions() 

1116 for concrete_function in concrete_functions: 

1117 signature = concrete_function.structured_input_signature 

1118 flattened = nest.flatten(signature) 

1119 if any( 

1120 isinstance(arg, func_graph_module.UnknownArgument) 

1121 for arg in flattened): 

1122 logging.info("Unsupported signature for serialization: %s.", 

1123 signature) 

1124 continue 

1125 equal_to_signature = functools.partial( 

1126 function_spec_lib.is_same_structure, signature, check_values=True) 

1127 if not any(equal_to_signature(s) for s in seen_signatures): 

1128 seen_signatures.append(signature) 

1129 

1130 # Re-create concrete functions for these signatures. Re-creating ensures 

1131 # that if the cache key has changed, the function will be traced again. 

1132 concrete_functions = [] 

1133 for args, kwargs in seen_signatures: 

1134 concrete_functions.append(self.get_concrete_function(*args, **kwargs)) 

1135 return concrete_functions 

1136 

1137 def _trackable_children(self, save_type="checkpoint", **kwargs): 

1138 """For implementing `Trackable`.""" 

1139 if save_type == "checkpoint": 

1140 return {} 

1141 return {f"trace_{n}": fn for n, fn in 

1142 enumerate(self._list_all_concrete_functions_for_serialization())} 

1143 

1144 def _deserialization_dependencies(self, children): 

1145 """Returns concrete functions which must be loaded before this object.""" 

1146 return children 

1147 

1148 def _get_concrete_function_garbage_collected(self, *args, **kwargs): 

1149 """Returns a `ConcreteFunction` specialized to inputs and execution context. 

1150 

1151 Unlike `get_concrete_function(...)`, the graph will be deleted when the 

1152 returned function is deleted. It's useful to avoid creating a reference 

1153 cycle when you know for sure that the graph will be no longer used without 

1154 the returned function. 

1155 

1156 Args: 

1157 *args: inputs to specialize on. 

1158 **kwargs: inputs to specialize on. 

1159 

1160 Returns: 

1161 A TensorFlow function which takes exactly one `tf.Tensor` per argument. 

1162 

1163 Raises: 

1164 ValueError: if this object has not yet been called on concrete values. 

1165 """ 

1166 with self._lock: 

1167 if self._variable_creation_fn is None: 

1168 initializers = [] 

1169 self._initialize(args, kwargs, add_initializers_to=initializers) 

1170 self._initialize_uninitialized_variables(initializers) 

1171 

1172 if self._created_variables: 

1173 # In this case we have created variables on the first call, so we run the 

1174 # version which is guaranteed to never create variables. 

1175 return self._no_variable_creation_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 

1176 *args, **kwargs) 

1177 elif self._variable_creation_fn is not None: 

1178 # In this case we have not created variables on the first call. So we can 

1179 # run the first trace but we should fail if variables are created. 

1180 concrete = self._variable_creation_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access 

1181 *args, **kwargs) 

1182 if self._created_variables: 

1183 raise ValueError("Creating variables on a non-first call to a function" 

1184 " decorated with tf.function.") 

1185 return concrete 

1186 

1187 def get_concrete_function(self, *args, **kwargs): 

1188 # Implements GenericFunction.get_concrete_function. 

1189 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 

1190 concrete._garbage_collector.release() # pylint: disable=protected-access 

1191 return concrete 

1192 

1193 def __tf_tracing_type__(self, _): 

1194 return trace_type.Weakref(weakref.ref(self)) 

1195 

1196 def __get__(self, instance, owner): 

1197 """Makes it possible to decorate instance methods.""" 

1198 del owner 

1199 # `instance` here is the instance that this `Function` was accessed through 

1200 # e.g., for 

1201 # 

1202 # class Foo: 

1203 # 

1204 # @tf.function 

1205 # def bar(self): 

1206 # ... 

1207 # 

1208 # foo = Foo() 

1209 # foo.bar() # `foo.bar` is a `Function` instance 

1210 # 

1211 # then `instance` will be `foo` (and `owner` will be `Foo`). For composite 

1212 # tensors, we can just treat `instance` as a normal parameter. But for 

1213 # other types, we create a new instance of `Function` here to allow 

1214 # different instances each to create variables once, thereby allowing 

1215 # methods to be decorated with tf.function. Keeps a cache to avoid retracing 

1216 # the function every time the descriptor is accessed. 

1217 # TODO(mdan): Identify types which can just be parameters more generically. 

1218 # 

1219 # The check for instance._type_spec=None is used because certain classes 

1220 # (including subclasses of tf.linalg.LinearOperator) are subclasses of 

1221 # CompositeTensor but do not actually implement the required APIs. 

1222 # TODO(b/199278478): Fix those classes, then remove the check for 

1223 # `instance._type_spec is not None`. 

1224 if (isinstance(instance, composite_tensor.CompositeTensor) and 

1225 instance._type_spec is not None): # pylint: disable=protected-access 

1226 return types_lib.MethodType(self, instance) 

1227 if instance not in self._descriptor_cache: 

1228 if instance is None: 

1229 return self 

1230 # TODO(mdan): If the CompositeTensor path works, do the same here. 

1231 # It's unclear whether we need the tf-decorator, or could just call 

1232 # MethodType(self.clone(), instance) 

1233 self._descriptor_cache[instance] = ( 

1234 tracing_compiler.class_method_to_instance_method(self, instance)) 

1235 return self._descriptor_cache[instance] 

1236 

1237 

1238@tf_export("function") 

1239@deprecation.deprecated_args(None, 

1240 "experimental_compile is deprecated, use " 

1241 "jit_compile instead", "experimental_compile") 

1242@deprecation.deprecated_args(None, 

1243 "experimental_relax_shapes is deprecated, use " 

1244 "reduce_retracing instead", 

1245 "experimental_relax_shapes") 

1246@deprecation.deprecated_args(None, 

1247 "experimental_follow_type_hints is deprecated", 

1248 "experimental_follow_type_hints") 

1249def function( 

1250 func=None, 

1251 input_signature=None, 

1252 autograph=True, 

1253 jit_compile=None, 

1254 reduce_retracing=False, 

1255 experimental_implements=None, 

1256 experimental_autograph_options=None, 

1257 experimental_attributes=None, 

1258 experimental_relax_shapes=None, 

1259 experimental_compile=None, 

1260 experimental_follow_type_hints=None # pylint: disable=unused-argument 

1261) -> core.GenericFunction: 

1262 """Compiles a function into a callable TensorFlow graph. 

1263 

1264 `tf.function` constructs a `tf.types.experimental.GenericFunction` that 

1265 executes a TensorFlow graph (`tf.Graph`) created by trace-compiling the 

1266 TensorFlow operations in `func`. More information on the topic can be found 

1267 in [Introduction to Graphs and tf.function] 

1268 (https://www.tensorflow.org/guide/intro_to_graphs). 

1269 

1270 See [Better Performance with tf.function] 

1271 (https://www.tensorflow.org/guide/function) for tips on performance and 

1272 known limitations. 

1273 

1274 Example usage: 

1275 

1276 >>> @tf.function 

1277 ... def f(x, y): 

1278 ... return x ** 2 + y 

1279 >>> x = tf.constant([2, 3]) 

1280 >>> y = tf.constant([3, -2]) 

1281 >>> f(x, y) 

1282 <tf.Tensor: ... numpy=array([7, 7], ...)> 

1283 

1284 The trace-compilation allows non-TensorFlow operations to execute, but under 

1285 special conditions. In general, only TensorFlow operations are guaranteed to 

1286 run and create fresh results whenever the `GenericFunction` is called. 

1287 

1288 ## Features 

1289 

1290 `func` may use data-dependent Python control flow statements, including `if`, 

1291 `for`, `while` `break`, `continue` and `return`: 

1292 

1293 >>> @tf.function 

1294 ... def f(x): 

1295 ... if tf.reduce_sum(x) > 0: 

1296 ... return x * x 

1297 ... else: 

1298 ... return -x // 2 

1299 >>> f(tf.constant(-2)) 

1300 <tf.Tensor: ... numpy=1> 

1301 

1302 `func`'s closure may include `tf.Tensor` and `tf.Variable` objects: 

1303 

1304 >>> @tf.function 

1305 ... def f(): 

1306 ... return x ** 2 + y 

1307 >>> x = tf.constant([-2, -3]) 

1308 >>> y = tf.Variable([3, -2]) 

1309 >>> f() 

1310 <tf.Tensor: ... numpy=array([7, 7], ...)> 

1311 

1312 `func` may also use ops with side effects, such as `tf.print`, `tf.Variable` 

1313 and others: 

1314 

1315 >>> v = tf.Variable(1) 

1316 >>> @tf.function 

1317 ... def f(x): 

1318 ... for i in tf.range(x): 

1319 ... v.assign_add(i) 

1320 >>> f(3) 

1321 >>> v 

1322 <tf.Variable ... numpy=4> 

1323 

1324 Important: Any Python side-effects (appending to a list, printing with 

1325 `print`, etc) will only happen once, when `func` is traced. To have 

1326 side-effects executed into your `tf.function` they need to be written 

1327 as TF ops: 

1328 

1329 >>> l = [] 

1330 >>> @tf.function 

1331 ... def f(x): 

1332 ... for i in x: 

1333 ... l.append(i + 1) # Caution! Will only happen once when tracing 

1334 >>> f(tf.constant([1, 2, 3])) 

1335 >>> l 

1336 [<tf.Tensor ...>] 

1337 

1338 Instead, use TensorFlow collections like `tf.TensorArray`: 

1339 

1340 >>> @tf.function 

1341 ... def f(x): 

1342 ... ta = tf.TensorArray(dtype=tf.int32, size=0, dynamic_size=True) 

1343 ... for i in range(len(x)): 

1344 ... ta = ta.write(i, x[i] + 1) 

1345 ... return ta.stack() 

1346 >>> f(tf.constant([1, 2, 3])) 

1347 <tf.Tensor: ..., numpy=array([2, 3, 4], ...)> 

1348 

1349 ## `tf.function` creates polymorphic callables 

1350 

1351 Internally, `tf.types.experimental.GenericFunction` may contain multiple 

1352 `tf.types.experimental.ConcreteFunction`s, each specialized to arguments with 

1353 different data types or shapes, since TensorFlow can perform more 

1354 optimizations on graphs of specific shapes, dtypes and values of constant 

1355 arguments. `tf.function` treats any pure Python values as opaque objects (best 

1356 thought of as compile-time constants), and builds a separate `tf.Graph` for 

1357 each set of Python arguments that it encounters. 

1358 For more information, see the 

1359 [tf.function guide](https://www.tensorflow.org/guide/function#rules_of_tracing) 

1360 

1361 Executing a `GenericFunction` will select and execute the appropriate 

1362 `ConcreteFunction` based on the argument types and values. 

1363 

1364 To obtain an individual `ConcreteFunction`, use the 

1365 `GenericFunction.get_concrete_function` method. It can be called with the 

1366 same arguments as `func` and returns a 

1367 `tf.types.experimental.ConcreteFunction`. `ConcreteFunction`s are backed by a 

1368 single `tf.Graph`: 

1369 

1370 >>> @tf.function 

1371 ... def f(x): 

1372 ... return x + 1 

1373 >>> isinstance(f.get_concrete_function(1).graph, tf.Graph) 

1374 True 

1375 

1376 `ConcreteFunction`s can be executed just like `GenericFunction`s, but their 

1377 input is resticted to the types to which they're specialized. 

1378 

1379 ## Retracing 

1380 

1381 `ConcreteFunctions` are built (traced) on the fly, as the `GenericFunction` is 

1382 called with new TensorFlow types or shapes, or with new Python values as 

1383 arguments. When `GenericFunction` builds a new trace, it is said that `func` 

1384 is retraced. Retracing is a frequent performance concern for `tf.function` as 

1385 it can be considerably slower than executing a graph that's already been 

1386 traced. It is ideal to minimize the amount of retracing in your code. 

1387 

1388 Caution: Passing python scalars or lists as arguments to `tf.function` will 

1389 usually retrace. To avoid this, pass numeric arguments as Tensors whenever 

1390 possible: 

1391 

1392 >>> @tf.function 

1393 ... def f(x): 

1394 ... return tf.abs(x) 

1395 >>> f1 = f.get_concrete_function(1) 

1396 >>> f2 = f.get_concrete_function(2) # Slow - compiles new graph 

1397 >>> f1 is f2 

1398 False 

1399 >>> f1 = f.get_concrete_function(tf.constant(1)) 

1400 >>> f2 = f.get_concrete_function(tf.constant(2)) # Fast - reuses f1 

1401 >>> f1 is f2 

1402 True 

1403 

1404 Python numerical arguments should only be used when they take few distinct 

1405 values, such as hyperparameters like the number of layers in a neural network. 

1406 

1407 ## Input signatures 

1408 

1409 For Tensor arguments, `GenericFunction`creates a new `ConcreteFunction` for 

1410 every unique set of input shapes and datatypes. The example below creates two 

1411 separate `ConcreteFunction`s, each specialized to a different shape: 

1412 

1413 >>> @tf.function 

1414 ... def f(x): 

1415 ... return x + 1 

1416 >>> vector = tf.constant([1.0, 1.0]) 

1417 >>> matrix = tf.constant([[3.0]]) 

1418 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 

1419 False 

1420 

1421 An "input signature" can be optionally provided to `tf.function` to control 

1422 this process. The input signature specifies the shape and type of each 

1423 Tensor argument to the function using a `tf.TensorSpec` object. More general 

1424 shapes can be used. This ensures only one `ConcreteFunction` is created, and 

1425 restricts the `GenericFunction` to the specified shapes and types. It is 

1426 an effective way to limit retracing when Tensors have dynamic shapes. 

1427 

1428 >>> @tf.function( 

1429 ... input_signature=[tf.TensorSpec(shape=None, dtype=tf.float32)]) 

1430 ... def f(x): 

1431 ... return x + 1 

1432 >>> vector = tf.constant([1.0, 1.0]) 

1433 >>> matrix = tf.constant([[3.0]]) 

1434 >>> f.get_concrete_function(vector) is f.get_concrete_function(matrix) 

1435 True 

1436 

1437 ## Variables may only be created once 

1438 

1439 `tf.function` only allows creating new `tf.Variable` objects when it is called 

1440 for the first time: 

1441 

1442 >>> class MyModule(tf.Module): 

1443 ... def __init__(self): 

1444 ... self.v = None 

1445 ... 

1446 ... @tf.function 

1447 ... def __call__(self, x): 

1448 ... if self.v is None: 

1449 ... self.v = tf.Variable(tf.ones_like(x)) 

1450 ... return self.v * x 

1451 

1452 In general, it is recommended to create `tf.Variable`s outside of 

1453 `tf.function`. 

1454 In simple cases, persisting state across `tf.function` boundaries may be 

1455 implemented using a pure functional style in which state is represented by 

1456 `tf.Tensor`s passed as arguments and returned as return values. 

1457 

1458 Contrast the two styles below: 

1459 

1460 >>> state = tf.Variable(1) 

1461 >>> @tf.function 

1462 ... def f(x): 

1463 ... state.assign_add(x) 

1464 >>> f(tf.constant(2)) # Non-pure functional style 

1465 >>> state 

1466 <tf.Variable ... numpy=3> 

1467 

1468 >>> state = tf.constant(1) 

1469 >>> @tf.function 

1470 ... def f(state, x): 

1471 ... state += x 

1472 ... return state 

1473 >>> state = f(state, tf.constant(2)) # Pure functional style 

1474 >>> state 

1475 <tf.Tensor: ... numpy=3> 

1476 

1477 ## Python operations execute only once per trace 

1478 

1479 `func` may contain TensorFlow operations mixed with pure Python operations. 

1480 However, when the function is executed, only the TensorFlow operations will 

1481 run. The Python operations run only once, at trace time. If TensorFlow 

1482 operations depend on results from Python operations, those results will be 

1483 frozen into the graph. 

1484 

1485 >>> @tf.function 

1486 ... def f(a, b): 

1487 ... print('this runs at trace time; a is', a, 'and b is', b) 

1488 ... return b 

1489 >>> f(1, tf.constant(1)) 

1490 this runs at trace time; a is 1 and b is Tensor("...", shape=(), dtype=int32) 

1491 <tf.Tensor: shape=(), dtype=int32, numpy=1> 

1492 

1493 >>> f(1, tf.constant(2)) 

1494 <tf.Tensor: shape=(), dtype=int32, numpy=2> 

1495 

1496 >>> f(2, tf.constant(1)) 

1497 this runs at trace time; a is 2 and b is Tensor("...", shape=(), dtype=int32) 

1498 <tf.Tensor: shape=(), dtype=int32, numpy=1> 

1499 

1500 >>> f(2, tf.constant(2)) 

1501 <tf.Tensor: shape=(), dtype=int32, numpy=2> 

1502 

1503 Args: 

1504 func: The function to be compiled. If `func` is None, `tf.function` returns 

1505 a decorator that can be invoked with a single argument - `func`. In other 

1506 words, `tf.function(input_signature=...)(func)` is equivalent to 

1507 `tf.function(func, input_signature=...)`. The former can be used as 

1508 decorator. 

1509 input_signature: A possibly nested sequence of `tf.TensorSpec` objects 

1510 specifying the shapes and dtypes of the Tensors that will be supplied to 

1511 this function. If `None`, a separate function is instantiated for each 

1512 inferred input signature. If input_signature is specified, every input to 

1513 `func` must be a `Tensor`, and `func` cannot accept `**kwargs`. 

1514 autograph: Whether autograph should be applied on `func` before tracing a 

1515 graph. Data-dependent Python control flow statements require 

1516 `autograph=True`. For more information, see the 

1517 [tf.function and AutoGraph guide]( 

1518 https://www.tensorflow.org/guide/function#autograph_transformations). 

1519 jit_compile: If `True`, compiles the function using 

1520 [XLA](https://tensorflow.org/xla). XLA performs compiler optimizations, 

1521 such as fusion, and attempts to emit more efficient code. This may 

1522 drastically improve the performance. If set to `True`, 

1523 the whole function needs to be compilable by XLA, or an 

1524 `errors.InvalidArgumentError` is thrown. 

1525 If `None` (default), compiles the function with XLA when running on TPU 

1526 and goes through the regular function execution path when running on 

1527 other devices. 

1528 If `False`, executes the function without XLA compilation. Set this value 

1529 to `False` when directly running a multi-device function on TPUs (e.g. two 

1530 TPU cores, one TPU core and its host CPU). 

1531 Not all functions are compilable, see a list of 

1532 [sharp corners](https://tensorflow.org/xla/known_issues). 

1533 reduce_retracing: When True, `tf.function` attempts to reduce the 

1534 amount of retracing, for example by using more generic shapes. This 

1535 can be controlled for user objects by customizing their associated 

1536 `tf.types.experimental.TraceType`. 

1537 experimental_implements: If provided, contains a name of a "known" function 

1538 this implements. For example "mycompany.my_recurrent_cell". 

1539 This is stored as an attribute in inference function, 

1540 which can then be detected when processing serialized function. 

1541 See [standardizing composite ops](https://github.com/tensorflow/community/blob/master/rfcs/20190610-standardizing-composite_ops.md) # pylint: disable=line-too-long 

1542 for details. For an example of utilizing this attribute see this 

1543 [example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc) 

1544 The code above automatically detects and substitutes function that 

1545 implements "embedded_matmul" and allows TFLite to substitute its own 

1546 implementations. For instance, a tensorflow user can use this 

1547 attribute to mark that their function also implements 

1548 `embedded_matmul` (perhaps more efficiently!) 

1549 by specifying it using this parameter: 

1550 `@tf.function(experimental_implements="embedded_matmul")` 

1551 This can either be specified as just the string name of the function or 

1552 a NameAttrList corresponding to a list of key-value attributes associated 

1553 with the function name. The name of the function will be in the 'name' 

1554 field of the NameAttrList. To define a formal TF op for this function 

1555 implements, try the experimental [composite TF](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tfr) 

1556 project. 

1557 experimental_autograph_options: Optional tuple of 

1558 `tf.autograph.experimental.Feature` values. 

1559 experimental_attributes: Optional dictionary of attributes to include in the 

1560 generated FunctionDefs. 

1561 experimental_relax_shapes: Deprecated. Use `reduce_retracing` 

1562 instead. 

1563 experimental_compile: Deprecated alias to 'jit_compile'. 

1564 experimental_follow_type_hints: Deprecated. Please use input_signature or 

1565 reduce_retracing instead. 

1566 

1567 Returns: 

1568 If `func` is not None, returns a `tf.types.experimental.GenericFunction`. 

1569 If `func` is None, returns a decorator that, when invoked with a single 

1570 `func` argument, returns a `tf.types.experimental.GenericFunction`. 

1571 

1572 Raises: 

1573 `ValueError` when attempting to use `jit_compile=True`, but XLA support is 

1574 not available. 

1575 """ 

1576 if jit_compile is None and JIT_COMPILE_FUNCTIONS: 

1577 jit_compile = True 

1578 

1579 # TODO(b/224808187): Remove after renaming usages. 

1580 if experimental_relax_shapes: 

1581 reduce_retracing = True 

1582 

1583 def decorated(inner_function): 

1584 try: 

1585 name = inner_function.__name__ 

1586 except AttributeError: 

1587 name = "function" 

1588 return tf_decorator.make_decorator( 

1589 inner_function, 

1590 decorator_name="tf.function", 

1591 decorator_func=Function( 

1592 inner_function, 

1593 name, 

1594 input_signature=input_signature, 

1595 autograph=autograph, 

1596 experimental_autograph_options=experimental_autograph_options, 

1597 reduce_retracing=reduce_retracing, 

1598 

1599 # TODO(b/171825496): Update once `experimental_compile` is removed 

1600 # entirely in favor of 'jit_compile'. 

1601 jit_compile=deprecation.deprecated_argument_lookup( 

1602 "jit_compile", 

1603 jit_compile, 

1604 "experimental_compile", 

1605 experimental_compile), 

1606 experimental_implements=experimental_implements, 

1607 experimental_attributes=experimental_attributes)) 

1608 

1609 # This code path is for the `foo = tf.function(foo, ...)` use case 

1610 if func is not None: 

1611 return decorated(func) 

1612 

1613 # This code path is for the 

1614 # 

1615 # @tf.function(...) 

1616 # def foo(...): 

1617 # ... 

1618 # 

1619 # use case, which is equivalent to `foo = tf.function(...)(foo)` 

1620 return decorated