Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/types/trace.py: 84%
43 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""tf.function tracing types.
17See `core.GenericFunction` and `core.ConcreteFunction`.
19`GenericFunction` assigns types to call arguments, forming a signature.
20Function signatures are used to match arguments to `ConcreteFunction`s.
21For example, when a new `ConcreteFunction` is traced, it is assigned a
22the signature of the arguments it was traced with. Subsequent call arguments
23which match its signature will be dispatched to the same `ConcreteFunction`.
24If no `ConcreteFunction` with a matching signature is found, a new one may be
25traced (a process known as retracing).
26"""
28import abc
29from typing import Any, List, Optional, Sequence
31from typing_extensions import Protocol
32from typing_extensions import runtime_checkable
34from tensorflow.python.types import core
35from tensorflow.python.util.tf_export import tf_export
36from tensorflow.tools.docs import doc_controls
39@tf_export("types.experimental.TraceType", v1=[])
40class TraceType(metaclass=abc.ABCMeta):
41 """Represents the type of object(s) for tf.function tracing purposes.
43 `TraceType` is an abstract class that other classes might inherit from to
44 provide information regarding associated class(es) for the purposes of
45 tf.function tracing. The typing logic provided through this mechanism will be
46 used to make decisions regarding usage of cached concrete functions and
47 retracing.
49 For example, if we have the following tf.function and classes:
50 ```python
51 @tf.function
52 def get_mixed_flavor(fruit_a, fruit_b):
53 return fruit_a.flavor + fruit_b.flavor
55 class Fruit:
56 flavor = tf.constant([0, 0])
58 class Apple(Fruit):
59 flavor = tf.constant([1, 2])
61 class Mango(Fruit):
62 flavor = tf.constant([3, 4])
63 ```
65 tf.function does not know when to re-use an existing concrete function in
66 regards to the `Fruit` class so naively it retraces for every new instance.
67 ```python
68 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
69 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function again
70 ```
72 However, we, as the designers of the `Fruit` class, know that each subclass
73 has a fixed flavor and we can reuse an existing traced concrete function if
74 it was the same subclass. Avoiding such unnecessary tracing of concrete
75 functions can have significant performance benefits.
77 ```python
78 class FruitTraceType(tf.types.experimental.TraceType):
79 def __init__(self, fruit):
80 self.fruit_type = type(fruit)
81 self.fruit_value = fruit
83 def is_subtype_of(self, other):
84 return (type(other) is FruitTraceType and
85 self.fruit_type is other.fruit_type)
87 def most_specific_common_supertype(self, others):
88 return self if all(self == other for other in others) else None
90 def placeholder_value(self, placeholder_context=None):
91 return self.fruit_value
93 class Fruit:
95 def __tf_tracing_type__(self, context):
96 return FruitTraceType(self)
97 ```
99 Now if we try calling it again:
100 ```python
101 get_mixed_flavor(Apple(), Mango()) # Traces a new concrete function
102 get_mixed_flavor(Apple(), Mango()) # Re-uses the traced concrete function
103 ```
104 """
106 @abc.abstractmethod
107 def is_subtype_of(self, other: "TraceType") -> bool:
108 """Returns True if `self` is a subtype of `other`.
110 For example, `tf.function` uses subtyping for dispatch:
111 if `a.is_subtype_of(b)` is True, then an argument of `TraceType`
112 `a` can be used as argument to a `ConcreteFunction` traced with an
113 a `TraceType` `b`.
115 Args:
116 other: A TraceType object to be compared against.
118 Example:
120 ```python
121 class Dimension(TraceType):
122 def __init__(self, value: Optional[int]):
123 self.value = value
125 def is_subtype_of(self, other):
126 # Either the value is the same or other has a generalized value that
127 # can represent any specific ones.
128 return (self.value == other.value) or (other.value is None)
129 ```
130 """
132 @abc.abstractmethod
133 def most_specific_common_supertype(
134 self, others: Sequence["TraceType"]) -> Optional["TraceType"]:
135 """Returns the most specific supertype of `self` and `others`, if exists.
137 The returned `TraceType` is a supertype of `self` and `others`, that is,
138 they are all subtypes (see `is_subtype_of`) of it.
139 It is also most specific, that is, there it has no subtype that is also
140 a common supertype of `self` and `others`.
142 If `self` and `others` have no common supertype, this returns `None`.
144 Args:
145 others: A sequence of TraceTypes.
147 Example:
148 ```python
149 class Dimension(TraceType):
150 def __init__(self, value: Optional[int]):
151 self.value = value
153 def most_specific_common_supertype(self, other):
154 # Either the value is the same or other has a generalized value that
155 # can represent any specific ones.
156 if self.value == other.value:
157 return self.value
158 else:
159 return Dimension(None)
160 ```
161 """
163 @abc.abstractmethod
164 def placeholder_value(self, placeholder_context) -> Any:
165 """Creates a placeholder for tracing.
167 tf.funcion traces with the placeholder value rather than the actual value.
168 For example, a placeholder value can represent multiple different
169 actual values. This means that the trace generated with that placeholder
170 value is more general and reusable which saves expensive retracing.
172 Args:
173 placeholder_context: A `PlaceholderContext` container for context
174 information when creating a placeholder value.
176 For the `Fruit` example shared above, implementing:
178 ```python
179 class FruitTraceType:
180 def placeholder_value(self, placeholder_context):
181 return Fruit()
182 ```
183 instructs tf.function to trace with the `Fruit()` objects
184 instead of the actual `Apple()` and `Mango()` objects when it receives a
185 call to `get_mixed_flavor(Apple(), Mango())`. For example, Tensor arguments
186 are replaced with Tensors of similar shape and dtype, output from
187 a tf.Placeholder op.
189 More generally, placeholder values are the arguments of a tf.function,
190 as seen from the function's body:
191 ```python
192 @tf.function
193 def foo(x):
194 # Here `x` is be the placeholder value
195 ...
197 foo(x) # Here `x` is the actual value
198 ```
199 """
201 @doc_controls.do_not_doc_inheritable
202 def _to_tensors(self, value) -> List[core.Tensor]:
203 """Breaks down a value of this type into Tensors.
205 Args:
206 value: An input value belonging to this TraceType
208 Returns:
209 List of Tensors.
210 """
211 del value
212 return []
214 @doc_controls.do_not_doc_inheritable
215 def _flatten(self) -> List["TraceType"]:
216 """Returns a list of TensorSpecs corresponding to `_to_tensors` values."""
217 return []
219 @doc_controls.do_not_doc_inheritable
220 def _cast(self, value, casting_context) -> Any: # pylint:disable=unused-argument
221 """Cast value to this type.
223 Args:
224 value: An input value belonging to this TraceType.
225 casting_context: A context reserved for future usage such as to determine
226 casting rules.
228 Returns:
229 The value casted to this TraceType.
231 Raises:
232 AssertionError: When _cast is not overloaded in subclass,
233 the value is returned directly, and it should be the same to
234 self.placeholder_value().
235 """
236 assert value == self.placeholder_value(
237 PlaceholderContext()), f"Can not cast {value!r} to type {self!r}"
238 return value
240 @abc.abstractmethod
241 def __hash__(self) -> int:
242 pass
244 @abc.abstractmethod
245 def __eq__(self, other) -> bool:
246 pass
249@tf_export("types.experimental.TracingContext", v1=[])
250class TracingContext(metaclass=abc.ABCMeta):
251 """Contains information scoped to the tracing of multiple objects.
253 `TracingContext` is a container class for flags and variables that have
254 any kind of influence on the tracing behaviour of the class implementing
255 the __tf_tracing_type__. This context will be shared across all
256 __tf_tracing_type__ calls while constructing the TraceType for a particular
257 set of objects.
258 """
261class PlaceholderContext():
262 """Contains context information for generating placeholders within a scope."""
265class CastContext():
266 """Contains context info and rules for casting values to a TypeSpec."""
269@runtime_checkable
270class SupportsTracingProtocol(Protocol):
271 """A protocol allowing custom classes to control tf.function retracing."""
273 @doc_controls.doc_private
274 @abc.abstractmethod
275 def __tf_tracing_type__(self, context: TracingContext) -> TraceType:
276 """Returns the tracing type of this object.
278 The tracing type is used to build the signature of a tf.function
279 when traced, and to match arguments with existing signatures.
280 When a Function object is called, tf.function looks at the tracing type
281 of the call arguments. If an existing signature of matching type exists,
282 it will be used. Otherwise, a new function is traced, and its signature
283 will use the tracing type of the call arguments.
285 Args:
286 context: a context object created for each function call for tracking
287 information about the call arguments as a whole
288 Returns:
289 The tracing type of this object.
290 """
292# TODO(b/219556836): Direct tf_export decorator adds non-method members to the
293# Protocol which breaks @runtime_checkable since it does not support them.
294tf_export(
295 "types.experimental.SupportsTracingProtocol",
296 v1=[]).export_constant(__name__, "SupportsTracingProtocol")