Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/iterator_ops.py: 38%

69 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Iterator ops.""" 

16 

17from tensorflow.python.checkpoint import checkpoint_management 

18from tensorflow.python.data.ops import iterator_ops 

19from tensorflow.python.data.ops import options as options_lib 

20from tensorflow.python.framework import ops 

21from tensorflow.python.training import basic_session_run_hooks 

22from tensorflow.python.training import saver as saver_lib 

23from tensorflow.python.training import session_run_hook 

24from tensorflow.python.util import deprecation 

25from tensorflow.python.util.tf_export import tf_export 

26 

27 

28def _convert_external_state_policy_to_enum(external_state_policy): 

29 if isinstance(external_state_policy, options_lib.ExternalStatePolicy): 

30 return external_state_policy 

31 if external_state_policy == "warn": 

32 return options_lib.ExternalStatePolicy.WARN 

33 if external_state_policy == "ignore": 

34 return options_lib.ExternalStatePolicy.IGNORE 

35 if external_state_policy == "fail": 

36 return options_lib.ExternalStatePolicy.FAIL 

37 raise ValueError( 

38 f"Invalid `ExternalStatePolicy.` Supported values include 'warn', " 

39 f"'ignore', and 'fail.' Received {external_state_policy}." 

40 ) 

41 

42 

43@tf_export("data.experimental.make_saveable_from_iterator") 

44@deprecation.deprecated( 

45 None, "`make_saveable_from_iterator` is intended for use in TF1 with " 

46 "`tf.compat.v1.Saver`. In TF2, use `tf.train.Checkpoint` instead.") 

47def make_saveable_from_iterator(iterator, external_state_policy=None): 

48 """Returns a SaveableObject for saving/restoring iterator state using Saver. 

49 

50 Args: 

51 iterator: Iterator. 

52 external_state_policy: A string that identifies how to handle input 

53 pipelines that depend on external state. Possible values are 

54 'ignore': The external state is silently ignored. 

55 'warn': The external state is ignored, logging a warning. 

56 'fail': The operation fails upon encountering external state. 

57 By default we set it to 'fail'. 

58 

59 Returns: 

60 A SaveableObject for saving/restoring iterator state using Saver. 

61 

62 Raises: 

63 ValueError: If iterator does not support checkpointing. 

64 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 

65 'fail'. 

66 

67 For example: 

68 

69 ```python 

70 with tf.Graph().as_default(): 

71 ds = tf.data.Dataset.range(10) 

72 iterator = ds.make_initializable_iterator() 

73 # Build the iterator SaveableObject. 

74 saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator) 

75 # Add the SaveableObject to the SAVEABLE_OBJECTS collection so 

76 # it can be automatically saved using Saver. 

77 tf.compat.v1.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) 

78 saver = tf.compat.v1.train.Saver() 

79 

80 while continue_training: 

81 ... Perform training ... 

82 if should_save_checkpoint: 

83 saver.save() 

84 ``` 

85 

86 Note: When restoring the iterator, the existing iterator state is completely 

87 discarded. This means that any changes you may have made to the Dataset 

88 graph will be discarded as well! This includes the new Dataset graph 

89 that you may have built during validation. So, while running validation, 

90 make sure to run the initializer for the validation input pipeline after 

91 restoring the checkpoint. 

92 

93 Note: Not all iterators support checkpointing yet. Attempting to save the 

94 state of an unsupported iterator will throw an error. 

95 """ 

96 if external_state_policy is None: 

97 external_state_policy = "fail" 

98 policy_enum = _convert_external_state_policy_to_enum(external_state_policy) 

99 return iterator_ops._IteratorSaveable( # pylint: disable=protected-access 

100 iterator._iterator_resource, # pylint: disable=protected-access 

101 iterator._iterator_resource.name, # pylint: disable=protected-access 

102 external_state_policy=policy_enum) 

103 

104 

105@tf_export("data.experimental.CheckpointInputPipelineHook") 

106class CheckpointInputPipelineHook(session_run_hook.SessionRunHook): 

107 """Checkpoints input pipeline state every N steps or seconds. 

108 

109 This hook saves the state of the iterators in the `Graph` so that when 

110 training is resumed the input pipeline continues from where it left off. 

111 This could potentially avoid overfitting in certain pipelines where the 

112 number of training steps per eval are small compared to the dataset 

113 size or if the training pipeline is pre-empted. 

114 

115 Differences from `CheckpointSaverHook`: 

116 1. Saves only the input pipelines in the "iterators" collection and not the 

117 global variables or other saveable objects. 

118 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary. 

119 

120 Example of checkpointing the training pipeline: 

121 

122 ```python 

123 est = tf.estimator.Estimator(model_fn) 

124 while True: 

125 est.train( 

126 train_input_fn, 

127 hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)], 

128 steps=train_steps_per_eval) 

129 # Note: We do not pass the hook here. 

130 metrics = est.evaluate(eval_input_fn) 

131 if should_stop_the_training(metrics): 

132 break 

133 ``` 

134 

135 This hook should be used if the input pipeline state needs to be saved 

136 separate from the model checkpoint. Doing so may be useful for a few reasons: 

137 1. The input pipeline checkpoint may be large, if there are large shuffle 

138 or prefetch buffers for instance, and may bloat the checkpoint size. 

139 2. If the input pipeline is shared between training and validation, restoring 

140 the checkpoint during validation may override the validation input 

141 pipeline. 

142 

143 For saving the input pipeline checkpoint alongside the model weights use 

144 `tf.data.experimental.make_saveable_from_iterator` directly to create a 

145 `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however, 

146 that you will need to be careful not to restore the training iterator during 

147 eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS 

148 collector when building the eval graph. 

149 """ 

150 

151 def __init__(self, estimator, external_state_policy=None): 

152 """Initializes a `CheckpointInputPipelineHook`. 

153 

154 If the input pipeline depends on external state (e.g. seeds for 

155 RandomUniform) beyond the input pipeline, this hook would be unable to 

156 serialize and deserialize that state. If its acceptable to ignore that state 

157 change the external_state_policy argument to 'warn' or 'ignore'. For e.g. 

158 

159 ```python 

160 est = tf.estimator.Estimator(model_fn) 

161 while True: 

162 est.train( 

163 train_input_fn, 

164 hooks=[tf.data.experimental.CheckpointInputPipelineHook( 

165 est, external_state_policy='warn')], 

166 steps=train_steps_per_eval) 

167 # Note: We do not pass the hook here. 

168 metrics = est.evaluate(eval_input_fn) 

169 if should_stop_the_training(metrics): 

170 break 

171 ``` 

172 

173 Args: 

174 estimator: Estimator. 

175 external_state_policy: A string that identifies how to handle input 

176 pipelines that depend on external state. Possible values are 

177 'ignore': The external state is silently ignored. 

178 'warn': The external state is ignored, logging a warning. 

179 'fail': The operation fails upon encountering external state. 

180 By default we set it to 'fail'. 

181 

182 Raises: 

183 ValueError: One of `save_steps` or `save_secs` should be set. 

184 ValueError: At most one of saver or scaffold should be set. 

185 ValueError: If `external_state_policy` is not one of 'warn', 'ignore' or 

186 'fail'. 

187 """ 

188 if external_state_policy is None: 

189 external_state_policy = "fail" 

190 self._external_state_policy = _convert_external_state_policy_to_enum( 

191 external_state_policy) 

192 # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or 

193 # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines. 

194 # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is 

195 # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix 

196 # to be different to avoid conflicts with the model checkpoint. 

197 

198 # pylint: disable=protected-access 

199 checkpoint_prefix = "input" 

200 if estimator._config.num_worker_replicas > 1: 

201 # Distributed setting. 

202 suffix = "_{}_{}".format(estimator._config.task_type, 

203 estimator._config.task_id) 

204 checkpoint_prefix += suffix 

205 # pylint: enable=protected-access 

206 

207 # We use a composition paradigm instead of inheriting from 

208 # `CheckpointSaverHook` because `Estimator` does an `isinstance` check 

209 # to check whether a `CheckpointSaverHook` is already present in the list 

210 # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook` 

211 # would thwart this behavior. This hook checkpoints *only the iterators* 

212 # and not the graph variables. 

213 self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook( 

214 estimator.model_dir, 

215 save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access 

216 save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access 

217 checkpoint_basename=checkpoint_prefix + ".ckpt") 

218 

219 # Name for the protocol buffer file that will contain the list of most 

220 # recent checkpoints stored as a `CheckpointState` protocol buffer. 

221 # This file, kept in the same directory as the checkpoint files, is 

222 # automatically managed by the `Saver` to keep track of recent checkpoints. 

223 # The default name used by the `Saver` for this file is "checkpoint". Here 

224 # we use the name "checkpoint_<checkpoint_prefix>" so that in case the 

225 # `checkpoint_dir` is the same as the model checkpoint directory, there are 

226 # no conflicts during restore. 

227 self._latest_filename = "checkpoint_" + checkpoint_prefix 

228 

229 def begin(self): 

230 # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS` 

231 # collection if no `Saver` or `Scaffold` is provided. 

232 # pylint: disable=protected-access 

233 if (self._checkpoint_saver_hook._saver is None and 

234 self._checkpoint_saver_hook._scaffold is None): 

235 iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS) 

236 saveables = [ 

237 iterator_ops._IteratorSaveable( 

238 i, i.name, external_state_policy=self._external_state_policy) 

239 for i in iterators 

240 ] 

241 self._checkpoint_saver_hook._saver = _CustomSaver( 

242 saveables, self._latest_filename, sharded=True) 

243 # pylint: enable=protected-access 

244 self._checkpoint_saver_hook.begin() 

245 

246 def after_create_session(self, session, coord): 

247 # If a new session was created, we set _first_run to True so that we can 

248 # restore if needed. 

249 self._first_run = True 

250 

251 def _restore_or_save_initial_ckpt(self, session): 

252 # Ideally this should be run in after_create_session but is not for the 

253 # following reason: 

254 # Currently there is no way of enforcing an order of running the 

255 # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook` 

256 # is run *after* this hook. That is troublesome because 

257 # 1. If a checkpoint exists and this hook restores it, the initializer hook 

258 # will override it. 

259 # 2. If no checkpoint exists, this hook will try to save an uninitialized 

260 # iterator which will result in an exception. 

261 # 

262 # As a temporary fix we enter the following implicit contract between this 

263 # hook and the _DatasetInitializerHook. 

264 # 1. The _DatasetInitializerHook initializes the iterator in the call to 

265 # after_create_session. 

266 # 2. This hook saves the iterator on the first call to `before_run()`, which 

267 # is guaranteed to happen after `after_create_session()` of all hooks 

268 # have been run. 

269 

270 # Check if there is an existing checkpoint. If so, restore from it. 

271 # pylint: disable=protected-access 

272 latest_checkpoint_path = checkpoint_management.latest_checkpoint( 

273 self._checkpoint_saver_hook._checkpoint_dir, 

274 latest_filename=self._latest_filename) 

275 if latest_checkpoint_path: 

276 self._checkpoint_saver_hook._get_saver().restore(session, 

277 latest_checkpoint_path) 

278 else: 

279 # The checkpoint saved here is the state at step "global_step". 

280 # Note: We do not save the GraphDef or MetaGraphDef here. 

281 global_step = session.run(self._checkpoint_saver_hook._global_step_tensor) 

282 self._checkpoint_saver_hook._save(session, global_step) 

283 self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step) 

284 # pylint: enable=protected-access 

285 

286 def before_run(self, run_context): 

287 if self._first_run: 

288 self._restore_or_save_initial_ckpt(run_context.session) 

289 self._first_run = False 

290 return self._checkpoint_saver_hook.before_run(run_context) 

291 

292 def after_run(self, run_context, run_values): 

293 self._checkpoint_saver_hook.after_run(run_context, run_values) 

294 

295 def end(self, session): 

296 self._checkpoint_saver_hook.end(session) 

297 

298 

299class _CustomSaver(saver_lib.Saver): 

300 """`Saver` with a different default `latest_filename`. 

301 

302 This is used in the `CheckpointInputPipelineHook` to avoid conflicts with 

303 the model ckpt saved by the `CheckpointSaverHook`. 

304 """ 

305 

306 def __init__(self, var_list, latest_filename, sharded=False): 

307 super(_CustomSaver, self).__init__(var_list, sharded=sharded) 

308 self._latest_filename = latest_filename 

309 

310 def save(self, 

311 sess, 

312 save_path, 

313 global_step=None, 

314 latest_filename=None, 

315 meta_graph_suffix="meta", 

316 write_meta_graph=True, 

317 write_state=True, 

318 strip_default_attrs=False): 

319 return super(_CustomSaver, self).save( 

320 sess, save_path, global_step, latest_filename or self._latest_filename, 

321 meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)