1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tracing Compiler implementation."""
16
17import collections
18import contextlib
19import threading
20import types as types_lib
21from typing import List
22import weakref
23
24from tensorflow.core.function import trace_type
25from tensorflow.core.function.capture import capture_container
26from tensorflow.core.function.polymorphism import function_cache
27from tensorflow.core.function.polymorphism import function_type as function_type_lib
28from tensorflow.python.eager import monitoring
29from tensorflow.python.eager.polymorphic_function import attributes as attributes_lib
30from tensorflow.python.eager.polymorphic_function import function_context
31from tensorflow.python.eager.polymorphic_function import function_spec
32from tensorflow.python.eager.polymorphic_function import monomorphic_function
33from tensorflow.python.eager.polymorphic_function import tf_method_target
34from tensorflow.python.framework import func_graph as func_graph_module
35from tensorflow.python.framework import ops
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.profiler import trace
38from tensorflow.python.util import compat
39from tensorflow.python.util import lazy_loader
40from tensorflow.python.util import tf_decorator
41
42# Loaded lazily due to a circular dependency (roughly
43# tf.function->autograph->->dataset->tf.function).
44# TODO(b/133251390): Use a regular import.
45ag_ctx = lazy_loader.LazyLoader(
46 "ag_ctx", globals(),
47 "tensorflow.python.autograph.core.ag_ctx")
48
49_graph_building_time_counter = monitoring.Counter(
50 "/tensorflow/core/tf_function/graph_building_time_usecs",
51 "Time for tf.function to build a graph (us).")
52
53
54# TODO(fmuham): Revamp the API of this class to be 100% compiler-focused.
55class TracingCompiler:
56 """Generates, caches and dispatchs traced Monomorphic Concrete Functions.
57
58 The tracing is done using the Python source function with respect to inputs
59 and other options specified by constructor.
60
61 See the documentation for `tf.function` for more information on the semantics
62 of defined functions.
63
64 `TracingCompiler` class is thread-compatible meaning that minimal usage of
65 tf.function (defining and calling) is thread-safe, but if users call other
66 methods or invoke the base `python_function` themselves, external
67 synchronization is necessary.
68
69 In addition, TracingCompiler is not reentrant, so recursive functions need
70 to call the wrapped function, not the wrapper.
71 """
72
73 def __init__(self,
74 python_function,
75 name,
76 input_signature=None,
77 attributes=None,
78 autograph=True,
79 autograph_options=None,
80 reduce_retracing=False,
81 capture_by_value=None,
82 jit_compile=None):
83 """Initializes a `TracingCompiler`.
84
85 Args:
86 python_function: the function to be wrapped.
87 name: the name given to it.
88 input_signature: a possibly nested sequence of `TensorSpec` objects
89 specifying the input signature of this function. If `None`, a separate
90 function is instantiated for each inferred input signature.
91 attributes: dict, extra keyword arguments that will be added as attribute
92 of the function.
93 autograph: whether to use autograph to compile `python_function`. See
94 https://www.tensorflow.org/guide/autograph for more information.
95 autograph_options: Experimental knobs to control behavior `when
96 autograph=True`. See https://www.tensorflow.org/guide/autograph for more
97 information.
98 reduce_retracing: When True, `tf.function` uses
99 `tf.types.experimental.TraceType` to trace supertypes of arguments to
100 reduce the number of traces.
101 capture_by_value: Experimental. Whether to capture resource variables by
102 value or reference. If None, will inherit from a parent context or
103 default to False.
104 jit_compile: Force-compile the function with XLA, cf. tf.function doc on
105 jit_compile.
106
107 Raises:
108 ValueError: if `input_signature` is not None and the `python_function`'s
109 argspec has keyword arguments.
110 """
111 self._python_function = python_function
112 pure_function = attributes and attributes_lib.IMPLEMENTS in attributes
113 self._function_spec = (
114 function_spec.FunctionSpec.from_function_and_signature(
115 python_function, input_signature, is_pure=pure_function
116 )
117 )
118 self._name = name
119 self._autograph = autograph
120 self._autograph_options = autograph_options
121 self._reduce_retracing = reduce_retracing
122 self._function_cache = function_cache.FunctionCache()
123
124 self._function_attributes = attributes or {}
125 for attribute in self._function_attributes:
126 if attribute not in attributes_lib.TRACING_COMPILER_ALLOWLIST:
127 raise ValueError(
128 f"TracingCompiler does not support `{attribute}` as an attribute."
129 )
130
131 self._capture_by_value = capture_by_value
132 self.tracing_count = 0
133 # Maintein a dict of all captures: identifier -> lambda function. It's used
134 # to get runtime values for all captures during ConcreteFunction dispatch,
135 self._func_captures = capture_container.FunctionCaptures()
136 self._lock = threading.RLock()
137 # _descriptor_cache is a of instance of a class to an instance-specific
138 # `TracingCompiler`, used to make sure tf.function-decorated methods
139 # create different functions for each instance.
140 self._descriptor_cache = weakref.WeakKeyDictionary()
141 self._jit_compile = jit_compile
142
143 def __call__(self, *args, **kwargs):
144 """Calls a graph function specialized to the inputs."""
145 with self._lock:
146 (concrete_function,
147 filtered_flat_args) = self._maybe_define_function(args, kwargs)
148 return concrete_function._call_flat(
149 filtered_flat_args, captured_inputs=concrete_function.captured_inputs) # pylint: disable=protected-access
150
151 @property
152 def python_function(self):
153 """Returns the wrapped Python function."""
154 return self._python_function # pylint: disable=protected-access
155
156 @property
157 def function_spec(self):
158 return self._function_spec
159
160 @property
161 def input_signature(self):
162 """Returns the input signature."""
163 return self._function_spec.input_signature
164
165 def _maybe_define_concrete_function(self, args, kwargs):
166 if self.input_signature and not args and not kwargs:
167 # TODO(b/215596825): Throw error here if multiple entries are defined.
168 args = self.input_signature
169 kwargs = {}
170
171 return self._maybe_define_function(args, kwargs)
172
173 def _get_concrete_function_internal_garbage_collected(self, *args, **kwargs):
174 """Returns a concrete function which cleans up its graph function."""
175 with self._lock:
176 concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
177 return concrete_function
178
179 def _get_concrete_function_internal(self, *args, **kwargs):
180 """Bypasses error checking when getting a graph function."""
181 concrete_function = self._get_concrete_function_internal_garbage_collected(
182 *args, **kwargs)
183 # We're returning this concrete function to someone, and they may keep a
184 # reference to the FuncGraph without keeping a reference to the
185 # ConcreteFunction object. So we won't clean up the reference cycles
186 # manually and instead will leave them to Python's garbage collector.
187 concrete_function._garbage_collector.release() # pylint: disable=protected-access
188 return concrete_function
189
190 def _get_concrete_function_garbage_collected(self, *args, **kwargs):
191 """Returns a `ConcreteFunction` specialized to inputs and execution context.
192
193 Unlike `get_concrete_function(...)`, the graph will be deleted when the
194 returned function is deleted. It's useful to avoid creating a reference
195 cycle when you know for sure that the graph will be no longer used without
196 the returned function.
197
198 Args:
199 *args: inputs to specialize on.
200 **kwargs: inputs to specialize on.
201 """
202 if self.input_signature and (args or kwargs):
203 # Check to see if a valid type can be generated from the args, kwargs
204 self._function_spec.make_canonicalized_monomorphic_type(args, kwargs)
205
206 with self._lock:
207 concrete_function, _ = self._maybe_define_concrete_function(args, kwargs)
208 seen_names = set()
209 concrete_function._arg_keywords = [] # pylint: disable=protected-access
210 prefix_counts = {}
211 graph = concrete_function.graph
212 num_captures = len(
213 graph.internal_captures + graph.deferred_internal_captures)
214 num_positional = len(graph.inputs) - num_captures
215 for arg in concrete_function.graph.inputs[:num_positional]:
216 user_arg_name = compat.as_str(arg.op.get_attr("_user_specified_name"))
217 proposal = user_arg_name
218 while proposal in seen_names:
219 index = prefix_counts.get(user_arg_name, 1)
220 proposal = "{}_{}".format(user_arg_name, index)
221 prefix_counts[user_arg_name] = index + 1
222 seen_names.add(proposal)
223 concrete_function._arg_keywords.append(proposal) # pylint: disable=protected-access
224 # Anything can be a positional argument, in the same order as .inputs
225 concrete_function._num_positional_args = num_positional # pylint: disable=protected-access
226 return concrete_function
227
228 def get_concrete_function(self, *args, **kwargs):
229 """Returns a `ConcreteFunction` specialized to inputs and execution context.
230
231 Args:
232 *args: inputs to specialize on. Can be concrete values (e.g. 1) or
233 `tf.Tensor` or `tf.TensorSpec`.
234 **kwargs: keyword inputs to specialize on. Concrete values (e.g. 1) or
235 `tf.Tensor` or `tf.TensorSpec`.
236 """
237 concrete_function = self._get_concrete_function_garbage_collected(
238 *args, **kwargs)
239 concrete_function._garbage_collector.release() # pylint: disable=protected-access
240 return concrete_function
241
242 def _list_all_concrete_functions(
243 self) -> List[monomorphic_function.ConcreteFunction]:
244 return self._function_cache.values()
245
246 def __get__(self, instance, owner):
247 """Makes it possible to decorate instance methods."""
248 del owner
249 # `instance` here is the instance that this `TracingCompiler` was
250 # accessed through e.g., for
251 #
252 # class Foo:
253 #
254 # @tf.function
255 # def bar(self):
256 # ...
257 #
258 # foo = Foo()
259 # foo.bar() # `foo.bar` is a `tf.function` instance
260 #
261 # then `instance` will be `foo` (and `owner` will be `Foo`). We create a
262 # new instance of `TracingCompiler` here to allow different instances
263 # to create variables once, thereby allowing methods to be decorated with
264 # tf.function. Keeps a cache to avoid retracing the function every time the
265 # descriptor is accessed.
266 if instance not in self._descriptor_cache:
267 if instance is None:
268 return self
269 # If there is no instance-specific `TracingCompiler` in the cache, we
270 # construct an instance-specific `TracingCompiler` that uses a weak
271 # reference to the instance (so that the instance will be correctly gc'd).
272
273 # And finally add the wrapped function to the description cache
274 self._descriptor_cache[instance] = class_method_to_instance_method(
275 self, instance)
276
277 # Return the cached `TracingCompiler` for the instance
278 return self._descriptor_cache[instance]
279
280 def _create_concrete_function(self, args, kwargs, func_graph):
281 """Create a `ConcreteFunction` from `args`, `kwargs`, and `func_graph`."""
282 self.tracing_count += 1
283
284 arglen = len(args)
285 base_arg_names = self._function_spec.arg_names[:arglen]
286 num_missing_args = arglen - len(self._function_spec.arg_names)
287 if num_missing_args > 0:
288 # Must have variable positional args if there are missing args.
289 var_arg_name = next(
290 p.name
291 for p in self._function_spec.function_type.parameters.values()
292 if p.kind is function_type_lib.Parameter.VAR_POSITIONAL
293 )
294 missing_arg_names = [var_arg_name] * num_missing_args
295 # Produce a list of missing args of the form ["arg_0", "arg_1", ...],
296 # where arg is based on the self._function_spec.vararg_name.
297 missing_arg_names = [
298 "%s_%d" % (arg, i) for i, arg in enumerate(missing_arg_names)
299 ]
300 arg_names = base_arg_names + missing_arg_names
301 else:
302 arg_names = base_arg_names
303
304 concrete_function = monomorphic_function.ConcreteFunction(
305 func_graph_module.func_graph_from_py_func(
306 self._name,
307 self._python_function,
308 args,
309 kwargs,
310 None,
311 func_graph=func_graph,
312 arg_names=arg_names,
313 capture_by_value=self._capture_by_value,
314 create_placeholders=False),
315 self._function_attributes,
316 spec=self.function_spec,
317 # Tell the ConcreteFunction to clean up its graph once it goes out of
318 # scope. This is not the default behavior since it gets used in some
319 # places (like Keras) where the FuncGraph lives longer than the
320 # ConcreteFunction.
321 shared_func_graph=False)
322 return concrete_function
323
324 def _maybe_define_function(self, args, kwargs):
325 """Gets a function for these inputs, defining it if necessary.
326
327 Caller must hold self._lock.
328
329 Args:
330 args: The varargs for the Python function.
331 kwargs: The keyword args for the Python function.
332
333 Returns:
334 A graph function corresponding to the input signature implied by args and
335 kwargs, as well as filtered flattened inputs (only Tensors and Variables)
336 that the object should be called with.
337
338 Raises:
339 ValueError: If inputs are incompatible with the input signature.
340 TypeError: If the function inputs include non-hashable objects
341 RuntimeError: If there's an internal bug (inconsistency) in handling
342 shape relaxation retracing.
343 """
344 args, kwargs, filtered_flat_args = (
345 self._function_spec.canonicalize_function_inputs(args, kwargs))
346
347 if self.input_signature is not None:
348 args = (*self.input_signature, *args[len(self.input_signature):])
349
350 # Get runtime values of captures
351 captures = self._func_captures.get_by_ref_snapshot()
352
353 current_func_context = function_context.make_function_context()
354
355 # cache_key_deletion_observer is useless here. It's based on all captures.
356 # A new cache key will be built later when saving ConcreteFunction because
357 # only active captures should be saved.
358 lookup_func_type, lookup_func_context = (
359 self._function_spec.make_canonicalized_monomorphic_type(
360 args, kwargs, captures))
361 concrete_function = self._function_cache.lookup(current_func_context,
362 lookup_func_type)
363 if concrete_function is not None:
364 return concrete_function, filtered_flat_args
365
366 # Use a timer for graph building only if not already inside a function. This
367 # avoids double counting graph building time for nested functions.
368 with monitoring.MonitoredTimer(
369 _graph_building_time_counter.get_cell()
370 ) if not ops.inside_function() else contextlib.nullcontext():
371 with trace.Trace("tf.function-graph_building"):
372 logging.vlog(
373 1, "Creating new FuncGraph for Python function %r (key: %r, %r)",
374 self._python_function, current_func_context, lookup_func_type)
375 logging.vlog(2, "Python function signature [args: %s] [kwargs: %s]",
376 args, kwargs)
377 ag_status = (
378 ag_ctx.Status.ENABLED
379 if self._autograph else ag_ctx.Status.DISABLED)
380 with ag_ctx.ControlStatusCtx(
381 status=ag_status, options=self._autograph_options):
382 func_graph = func_graph_module.FuncGraph(
383 self._name, capture_by_value=self._capture_by_value)
384 if self.input_signature is None and self._reduce_retracing:
385 target_func_type = self._function_cache.generalize(
386 current_func_context, lookup_func_type)
387 else:
388 target_func_type = lookup_func_type
389 placeholder_mapping = lookup_func_context.get_placeholder_mapping()
390 placeholder_context = trace_type.InternalPlaceholderContext(
391 func_graph, placeholder_mapping)
392 with func_graph.as_default():
393 placeholder_bound_args = target_func_type.placeholder_arguments(
394 placeholder_context)
395 args = placeholder_bound_args.args
396 kwargs = placeholder_bound_args.kwargs
397
398 concrete_function = self._create_concrete_function(
399 args, kwargs, func_graph)
400
401 # TODO(b/263520817): Remove access to private attribute.
402 graph_capture_container = concrete_function.graph.function_captures
403 # Maintain the list of all captures
404 self._func_captures.merge_by_ref_with(graph_capture_container)
405 # Get current active captures snapshot
406 captures = graph_capture_container.get_by_ref_snapshot()
407
408 # Create a cache_key with args and captures
409 traced_func_type = _insert_capture_type(
410 target_func_type, captures, lookup_func_context)
411
412 self._function_cache.add(current_func_context, traced_func_type,
413 concrete_function)
414
415 return concrete_function, filtered_flat_args
416
417
418def class_method_to_instance_method(original_function, instance):
419 """Constructs a new `TracingCompiler` with `self` bound."""
420 weak_instance = weakref.ref(instance)
421
422 # Note: while we could bind to a weakref proxy instead, that causes the
423 # bound method to be unhashable.
424 bound_method = types_lib.MethodType(
425 original_function.python_function,
426 tf_method_target.TfMethodTarget(weak_instance,
427 original_function.python_function))
428
429 # original_function is expected to be either `TracingCompiler` or
430 # def_function.Function
431 assert hasattr(original_function, "_name")
432 assert hasattr(original_function, "_autograph")
433 assert hasattr(original_function, "_function_spec")
434 assert hasattr(original_function, "python_function")
435
436 weak_bound_method_wrapper = None
437
438 def bound_method_wrapper(*args, **kwargs):
439 """Wraps either a dummy MethodType or a converted AutoGraph function."""
440 # __wrapped__ allows AutoGraph to swap in a converted function.
441 strong_bound_method_wrapper = weak_bound_method_wrapper()
442 wrapped_fn = strong_bound_method_wrapper.__wrapped__
443
444 if wrapped_fn is strong_bound_method_wrapper.__original_wrapped__:
445 # If __wrapped__ was not replaced, then call original_function.
446 # TODO(mdan): For better consistency, use the wrapper's call().
447 wrapped_fn = original_function.python_function
448 return wrapped_fn(weak_instance(), *args, **kwargs)
449
450 # If __wrapped__ was replaced, then it is always an unbound function.
451 # However, the replacer is still responsible for attaching self properly.
452 # TODO(mdan): Is it possible to do it here instead?
453 return wrapped_fn(*args, **kwargs)
454
455 weak_bound_method_wrapper = weakref.ref(bound_method_wrapper)
456
457 # pylint: disable=protected-access
458 # We make a dummy MethodType object to generate the correct bound method
459 # signature. The actual call is to a function with a weak reference to
460 # `instance`.
461 instance_func = type(original_function)(
462 tf_decorator.make_decorator(bound_method, bound_method_wrapper),
463 name=original_function._name,
464 autograph=original_function._autograph,
465 input_signature=original_function.input_signature,
466 reduce_retracing=original_function._reduce_retracing,
467 jit_compile=original_function._jit_compile)
468 # pylint: enable=protected-access
469
470 # We wrap the bound method with tf_decorator so inspection works correctly
471 wrapped_instance_func = tf_decorator.make_decorator(bound_method,
472 instance_func)
473 return wrapped_instance_func
474
475
476def _insert_capture_type(original_func_type, captures, type_context):
477 capture_types = collections.OrderedDict()
478 for name, value in captures.items():
479 capture_types[name] = trace_type.from_value(value, type_context)
480 return function_type_lib.FunctionType(
481 original_func_type.parameters.values(), capture_types)