Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/composite_tensor_ops.py: 38%
37 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""Operations for ExtensionTypes (aka Composite Tensors)."""
17from tensorflow.core.protobuf import composite_tensor_variant_pb2
18from tensorflow.python.framework import composite_tensor
19from tensorflow.python.framework import dtypes
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import gen_composite_tensor_ops
22from tensorflow.python.saved_model import nested_structure_coder
23from tensorflow.python.util import nest
26def composite_tensor_to_variants(value, type_spec=None, name=None):
27 """Encodes `value` as a scalar variant tensor.
29 Args:
30 value: The `ExtensionType` value to encode.
31 type_spec: Information about the value's type that should be included in the
32 encoding.
33 name: Optional name for the operation.
35 Returns:
36 A Tensor with shape=`()` and dtype=`tf.variant`.
38 Raises:
39 ValueError: If `type_spec` is not compatible with `value`.
40 """
41 if not isinstance(value, composite_tensor.CompositeTensor):
42 raise TypeError("Expected `value` to be a CompositeTensor. "
43 f"Received {type(value)}.")
45 if type_spec is None:
46 type_spec = value._type_spec # pylint: disable=protected-access
47 if not type_spec.is_compatible_with(value):
48 raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` "
49 f"{value!r}.")
50 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
51 metadata.type_spec_proto.CopyFrom(
52 nested_structure_coder.encode_structure(type_spec).type_spec_value)
54 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
55 components=nest.flatten(value, expand_composites=True),
56 metadata=metadata.SerializeToString(),
57 name=name)
60def composite_tensor_from_variant(encoded, type_spec, name=None):
61 """Returns the `ExtensionType` value encoded by a variant scalar tensor.
63 Args:
64 encoded: A Tensor returned by `composite_tensor_to_variants`.
65 type_spec: The `TypeSpec` of the original value. This is used to determine
66 the number and types of the component tensors that comprise the decoded
67 value. Must be compatible with the `TypeSpec` serilized in `encoded`.
68 name: Optional name for the operation.
70 Returns:
71 An `ExtensionType` value that is compatible with `TypeSpec`.
73 Raises:
74 TypeError: If `encoded` is not a Tensor with dtype=variant.
75 InvalidArgumentError: If `encoded` is not compatible with `type_spec`.
76 """
77 if not isinstance(encoded, ops.Tensor):
78 raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.")
79 if encoded.dtype != dtypes.variant:
80 raise TypeError("Expected `encoded` to have dtype=variant, got "
81 f"{encoded!r}.")
82 encoded.shape.assert_is_compatible_with(())
84 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata()
85 metadata.type_spec_proto.CopyFrom(
86 nested_structure_coder.encode_structure(type_spec).type_spec_value)
88 component_dtypes = [
89 t.dtype for t in nest.flatten(type_spec, expand_composites=True)
90 ]
92 components = gen_composite_tensor_ops.CompositeTensorVariantToComponents(
93 encoded=encoded,
94 metadata=metadata.SerializeToString(),
95 Tcomponents=component_dtypes,
96 name=name)
97 return nest.pack_sequence_as(type_spec, components, expand_composites=True)
100@ops.RegisterGradient("CompositeTensorVariantFromComponents")
101def _composite_tensor_to_variants_grad(op, grad):
102 return gen_composite_tensor_ops.CompositeTensorVariantToComponents(
103 encoded=grad,
104 metadata=op.get_attr("metadata"),
105 Tcomponents=op.get_attr("Tcomponents"))
108@ops.RegisterGradient("CompositeTensorVariantToComponents")
109def _composite_tensor_from_variant_grad(op, *grad):
110 assert len(grad) == len(op.outputs)
111 # `components` is `op.outputs`, but with any tensors for which we're
112 # taking the gradient replaced by the corresponding value from `grad`.
113 components = [
114 op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad))
115 ]
116 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents(
117 components=components, metadata=op.get_attr("metadata"))