Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/async_checkpoint_helper.py: 20%
236 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for saving/loading Trackable objects asynchronously."""
17import atexit
18import collections
19import copy
20import threading
21import time
22import weakref
24from absl import logging
26from tensorflow.python.checkpoint import checkpoint_context
27from tensorflow.python.distribute import device_util
28from tensorflow.python.distribute.sharded_variable import ShardedVariable
29from tensorflow.python.eager import context
30from tensorflow.python.eager import def_function
31from tensorflow.python.eager import executor
32from tensorflow.python.framework import device as pydev
33from tensorflow.python.framework import ops
34from tensorflow.python.ops.resource_variable_ops import UninitializedVariable
35from tensorflow.python.ops.variables import Variable
36from tensorflow.python.saved_model.pywrap_saved_model import metrics
37from tensorflow.python.util import object_identity
39# Captures the timestamp of the first Checkpoint instantiation or end of a write
40# operation. Can be accessed by multiple Checkpoint instances.
41_END_TIME_OF_LAST_ASYNC_WRITE = None
42_END_TIME_OF_LAST_ASYNC_WRITE_LOCK = threading.Lock()
44# API label for cell names used in async checkpoint metrics.
45_ASYNC_CHECKPOINT = "async_checkpoint"
47# Name of TPUEmbedding attribute. This is a temporary workaround
48# to identify TPUEmbedding while avoiding import cycles.
49_TPU_EMBEDDING_ATTR = "_create_copy_for_async_checkpoint"
52def _get_duration_microseconds(start_time_seconds, end_time_seconds):
53 """Calculate the duration between start and end time.
55 Args:
56 start_time_seconds: The start time in seconds.
57 end_time_seconds: The end time in seconds.
59 Returns:
60 The duration between the start and the end time. Return 0 if
61 end_time_seconds < start_time_seconds.
62 """
63 if end_time_seconds < start_time_seconds:
64 # Avoid returning negative value in case of clock skew.
65 return 0
66 return round((end_time_seconds - start_time_seconds) * 1000000)
69class AsyncCheckpointHelper:
70 """Helper class for async checkpoint."""
72 def __init__(self, checkpointer_impl, root=None, **kwargs):
73 """Initialize AsyncCheckpoint.
75 Args:
76 checkpointer_impl: The Checkpoint class to power the AsyncCheckpoint.
77 root: The root object to checkpoint. `root` may be a trackable object or
78 `WeakRef` of a trackable object.
79 **kwargs: The keyword arguments representing the checkpointed variables.
80 """
81 # TODO(chienchunh): Make sure the processing for the root object is
82 # consistent when integrating with the public API, e.g., adding all kwarg
83 # items as the child of the root object.
84 if root:
85 trackable_root = root() if isinstance(root, weakref.ref) else root
86 kwargs["root"] = trackable_root
87 trackable_root._maybe_initialize_trackable()
89 self._checkpointer_impl = checkpointer_impl
90 self._checkpoint_items = kwargs
92 # The underlying Checkpoint instance and its items.
93 self._checkpoint = None
94 self._checkpoint_options = None
96 # The callback function that needs to be executed after checkpoint write.
97 # Currently this is only applied to the scenario where CheckpointManager is
98 # used, which triggers the _write() method.
99 self._async_write_done_callback = None
101 # The list of all nodes from the original checkpoint items.
102 # TODO(chienchunh): Consider changing this to local variable.
103 self._original_nodes = None
104 # The mapping between the original and the copied resource variables.
105 # The copied variables are used for the underlying checkpointing.
106 self._object_map = None
107 # A list of TPUEmbedding objects included in the checkpoint items.
108 self._tpu_embedding_objects = None
110 self._default_device = device_util.current() or "CPU:0"
111 self._default_device = device_util.canonicalize(self._default_device)
113 self._save_file_prefix = None
114 self._use_checkpoint_save = False
115 self._async_save_thread = None
116 self._async_save_thread_shutdown = False
117 # Semaphores for writing/reading the cpu-copied variables (self._var_pairs)
118 # TODO(chienchunh): Consider Queue/Condition instead of Semaphore.
119 self._writer_sem = threading.Semaphore(1)
120 self._reader_sem = threading.Semaphore(0)
122 # Register to join the async save thread upon exit.
123 atexit.register(self._join_async_save_thread)
125 self._async_error = None
127 global _END_TIME_OF_LAST_ASYNC_WRITE
128 with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
129 if _END_TIME_OF_LAST_ASYNC_WRITE is None:
130 _END_TIME_OF_LAST_ASYNC_WRITE = time.time()
132 @def_function.function
133 def _copy_from_cpu(self):
134 """Copy the checkpointed variables from the host CPU to the accelerator.
136 TODO(chienchunh): Get the concrete function before firstly called to avoid
137 hangining the accelerators idle during function tracing.
138 """
139 for accelerator_var, cpu_var in self._object_map.items():
140 if isinstance(accelerator_var, ShardedVariable) or hasattr(
141 accelerator_var, _TPU_EMBEDDING_ATTR):
142 # Skip for SharededVariable and TPUEmbedding as their sub-variables will
143 # be copied over separately through other entries in the object map.
144 continue
145 with ops.device(accelerator_var.device):
146 accelerator_var.assign(cpu_var.read_value())
148 @def_function.function
149 def _copy_to_cpu(self):
150 """Copy the checkpointed variables from the accelerator to the host CPU.
152 TODO(chienchunh): Get the concrete function before firstly called to avoid
153 hangining the accelerators idle during function tracing.
154 """
155 for accelerator_var, cpu_var in self._object_map.items():
156 if isinstance(accelerator_var, ShardedVariable) or hasattr(
157 accelerator_var, _TPU_EMBEDDING_ATTR):
158 # Skip for SharededVariable and TPUEmbedding as their sub-variables will
159 # be copied over separately through other entries in the object map.
160 continue
161 with ops.device(cpu_var.device):
162 cpu_var.assign(accelerator_var.read_value())
163 for tpu_embedding in self._tpu_embedding_objects:
164 tpu_embedding._retrieve_variables() # pylint: disable=protected-access
166 def _traverse_variables(self, to_traverse, visited):
167 """Create the copied nodes and variables while traversing the nodes.
169 This method performs a BFS to traverse the nodes while avoiding duplicated
170 visits. Throughout the process, self._mapping, self._original_nodes, and
171 self._var_pairs are populated.
173 Args:
174 to_traverse: A deque that stores the nodes to be traversed.
175 visited: A list of nodes that have been visited.
176 """
177 # pylint: disable=protected-access
178 while to_traverse:
179 current_trackable = to_traverse.popleft()
180 self._original_nodes.append(current_trackable)
182 if isinstance(current_trackable, (Variable, ShardedVariable)):
183 self._copy_trackable(current_trackable)
184 if hasattr(current_trackable, _TPU_EMBEDDING_ATTR):
185 self._handle_tpu_embedding(current_trackable)
187 for child in current_trackable._trackable_children(
188 save_type="checkpoint").values():
189 if child in visited:
190 continue
191 visited.add(child)
192 to_traverse.append(child)
193 # pylint: enable=protected-access
195 def _ensure_initialized(self):
196 """Initialize the async checkpoint internal state."""
197 if self._checkpoint is not None:
198 return
200 self._original_nodes = []
201 self._object_map = object_identity.ObjectIdentityDictionary()
202 self._tpu_embedding_objects = []
204 # Add the top-level checkpoint items to be traversed,
205 to_traverse = collections.deque([])
206 visited = object_identity.ObjectIdentitySet()
207 for v in self._checkpoint_items.values():
208 if isinstance(v, (Variable, ShardedVariable)):
209 self._copy_trackable(v)
210 elif hasattr(v, _TPU_EMBEDDING_ATTR):
211 self._handle_tpu_embedding(v)
212 to_traverse.append(v)
213 visited.add(v)
214 self._traverse_variables(to_traverse, visited)
216 # Copy for the slot variables.
217 for current_trackable in self._original_nodes:
218 # Note: dir() is used rather than hasattr() here to avoid triggering
219 # custom __getattr__ code, see b/152031870 for context.
220 if "get_slot_names" in dir(current_trackable):
221 slot_names = current_trackable.get_slot_names()
222 for slot_name in slot_names:
223 for original_variable in self._original_nodes:
224 if not isinstance(original_variable, Variable):
225 continue
226 try:
227 original_slot_variable = current_trackable.get_slot(
228 original_variable, slot_name)
229 except (AttributeError, KeyError):
230 continue
231 if isinstance(original_slot_variable, (Variable, ShardedVariable)):
232 self._copy_trackable(original_slot_variable)
234 # Initiate the underlying Checkpoint instance with the copied items.
235 self._checkpoint = self._checkpointer_impl(**self._checkpoint_items)
236 # Initiate the underlying Checkpoint instance's save_counter.
237 save_counter = self._checkpoint.save_counter
238 logging.info("Initializing async checkpoint's save_counter: %d",
239 save_counter)
241 # Pass the object map of the copied variables to the underlying Checkpoint.
242 self._checkpoint._saver._object_map = self._object_map # pylint: disable=protected-access
244 # Initiate the async thread for checkpoint saving.
245 self._async_save_thread = threading.Thread(
246 target=self._async_save, daemon=True)
247 self._async_save_thread.start()
249 def _check_async_thread_error(self):
250 """Expose the most recent error from the async saving thread to the caller.
251 """
252 if self._async_error:
253 e = self._async_error
254 self._async_error = None
255 logging.error("Propagating the most recent error from the async thread "
256 "before joining: %s", str(e))
257 # This allows the registered at-exit method '_join_async_save_thread' to
258 # acquire the semaphore instead of timing out.
259 self._writer_sem.release()
260 raise e
262 def _join_async_save_thread(self):
263 """Join the async save thread.
265 The steps for terminating the async save thread:
266 1). Wait until the last async save event is done.
267 2). Set _async_save_thread_shutdown flag to false to indicate termination.
268 3). Trigger the async save thread to check and fail the while-predicate.
269 4). Join the async save thread. (The thread may finish before joining.)
270 """
271 # Expose the async thread error (if any) before joining the thread.
272 self._check_async_thread_error()
274 if self._writer_sem.acquire(timeout=300): # Step-1.
275 self._async_save_thread_shutdown = True # Step-2.
276 self._reader_sem.release() # Step-3.
277 logging.info("Joining the async save thread.")
278 if self._async_save_thread is not None:
279 self._async_save_thread.join() # Step-4.
280 else:
281 logging.error("Timeout waiting for the async save thread; terminating the"
282 " thread instead. The last checkpoint may be incomeplete.")
284 def _async_save(self):
285 """The thread function for the async checkpoint save."""
286 with context.executor_scope(
287 executor.new_executor(
288 enable_async=False, enable_streaming_enqueue=False)):
289 while self._reader_sem.acquire() and not self._async_save_thread_shutdown:
290 logging.info("Starting async checkpoint save on the device: %s",
291 self._default_device)
293 async_save_start_time = time.time()
295 # Specify the ops placement on the worker if running with
296 # coordinator-worker mode. This is required as launching a new thread
297 # would clear the placement policy and make localhost the default
298 # placement, while the main thread's default placement would be the
299 # master worker's CPU:0.
300 try:
301 with ops.device(self._default_device):
302 with checkpoint_context.async_metrics_context():
303 if self._use_checkpoint_save:
304 self._checkpoint.save(self._save_file_prefix,
305 self._checkpoint_options)
306 else:
307 self._checkpoint._write( # pylint: disable=protected-access
308 self._save_file_prefix,
309 options=self._checkpoint_options,
310 write_done_callback=self._async_write_done_callback)
311 except Exception as e: # # pylint: disable=broad-except
312 self._async_error = e
313 finally:
314 self._writer_sem.release()
316 async_save_end_time = time.time()
317 metrics.AddAsyncCheckpointWriteDuration(
318 api_label=_ASYNC_CHECKPOINT,
319 microseconds=_get_duration_microseconds(async_save_start_time,
320 async_save_end_time))
322 # Measure the elapsed time since the last checkpoint.
323 # Due to the nature of async checkpoint, here it actually captures the
324 # duration between the start_time of the previous checkpoint and the
325 # start time of this checkpoint. As a result, the duration of the final
326 # async checkpoint is excluded, which is fine since it does not take
327 # much time.
328 global _END_TIME_OF_LAST_ASYNC_WRITE
329 with _END_TIME_OF_LAST_ASYNC_WRITE_LOCK:
330 metrics.AddTrainingTimeSaved(
331 api_label=_ASYNC_CHECKPOINT,
332 microseconds=_get_duration_microseconds(
333 _END_TIME_OF_LAST_ASYNC_WRITE, async_save_start_time))
334 _END_TIME_OF_LAST_ASYNC_WRITE = async_save_start_time
335 logging.info("Async save thread reached the end of the execution.")
337 def _copy_for_variable(self, original_var):
338 """Create a new instance for the input trackable.
340 Args:
341 original_var: Input Variable object to be copied.
342 """
343 op_device = pydev.DeviceSpec.from_string(original_var.device).replace(
344 device_type="CPU", device_index=0).to_string()
345 with ops.device(op_device):
346 new_var = UninitializedVariable(
347 trainable=original_var.trainable,
348 shape=original_var.shape,
349 dtype=original_var.dtype,
350 name=original_var._shared_name) # pylint: disable=protected-access
351 self._object_map[original_var] = new_var
353 def _copy_for_sharded_variable(self, original_var):
354 """Create a new instance for the input ShardedVariable.
356 Args:
357 original_var: Input ShardedVariable object to be copied.
358 """
359 copied_vars = []
360 for v in original_var._variables: # pylint: disable=protected-access
361 self._copy_for_variable(v)
362 copied_vars.append(self._object_map[v])
363 self._object_map[original_var] = ShardedVariable(
364 copied_vars, name=original_var.name)
366 def _copy_trackable(self, original_trackable):
367 """Create a new instance for the input trackable.
369 Args:
370 original_trackable: The trackable instance to be copied.
372 Raises:
373 AttributeError: if the input trackable is not Variable or ShardedVariable.
374 """
375 if isinstance(original_trackable, ShardedVariable):
376 self._copy_for_sharded_variable(original_trackable)
377 elif isinstance(original_trackable, Variable):
378 self._copy_for_variable(original_trackable)
379 else:
380 raise AttributeError("Only Variable or ShardedVariable can be copied.")
382 def _handle_tpu_embedding(self, tpu_embedding):
383 """Handle TPUEmbedding.
385 Args:
386 tpu_embedding: TPUEmbedding object to be handled.
388 Raises:
389 AttributeError: if the input trackable is not TPUEmbedding type.
390 """
391 if not hasattr(
392 tpu_embedding, _TPU_EMBEDDING_ATTR
393 ) or not callable(tpu_embedding._create_copy_for_async_checkpoint): # pylint: disable=protected-access
394 raise AttributeError(
395 "Expecting TPUEmbedding type; got %s" % type(tpu_embedding)
396 )
398 # Create a dummy TPUEmbedding object and add it to the object_map. This is
399 # to prevent the TPUEmbedding's save_callback from being triggered because
400 # the embedding values have already being retrieved by AsyncCheckpoint.
401 # pylint: disable=protected-access
402 new_embedding = tpu_embedding._create_copy_for_async_checkpoint(
403 feature_config=tpu_embedding._feature_config,
404 optimizer=tpu_embedding._table_config[0]
405 if tpu_embedding._table_config
406 else None,
407 pipeline_execution_with_tensor_core=tpu_embedding._pipeline_execution_with_tensor_core,
408 )
409 self._object_map[tpu_embedding] = new_embedding
410 # pylint: enable=protected-access
412 if tpu_embedding not in self._tpu_embedding_objects:
413 self._tpu_embedding_objects.append(tpu_embedding)
415 @property
416 def save_counter(self):
417 """An integer variable numbering the checkpoint events.
419 This is maintained by the underlying tf.train.Checkpoing object employed by
420 AsyncCheckpoint class. The number starts at 0 and gets incremented for each
421 checkpoint event.
423 Returns:
424 The save counter variable.
425 """
426 # TODO(sagunb): Improve the solution for initializing save_counter.
427 # If save_counter() is called before all the variables are created,
428 # self._ensure_initialized() would construct the object_map without some
429 # variables that need to be checkpointed, e.g., slot variables.
430 self._ensure_initialized()
431 return self._checkpoint.save_counter
433 def write(self, save_path, options=None):
434 """Save the checkpointed variables.
436 Args:
437 save_path: The file prefix of the checkpoint file.
438 options: Optional CheckpointOption instance.
440 Returns:
441 The full path of the checkpoint file.
442 """
443 self._write(save_path, options)
445 def _write(self, save_path, options=None, write_done_callback=None):
446 """Save the checkpointed variables.
448 This method has exactly the same logic as save(), except it does not
449 increment the underlying save_counter, which is done by the caller, e.g.,
450 CheckpointManager.
452 Args:
453 save_path: The file prefix of the checkpoint file.
454 options: Optional CheckpointOption instance.
455 write_done_callback: Optional callback function executed after the async
456 write is done.
458 Returns:
459 The full path of the checkpoint file.
460 """
461 self._ensure_initialized()
463 write_start_time = time.time()
465 # Copy the variable values to the host CPU.
466 if self._writer_sem.acquire():
467 self._copy_to_cpu()
469 # Surface the error from the async thread, if any.
470 # This step should come after the sem acquision step in the above, so that
471 # it makes sure it waits until the previous async save finishes storing the
472 # error.
473 self._check_async_thread_error()
475 # Trigger the async thread to checkpoint the cpu-copied variables.
476 # Need to wait until the weight copying finishes before checkpoint save.
477 context.async_wait()
478 self._save_file_prefix = save_path
479 self._use_checkpoint_save = False
481 # Ensure that we do not request async checkpointing to the underlying
482 # checkpointer as this could lead to an infinite loop.
483 self._checkpoint_options = copy.copy(options) if options else None
484 if self._checkpoint_options:
485 self._checkpoint_options.experimental_enable_async_checkpoint = False
487 self._async_write_done_callback = write_done_callback
488 self._reader_sem.release()
490 write_end_time = time.time()
491 metrics.AddCheckpointWriteDuration(
492 api_label=_ASYNC_CHECKPOINT,
493 microseconds=_get_duration_microseconds(write_start_time,
494 write_end_time))
496 return save_path
498 def save(self, save_path, options=None):
499 """Save the checkpointed variables.
501 Args:
502 save_path: The file prefix of the checkpoint file.
503 options: Optional CheckpointOption instance.
505 Returns:
506 The full path of the checkpoint file.
507 """
508 # If this is the first time that AsyncCheckpoint.save() is called,
509 # initialize the cpu-copied variables and create the pair-wise mapping
510 # between the original model variables and the cpu-copied variables.
511 #
512 # This is not performed in the initializer because some variables, e.g.,
513 # slot variables of the optimizer, were not created until actually running
514 # the train function, so we could only get the complete list of the
515 # variables after some train steps were run.
516 self._ensure_initialized()
518 save_start_time = time.time()
520 # Copy the variable values to the host CPU.
521 if self._writer_sem.acquire():
522 self._copy_to_cpu()
524 # Surface the error from the async thread, if any.
525 # This step should come after the sem acquision step in the above, so that
526 # it makes sure it waits until the previous async save finishes storing the
527 # error.
528 self._check_async_thread_error()
530 # Retrieve the save counter from the underlying checkpoint object to
531 # re-construct the full path of the checkpoint file.
532 # This step has to happen before triggerting the underlying checkpoint;
533 # otherwise, the save_counter value may or may not have been updated.
534 save_counter = self._checkpoint.save_counter.numpy() + 1
535 full_path = "{}-{}".format(save_path, save_counter)
537 # Trigger the async thread to checkpoint the cpu-copied variables.
538 # Need to wait until the weight copying finishes before checkpoint save.
539 context.async_wait()
540 self._save_file_prefix = save_path
541 self._use_checkpoint_save = True
543 # Ensure that we do not request async checkpointing to the underlying
544 # checkpointer as this could lead to an infinite loop.
545 self._checkpoint_options = copy.copy(options) if options else None
546 if self._checkpoint_options:
547 self._checkpoint_options.experimental_enable_async_checkpoint = False
549 self._reader_sem.release()
551 save_end_time = time.time()
552 metrics.AddCheckpointWriteDuration(
553 api_label=_ASYNC_CHECKPOINT,
554 microseconds=_get_duration_microseconds(save_start_time, save_end_time))
556 return full_path
558 def read(self, save_path, options=None):
559 """Restore the checkpointed variables.
561 This method has exactly the same logic as restore(). This method is
562 implemented only to fulfill the duty of subclassing tf.train.Checkpoint.
564 Args:
565 save_path: The full name of the checkpoint file to be restored.
566 options: CheckpointOption instance.
568 Returns:
569 A load status object, which can be used to make assertions about the
570 status of a checkpoint restoration. See tf.train.Checkpoint.restore()
571 for more details.
572 """
573 return self.restore(save_path, options)
575 def restore(self, save_path, options=None):
576 """Restore the checkpointed variables.
578 Args:
579 save_path: The full name of the checkpoint file to be restored.
580 options: CheckpointOption instance.
582 Returns:
583 A load status object, which can be used to make assertions about the
584 status of a checkpoint restoration. See tf.train.Checkpoint.restore()
585 for more details.
586 """
587 # Ensure that we do not request async checkpointing to the underlying
588 # checkpointer as this could lead to an infinite loop.
589 self._checkpoint_options = (
590 copy.copy(options) if options else self._checkpoint_options)
591 if self._checkpoint_options:
592 self._checkpoint_options.experimental_enable_async_checkpoint = False
594 # Wait for any ongoing checkpoint event to finish.
595 with self._writer_sem:
596 # If _checkpoint has not been initialized yet, it means the restore() is
597 # called right after the coordinator is restarted. We directly restore
598 # the checkpointed items through tf.train.Checkpoint.restore().
599 if self._checkpoint is None:
600 tmp_checkpoint = self._checkpointer_impl(**self._checkpoint_items)
601 return tmp_checkpoint.restore(save_path, self._checkpoint_options)
603 # Restore the values of the cpu-copied variables.
604 status = self._checkpoint.restore(save_path, self._checkpoint_options)
606 # Restore the values of the original model.
607 self._copy_from_cpu()
608 return status
610 def sync(self):
611 """Sync on any ongoing save or restore events."""
612 with self._writer_sem:
613 logging.info("Sync on ongoing save/restore.")