Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/dtensor/python/d_checkpoint.py: 22%

195 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""DTensor Checkpoint. 

16 

17Note that this module contains deprecated functionality, and the DTensor related 

18checkpoint has been integrated with tf.train.Checkpoint. It can be used out of 

19the box to save and restore dtensors. 

20""" 

21 

22from typing import Dict, List, Optional 

23import weakref 

24 

25from tensorflow.core.protobuf import trackable_object_graph_pb2 

26 

27from tensorflow.dtensor.python import api 

28from tensorflow.dtensor.python import d_variable 

29from tensorflow.dtensor.python import gen_dtensor_ops 

30from tensorflow.dtensor.python import layout 

31from tensorflow.dtensor.python import save_restore 

32from tensorflow.python.checkpoint import checkpoint as util 

33from tensorflow.python.checkpoint import checkpoint_options 

34from tensorflow.python.checkpoint import graph_view as graph_view_lib 

35from tensorflow.python.checkpoint import restore as restore_lib 

36from tensorflow.python.eager import context 

37from tensorflow.python.framework import constant_op 

38from tensorflow.python.framework import errors_impl 

39from tensorflow.python.framework import ops 

40from tensorflow.python.ops import array_ops 

41from tensorflow.python.trackable import base 

42from tensorflow.python.trackable import data_structures 

43from tensorflow.python.training import py_checkpoint_reader 

44from tensorflow.python.training.saving import saveable_object 

45from tensorflow.python.training.saving import saveable_object_util 

46from tensorflow.python.util import deprecation 

47from tensorflow.python.util import nest 

48from tensorflow.python.util.tf_export import tf_export 

49 

50 

51class _DSaver: # pylint: disable=protected-access 

52 """A single device saver that places tensors on DTensor Device.""" 

53 

54 def __init__(self, mesh: layout.Mesh, 

55 saveable_objects: List[saveable_object.SaveableObject]): 

56 self._saveable_objects = saveable_objects 

57 self._mesh = mesh 

58 

59 def save( 

60 self, 

61 file_prefix: str, 

62 options: Optional[checkpoint_options.CheckpointOptions] = None 

63 ) -> Optional[ops.Operation]: 

64 """Saves the saveable objects to a checkpoint with `file_prefix`. 

65 

66 Also query the generated shards from the distributed DTensor SaveV2 ops and 

67 do a MergeV2 on those. Each op here is backed by a global_barrier to avoid 

68 racing from multiple clients. 

69 

70 Args: 

71 file_prefix: A string or scalar string Tensor containing the prefix to 

72 save under. 

73 options: Optional `CheckpointOptions` object. This is unused in DTensor. 

74 

75 Returns: 

76 An `Operation`, or None when executing eagerly. 

77 """ 

78 if options is not None and options.experimental_io_device is not None: 

79 raise ValueError( 

80 "Specified experimental_io_device in DTensor checkpoint is not supported." 

81 ) 

82 del options 

83 tensor_names = [] 

84 tensors = [] 

85 tensor_slices = [] 

86 for saveable in self._saveable_objects: 

87 for spec in saveable.specs: 

88 tensor = spec.tensor 

89 # A tensor value of `None` indicates that this SaveableObject gets 

90 # recorded in the object graph, but that no value is saved in the 

91 # checkpoint. 

92 if tensor is not None: 

93 if api.device_name() != spec.device: 

94 # Some small tensors are placed on CPU0 from save manager and 

95 # broadcasted to DTensor mesh, e,g., SaveCounter. 

96 tensor = api.pack([tensor] * 

97 self._mesh.host_mesh().num_local_devices(), 

98 layout.Layout.replicated( 

99 self._mesh.host_mesh(), 

100 rank=tensor.shape.rank)) 

101 tensor_names.append(spec.name) 

102 tensors.append(tensor) 

103 tensor_slices.append(spec.slice_spec) 

104 return save_restore.sharded_save(self._mesh, file_prefix, tensor_names, 

105 tensor_slices, tensors) 

106 

107 def restore( 

108 self, 

109 file_prefix: str, 

110 options: Optional[checkpoint_options.CheckpointOptions] = None 

111 ) -> Dict[str, ops.Operation]: 

112 """Restore the saveable objects from a checkpoint with `file_prefix`. 

113 

114 Args: 

115 file_prefix: A string or scalar string Tensor containing the prefix for 

116 files to read from. 

117 options: Optional `CheckpointOptions` object. This is unused in DTensor. 

118 

119 Returns: 

120 A dictionary mapping from SaveableObject names to restore operations. 

121 """ 

122 if options is not None and options.experimental_io_device is not None: 

123 raise ValueError( 

124 "Specified experimental_io_device in DTensor checkpoint is not " 

125 "supported.") 

126 del options 

127 restore_specs = [] 

128 tensor_structure = [] 

129 for saveable in self._saveable_objects: 

130 saveable_tensor_structure = [] 

131 tensor_structure.append(saveable_tensor_structure) 

132 # DTensor change 1 : Gather shapes and layout from original saveable 

133 # specs. 

134 # Note that this relies on the fact that the variables are already 

135 # initialized -- which isn't the behavior we want eventually. 

136 # TODO(b/159035705): Handle the variable initialization in restore. 

137 for spec in saveable.specs: 

138 saveable_tensor_structure.append(spec.name) 

139 if isinstance(spec, d_variable.DSaveSpec): 

140 restore_specs.append((spec.name, spec.slice_spec, spec.dtype, 

141 spec.layout, spec.global_shape)) 

142 # Fall back to replicated layouts for non-DTensor saves that constructs 

143 # normal SaveSpec. 

144 elif isinstance(spec, saveable_object.SaveSpec): 

145 restore_specs.append( 

146 (spec.name, spec.slice_spec, spec.dtype, 

147 layout.Layout.replicated(self._mesh.host_mesh(), 

148 spec.tensor.shape.rank).to_string(), 

149 spec.tensor.shape.as_list())) 

150 tensor_names, tensor_slices, tensor_dtypes, layouts, global_shapes = zip( 

151 *restore_specs) 

152 with ops.device(api.device_name()): 

153 # DTensor change 2 : Run on customized DTensor RestoreV2 op rather than 

154 # stock TF io_ops.RestoreV2. 

155 restored_tensors = gen_dtensor_ops.d_tensor_restore_v2( 

156 prefix=file_prefix, 

157 tensor_names=tensor_names, 

158 shape_and_slices=tensor_slices, 

159 input_shapes=global_shapes, 

160 input_layouts=layouts, 

161 dtypes=tensor_dtypes) 

162 structured_restored_tensors = nest.pack_sequence_as(tensor_structure, 

163 restored_tensors) 

164 restore_ops = {} 

165 for saveable, restored_tensors in zip(self._saveable_objects, 

166 structured_restored_tensors): 

167 restore_ops[saveable.name] = saveable.restore( 

168 restored_tensors, restored_shapes=None) 

169 return restore_ops 

170 

171 

172class _DCheckpointRestoreCoordinator(util._CheckpointRestoreCoordinator): # pylint: disable=protected-access 

173 """Holds the status of an object-based checkpoint load.""" 

174 

175 def __init__(self, mesh: layout.Mesh, **kwargs): 

176 super().__init__(**kwargs) 

177 self._mesh = mesh 

178 

179 def restore_saveables(self, 

180 tensor_saveables: Dict[str, 

181 saveable_object.SaveableObject], 

182 python_positions: List[restore_lib.CheckpointPosition], 

183 registered_savers: Optional[Dict[str, Dict[ 

184 str, base.Trackable]]] = None, 

185 reader: py_checkpoint_reader.NewCheckpointReader = None 

186 ) -> Optional[List[ops.Operation]]: 

187 """Run or build restore operations for SaveableObjects. 

188 

189 Args: 

190 tensor_saveables: `SaveableObject`s which correspond to Tensors. 

191 python_positions: `CheckpointPosition`s which correspond to `PythonState` 

192 Trackables bound to the checkpoint. 

193 registered_savers: a dict mapping saver names-> object name -> Trackable. 

194 This argument is not implemented for DTensorCheckpoint. 

195 reader: A CheckpointReader. Creates one lazily if None. 

196 

197 Returns: 

198 When graph building, a list of restore operations, either cached or newly 

199 created, to restore `tensor_saveables`. 

200 """ 

201 del registered_savers 

202 

203 restore_ops = [] 

204 # Eagerly run restorations for Python state. 

205 if python_positions: 

206 # Lazily create the NewCheckpointReader, since this requires file access 

207 # and we may not have any Python saveables. 

208 if reader is None: 

209 reader = py_checkpoint_reader.NewCheckpointReader(self.save_path_string) 

210 for position in python_positions: 

211 key = position.object_proto.attributes[0].checkpoint_key 

212 position.trackable.deserialize(reader.get_tensor(key)) 

213 

214 # If we have new SaveableObjects, extract and cache restore ops. 

215 if tensor_saveables: 

216 validated_saveables = saveable_object_util.validate_and_slice_inputs( 

217 tensor_saveables) 

218 validated_names = set(saveable.name for saveable in validated_saveables) 

219 if set(tensor_saveables.keys()) != validated_names: 

220 raise AssertionError( 

221 ("Saveable keys changed when validating. Got back %s, was " 

222 "expecting %s") % (tensor_saveables.keys(), validated_names)) 

223 # DTensor change: Use _DSaver that does restore on DTensor with 

224 # customized DTensorRestoreV2 op. 

225 new_restore_ops = _DSaver(self._mesh, validated_saveables).restore( 

226 self.save_path_tensor, self.options) 

227 if not context.executing_eagerly(): 

228 for name, restore_op in sorted(new_restore_ops.items()): 

229 restore_ops.append(restore_op) 

230 assert name not in self.restore_ops_by_name 

231 self.restore_ops_by_name[name] = restore_op 

232 return restore_ops 

233 

234 

235class DTrackableSaver(util.TrackableSaver): 

236 """A DTensor trackable saver that uses _SingleDeviceSaver.""" 

237 

238 def __init__(self, mesh: layout.Mesh, graph_view): 

239 super(DTrackableSaver, self).__init__(graph_view) 

240 self._mesh = mesh 

241 

242 def _gather_saveables(self, object_graph_tensor=None): 

243 # Since the base Checkpoint class does not return SaveableObjects, re-use 

244 # the saveables cache or generate new Saveables. 

245 (serialized_tensors, feed_additions, registered_savers, 

246 graph_proto) = self._gather_serialized_tensors(object_graph_tensor) 

247 

248 saveables_dict = self._saveables_cache 

249 if saveables_dict is None: 

250 # Get and remove object graph tensor from `serialized_tensors`, because 

251 # the function `serialized_tensors_to_saveable_cache` isn't equipped 

252 # to handle it. 

253 object_graph_tensor = serialized_tensors.pop( 

254 None)[base.OBJECT_GRAPH_PROTO_KEY] 

255 saveables_dict = ( 

256 saveable_object_util.serialized_tensors_to_saveable_cache( 

257 serialized_tensors)) 

258 named_saveable_objects = [] 

259 for saveable_by_name in saveables_dict.values(): 

260 for saveables in saveable_by_name.values(): 

261 named_saveable_objects.extend(saveables) 

262 named_saveable_objects.append( 

263 base.NoRestoreSaveable( 

264 tensor=object_graph_tensor, 

265 name=base.OBJECT_GRAPH_PROTO_KEY)) 

266 return (named_saveable_objects, graph_proto, feed_additions, 

267 registered_savers) 

268 

269 def _save_cached_when_graph_building(self, 

270 file_prefix, 

271 object_graph_tensor, 

272 options, 

273 update_ckpt_state=False): 

274 """Create or retrieve save ops, overrides parents's private method. 

275 

276 Args: 

277 file_prefix: The prefix for saved checkpoint files. 

278 object_graph_tensor: A `Tensor` to which the current object graph will be 

279 fed. 

280 options: `CheckpointOptions` object. 

281 update_ckpt_state: Optional bool flag. Indiciate whether the internal 

282 checkpoint state needs to be updated. This is used for async checkpoint, 

283 which DTrackableSaver currently does not support. 

284 TODO(chienchunh): Implement async checkpoint for DTrackableSaver. 

285 

286 Returns: 

287 A two-element tuple with a filename tensor and a feed_dict of tensors to 

288 feed when running it (if graph building). The feed dict contains the 

289 current object graph and any Python state to be saved in the 

290 checkpoint. When executing eagerly only the first argument is meaningful. 

291 """ 

292 (named_saveable_objects, graph_proto, feed_additions, 

293 unused_registered_savers) = self._gather_saveables( 

294 object_graph_tensor=object_graph_tensor) 

295 if (self._last_save_object_graph != graph_proto 

296 # When executing eagerly, we need to re-create SaveableObjects each time 

297 # save() is called so they pick up new Tensors passed to their 

298 # constructors. That means the Saver needs to be copied with a new 

299 # var_list. 

300 or context.executing_eagerly() or ops.inside_function()): 

301 # This is needed to avoid MultiDeviceSaver creating unnecessary MergeV2 

302 # ops in DTensor. It is an issue when saving TPU Variables on host CPU 

303 # mesh given our limited expressiveness in API and hard-coded logic in 

304 # broadcasting -- for a small constant Tensor with no extra information, 

305 # we place it on the first registered mesh(A.K.A. default mesh). 

306 saver = _DSaver(self._mesh, named_saveable_objects) 

307 save_op = saver.save(file_prefix, options=options) 

308 with ops.device("/cpu:0"): 

309 with ops.control_dependencies([save_op]): 

310 self._cached_save_operation = array_ops.identity(file_prefix) 

311 self._last_save_object_graph = graph_proto 

312 return self._cached_save_operation, feed_additions 

313 

314 # TODO(b/180466245): Use proper mesh placement semantic. 

315 def restore(self, save_path, options=None): 

316 """Restore a training checkpoint with host mesh placement.""" 

317 options = options or checkpoint_options.CheckpointOptions() 

318 if save_path is None: 

319 return util.InitializationOnlyStatus(self._graph_view, ops.uid()) 

320 reader = py_checkpoint_reader.NewCheckpointReader(save_path) 

321 graph_building = not context.executing_eagerly() 

322 if graph_building: 

323 dtype_map = None 

324 else: 

325 dtype_map = reader.get_variable_to_dtype_map() 

326 try: 

327 object_graph_string = reader.get_tensor(base.OBJECT_GRAPH_PROTO_KEY) 

328 except errors_impl.NotFoundError: 

329 # The object graph proto does not exist in this checkpoint. Try the 

330 # name-based compatibility mode. 

331 restore_coordinator = util._NameBasedRestoreCoordinator( # pylint: disable=protected-access 

332 save_path=save_path, 

333 dtype_map=dtype_map) 

334 if not graph_building: 

335 for existing_trackable in self._graph_view.list_objects(): 

336 # pylint: disable=protected-access 

337 existing_trackable._maybe_initialize_trackable() 

338 existing_trackable._name_based_restores.add(restore_coordinator) 

339 existing_trackable._name_based_attribute_restore(restore_coordinator) 

340 # pylint: enable=protected-access 

341 return util.NameBasedSaverStatus( 

342 restore_coordinator, graph_view=self._graph_view) 

343 

344 if graph_building: 

345 if self._file_prefix_placeholder is None: 

346 # DTensor change: provide a hint for mesh broadcasting to put the input 

347 # onto the host mesh. 

348 self._file_prefix_placeholder = api.pack( 

349 [constant_op.constant("model")] * self._mesh.num_local_devices(), 

350 layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) 

351 file_prefix_tensor = self._file_prefix_placeholder 

352 file_prefix_feed_dict = {self._file_prefix_placeholder: save_path} 

353 else: 

354 # DTensor change: provide a hint for mesh broadcasting to put the input 

355 # onto the host mesh. 

356 file_prefix_tensor = api.pack( 

357 [constant_op.constant(save_path)] * self._mesh.num_local_devices(), 

358 layout.Layout.replicated(self._mesh.host_mesh(), rank=0)) 

359 file_prefix_feed_dict = None 

360 object_graph_proto = (trackable_object_graph_pb2.TrackableObjectGraph()) 

361 object_graph_proto.ParseFromString(object_graph_string) 

362 # DTensor Change: Hook the proper DSaver in restore. 

363 checkpoint = _DCheckpointRestoreCoordinator( 

364 mesh=self._mesh, 

365 object_graph_proto=object_graph_proto, 

366 save_path=save_path, 

367 save_path_tensor=file_prefix_tensor, 

368 reader=reader, 

369 restore_op_cache=self._restore_op_cache, 

370 graph_view=self._graph_view, 

371 options=options, 

372 saveables_cache=self._saveables_cache) 

373 restore_lib.CheckpointPosition( 

374 checkpoint=checkpoint, proto_id=0).restore(self._graph_view.root) 

375 

376 # Attached dependencies are not attached to the root, so should be restored 

377 # separately. 

378 if self._graph_view.attached_dependencies: 

379 for ref in self._graph_view.attached_dependencies: 

380 if ref.name == "root": 

381 # Root dependency is automatically added to attached dependencies -- 

382 # this can be ignored since it maps back to the root object. 

383 continue 

384 proto_id = None 

385 # Find proto ID of attached dependency (if it is in the proto). 

386 for proto_ref in object_graph_proto.nodes[0].children: 

387 if proto_ref.local_name == ref.name: 

388 proto_id = proto_ref.node_id 

389 break 

390 

391 if proto_id in checkpoint.object_by_proto_id: 

392 # Object has already been restored. This can happen when there's an 

393 # indirect connection from the attached object to the root. 

394 continue 

395 

396 restore_lib.CheckpointPosition( 

397 checkpoint=checkpoint, proto_id=proto_id).restore(ref.ref) 

398 

399 load_status = util.CheckpointLoadStatus( 

400 checkpoint, 

401 graph_view=self._graph_view, 

402 feed_dict=file_prefix_feed_dict) 

403 return load_status 

404 

405 

406@deprecation.deprecated( 

407 date=None, 

408 instructions="Please use tf.train.Checkpoint instead of DTensorCheckpoint. " 

409 "DTensor is integrated with tf.train.Checkpoint and it can be " 

410 "used out of the box to save and restore dtensors.") 

411@tf_export("experimental.dtensor.DTensorCheckpoint", v1=[]) 

412class DTensorCheckpoint(util.Checkpoint): 

413 """Manages saving/restoring trackable values to disk, for DTensor.""" 

414 

415 def __init__(self, mesh: layout.Mesh, root=None, **kwargs): 

416 super(DTensorCheckpoint, self).__init__(root=root, **kwargs) 

417 self._mesh = mesh 

418 

419 saver_root = self 

420 attached_dependencies = None 

421 self._save_counter = None # Created lazily for restore-on-create. 

422 self._save_assign_op = None 

423 

424 if root: 

425 util._assert_trackable(root, "root") 

426 saver_root = root 

427 attached_dependencies = [] 

428 

429 # All keyword arguments (including root itself) are set as children 

430 # of root. 

431 kwargs["root"] = root 

432 root._maybe_initialize_trackable() 

433 

434 self._save_counter = data_structures.NoDependency( 

435 root._lookup_dependency("save_counter")) 

436 self._root = data_structures.NoDependency(root) 

437 

438 for k, v in sorted(kwargs.items(), key=lambda item: item[0]): 

439 setattr(self, k, v) 

440 

441 # Call getattr instead of directly using v because setattr converts 

442 # v to a Trackable data structure when v is a list/dict/tuple. 

443 converted_v = getattr(self, k) 

444 util._assert_trackable(converted_v, k) 

445 

446 if root: 

447 # Make sure that root doesn't already have dependencies with these names 

448 attached_dependencies = attached_dependencies or [] 

449 child = root._lookup_dependency(k) 

450 if child is None: 

451 attached_dependencies.append(base.TrackableReference(k, converted_v)) 

452 elif child != converted_v: 

453 raise ValueError( 

454 "Cannot create a Checkpoint with keyword argument {name} if " 

455 "root.{name} already exists.".format(name=k)) 

456 # DTensor Change: 

457 # Override the parents saver with DTrackableSaver with _SingleDeviceSaver. 

458 self._saver = DTrackableSaver( 

459 mesh, 

460 graph_view_lib.ObjectGraphView( 

461 weakref.ref(saver_root), 

462 attached_dependencies=attached_dependencies))