Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/types/core.py: 81%
43 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Core TensorFlow types."""
17import sys
18import textwrap
20from typing import Union
22import numpy as np
24from tensorflow.python.types import doc_typealias
25from tensorflow.python.util.tf_export import tf_export
27# pylint:disable=g-import-not-at-top
28if sys.version_info >= (3, 8):
29 from typing import Protocol
30 from typing import runtime_checkable
31else:
32 from typing_extensions import Protocol
33 from typing_extensions import runtime_checkable
34# pylint:enable=g-import-not-at-top
36# TODO(mdan): Consider adding ABC once the dependence on isinstance is reduced.
37# TODO(mdan): Add type annotations.
40# TODO(b/178822082): Revisit this API when tf.types gets more resource.
41@tf_export("__internal__.types.Tensor", v1=[])
42class Tensor(object):
43 """The base class of all dense Tensor objects.
45 A dense tensor has a static data type (dtype), and may have a static rank and
46 shape. Tensor objects are immutable. Mutable objects may be backed by a Tensor
47 which holds the unique handle that identifies the mutable object.
48 """
50 @property
51 def dtype(self):
52 pass
54 @property
55 def shape(self):
56 pass
59# `ops.EagerTensor` subclasses `Symbol` by way of subclassing `ops.Tensor`;
60# care should be taken when performing `isinstance` checks on `Value`, e.g.:
61#
62# ```
63# if isinstance(core.Symbol) and not isinstance(core.Value):
64# ...
65# ```
66class Symbol(Tensor):
67 """Symbolic "graph" Tensor.
69 These objects represent the output of an op definition and do not carry a
70 value.
71 """
72 pass
75class Value(Tensor):
76 """Tensor that can be associated with a value (aka "eager tensor").
78 These objects represent the (usually future) output of executing an op
79 immediately.
80 """
82 def numpy(self):
83 pass
86@tf_export("types.experimental.Callable", v1=[])
87class Callable:
88 """Base class for TF callables like those created by tf.function.
90 Note: Callables are conceptually very similar to `tf.Operation`: a
91 `tf.Operation` is a kind of callable.
92 """
94 def __call__(self, *args, **kwargs):
95 """Executes this callable.
97 This behaves like a regular op - in eager mode, it immediately starts
98 execution, returning results. In graph mode, it creates ops which return
99 symbolic TensorFlow values (like `tf.Tensor`, `tf.data.Dataset`,
100 etc.). For example, `tf.function` callables typically generate a
101 `tf.raw_ops.PartitionedCall` op, but not always - the
102 exact operations being generated are an internal implementation detail.
104 Args:
105 *args: positional argument for this call
106 **kwargs: keyword arguments for this call
107 Returns:
108 The execution results.
109 """
112@tf_export("types.experimental.ConcreteFunction", v1=[])
113class ConcreteFunction(Callable):
114 """Base class for graph functions.
116 A `ConcreteFunction` encapsulates a single graph function definition and
117 is differentiable under `tf.GradientTape` contexts.
118 """
121# TODO(mdan): Name just `types.Function`, for historic continuity?
122@tf_export("types.experimental.GenericFunction", v1=[])
123class GenericFunction(Callable):
124 """Base class for polymorphic graph functions.
126 Graph functions are Python callable objects that dispatch calls to a
127 TensorFlow graph. Polymorphic graph functions can be backed by multiple TF
128 graphs, and automatically select the appropriate specialization based on the
129 type of input they were called with. They may also create specializations on
130 the fly if necessary, for example by tracing.
132 Also see `tf.function`.
133 """
135 def get_concrete_function(self, *args, **kwargs) -> ConcreteFunction:
136 """Returns a `ConcreteFunction` specialized to input types.
138 The arguments specified by `args` and `kwargs` follow normal function call
139 rules. The returned `ConcreteFunction` has the same set of positional and
140 keyword arguments as `self`, but their types are compatible to the types
141 specified by `args` and `kwargs` (though not neccessarily equal).
143 >>> @tf.function
144 ... def f(x):
145 ... return x
146 >>> f_concrete = f.get_concrete_function(tf.constant(1.0))
147 >>> f_concrete = f.get_concrete_function(x=tf.constant(1.0))
149 Unlike normal calls, `get_concrete_function` allow type specifiers instead
150 of TensorFlow objects, so for example `tf.Tensor`s may be replaced with
151 `tf.TensorSpec`s.
153 >>> @tf.function
154 ... def f(x):
155 ... return x
156 >>> f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))
158 If the function definition allows only one specialization, `args` and
159 `kwargs` may be omitted altogether.
161 >>> @tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
162 ... def f(x):
163 ... return x
164 >>> f_concrete = f.get_concrete_function()
166 The returned `ConcreteFunction` can be called normally:
168 >>> f_concrete(tf.constant(1.0))
169 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
170 >>> f_concrete(x=tf.constant(1.0))
171 <tf.Tensor: shape=(), dtype=float32, numpy=1.0>
173 Args:
174 *args: inputs to specialize on.
175 **kwargs: inputs to specialize on.
177 Returns:
178 A `ConcreteFunction`.
179 """
180 pass
182 def experimental_get_compiler_ir(self, *args, **kwargs):
183 """Returns compiler IR for the compiled function.
185 This API is intended *only* for debugging as there are no guarantees on
186 backwards compatibility of returned IR or the allowed values of `stage`.
188 Args:
189 *args: compilation args supports inputs either: (1) all inputs are
190 TensorSpec or (2) all inputs are tf.Tensor/Python variables.
191 **kwargs: Keyword arguments used for compilation. Same requirement as
192 compiliation args.
194 Returns:
195 Function callable with the following kwargs:
196 - `stage` at which the compiler IR should be serialized. Allowed values
197 are:
198 - `hlo`: HLO output after conversion from TF
199 (https://www.tensorflow.org/xla/operation_semantics).
200 - `hlo_serialized`: Like stage=`hlo`, but the output is a serialized
201 HLO module proto (a bytes object).
202 - `optimized_hlo`: HLO after compiler optimizations.
203 - `optimized_hlo_serialized`: Like stage=`optimized_hlo`, but the
204 output is a serialized HLO module proto (a bytes object).
205 - `optimized_hlo_dot`: optimized HLO in DOT format suitable for
206 Graphviz.
207 - `device_name` can be either None, in which case the preferred device
208 is used for compilation, or a device name. It can be a full device
209 name, or a partial one, e.g., `/device:CPU:0`.
211 For example, for
213 ```python
214 @tf.function(jit_compile=True)
215 def f(x):
216 return x + 1
218 f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')
219 ```
221 the output is:
223 ```
224 HloModule a_inference_f_13__.9
226 ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
227 %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
228 %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
229 %constant.3 = f32[] constant(1)
230 %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
231 %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
232 f32[10,10]{1,0} %broadcast.4)
233 %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
234 %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
235 ROOT %get-tuple-element.8 = f32[10,10]{1,0}
236 get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0
237 }
238 ```
240 Here is another example using tf.TensorSpec inputs:
242 ```python
243 y = tf.Variable(tf.zeros([10, 20], dtype=tf.float32))
245 @tf.function(jit_compile=True)
246 def f(x):
247 return x + y
249 hlo_str = f.experimental_get_compiler_ir(tf.TensorSpec(shape=(10,
250 20)))(stage='hlo')
251 ```
253 The output is:
255 ```
256 HloModule a_inference_f_120__.8,
257 entry_computation_layout={(f32[10,20]{1,0},f32[10,20]{1,0})->f32[10,20]{1,0}}
259 ENTRY %a_inference_f_120__.8 (arg0.1: f32[10,20], arg1.2: f32[10,20]) ->
260 f32[10,20] {
261 %arg0.1 = f32[10,20]{1,0} parameter(0), parameter_replication={false},
262 metadata={op_name="XLA_Args"}
263 %reshape.3 = f32[10,20]{1,0} reshape(f32[10,20]{1,0} %arg0.1)
264 %arg1.2 = f32[10,20]{1,0} parameter(1), parameter_replication={false},
265 metadata={op_name="XLA_Args"}
266 %add.4 = f32[10,20]{1,0} add(f32[10,20]{1,0} %reshape.3, f32[10,20]{1,0}
267 %arg1.2), metadata={op_type="AddV2" op_name="add"
268 source_file="<ipython-input-16-ea04879c1873>" source_line=4}
269 %reshape.5 = f32[10,20]{1,0} reshape(f32[10,20]{1,0} %add.4),
270 metadata={op_name="XLA_Retvals"}
271 %tuple.6 = (f32[10,20]{1,0}) tuple(f32[10,20]{1,0} %reshape.5),
272 metadata={op_name="XLA_Retvals"}
273 ROOT %get-tuple-element.7 = f32[10,20]{1,0}
274 get-tuple-element((f32[10,20]{1,0}) %tuple.6), index=0,
275 metadata={op_name="XLA_Retvals"}
276 }
277 ```
279 The HLO module accepts a flat list of inputs. To retrieve the order
280 of these inputs signatures, users can call the
281 `concrete_fn.structured_input_signature` and `concrete_fn.captured_inputs`:
283 ```python
284 # Use concrete_fn to get the hlo_module flat_args.
285 concrete_fn = f.get_concrete_function(tf.TensorSpec(shape=(10, 20)))
286 flat_args = list(
287 tf.nest.flatten(concrete_fn.structured_input_signature)
288 ) + concrete_fn.captured_inputs
289 ```
291 Raises:
292 ValueError:
293 (1) If an invalid `stage` is selected
294 (2) or if applied to a function which is not compiled
295 (`jit_compile=True` is not set).
296 (3) or if input shapes are not fully defined for tf.TensorSpec inputs
297 TypeError: When called with input in graph mode.
298 """
299 pass
302@runtime_checkable
303class TensorProtocol(Protocol):
304 """Protocol type for objects that can be converted to Tensor."""
306 def __tf_tensor__(self, dtype=None, name=None):
307 """Converts this object to a Tensor.
309 Args:
310 dtype: data type for the returned Tensor
311 name: a name for the operations which create the Tensor
312 Returns:
313 A Tensor.
314 """
315 pass
318# TODO(rahulkamat): Add missing types that are convertible to Tensor.
319TensorLike = Union[Tensor, TensorProtocol, int, float, bool, str, bytes,
320 complex, tuple, list, np.ndarray, np.generic]
321doc_typealias.document(
322 obj=TensorLike,
323 doc=textwrap.dedent("""\
324 Union of all types that can be converted to a `tf.Tensor` by `tf.convert_to_tensor`.
326 This definition may be used in user code. Additional types may be added
327 in the future as more input types are supported.
329 Example:
331 ```
332 def foo(x: TensorLike):
333 pass
334 ```
336 This definition passes static type verification for:
338 ```
339 foo(tf.constant([1, 2, 3]))
340 foo([1, 2, 3])
341 foo(np.array([1, 2, 3]))
342 ```
343 """),
344)
345tf_export("types.experimental.TensorLike").export_constant(
346 __name__, "TensorLike")