Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/ops/dataset_autograph.py: 33%

102 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Autograph specifc overrides for dataset_ops.""" 

16from tensorflow.python.autograph.operators import control_flow 

17from tensorflow.python.autograph.operators import py_builtins 

18from tensorflow.python.data.experimental.ops import take_while_ops 

19from tensorflow.python.data.ops import dataset_ops 

20from tensorflow.python.framework import constant_op 

21from tensorflow.python.framework import dtypes 

22from tensorflow.python.framework import ops 

23from tensorflow.python.ops import array_ops 

24from tensorflow.python.ops import cond 

25from tensorflow.python.ops import control_flow_assert 

26from tensorflow.python.ops import gen_string_ops 

27from tensorflow.python.ops import math_ops 

28from tensorflow.python.util import nest 

29 

30 

31def _general_purpose_scan(ds, init_state, body): 

32 """Variant of Dataset.scan with semantics of general-purpose computation.""" 

33 # Datasets are typically intended for data preprocessing. However, in 

34 # autograph loops they usually appear as general-purpose computations (for 

35 # example, a custom training loop). These two use cases require significantly 

36 # different optimization policies, the most important of which is the device 

37 # placement. The flag override for use_default_device below instructs the 

38 # runtime to treat the computation as general-purpose, rather than data 

39 # preprocessing. 

40 

41 # Loaded lazily due to a circular dependency (dataset_ops -> 

42 # scan_op -> dataset_ops). 

43 # pylint: disable=g-import-not-at-top,protected-access 

44 from tensorflow.python.data.ops import scan_op 

45 return scan_op._ScanDataset(ds, init_state, body, use_default_device=False) 

46 # pylint: enable=g-import-not-at-top,protected-access 

47 

48 

49def _tf_ag_dataset_for_stmt( 

50 ds, extra_test, body, get_state, set_state, symbol_names, opts 

51): 

52 """Overload of _dataset_for_stmt with early stopping. See for_stmt.""" 

53 # Note: This is easier to follow with the insight that the computations in 

54 # a dataset pipeline are transposed (aka fused). 

55 # For example, given a pipeline input -> scan -> take_while -> reduce, 

56 # and a dataset with input [1, 2, 3], the computations occur in the following 

57 # order: 

58 # reduce(take_while(scan(1))) 

59 # reduce(take_while(scan(2))) 

60 # reduce(take_while(scan(3))) 

61 

62 init_vars = get_state() 

63 control_flow.verify_loop_init_vars(init_vars, symbol_names) 

64 

65 # Workaround for Dataset.reduce not allowing empty state tensors - create 

66 # a dummy state variable that remains unused. 

67 # TODO(mdan): reduce should allow and match empty structures. 

68 if not init_vars: 

69 init_vars = (constant_op.constant(0),) 

70 symbol_names = ("<internal dummy>",) 

71 

72 def dummy_set_state(unused_dummy): 

73 pass 

74 

75 def dummy_get_state(): 

76 return (constant_op.constant(0),) 

77 

78 get_state, set_state = dummy_get_state, dummy_set_state 

79 

80 def scan_body(scan_state, scan_inputs): 

81 """Main body of the Dataset.scan.""" 

82 loop_vars, iterate = scan_state, scan_inputs 

83 set_state(loop_vars) 

84 

85 def main_path(): 

86 body(iterate) 

87 new_loop_vars = get_state() 

88 control_flow.verify_tf_loop_vars( 

89 init_vars, 

90 loop_vars, 

91 new_loop_vars, 

92 symbol_names, 

93 opts, 

94 check_shapes=False) 

95 return new_loop_vars 

96 

97 if extra_test is not None: 

98 extra_cond = extra_test() 

99 new_loop_vars = cond.cond(extra_cond, main_path, 

100 lambda: loop_vars) 

101 else: 

102 # TODO(mdan): the optimizer should be able to remove an invariant cond? 

103 extra_cond = (constant_op.constant(True),) # dummy value, unused 

104 new_loop_vars = main_path() 

105 

106 scan_outputs = new_loop_vars, extra_cond 

107 new_scan_state = new_loop_vars 

108 return new_scan_state, scan_outputs 

109 

110 def take_while_predicate(unused_loop_vars, extra_cond): 

111 return extra_cond 

112 

113 def reduce_body(unused_reduce_state, scan_outputs): 

114 output_loop_vars, unused_extra_cond = scan_outputs 

115 new_reduce_state = output_loop_vars 

116 return new_reduce_state 

117 

118 ds = _general_purpose_scan(ds, init_vars, scan_body) 

119 if extra_test is not None: 

120 ds = ds.apply(take_while_ops.take_while(take_while_predicate)) 

121 final_loop_vars = ds.reduce(init_vars, reduce_body) 

122 set_state(final_loop_vars) 

123 

124 

125def _tf_ag_dataset_abs(ds): 

126 specs = nest.flatten(ds.element_spec) 

127 if len(specs) == 1: 

128 return ds.map(math_ops.abs, num_parallel_calls=dataset_ops.AUTOTUNE) 

129 return ds.map( 

130 lambda *e: nest.map_structure(math_ops.abs, e), 

131 num_parallel_calls=dataset_ops.AUTOTUNE) 

132 

133 

134def _tf_ag_dataset_len(s): 

135 """Autograph override of the builtin len for dataset_ops.DataSetV2.""" 

136 l = s.cardinality() 

137 msg = gen_string_ops.string_join([ 

138 "len requires dataset with definitive cardinality, got ", 

139 gen_string_ops.as_string(l), 

140 ]) 

141 # TODO(yongtang): UNKNOWN is treated as an error. 

142 # In case there are more UNKNOWN cases for dataset, we could 

143 # use dataset.reduce() to find out the length (in an expensive way). 

144 with ops.control_dependencies([ 

145 control_flow_assert.Assert( 

146 math_ops.logical_and( 

147 math_ops.not_equal(l, dataset_ops.INFINITE), 

148 math_ops.not_equal(l, dataset_ops.UNKNOWN)), [msg]) 

149 ]): 

150 l = array_ops.identity(l) 

151 

152 return l 

153 

154 

155def _tf_ag_dataset_enumerate(ds, start=0): 

156 return ds.enumerate(start) 

157 

158 

159def _tf_ag_dataset_zip(*iterables, strict=False): 

160 if strict: 

161 raise ValueError("strict zip not supported by Dataset") 

162 return dataset_ops.DatasetV2.zip(iterables) 

163 

164 

165def _tf_ag_dataset_map(fn, *iterables): 

166 return dataset_ops.DatasetV2.zip(iterables).map(fn) 

167 

168 

169def _tf_ag_dataset_filter(fn, iterable): 

170 return iterable.filter(fn) 

171 

172 

173# any() operation is essentially a "if first True element exist". 

174# For that it could be translated to `filter(True)` to filter out 

175# only `True` element, and then `take(1)`. This works in tf.data 

176# as tf.data's filter+take is done in pipeline so it will stop 

177# as soon as `take(1)` returns. 

178def _tf_ag_dataset_any(iterable): 

179 # check and make sure iterable.element_spec only consists of one 

180 # element of tf.bool. 

181 specs = nest.flatten(iterable.element_spec) 

182 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 

183 raise ValueError('in graph mode, the "any" builtin only supports datasets ' 

184 'that return bool scalars; got: {}'.format( 

185 iterable.element_spec)) 

186 ds = iterable.filter(lambda x: x) 

187 ds = ds.take(1) 

188 ds = ds.reduce(constant_op.constant(False, dtype=dtypes.bool), lambda _, y: y) 

189 return ds 

190 

191 

192# all() operation is similar to any() and could be translated 

193# to `filter(False)` then `take(1)`, and check if `False` exists. 

194def _tf_ag_dataset_all(iterable): 

195 # check and make sure iterable.element_spec only consists of one 

196 # element of tf.bool. 

197 specs = nest.flatten(iterable.element_spec) 

198 if len(specs) != 1 or specs[0].dtype != dtypes.bool: 

199 raise ValueError('in graph mode, the "all" builtin only supports datasets ' 

200 'that return bool scalars; got: {}'.format( 

201 iterable.element_spec)) 

202 ds = iterable.filter(math_ops.logical_not) 

203 ds = ds.take(1) 

204 ds = ds.reduce(constant_op.constant(True, dtype=dtypes.bool), lambda _, y: y) 

205 return ds 

206 

207 

208def register_overrides(): 

209 """Registers the autograph specific overrides for dataset_ops.""" 

210 control_flow.for_loop_registry.register( 

211 dataset_ops.DatasetV2, _tf_ag_dataset_for_stmt 

212 ) 

213 py_builtins.abs_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_abs) 

214 py_builtins.len_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_len) 

215 py_builtins.enumerate_registry.register( 

216 dataset_ops.DatasetV2, _tf_ag_dataset_enumerate 

217 ) 

218 py_builtins.zip_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_zip) 

219 py_builtins.map_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_map) 

220 py_builtins.filter_registry.register( 

221 dataset_ops.DatasetV2, _tf_ag_dataset_filter 

222 ) 

223 py_builtins.any_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_any) 

224 py_builtins.all_registry.register(dataset_ops.DatasetV2, _tf_ag_dataset_all)