Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/composite_tensor_ops.py: 38%

37 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Operations for ExtensionTypes (aka Composite Tensors).""" 

16 

17from tensorflow.core.protobuf import composite_tensor_variant_pb2 

18from tensorflow.python.framework import composite_tensor 

19from tensorflow.python.framework import dtypes 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import gen_composite_tensor_ops 

22from tensorflow.python.saved_model import nested_structure_coder 

23from tensorflow.python.util import nest 

24 

25 

26def composite_tensor_to_variants(value, type_spec=None, name=None): 

27 """Encodes `value` as a scalar variant tensor. 

28 

29 Args: 

30 value: The `ExtensionType` value to encode. 

31 type_spec: Information about the value's type that should be included in the 

32 encoding. 

33 name: Optional name for the operation. 

34 

35 Returns: 

36 A Tensor with shape=`()` and dtype=`tf.variant`. 

37 

38 Raises: 

39 ValueError: If `type_spec` is not compatible with `value`. 

40 """ 

41 if not isinstance(value, composite_tensor.CompositeTensor): 

42 raise TypeError("Expected `value` to be a CompositeTensor. " 

43 f"Received {type(value)}.") 

44 

45 if type_spec is None: 

46 type_spec = value._type_spec # pylint: disable=protected-access 

47 if not type_spec.is_compatible_with(value): 

48 raise ValueError(f"`type_spec` {type_spec} is not compatible with `value` " 

49 f"{value!r}.") 

50 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 

51 metadata.type_spec_proto.CopyFrom( 

52 nested_structure_coder.encode_structure(type_spec).type_spec_value) 

53 

54 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 

55 components=nest.flatten(value, expand_composites=True), 

56 metadata=metadata.SerializeToString(), 

57 name=name) 

58 

59 

60def composite_tensor_from_variant(encoded, type_spec, name=None): 

61 """Returns the `ExtensionType` value encoded by a variant scalar tensor. 

62 

63 Args: 

64 encoded: A Tensor returned by `composite_tensor_to_variants`. 

65 type_spec: The `TypeSpec` of the original value. This is used to determine 

66 the number and types of the component tensors that comprise the decoded 

67 value. Must be compatible with the `TypeSpec` serilized in `encoded`. 

68 name: Optional name for the operation. 

69 

70 Returns: 

71 An `ExtensionType` value that is compatible with `TypeSpec`. 

72 

73 Raises: 

74 TypeError: If `encoded` is not a Tensor with dtype=variant. 

75 InvalidArgumentError: If `encoded` is not compatible with `type_spec`. 

76 """ 

77 if not isinstance(encoded, ops.Tensor): 

78 raise TypeError(f"Expected `encoded` to be a Tensor, got {encoded!r}.") 

79 if encoded.dtype != dtypes.variant: 

80 raise TypeError("Expected `encoded` to have dtype=variant, got " 

81 f"{encoded!r}.") 

82 encoded.shape.assert_is_compatible_with(()) 

83 

84 metadata = composite_tensor_variant_pb2.CompositeTensorVariantMetadata() 

85 metadata.type_spec_proto.CopyFrom( 

86 nested_structure_coder.encode_structure(type_spec).type_spec_value) 

87 

88 component_dtypes = [ 

89 t.dtype for t in nest.flatten(type_spec, expand_composites=True) 

90 ] 

91 

92 components = gen_composite_tensor_ops.CompositeTensorVariantToComponents( 

93 encoded=encoded, 

94 metadata=metadata.SerializeToString(), 

95 Tcomponents=component_dtypes, 

96 name=name) 

97 return nest.pack_sequence_as(type_spec, components, expand_composites=True) 

98 

99 

100@ops.RegisterGradient("CompositeTensorVariantFromComponents") 

101def _composite_tensor_to_variants_grad(op, grad): 

102 return gen_composite_tensor_ops.CompositeTensorVariantToComponents( 

103 encoded=grad, 

104 metadata=op.get_attr("metadata"), 

105 Tcomponents=op.get_attr("Tcomponents")) 

106 

107 

108@ops.RegisterGradient("CompositeTensorVariantToComponents") 

109def _composite_tensor_from_variant_grad(op, *grad): 

110 assert len(grad) == len(op.outputs) 

111 # `components` is `op.outputs`, but with any tensors for which we're 

112 # taking the gradient replaced by the corresponding value from `grad`. 

113 components = [ 

114 op.outputs[i] if grad[i] is None else grad[i] for i in range(len(grad)) 

115 ] 

116 return gen_composite_tensor_ops.CompositeTensorVariantFromComponents( 

117 components=components, metadata=op.get_attr("metadata"))