Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/checkpoint/functional_saver.py: 20%

210 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Saves and restore variables inside traced @tf.functions.""" 

16 

17from tensorflow.core.protobuf import saver_pb2 

18from tensorflow.python.checkpoint import checkpoint_options 

19from tensorflow.python.eager import context 

20from tensorflow.python.eager import def_function 

21from tensorflow.python.framework import constant_op 

22from tensorflow.python.framework import dtypes 

23from tensorflow.python.framework import ops 

24from tensorflow.python.framework import tensor_spec 

25from tensorflow.python.framework import tensor_util 

26from tensorflow.python.ops import array_ops 

27from tensorflow.python.ops import gen_io_ops 

28from tensorflow.python.ops import io_ops 

29from tensorflow.python.ops import string_ops 

30from tensorflow.python.saved_model import registration 

31from tensorflow.python.trackable import trackable_utils 

32from tensorflow.python.training.saving import saveable_object 

33from tensorflow.python.training.saving import saveable_object_util 

34from tensorflow.python.util import nest 

35from tensorflow.python.util import object_identity 

36 

37 

38class _SingleDeviceSaver(object): 

39 """Saves and restores checkpoints from the current device.""" 

40 

41 __slots__ = ["_tensor_slice_dict"] 

42 

43 def __init__(self, tensor_slice_dict): 

44 """Specify a list of `SaveableObject`s to save and restore. 

45 

46 Args: 

47 tensor_slice_dict: A dict mapping checkpoint key -> slice_spec -> tensor. 

48 """ 

49 self._tensor_slice_dict = tensor_slice_dict 

50 

51 def save(self, file_prefix, options=None): 

52 """Save the saveable objects to a checkpoint with `file_prefix`. 

53 

54 Args: 

55 file_prefix: A string or scalar string Tensor containing the prefix to 

56 save under. 

57 options: Optional `CheckpointOptions` object. 

58 Returns: 

59 An `Operation`, or None when executing eagerly. 

60 """ 

61 options = options or checkpoint_options.CheckpointOptions() 

62 tensor_names = [] 

63 tensors = [] 

64 slice_specs = [] 

65 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 

66 for slice_spec, tensor in tensor_slices.items(): 

67 if isinstance(tensor, saveable_object.SaveSpec): 

68 tensor_value = tensor.tensor 

69 # A tensor value of `None` indicates that this SaveableObject gets 

70 # recorded in the object graph, but that no value is saved in the 

71 # checkpoint. 

72 if tensor_value is not None: 

73 tensor_names.append(tensor.name) 

74 tensors.append(tensor_value) 

75 slice_specs.append(tensor.slice_spec) 

76 else: 

77 tensor_names.append(checkpoint_key) 

78 tensors.append(tensor) 

79 slice_specs.append(slice_spec) 

80 save_device = options.experimental_io_device or ( 

81 len(tensors) and saveable_object_util.set_cpu0(tensors[0].device)) 

82 save_device = save_device or "cpu:0" 

83 with ops.device(save_device): 

84 return io_ops.save_v2(file_prefix, tensor_names, slice_specs, tensors) 

85 

86 def restore(self, file_prefix, options=None): 

87 """Restore the saveable objects from a checkpoint with `file_prefix`. 

88 

89 Args: 

90 file_prefix: A string or scalar string Tensor containing the prefix for 

91 files to read from. 

92 options: Optional `CheckpointOptions` object. 

93 

94 Returns: 

95 A restored tensor dict (maps checkpoint_key -> slice_spec -> tensor). 

96 """ 

97 options = options or checkpoint_options.CheckpointOptions() 

98 tensor_names = [] 

99 tensor_dtypes = [] 

100 slice_specs = [] 

101 

102 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 

103 for slice_spec, tensor in tensor_slices.items(): 

104 tensor_dtypes.append(tensor.dtype) 

105 if isinstance(tensor, saveable_object.SaveSpec): 

106 slice_specs.append(tensor.slice_spec) 

107 tensor_names.append(tensor.name) 

108 else: 

109 slice_specs.append(slice_spec) 

110 tensor_names.append(checkpoint_key) 

111 

112 restore_device = options.experimental_io_device or "cpu:0" 

113 with ops.device(restore_device): 

114 restored_tensors = io_ops.restore_v2( 

115 file_prefix, tensor_names, slice_specs, tensor_dtypes) 

116 

117 restored_tensor_dict = {} 

118 for checkpoint_key, tensor_slices in self._tensor_slice_dict.items(): 

119 for slice_spec in tensor_slices: 

120 restored_tensor = restored_tensors.pop(0) 

121 restored_tensor_dict.setdefault(checkpoint_key, {})[slice_spec] = ( 

122 restored_tensor) 

123 return restored_tensor_dict 

124 

125 

126def sharded_filename(filename_tensor, shard, num_shards): 

127 """Append sharding information to a filename. 

128 

129 Args: 

130 filename_tensor: A string tensor. 

131 shard: Integer. The shard for the filename. 

132 num_shards: An int Tensor for the number of shards. 

133 

134 Returns: 

135 A string tensor. 

136 """ 

137 return gen_io_ops.sharded_filename(filename_tensor, shard, num_shards) 

138 

139 

140def registered_saver_filename(filename_tensor, saver_name): 

141 return string_ops.string_join( 

142 [filename_tensor, constant_op.constant(f"-{saver_name}")]) 

143 

144 

145def _get_mapped_registered_save_fn(fn, trackables, call_with_mapped_captures): 

146 """Converts the function to a python or tf.function with a single file arg.""" 

147 

148 def save_fn(file_prefix): 

149 return fn(trackables=trackables, file_prefix=file_prefix) 

150 if call_with_mapped_captures is None: 

151 return save_fn 

152 else: 

153 tf_fn = def_function.function(save_fn, autograph=False) 

154 concrete = tf_fn.get_concrete_function( 

155 file_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) 

156 

157 def save_fn_with_replaced_captures(file_prefix): 

158 return call_with_mapped_captures(concrete, [file_prefix]) 

159 

160 return save_fn_with_replaced_captures 

161 

162 

163def _get_mapped_registered_restore_fn(fn, trackables, 

164 call_with_mapped_captures): 

165 """Converts the function to a python or tf.function with a single file arg.""" 

166 

167 def restore_fn(merged_prefix): 

168 return fn(trackables=trackables, merged_prefix=merged_prefix) 

169 if call_with_mapped_captures is None: 

170 return restore_fn 

171 else: 

172 tf_fn = def_function.function(restore_fn, autograph=False) 

173 concrete = tf_fn.get_concrete_function( 

174 merged_prefix=tensor_spec.TensorSpec(shape=(), dtype=dtypes.string)) 

175 

176 def restore_fn_with_replaced_captures(merged_prefix): 

177 return call_with_mapped_captures(concrete, [merged_prefix]) 

178 

179 return restore_fn_with_replaced_captures 

180 

181 

182_restore_noop = lambda *args, **kwargs: None 

183 

184 

185class MultiDeviceSaver(object): 

186 """Saves checkpoints directly from multiple devices. 

187 

188 Note that this is a low-level utility which stores Tensors in the keys 

189 specified by `SaveableObject`s. Higher-level utilities for object-based 

190 checkpointing are built on top of it. 

191 """ 

192 

193 def __init__(self, 

194 serialized_tensors, 

195 registered_savers=None, 

196 call_with_mapped_captures=None): 

197 """Specify a list of `SaveableObject`s to save and restore. 

198 

199 Args: 

200 serialized_tensors: A dictionary mapping `Trackable` to a tensor dict, 

201 which maps checkpoint_key -> (slice_spec ->) -> Tensor/SaveSpec. The 

202 `Trackable` key is used to get the `restore_from_tensors` function, 

203 and may be `None` if the tensor is not meant to be restored. 

204 registered_savers: A dictionary mapping `registration.RegisteredSaver` 

205 namedtuples to a dictionary of named Trackables. The keys of the 

206 Trackable dictionary are string names that uniquely identify the 

207 Trackable in the checkpoint. 

208 call_with_mapped_captures: TODO 

209 """ 

210 # Keep these two data structures so that we can map restored tensors to 

211 # the Trackable restore functions. 

212 self._keys_to_restore_fn = {} 

213 self._restore_fn_to_keys = {} 

214 

215 # Extract serialized tensors and separate by device. 

216 tensors_by_device = {} # device -> checkpoint key -> (slice_spec ->) tensor 

217 

218 for obj, tensor_dict in serialized_tensors.items(): 

219 restore_fn = _restore_noop if obj is None else obj._restore_from_tensors 

220 

221 # Divide tensor_dict by device. 

222 for checkpoint_key, maybe_tensor in tensor_dict.items(): 

223 if not isinstance(maybe_tensor, dict): 

224 # Make sure that maybe_tensor is structured as {slice_spec -> tensor}. 

225 maybe_tensor = {"": maybe_tensor} 

226 

227 for slice_spec, tensor in maybe_tensor.items(): 

228 if (checkpoint_key, slice_spec) in self._keys_to_restore_fn: 

229 raise ValueError( 

230 "Recieved multiple tensors with the same checkpoint key and " 

231 "slice spec. This is invalid because one will overwrite the " 

232 "other in the checkpoint. This indicates a bug in the " 

233 "Checkpoint key-generation.") 

234 self._keys_to_restore_fn[(checkpoint_key, slice_spec)] = restore_fn 

235 self._restore_fn_to_keys.setdefault(restore_fn, []).append( 

236 (checkpoint_key, slice_spec)) 

237 

238 host_device = saveable_object_util.set_cpu0(tensor.device) 

239 (tensors_by_device 

240 .setdefault(host_device, {}) 

241 .setdefault(checkpoint_key, {})[slice_spec]) = tensor 

242 self._single_device_savers = { 

243 device: _SingleDeviceSaver(tensor_slice_dict) 

244 for device, tensor_slice_dict in tensors_by_device.items()} 

245 

246 self._registered_savers = {} 

247 if registered_savers: 

248 for registered_name, trackables in registered_savers.items(): 

249 save_fn = _get_mapped_registered_save_fn( 

250 registration.get_save_function(registered_name), 

251 trackables, call_with_mapped_captures) 

252 restore_fn = _get_mapped_registered_restore_fn( 

253 registration.get_restore_function(registered_name), 

254 trackables, call_with_mapped_captures) 

255 self._registered_savers[registered_name] = (save_fn, restore_fn) 

256 

257 @classmethod 

258 def from_saveables(cls, saveables, registered_savers=None, 

259 call_with_mapped_captures=None): 

260 serialized_tensors = object_identity.ObjectIdentityDictionary() 

261 for saveable in saveables: 

262 trackable = saveable_object_util.SaveableCompatibilityConverter( 

263 saveable, saveables=[saveable]) 

264 serialized_tensors[trackable] = trackable._serialize_to_tensors() # pylint: disable=protected-access 

265 return cls(serialized_tensors, registered_savers, call_with_mapped_captures) 

266 

267 def to_proto(self): 

268 """Serializes to a SaverDef referencing the current graph.""" 

269 filename_tensor = array_ops.placeholder( 

270 shape=[], dtype=dtypes.string, name="saver_filename") 

271 save_tensor = self._traced_save(filename_tensor) 

272 restore_op = self._traced_restore(filename_tensor).op 

273 return saver_pb2.SaverDef( 

274 filename_tensor_name=filename_tensor.name, 

275 save_tensor_name=save_tensor.name, 

276 restore_op_name=restore_op.name, 

277 version=saver_pb2.SaverDef.V2) 

278 

279 @def_function.function( 

280 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 

281 autograph=False) 

282 def _traced_save(self, file_prefix): 

283 save_op = self.save(file_prefix) 

284 with ops.device("cpu:0"): 

285 with ops.control_dependencies([save_op]): 

286 return array_ops.identity(file_prefix) 

287 

288 @def_function.function( 

289 input_signature=(tensor_spec.TensorSpec(shape=(), dtype=dtypes.string),), 

290 autograph=False) 

291 def _traced_restore(self, file_prefix): 

292 restore_ops = self.restore(file_prefix) 

293 with ops.device("cpu:0"): 

294 with ops.control_dependencies(restore_ops.values()): 

295 return array_ops.identity(file_prefix) 

296 

297 def save(self, file_prefix, options=None): 

298 """Save the saveable objects to a checkpoint with `file_prefix`. 

299 

300 Args: 

301 file_prefix: A string or scalar string Tensor containing the prefix to 

302 save under. 

303 options: Optional `CheckpointOptions` object. 

304 Returns: 

305 An `Operation`, or None when executing eagerly. 

306 """ 

307 options = options or checkpoint_options.CheckpointOptions() 

308 

309 # IMPLEMENTATION DETAILS: most clients should skip. 

310 # 

311 # Suffix for any well-formed "checkpoint_prefix", when sharded. 

312 # Transformations: 

313 # * Users pass in "save_path" in save() and restore(). Say "myckpt". 

314 # * checkpoint_prefix gets fed <save_path><sharded_suffix>. 

315 # 

316 # Example: 

317 # During runtime, a temporary directory is first created, which contains 

318 # files 

319 # 

320 # <train dir>/myckpt_temp/ 

321 # part-?????-of-?????{.index, .data-00000-of-00001} 

322 # 

323 # Before .save() finishes, they will be (hopefully, atomically) renamed to 

324 # 

325 # <train dir>/ 

326 # myckpt{.index, .data-?????-of-?????} 

327 # 

328 # Filesystems with eventual consistency (such as S3), don't need a 

329 # temporary location. Using a temporary directory in those cases might 

330 # cause situations where files are not available during copy. 

331 # 

332 # Users only need to interact with the user-specified prefix, which is 

333 # "<train dir>/myckpt" in this case. Save() and Restore() work with the 

334 # prefix directly, instead of any physical pathname. (On failure and 

335 # subsequent restore, an outdated and orphaned temporary directory can be 

336 # safely removed.) 

337 with ops.device("CPU"): 

338 sharded_suffix = array_ops.where( 

339 string_ops.regex_full_match(file_prefix, "^s3://.*"), 

340 constant_op.constant(".part"), 

341 constant_op.constant("_temp/part")) 

342 tmp_checkpoint_prefix = string_ops.string_join( 

343 [file_prefix, sharded_suffix]) 

344 registered_paths = { 

345 saver_name: registered_saver_filename(file_prefix, saver_name) 

346 for saver_name in self._registered_savers 

347 } 

348 

349 def save_fn(): 

350 saved_prefixes = [] 

351 # Save with the registered savers. These run before default savers due to 

352 # the API contract. 

353 for saver_name, (save_fn, _) in self._registered_savers.items(): 

354 maybe_saved_prefixes = save_fn(registered_paths[saver_name]) 

355 if maybe_saved_prefixes is not None: 

356 flattened_saved_prefixes = nest.flatten(maybe_saved_prefixes) 

357 if not all( 

358 tensor_util.is_tf_type(x) and x.dtype == dtypes.string 

359 for x in flattened_saved_prefixes): 

360 raise ValueError( 

361 "Registered saver must return a (maybe empty) list of " 

362 f"string type tensors. Got {maybe_saved_prefixes}.") 

363 saved_prefixes.extend(flattened_saved_prefixes) 

364 

365 # (Default saver) Save with single device savers. 

366 num_shards = len(self._single_device_savers) 

367 sharded_saves = [] 

368 num_shards_tensor = constant_op.constant(num_shards, name="num_shards") 

369 last_device = None 

370 for shard, (device, saver) in enumerate( 

371 sorted(self._single_device_savers.items())): 

372 last_device = device 

373 with ops.device(saveable_object_util.set_cpu0(device)): 

374 shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, 

375 num_shards_tensor) 

376 saved_prefixes.append(shard_prefix) 

377 with ops.device(device): 

378 # _SingleDeviceSaver will use the CPU device when necessary, but 

379 # initial read operations should be placed on the SaveableObject's 

380 # device. 

381 sharded_saves.append(saver.save(shard_prefix, options)) 

382 

383 with ops.control_dependencies(sharded_saves): 

384 # Merge on the io_device if specified, otherwise co-locates the merge op 

385 # with the last device used. 

386 merge_device = ( 

387 options.experimental_io_device or 

388 saveable_object_util.set_cpu0(last_device)) 

389 with ops.device(merge_device): 

390 # V2 format write path consists of a metadata merge step. Once 

391 # merged, attempts to delete the temporary directory, 

392 # "<user-fed prefix>_temp". 

393 return gen_io_ops.merge_v2_checkpoints( 

394 saved_prefixes, file_prefix, delete_old_dirs=True) 

395 

396 # Since this will causes a function re-trace on each save, limit this to the 

397 # cases where it is needed: eager and when there are multiple tasks/single 

398 # device savers. Note that the retrace is needed to ensure we pickup the 

399 # latest values of options like experimental_io_device. 

400 if context.executing_eagerly() and len(self._single_device_savers) > 1: 

401 # Explicitly place the identity op on the first device. 

402 @def_function.function(jit_compile=False) 

403 def tf_function_save(): 

404 save_fn() 

405 tf_function_save() 

406 else: 

407 return save_fn() 

408 

409 def restore(self, file_prefix, options=None): 

410 """Restore the saveable objects from a checkpoint with `file_prefix`. 

411 

412 Args: 

413 file_prefix: A string or scalar string Tensor containing the prefix for 

414 files to read from. 

415 options: Optional `CheckpointOptions` object. 

416 

417 Returns: 

418 When not run eagerly or when saving on a single device, returns a 

419 dictionary mapping from SaveableObject names to restore operations; 

420 otherwise, returns an empty dict. 

421 """ 

422 options = options or checkpoint_options.CheckpointOptions() 

423 

424 def restore_fn(): 

425 restore_fn_inputs = {} 

426 restore_fn_input_count = { 

427 fn: len(keys) for fn, keys in self._restore_fn_to_keys.items()} 

428 

429 restore_ops = {} 

430 # Sort by device name to avoid propagating non-deterministic dictionary 

431 # ordering in some Python versions. 

432 for device, saver in sorted(self._single_device_savers.items()): 

433 with ops.device(device): 

434 # Load values from checkpoint 

435 restored_tensor_dict = saver.restore(file_prefix, options) 

436 

437 # Map restored tensors to the corresponding restore_fn, and see if all 

438 # inputs have all been loaded. Call `restore_fn` if that is the case. 

439 for checkpoint_key, slice_and_tensor in restored_tensor_dict.items(): 

440 for slice_spec, tensor in slice_and_tensor.items(): 

441 restore_fn = self._keys_to_restore_fn[(checkpoint_key, 

442 slice_spec)] 

443 

444 # Processing the returned restored_tensor_dict to prepare for the 

445 # Trackable `restore` function. The `restore` function expects a 

446 # map of `string name (checkpoint_key) -> Tensor`. Unless there is 

447 # a slice_spec, in which case the map will be of 

448 # `string name (checkpoint_key)-> slice_spec -> Tensor`. 

449 if slice_spec: 

450 (restore_fn_inputs.setdefault(restore_fn, {}).setdefault( 

451 checkpoint_key, {})[slice_spec]) = tensor 

452 else: 

453 restore_fn_inputs.setdefault(restore_fn, 

454 {})[checkpoint_key] = tensor 

455 restore_fn_input_count[restore_fn] -= 1 

456 

457 if restore_fn_input_count[restore_fn] == 0: 

458 restored_tensors = {} 

459 # Extracts the substring after the "/.ATTRIBUTES/" in the 

460 # ckpt_key from restore_fn_inputs[restore_fn] to 

461 # restored_tensors. For example, if restore_fn_input[restore_fn] 

462 # is dict { "/.ATTIBUTES/a": Tensor}, restored_tensors will be 

463 # changed to dict {"a": Tensor} 

464 for ckpt_key, tensor in restore_fn_inputs[restore_fn].items(): 

465 restored_tensors[trackable_utils.extract_local_name( 

466 ckpt_key)] = tensor 

467 ret = restore_fn(restored_tensors) 

468 if isinstance(ret, dict): 

469 restore_ops.update(ret) 

470 # Run registered restore methods after the default restore ops. 

471 for _, (_, restore_fn) in self._registered_savers.items(): 

472 restore_fn(file_prefix) 

473 return restore_ops 

474 

475 has_custom_device_saver = any([ 

476 context.is_custom_device(d) for d in self._single_device_savers.keys() 

477 ]) 

478 # Since this will cause a function re-trace on each restore, limit this to 

479 # cases where it is needed: eager and when there are multiple tasks/single 

480 # device savers or any single device saver is a custom device. Note that the 

481 # retrace is needed to ensure we pickup the latest values of options like 

482 # experimental_io_device. 

483 # 

484 # We run in a function when there is a custom device saver because custom 

485 # devices, such as DTensor, usually do a sharded save and restore. 

486 # Doing a sharded save and restore requires knowledge about what shards 

487 # of variables we are restoring to. In practice, this means that custom 

488 # devices need the AssignVariableOps along with the Restore op within the 

489 # same graph to infer shapes and shard specs for Restore op. 

490 if context.executing_eagerly() and (len(self._single_device_savers) > 1 or 

491 has_custom_device_saver): 

492 @def_function.function(jit_compile=False, autograph=False) 

493 def tf_function_restore(): 

494 restore_fn() 

495 return {} 

496 

497 restore_ops = tf_function_restore() 

498 else: 

499 restore_ops = restore_fn() 

500 

501 return restore_ops