Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/saving/saveable_object_util.py: 18%

396 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for working with and creating SaveableObjects.""" 

16import functools 

17 

18from tensorflow.python.checkpoint import saveable_compat 

19from tensorflow.python.client import session 

20from tensorflow.python.eager import context 

21 

22from tensorflow.python.framework import constant_op 

23from tensorflow.python.framework import device as pydev 

24from tensorflow.python.framework import dtypes 

25from tensorflow.python.framework import ops 

26from tensorflow.python.framework import tensor_util 

27 

28from tensorflow.python.ops import array_ops 

29from tensorflow.python.ops import gen_control_flow_ops 

30from tensorflow.python.ops import ref_variable 

31from tensorflow.python.ops import resource_variable_ops 

32from tensorflow.python.ops import state_ops 

33from tensorflow.python.ops import variables 

34from tensorflow.python.platform import tf_logging as logging 

35from tensorflow.python.trackable import base as trackable 

36from tensorflow.python.trackable import python_state 

37from tensorflow.python.trackable import trackable_utils 

38from tensorflow.python.training.saving import saveable_object 

39from tensorflow.python.types import core 

40from tensorflow.python.util import compat 

41from tensorflow.python.util import nest 

42from tensorflow.python.util import object_identity 

43from tensorflow.python.util.tf_export import tf_export 

44 

45# Op names which identify variable reads which should be saved. 

46_VARIABLE_OPS = set(["Variable", 

47 "VariableV2", 

48 "AutoReloadVariable", 

49 "VarHandleOp", 

50 "ReadVariableOp"]) 

51 

52_REF_VARIABLE_OPS = frozenset(["Variable", "VariableV2", "AutoReloadVariable"]) 

53 

54 

55def set_cpu0(device_string): 

56 """Creates a new device string based on `device_string` but using /CPU:0. 

57 

58 If the device is already on /CPU:0 or it is a custom device, this is a no-op. 

59 

60 Args: 

61 device_string: A device string. 

62 

63 Returns: 

64 A device string. 

65 """ 

66 if context.is_custom_device(device_string): 

67 return device_string 

68 parsed_device = pydev.DeviceSpec.from_string(device_string) 

69 parsed_device = parsed_device.replace(device_type="CPU", device_index=0) 

70 return parsed_device.to_string() 

71 

72 

73class ReferenceVariableSaveable(saveable_object.SaveableObject): 

74 """SaveableObject implementation that handles reference variables.""" 

75 

76 def __init__(self, var, slice_spec, name): 

77 spec = saveable_object.SaveSpec(var, slice_spec, name, dtype=var.dtype) 

78 super(ReferenceVariableSaveable, self).__init__(var, [spec], name) 

79 

80 def restore(self, restored_tensors, restored_shapes): 

81 restored_tensor = restored_tensors[0] 

82 if restored_shapes is not None: 

83 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 

84 return state_ops.assign( 

85 self.op, 

86 restored_tensor, 

87 validate_shape=restored_shapes is None and 

88 self.op.get_shape().is_fully_defined()) 

89 

90 

91class ResourceVariableSaveable(saveable_object.SaveableObject): 

92 """SaveableObject implementation that handles ResourceVariables.""" 

93 

94 def __init__(self, var, slice_spec, name): 

95 self._var_device = var.device 

96 self._var_shape = var.shape 

97 if isinstance(var, ops.Tensor): 

98 self.handle_op = var.op.inputs[0] 

99 tensor = var 

100 elif resource_variable_ops.is_resource_variable(var): 

101 

102 def _read_variable_closure(v): 

103 def f(): 

104 with ops.device(v.device): 

105 if context.executing_eagerly() and not v.is_initialized(): 

106 # A SaveSpec tensor value of `None` indicates that the variable is 

107 # uninitialized. 

108 return None 

109 # Read the variable without making a copy to limit memory usage. 

110 x = v.read_value_no_copy() 

111 # To allow variables placed on non-CPU devices to be checkpointed, 

112 # we copy them to CPU on the same machine first. 

113 with ops.device("/device:CPU:0"): 

114 return array_ops.identity(x) 

115 

116 return f 

117 

118 self.handle_op = var.handle 

119 tensor = _read_variable_closure(var) 

120 else: 

121 raise ValueError( 

122 "Saveable is neither a resource variable nor a read operation." 

123 f" Got: {repr(var)}") 

124 spec = saveable_object.SaveSpec(tensor, slice_spec, name, 

125 dtype=var.dtype, device=var.device) 

126 super(ResourceVariableSaveable, self).__init__(var, [spec], name) 

127 

128 def restore(self, restored_tensors, restored_shapes): 

129 """Restores tensors. Raises ValueError if incompatible shape found.""" 

130 restored_tensor = restored_tensors[0] 

131 if restored_shapes is not None: 

132 restored_tensor = array_ops.reshape(restored_tensor, restored_shapes[0]) 

133 # Copy the restored tensor to the variable's device. 

134 with ops.device(self._var_device): 

135 restored_tensor = array_ops.identity(restored_tensor) 

136 try: 

137 assigned_variable = resource_variable_ops.shape_safe_assign_variable_handle( 

138 self.handle_op, self._var_shape, restored_tensor) 

139 except ValueError as e: 

140 raise ValueError( 

141 f"Received incompatible tensor with shape {restored_tensor.shape} " 

142 f"when attempting to restore variable with shape {self._var_shape} " 

143 f"and name {self.name}.") from e 

144 return assigned_variable 

145 

146 

147def _tensor_comes_from_variable(v): 

148 return isinstance(v, ops.Tensor) and v.op.type in _VARIABLE_OPS 

149 

150 

151def saveable_objects_for_op(op, name): 

152 """Create `SaveableObject`s from an operation. 

153 

154 Args: 

155 op: A variable, operation, or SaveableObject to coerce into a 

156 SaveableObject. 

157 name: A string name for the SaveableObject. 

158 

159 Yields: 

160 `SaveableObject`s which together save/restore `op`. 

161 

162 Raises: 

163 TypeError: If `name` is not a string. 

164 ValueError: For operations with no known conversion to SaveableObject. 

165 """ 

166 if not isinstance(name, str): 

167 raise TypeError( 

168 "names_to_saveables must be a dict mapping string names to " 

169 f"trackable operations. Name is not a string: {name}") 

170 if isinstance(op, saveable_object.SaveableObject): 

171 yield op 

172 elif isinstance(op, (list, tuple, variables.PartitionedVariable)): 

173 if isinstance(op, variables.PartitionedVariable): 

174 op = list(op) 

175 # A set of slices. 

176 slice_name = None 

177 # pylint: disable=protected-access 

178 for variable in op: 

179 if isinstance(variable, saveable_object.SaveableObject): 

180 yield variable 

181 continue 

182 if not isinstance(variable, variables.Variable): 

183 raise ValueError(f"Slices must all be Variables: {variable}") 

184 if not variable._save_slice_info: 

185 raise ValueError(f"Slices must all be slices: {variable}") 

186 if slice_name is None: 

187 slice_name = variable._save_slice_info.full_name 

188 elif slice_name != variable._save_slice_info.full_name: 

189 raise ValueError( 

190 f"Slices must all be from the same tensor: {slice_name} != " 

191 f"{variable._save_slice_info.full_name}") 

192 if variable.op.type in _REF_VARIABLE_OPS: 

193 yield ReferenceVariableSaveable( 

194 variable, variable._save_slice_info.spec, name) 

195 else: 

196 yield ResourceVariableSaveable(variable, variable._save_slice_info.spec, 

197 name) 

198 # pylint: enable=protected-access 

199 elif isinstance(op, trackable.Trackable) and not isinstance( 

200 op, variables.Variable): 

201 # pylint: disable=protected-access 

202 for attr, factory in saveable_objects_from_trackable( 

203 op, tf1_saver=True).items(): 

204 if attr == trackable.VARIABLE_VALUE_KEY: 

205 # Keep original name for classes masquerading as variables and 

206 # Trackables that define _serialize_to_tensors. 

207 full_name = name 

208 elif attr == trackable_utils.SERIALIZE_TO_TENSORS_NAME: 

209 full_name = name 

210 else: 

211 full_name = name + "_" + attr 

212 op = (factory(full_name) if callable(factory) else factory) 

213 for op in saveable_objects_for_op(op, op.name): 

214 yield op 

215 # pylint: enable=protected-access 

216 else: 

217 # A variable or tensor. 

218 if isinstance(op, resource_variable_ops.BaseResourceVariable): 

219 if op._in_graph_mode: # pylint: disable=protected-access 

220 variable = op._graph_element # pylint: disable=protected-access 

221 else: 

222 variable = op 

223 yield ResourceVariableSaveable(variable, "", name) 

224 else: 

225 if context.executing_eagerly(): 

226 raise ValueError("Can only save/restore ResourceVariables when " 

227 f"executing eagerly, got type: {type(op)}.") 

228 

229 variable = ops.convert_to_tensor(op, as_ref=True) 

230 if not _tensor_comes_from_variable(variable): 

231 raise TypeError( 

232 "names_to_saveables must be a dict mapping string " 

233 f"names to Tensors/Variables. Not a variable: {variable}") 

234 if variable.op.type in _REF_VARIABLE_OPS: 

235 yield ReferenceVariableSaveable(variable, "", name) 

236 else: 

237 yield ResourceVariableSaveable(variable, "", name) 

238 

239 

240def op_list_to_dict(op_list, convert_variable_to_tensor=True): 

241 """Create a dictionary of names to operation lists. 

242 

243 This method is only used when the variable name matters (e.g. when saving 

244 or restoring from a TF1 name-based checkpoint). In TF2, this can be called 

245 from `tf.train.Checkpoint.restore` when loading from a name-based checkpoint. 

246 

247 Args: 

248 op_list: A (nested) list, tuple, or set of Variables or SaveableObjects. 

249 convert_variable_to_tensor: Whether or not to convert single Variables 

250 with no slice info into Tensors. 

251 

252 Returns: 

253 A dictionary of names to the operations that must be saved under 

254 that name. Variables with save_slice_info are grouped together under the 

255 same key in no particular order. 

256 

257 Raises: 

258 TypeError: If the type of op_list or its elements is not supported. 

259 ValueError: If at least two saveables share the same name. 

260 """ 

261 if not isinstance(op_list, (list, tuple, set)): 

262 raise TypeError("Variables to save should be passed in a dict or a " 

263 f"list. Got {op_list}") 

264 # List casting is necessary to support sets. 

265 op_list = nest.flatten(list(op_list)) 

266 # When ResourceVariables are converted to Tensors, read ops are added to the 

267 # graph. Sorting the op_list ensures that the resulting graph is always 

268 # constructed in a deterministic way: 

269 op_list = sorted(op_list, key=lambda x: x.name) 

270 names_to_saveables = {} 

271 # pylint: disable=protected-access 

272 for var in op_list: 

273 resource_or_ref_variable = ( 

274 isinstance(var, resource_variable_ops.BaseResourceVariable) or 

275 isinstance(var, ref_variable.RefVariable)) 

276 

277 if isinstance(var, saveable_object.SaveableObject): 

278 names_to_saveables[var.name] = var 

279 elif isinstance(var, variables.PartitionedVariable): 

280 if var.name in names_to_saveables: 

281 raise ValueError( 

282 f"At least two variables have the same name: {var.name}") 

283 names_to_saveables[var.name] = var 

284 elif isinstance(var, variables.Variable) and var._save_slice_info: 

285 name = var._save_slice_info.full_name 

286 if name in names_to_saveables: 

287 if not isinstance(names_to_saveables[name], list): 

288 raise ValueError("Mixing slices and non-slices with the same name: " 

289 f"{name}") 

290 names_to_saveables[name].append(var) 

291 else: 

292 names_to_saveables[name] = [var] 

293 elif isinstance(var, trackable.Trackable) and not resource_or_ref_variable: 

294 trackable_saveables = [ 

295 (factory() if callable(factory) else factory) 

296 for factory in ( 

297 saveable_objects_from_trackable(var, tf1_saver=True).values())] 

298 names_to_saveables.update( 

299 op_list_to_dict(trackable_saveables)) 

300 else: 

301 # Variables (reference and resource) have an _in_graph_mode property 

302 # indicating whether they were created in a graph building context. We 

303 # also get Tensors when graph building, which do not have this property. 

304 if not getattr(var, "_in_graph_mode", True): 

305 if not isinstance(var, resource_variable_ops.BaseResourceVariable): 

306 raise ValueError( 

307 "Can only save/restore ResourceVariables when eager execution " 

308 f"is enabled. Got type: {type(var)}.") 

309 set_var = names_to_saveables.setdefault(var._shared_name, var) 

310 if set_var is not var: 

311 raise ValueError( 

312 "Two different ResourceVariable objects with the same " 

313 f"shared_name '{var._shared_name}' were passed to the Saver. This" 

314 " likely means that they were created in different Graphs or " 

315 "isolated contexts, and may not be checkpointed together.") 

316 else: 

317 if convert_variable_to_tensor: 

318 if isinstance(var, resource_variable_ops.BaseResourceVariable): 

319 var = var._graph_element # pylint: disable=protected-access 

320 else: 

321 var = ops.convert_to_tensor(var, as_ref=True) 

322 if not _tensor_comes_from_variable(var): 

323 raise TypeError(f"Variable to save is not a Variable: {var}") 

324 if var.op.type == "ReadVariableOp": 

325 name = var.op.inputs[0].op.name 

326 else: 

327 name = var.op.name 

328 if name in names_to_saveables: 

329 raise ValueError(f"At least two variables have the same name: {name}") 

330 names_to_saveables[name] = var 

331 

332 # pylint: enable=protected-access 

333 return names_to_saveables 

334 

335 

336def _add_saveable(saveables, seen_ops, saveable): 

337 """Adds the saveable to the saveables list. 

338 

339 Args: 

340 saveables: List to append the SaveableObject to. 

341 seen_ops: Set of the ops of the saveables already processed. Used to 

342 check that each saveable is only saved once. 

343 saveable: The saveable. 

344 

345 Raises: 

346 ValueError: If the saveable has already been processed. 

347 """ 

348 if saveable.op is not None and saveable.op in seen_ops: 

349 raise ValueError("The same saveable will be restored with two names: " 

350 f"{saveable.name}") 

351 saveables.append(saveable) 

352 seen_ops.add(saveable.op) 

353 

354 

355def validate_and_slice_inputs(names_to_saveables): 

356 """Returns the variables and names that will be used for a Saver. 

357 

358 Args: 

359 names_to_saveables: A dict (k, v) where k is the name of an operation and 

360 v is an operation to save or a BaseSaverBuilder.Saver. 

361 

362 Returns: 

363 A list of SaveableObjects. 

364 

365 Raises: 

366 TypeError: If any of the keys are not strings or any of the 

367 values are not one of Tensor or Variable or a trackable operation. 

368 ValueError: If the same operation is given in more than one value 

369 (this also applies to slices of SlicedVariables). 

370 """ 

371 saveables = [] 

372 seen_ops = object_identity.ObjectIdentitySet() 

373 for name, op in sorted(names_to_saveables.items(), 

374 # Avoid comparing ops, sort only by name. 

375 key=lambda x: x[0]): 

376 for converted_saveable_object in saveable_objects_for_op(op, name): 

377 _add_saveable(saveables, seen_ops, converted_saveable_object) 

378 return saveables 

379 

380 

381def validate_saveables_for_saved_model(saveables, obj): 

382 """Makes sure SaveableObjects are compatible with SavedModel.""" 

383 if isinstance(obj, python_state.PythonState): 

384 logging.warn( 

385 f"Note that object {obj} stores python values into the checkpoint. " 

386 "These values will not be restored when loading the SavedModel " 

387 "into python.") 

388 return [] 

389 if any(isinstance(saveable, trackable.NoRestoreSaveable) 

390 for saveable in saveables): 

391 return [] 

392 return saveables 

393 

394 

395class RestoredSaveableObject(saveable_object.SaveableObject): 

396 """SaveableObject restored from SavedModel using the traced save/restore.""" 

397 

398 def __init__(self, names_and_slices, save_function, restore_function, name): 

399 self.save_function = save_function 

400 self.restore_function = restore_function 

401 

402 if tensor_util.is_tf_type(name): 

403 name_tensor = name 

404 else: 

405 with ops.init_scope(): 

406 name_tensor = constant_op.constant(name) 

407 tensors = save_function(name_tensor) 

408 specs = [] 

409 for (str_name, str_slice), tensor_info in zip(names_and_slices, tensors): 

410 specs.append(saveable_object.SaveSpec(tensor_info["tensor"], str_slice, 

411 name + str_name)) 

412 super(RestoredSaveableObject, self).__init__(None, specs, name) 

413 

414 def restore(self, restored_tensors, restored_shapes): 

415 del restored_shapes # unused 

416 return self.restore_function( 

417 *[restored_tensors[i] for i in range(len(self.specs))]) 

418 

419 

420def recreate_saveable_objects(saveable_fn_by_name, temp_session): 

421 """Returns a dict of SaveableObject factories generated from loaded fns.""" 

422 

423 names_and_slices = [] 

424 

425 with ops.init_scope(): 

426 

427 for save_fn, _ in saveable_fn_by_name.values(): 

428 for tensor_info in save_fn(""): 

429 name = tensor_info["name"] 

430 slice_spec = tensor_info["slice_spec"] 

431 if not context.executing_eagerly(): 

432 sess = ops.get_default_session() 

433 if sess is None: 

434 if temp_session[0] is not None: 

435 sess = temp_session[0] 

436 else: 

437 sess = temp_session[0] = session.Session() 

438 name, slice_spec = sess.run([name, slice_spec]) 

439 names_and_slices.append(( 

440 _convert_to_string(name), 

441 _convert_to_string(slice_spec))) 

442 

443 saveable_factories = {} 

444 for name, (save_fn, restore_fn) in saveable_fn_by_name.items(): 

445 saveable_factories[name] = functools.partial( 

446 RestoredSaveableObject, 

447 names_and_slices=names_and_slices, 

448 save_function=save_fn, 

449 restore_function=restore_fn) 

450 return saveable_factories 

451 

452 

453def create_saveable_object(name, key, factory, call_with_mapped_captures): 

454 """Creates a SaveableObject while potentially in a different graph. 

455 

456 When creating the frozen saver for SavedModel, the save and restore ops are 

457 placed in a separate graph. Since RestoredSaveableObject uses tf.functions to 

458 save and restore, the function captures must be mapped to the new graph. 

459 

460 Args: 

461 name: Name of SaveableObject factory. 

462 key: Checkpoint key of this SaveableObject. 

463 factory: Factory method for creating the SaveableObject. 

464 call_with_mapped_captures: Helper that calls a tf.function while remapping 

465 the captures. 

466 

467 Returns: 

468 a SaveableObject. 

469 """ 

470 if call_with_mapped_captures is None: 

471 return factory(name=key) 

472 if name == trackable_utils.SERIALIZE_TO_TENSORS_NAME: 

473 return factory(name=key, 

474 call_with_mapped_captures=call_with_mapped_captures) 

475 elif is_factory_for_restored_saveable_object(factory): 

476 concrete_save_fn = factory.keywords["save_function"] 

477 

478 def save_fn(name): 

479 return call_with_mapped_captures(concrete_save_fn, [name]) 

480 

481 concrete_restore_fn = factory.keywords["restore_function"] 

482 

483 def restore_fn(*restored_tensors): 

484 return call_with_mapped_captures(concrete_restore_fn, restored_tensors) 

485 

486 return factory(save_function=save_fn, restore_function=restore_fn, 

487 name=key) 

488 else: 

489 return factory(name=key) 

490 

491 

492def is_factory_for_restored_saveable_object(factory): 

493 return (isinstance(factory, functools.partial) and 

494 factory.func is RestoredSaveableObject) 

495 

496 

497@tf_export("__internal__.tracking.saveable_objects_from_trackable", v1=[]) 

498def saveable_objects_from_trackable(obj, tf1_saver=False): 

499 """Returns SaveableObject factory dict from a Trackable. 

500 

501 Args: 

502 obj: A `Trackable` 

503 tf1_saver: Boolean, whether this is being called from a TF1 Saver ( 

504 `tf.compat.v1.train.Saver`). When this is True, the SaveableObject will 

505 be generated from `obj`'s legacy `_gather_saveables_for_checkpoint` fn. 

506 When saving with TF2, `Trackable._serialize_from_tensors` is preferred. 

507 

508 Returns: 

509 A dict mapping attribute names to SaveableObject factories (callables that 

510 produce a SaveableObject). 

511 """ 

512 if isinstance(obj, python_state.PythonState): 

513 return { 

514 python_state.PYTHON_STATE: 

515 functools.partial( 

516 _PythonStringStateSaveable, 

517 state_callback=obj.serialize, 

518 restore_callback=obj.deserialize) 

519 } 

520 

521 if tf1_saver: 

522 saveable_factories = obj._gather_saveables_for_checkpoint() # pylint: disable=protected-access 

523 if saveable_factories: 

524 return saveable_factories 

525 

526 if trackable_has_serialize_to_tensor(obj): 

527 

528 def create_saveable(name="", call_with_mapped_captures=None): 

529 save_fn = obj._serialize_to_tensors # pylint: disable=protected-access 

530 if (call_with_mapped_captures and 

531 isinstance(save_fn, core.ConcreteFunction)): 

532 tensor_dict = call_with_mapped_captures(save_fn, []) 

533 else: 

534 tensor_dict = save_fn() 

535 

536 specs = [] 

537 local_names = [] 

538 for tensor_name, maybe_tensor in tensor_dict.items(): 

539 local_names.append(tensor_name) 

540 

541 if not isinstance(maybe_tensor, dict): 

542 maybe_tensor = {"": maybe_tensor} 

543 

544 spec_name = name + trackable_utils.escape_local_name(tensor_name) 

545 # Create separate specs for each slice spec. 

546 for slice_spec, tensor in maybe_tensor.items(): 

547 if isinstance(tensor, saveable_object.SaveSpec): 

548 spec = tensor 

549 spec.name = spec_name 

550 spec.slice_spec = slice_spec 

551 else: 

552 spec = saveable_object.SaveSpec(tensor, slice_spec, spec_name) 

553 specs.append(spec) 

554 

555 return TrackableSaveable( 

556 obj=obj, 

557 specs=specs, 

558 name=name, 

559 local_names=local_names, 

560 prefix=saveable_compat.get_saveable_name(obj) or "", 

561 call_with_mapped_captures=call_with_mapped_captures) 

562 

563 return {trackable_utils.SERIALIZE_TO_TENSORS_NAME: create_saveable} 

564 else: 

565 return obj._gather_saveables_for_checkpoint() # pylint: disable=protected-access 

566 

567 

568class TrackableSaveable(saveable_object.SaveableObject): 

569 """A SaveableObject that defines `Trackable` checkpointing steps.""" 

570 

571 def __init__(self, obj, specs, name, local_names, prefix, 

572 call_with_mapped_captures=None): 

573 self._prefix = prefix 

574 self._local_names = local_names 

575 self._trackable = obj 

576 self._call_with_mapped_captures = call_with_mapped_captures 

577 super(TrackableSaveable, self).__init__(obj, specs, name) 

578 

579 def restore(self, restored_tensors, restored_shapes): 

580 del restored_shapes # Unused. 

581 restored_tensor_dict = {} 

582 for n, local_name in enumerate(self._local_names): 

583 restored_tensor_dict[local_name] = restored_tensors[n] 

584 

585 restore_fn = self._trackable._restore_from_tensors # pylint: disable=protected-access 

586 

587 # When restoring a RefVariable, call the restore function directly. 

588 # pylint: disable=protected-access 

589 if not ops.executing_eagerly_outside_functions() and any([ 

590 spec._tensor.op.type in _REF_VARIABLE_OPS 

591 for spec in self.specs 

592 if isinstance(spec._tensor, ops.Tensor)]): 

593 return restore_fn(restored_tensor_dict) 

594 # pylint: enable=protected-access 

595 

596 if (self._call_with_mapped_captures and 

597 isinstance(restore_fn, core.ConcreteFunction)): 

598 ret = self._call_with_mapped_captures(restore_fn, [restored_tensor_dict]) 

599 else: 

600 ret = restore_fn(restored_tensor_dict) 

601 if ret is not None: 

602 return ret 

603 return gen_control_flow_ops.no_op() 

604 

605 def get_proto_names_and_checkpoint_keys(self): 

606 return [(self._prefix + local_name, spec.name) 

607 for local_name, spec in zip(self._local_names, self.specs)] 

608 

609 

610class _PythonStringStateSaveable(saveable_object.SaveableObject): 

611 """Saves Python state in a checkpoint.""" 

612 

613 def __init__(self, name, state_callback, restore_callback): 

614 """Configure saving. 

615 

616 Args: 

617 name: The checkpoint key to write to. 

618 state_callback: A function taking no arguments which returns a string. 

619 This function is run every time a checkpoint is written. 

620 restore_callback: A function taking a Python string, used to restore 

621 state. 

622 """ 

623 

624 def _state_callback_wrapper(): 

625 with ops.init_scope(): 

626 return state_callback() 

627 

628 self._state_callback = _state_callback_wrapper 

629 self._restore_callback = restore_callback 

630 with ops.device("/cpu:0"): 

631 self._save_string = constant_op.constant("", dtype=dtypes.string) 

632 spec = saveable_object.SaveSpec( 

633 self._save_string, "", name, dtype=dtypes.string) 

634 super(_PythonStringStateSaveable, self).__init__(self._save_string, [spec], 

635 name) 

636 

637 def feed_dict_additions(self): 

638 """When running a graph, indicates fresh state to feed.""" 

639 return {self._save_string: self._state_callback()} 

640 

641 def freeze(self): 

642 """Create a frozen `SaveableObject` which saves the current state.""" 

643 

644 def _constant_state(): 

645 return constant_op.constant(self._state_callback(), dtype=dtypes.string) 

646 

647 return trackable.NoRestoreSaveable( 

648 tensor=_constant_state, 

649 dtype=dtypes.string, 

650 name=self.name, 

651 device="cpu:0") 

652 

653 

654def trackable_has_serialize_to_tensor(obj): 

655 """Returns whether obj's class has `_serialize_to_tensors` defined.""" 

656 try: 

657 if "_serialize_to_tensors" in obj.__dict__: 

658 # In some cases (e.g. restored objects), the object may have 

659 # `_serialize_to_tensors` even if the class does not. 

660 return True 

661 except (AttributeError, TypeError): 

662 # Data structure proxy wrappers don't have __dict__. 

663 pass 

664 

665 # Use MRO so that if a parent class has `_serialize_to_tensors`, but the 

666 # object class has not yet been migrated, we'll continue to use the obj 

667 # class's `_gather_saveables_for_checkpoint` method. 

668 for t in type(obj).mro(): 

669 if t is trackable.Trackable: 

670 # Base case. Return False since _serialize_to_tensors will raise a 

671 # NotImplemented Error. 

672 return False 

673 elif "_serialize_to_tensors" in t.__dict__: 

674 return True 

675 elif "_gather_saveables_for_checkpoint" in t.__dict__: 

676 return False 

677 return False 

678 

679 

680def _convert_to_string(x): 

681 return compat.as_str(tensor_util.constant_value(x)) 

682 

683 

684class SaveableCompatibilityConverter(trackable.Trackable): 

685 """Converts object's `SaveableObjects` to functions used in TF2 checkpointing. 

686 

687 A class that converts a Trackable object's `SaveableObjects` to save and 

688 restore functions with the same signatures as 

689 `Trackable._serialize_to_tensors` and `Trackable._restore_from_tensors`. 

690 This class also produces a method for filling the object proto. 

691 """ 

692 

693 __slots__ = ("_obj", "_saveables") 

694 

695 def __init__(self, obj, saveables): 

696 """Constructor. 

697 

698 Args: 

699 obj: A Trackable object. 

700 saveables: A list of saveables for `obj`. 

701 """ 

702 self._obj = obj 

703 self._saveables = saveables 

704 

705 @property 

706 def obj(self): 

707 return self._obj 

708 

709 @property 

710 def saveables(self): 

711 """Returns a list of SaveableObjects generated from the Trackable object.""" 

712 return self._saveables 

713 

714 def _serialize_to_tensors(self): 

715 """Returns a dict of tensors to serialize.""" 

716 return saveable_object_to_tensor_dict(self.saveables) 

717 

718 def _restore_from_tensors(self, restored_tensors): 

719 """Returns the restore ops defined in the Saveables.""" 

720 # Map restored tensors to the corresponding SaveableObjects, then call 

721 # restore. There must be an exact match between restored tensors and the 

722 # expected attributes. 

723 expected_keys = [] 

724 for saveable in self.saveables: 

725 expected_keys.extend( 

726 trackable_utils.extract_local_name(_convert_to_string(spec.name)) 

727 for spec in saveable.specs) 

728 if set(expected_keys) != restored_tensors.keys(): 

729 raise ValueError(f"Could not restore object {self._obj} because not all " 

730 "expected tensors were in the checkpoint." 

731 f"\n\tExpected: {expected_keys}" 

732 f"\n\tGot: {list(restored_tensors.keys())}") 

733 

734 return saveable_object_to_restore_fn(self.saveables)(restored_tensors) 

735 

736 

737def saveable_object_to_tensor_dict(saveables): 

738 """Converts a list of SaveableObjects to a tensor dictionary.""" 

739 tensor_dict = {} 

740 for saveable in saveables: 

741 for spec in saveable.specs: 

742 name = _convert_to_string(spec.name) 

743 slice_spec = _convert_to_string(spec.slice_spec) 

744 # Currently, tensor dict cannot handle callable tensor values (which 

745 # are needed for uninitialized variables), so keep using SaveSpec. 

746 tensor = spec if callable(spec._tensor) else spec._tensor # pylint: disable=protected-access 

747 if slice_spec: 

748 tensor_dict.setdefault(name, {})[slice_spec] = tensor 

749 else: 

750 tensor_dict[name] = tensor 

751 return tensor_dict 

752 

753 

754def saveable_object_to_restore_fn(saveables): 

755 """Generates `Trackable._restore_from_tensors` from SaveableObjects.""" 

756 

757 def _restore_from_tensors(restored_tensors): 

758 restore_ops = {} 

759 

760 for saveable in saveables: 

761 saveable_restored_tensors = [] 

762 for spec in saveable.specs: 

763 name = trackable_utils.extract_local_name(_convert_to_string(spec.name)) 

764 slice_spec = _convert_to_string(spec.slice_spec) 

765 

766 maybe_tensor = restored_tensors[name] 

767 if not isinstance(maybe_tensor, dict): 

768 maybe_tensor = {"": maybe_tensor} 

769 

770 saveable_restored_tensors.append(maybe_tensor[slice_spec]) 

771 restore_ops[saveable.name] = saveable.restore( 

772 saveable_restored_tensors, restored_shapes=None) 

773 return restore_ops 

774 

775 return _restore_from_tensors 

776 

777 

778def serialized_tensors_to_saveable_cache(serialized_tensors): 

779 """Converts a tensor dict to a SaveableObject cache. 

780 

781 Args: 

782 serialized_tensors: Map from Trackable to a tensor dict. The tensor dict 

783 maps checkpoint key (-> slice_spec) -> Tensor 

784 

785 Returns: 

786 A dict mapping Trackable objects to a map from local savable name to 

787 SaveableObject. 

788 """ 

789 saveables_cache = object_identity.ObjectIdentityWeakKeyDictionary() 

790 

791 for obj, tensor_dict in serialized_tensors.items(): 

792 if not tensor_dict: continue 

793 if isinstance(obj, SaveableCompatibilityConverter): 

794 trackable_obj = obj.obj 

795 saveables_cache[trackable_obj] = {} 

796 for saveable in obj.saveables: 

797 local_name = trackable_utils.extract_local_name(saveable.name) 

798 saveables_cache[trackable_obj][local_name] = [saveable] 

799 continue 

800 

801 specs = [] 

802 # The local names and prefixes are computed to ensure that the generated 

803 # SaveableObject can call `Trackable._restore_from_tensors()` 

804 local_names = [] 

805 prefix = saveable_compat.get_saveable_name(obj) or "" 

806 for checkpoint_key, maybe_tensor in tensor_dict.items(): 

807 # Make sure that `maybe_tensor` is a dict from `slice_spec` to `tensor`. 

808 if not isinstance(maybe_tensor, dict): 

809 maybe_tensor = {"": maybe_tensor} 

810 

811 for slice_spec, tensor in maybe_tensor.items(): 

812 if isinstance(tensor, saveable_object.SaveSpec): 

813 specs.append(tensor) 

814 else: 

815 specs.append(saveable_object.SaveSpec(tensor, 

816 slice_spec, 

817 checkpoint_key)) 

818 local_names.append(trackable_utils.extract_local_name(checkpoint_key, 

819 prefix)) 

820 

821 object_name = trackable_utils.extract_object_name( 

822 next(iter(tensor_dict.keys()))) 

823 saveables_cache[obj] = { 

824 trackable_utils.SERIALIZE_TO_TENSORS_NAME: [TrackableSaveable( 

825 obj, specs, object_name, local_names=local_names, prefix=prefix)]} 

826 return saveables_cache