Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/iterator_autograph.py: 46%

35 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Autograph specifc overrides for tf.data.ops.""" 

16import functools 

17 

18import numpy as np 

19 

20from tensorflow.python.autograph.operators import control_flow 

21from tensorflow.python.autograph.operators import py_builtins 

22from tensorflow.python.data.ops import iterator_ops 

23from tensorflow.python.framework import tensor_conversion 

24from tensorflow.python.framework import tensor_spec 

25from tensorflow.python.ops import cond 

26from tensorflow.python.util import nest 

27 

28 

29# TODO(mdan): These checks should be easier. Fix the nest API. 

30def _verify_spec_compatible(input_name, spec_name, input_, spec): 

31 """Verifies that a symbol has a type compatible vith a given spec. 

32 

33 Here, compatibility is viewed in the general TensorFlow sense: that the dtypes 

34 are the same after implicit conversion, if both are tensors. 

35 

36 This verifier ensures consistent treatment of types across AutoGraph. 

37 

38 Args: 

39 input_name: A name to use for `input_` in error messages. 

40 spec_name: A name to use for `spec` in error messages. 

41 input_: Any, value to verify. 

42 spec: TypeSpec that `input_` must be compatible with. 

43 

44 Raises: 

45 ValueError if the two types have been determined not to be compatible. 

46 """ 

47 assert isinstance(spec, tensor_spec.TensorSpec) 

48 if input is None: 

49 # TODO(mdan): raise from None when switching to Py3. 

50 raise ValueError("{} cannot be None".format(input_name)) 

51 

52 # TODO(mdan): Use TensorCompatible when ready. 

53 if isinstance(input_, (bool, int, float, str, np.ndarray)): 

54 input_ = tensor_conversion.convert_to_tensor_v2(input_) 

55 

56 input_dtype = getattr(input_, "dtype", None) 

57 

58 if input_dtype != spec.dtype: 

59 input_dtype_str = "no dtype" if input_dtype is None else str(input_dtype) 

60 

61 raise TypeError( 

62 "{} must have the same dtype as {}. Expected {}, got {}".format( 

63 input_name, spec_name, spec.dtype, input_dtype_str 

64 ) 

65 ) 

66 

67 

68def _verify_structure_compatible(input_name, spec_name, input_, spec): 

69 """Verifies that possibly-structured symbol has types compatible vith another. 

70 

71 See _verify_spec_compatible for a more concrete meaning of "compatible". 

72 Unspec _verify_spec_compatible, which handles singular Tensor-spec objects, 

73 verify_structures_compatible can process structures recognized by tf.nest. 

74 

75 Args: 

76 input_name: A name to use for `input_` in error messages. 

77 spec_name: A name to use for `spec` in error messages. 

78 input_: Any, value to verify. May, but doesn't need to, be a structure. 

79 spec: Any, value that `input_` must be compatible with. May, but doesn't 

80 need to, be a structure. 

81 

82 Raises: 

83 ValueError if the two types have been determined not to be compatible. 

84 """ 

85 try: 

86 nest.assert_same_structure(input_, spec, expand_composites=True) 

87 except (ValueError, TypeError) as e: 

88 raise TypeError( 

89 "{} must have the same element structure as {}.\n\n{}".format( 

90 input_name, spec_name, str(e) 

91 ) 

92 ) from e 

93 

94 nest.map_structure( 

95 functools.partial(_verify_spec_compatible, input_name, spec_name), input_, 

96 spec) 

97 

98 

99def _next_tf_iterator(iterator, default=py_builtins.UNSPECIFIED): 

100 if default is py_builtins.UNSPECIFIED: 

101 # Without a default, fall back to the "normal" behavior which raises 

102 # a runtime exception. 

103 return next(iterator) 

104 opt_iterate = iterator.get_next_as_optional() 

105 _verify_structure_compatible( 

106 "the default argument", "the iterate", default, iterator.element_spec 

107 ) 

108 return cond.cond( 

109 opt_iterate.has_value(), opt_iterate.get_value, lambda: default 

110 ) 

111 

112 

113def register_overrides(): 

114 py_builtins.next_registry.register( 

115 iterator_ops.OwnedIterator, _next_tf_iterator 

116 ) 

117 control_flow.for_loop_registry.register( 

118 iterator_ops.OwnedIterator, control_flow._tf_iterator_for_stmt # pylint: disable=protected-access 

119 )