Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/registration/registration.py: 55%

97 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Serialization Registration for SavedModel. 

16 

17revived_types registration will be migrated to this infrastructure. 

18 

19See the Advanced saving section in go/savedmodel-configurability. 

20This API is approved for TF internal use only. 

21""" 

22import collections 

23import re 

24 

25from tensorflow.python.util import tf_inspect 

26 

27 

28# Only allow valid file/directory characters 

29_VALID_REGISTERED_NAME = re.compile(r"^[a-zA-Z0-9._-]+$") 

30 

31 

32class _PredicateRegistry(object): 

33 """Registry with predicate-based lookup. 

34 

35 See the documentation for `register_checkpoint_saver` and 

36 `register_serializable` for reasons why predicates are required over a 

37 class-based registry. 

38 

39 Since this class is used for global registries, each object must be registered 

40 to unique names (an error is raised if there are naming conflicts). The lookup 

41 searches the predicates in reverse order, so that later-registered predicates 

42 are executed first. 

43 """ 

44 __slots__ = ("_registry_name", "_registered_map", "_registered_predicates", 

45 "_registered_names") 

46 

47 def __init__(self, name): 

48 self._registry_name = name 

49 # Maps registered name -> object 

50 self._registered_map = {} 

51 # Maps registered name -> predicate 

52 self._registered_predicates = {} 

53 # Stores names in the order of registration 

54 self._registered_names = [] 

55 

56 @property 

57 def name(self): 

58 return self._registry_name 

59 

60 def register(self, package, name, predicate, candidate): 

61 """Registers a candidate object under the package, name and predicate.""" 

62 if not isinstance(package, str) or not isinstance(name, str): 

63 raise TypeError( 

64 f"The package and name registered to a {self.name} must be strings, " 

65 f"got: package={type(package)}, name={type(name)}") 

66 if not callable(predicate): 

67 raise TypeError( 

68 f"The predicate registered to a {self.name} must be callable, " 

69 f"got: {type(predicate)}") 

70 registered_name = package + "." + name 

71 if not _VALID_REGISTERED_NAME.match(registered_name): 

72 raise ValueError( 

73 f"Invalid registered {self.name}. Please check that the package and " 

74 f"name follow the regex '{_VALID_REGISTERED_NAME.pattern}': " 

75 f"(package='{package}', name='{name}')") 

76 if registered_name in self._registered_map: 

77 raise ValueError( 

78 f"The name '{registered_name}' has already been registered to a " 

79 f"{self.name}. Found: {self._registered_map[registered_name]}") 

80 

81 self._registered_map[registered_name] = candidate 

82 self._registered_predicates[registered_name] = predicate 

83 self._registered_names.append(registered_name) 

84 

85 def lookup(self, obj): 

86 """Looks up the registered object using the predicate. 

87 

88 Args: 

89 obj: Object to pass to each of the registered predicates to look up the 

90 registered object. 

91 Returns: 

92 The object registered with the first passing predicate. 

93 Raises: 

94 LookupError if the object does not match any of the predicate functions. 

95 """ 

96 return self._registered_map[self.get_registered_name(obj)] 

97 

98 def name_lookup(self, registered_name): 

99 """Looks up the registered object using the registered name.""" 

100 try: 

101 return self._registered_map[registered_name] 

102 except KeyError: 

103 raise LookupError(f"The {self.name} registry does not have name " 

104 f"'{registered_name}' registered.") 

105 

106 def get_registered_name(self, obj): 

107 for registered_name in reversed(self._registered_names): 

108 predicate = self._registered_predicates[registered_name] 

109 if predicate(obj): 

110 return registered_name 

111 raise LookupError(f"Could not find matching {self.name} for {type(obj)}.") 

112 

113 def get_predicate(self, registered_name): 

114 try: 

115 return self._registered_predicates[registered_name] 

116 except KeyError: 

117 raise LookupError(f"The {self.name} registry does not have name " 

118 f"'{registered_name}' registered.") 

119 

120 def get_registrations(self): 

121 return self._registered_predicates 

122 

123_class_registry = _PredicateRegistry("serializable class") 

124_saver_registry = _PredicateRegistry("checkpoint saver") 

125 

126 

127def get_registered_class_name(obj): 

128 try: 

129 return _class_registry.get_registered_name(obj) 

130 except LookupError: 

131 return None 

132 

133 

134def get_registered_class(registered_name): 

135 try: 

136 return _class_registry.name_lookup(registered_name) 

137 except LookupError: 

138 return None 

139 

140 

141def register_serializable(package="Custom", name=None, predicate=None): # pylint: disable=unused-argument 

142 """Decorator for registering a serializable class. 

143 

144 THIS METHOD IS STILL EXPERIMENTAL AND MAY CHANGE AT ANY TIME. 

145 

146 Registered classes will be saved with a name generated by combining the 

147 `package` and `name` arguments. When loading a SavedModel, modules saved with 

148 this registered name will be created using the `_deserialize_from_proto` 

149 method. 

150 

151 By default, only direct instances of the registered class will be saved/ 

152 restored with the `serialize_from_proto`/`deserialize_from_proto` methods. To 

153 extend the registration to subclasses, use the `predicate argument`: 

154 

155 ```python 

156 class A(tf.Module): 

157 pass 

158 

159 register_serializable( 

160 package="Example", predicate=lambda obj: isinstance(obj, A))(A) 

161 ``` 

162 

163 Args: 

164 package: The package that this class belongs to. 

165 name: The name to serialize this class under in this package. If None, the 

166 class's name will be used. 

167 predicate: An optional function that takes a single Trackable argument, and 

168 determines whether that object should be serialized with this `package` 

169 and `name`. The default predicate checks whether the object's type exactly 

170 matches the registered class. Predicates are executed in the reverse order 

171 that they are added (later registrations are checked first). 

172 

173 Returns: 

174 A decorator that registers the decorated class with the passed names and 

175 predicate. 

176 """ 

177 def decorator(arg): 

178 """Registers a class with the serialization framework.""" 

179 nonlocal predicate 

180 if not tf_inspect.isclass(arg): 

181 raise TypeError("Registered serializable must be a class: {}".format(arg)) 

182 

183 class_name = name if name is not None else arg.__name__ 

184 if predicate is None: 

185 predicate = lambda x: isinstance(x, arg) 

186 _class_registry.register(package, class_name, predicate, arg) 

187 return arg 

188 

189 return decorator 

190 

191 

192RegisteredSaver = collections.namedtuple( 

193 "RegisteredSaver", ["name", "predicate", "save_fn", "restore_fn"]) 

194_REGISTERED_SAVERS = {} 

195_REGISTERED_SAVER_NAMES = [] # Stores names in the order of registration 

196 

197 

198def register_checkpoint_saver(package="Custom", 

199 name=None, 

200 predicate=None, 

201 save_fn=None, 

202 restore_fn=None, 

203 strict_predicate_restore=True): 

204 """Registers functions which checkpoints & restores objects with custom steps. 

205 

206 If you have a class that requires complicated coordination between multiple 

207 objects when checkpointing, then you will need to register a custom saver 

208 and restore function. An example of this is a custom Variable class that 

209 splits the variable across different objects and devices, and needs to write 

210 checkpoints that are compatible with different configurations of devices. 

211 

212 The registered save and restore functions are used in checkpoints and 

213 SavedModel. 

214 

215 Please make sure you are familiar with the concepts in the [Checkpointing 

216 guide](https://www.tensorflow.org/guide/checkpoint), and ops used to save the 

217 V2 checkpoint format: 

218 

219 * io_ops.SaveV2 

220 * io_ops.MergeV2Checkpoints 

221 * io_ops.RestoreV2 

222 

223 **Predicate** 

224 

225 The predicate is a filter that will run on every `Trackable` object connected 

226 to the root object. This function determines whether a `Trackable` should use 

227 the registered functions. 

228 

229 Example: `lambda x: isinstance(x, CustomClass)` 

230 

231 **Custom save function** 

232 

233 This is how checkpoint saving works normally: 

234 1. Gather all of the Trackables with saveable values. 

235 2. For each Trackable, gather all of the saveable tensors. 

236 3. Save checkpoint shards (grouping tensors by device) with SaveV2 

237 4. Merge the shards with MergeCheckpointV2. This combines all of the shard's 

238 metadata, and renames them to follow the standard shard pattern. 

239 

240 When a saver is registered, Trackables that pass the registered `predicate` 

241 are automatically marked as having saveable values. Next, the custom save 

242 function replaces steps 2 and 3 of the saving process. Finally, the shards 

243 returned by the custom save function are merged with the other shards. 

244 

245 The save function takes in a dictionary of `Trackables` and a `file_prefix` 

246 string. The function should save checkpoint shards using the SaveV2 op, and 

247 list of the shard prefixes. SaveV2 is currently required to work a correctly, 

248 because the code merges all of the returned shards, and the `restore_fn` will 

249 only be given the prefix of the merged checkpoint. If you need to be able to 

250 save and restore from unmerged shards, please file a feature request. 

251 

252 Specification and example of the save function: 

253 

254 ``` 

255 def save_fn(trackables, file_prefix): 

256 # trackables: A dictionary mapping unique string identifiers to trackables 

257 # file_prefix: A unique file prefix generated using the registered name. 

258 ... 

259 # Gather the tensors to save. 

260 ... 

261 io_ops.SaveV2(file_prefix, tensor_names, shapes_and_slices, tensors) 

262 return file_prefix # Returns a tensor or a list of string tensors 

263 ``` 

264 

265 The save function is executed before the unregistered save ops. 

266 

267 **Custom restore function** 

268 

269 Normal checkpoint restore behavior: 

270 1. Gather all of the Trackables that have saveable values. 

271 2. For each Trackable, get the names of the desired tensors to extract from 

272 the checkpoint. 

273 3. Use RestoreV2 to read the saved values, and pass the restored tensors to 

274 the corresponding Trackables. 

275 

276 The custom restore function replaces steps 2 and 3. 

277 

278 The restore function also takes a dictionary of `Trackables` and a 

279 `merged_prefix` string. The `merged_prefix` is different from the 

280 `file_prefix`, since it contains the renamed shard paths. To read from the 

281 merged checkpoint, you must use `RestoreV2(merged_prefix, ...)`. 

282 

283 Specification: 

284 

285 ``` 

286 def restore_fn(trackables, merged_prefix): 

287 # trackables: A dictionary mapping unique string identifiers to Trackables 

288 # merged_prefix: File prefix of the merged shard names. 

289 

290 restored_tensors = io_ops.restore_v2( 

291 merged_prefix, tensor_names, shapes_and_slices, dtypes) 

292 ... 

293 # Restore the checkpoint values for the given Trackables. 

294 ``` 

295 

296 The restore function is executed after the non-registered restore ops. 

297 

298 Args: 

299 package: Optional, the package that this class belongs to. 

300 name: (Required) The name of this saver, which is saved to the checkpoint. 

301 When a checkpoint is restored, the name and package are used to find the 

302 the matching restore function. The name and package are also used to 

303 generate a unique file prefix that is passed to the save_fn. 

304 predicate: (Required) A function that returns a boolean indicating whether a 

305 `Trackable` object should be checkpointed with this function. Predicates 

306 are executed in the reverse order that they are added (later registrations 

307 are checked first). 

308 save_fn: (Required) A function that takes a dictionary of trackables and a 

309 file prefix as the arguments, writes the checkpoint shards for the given 

310 Trackables, and returns the list of shard prefixes. 

311 restore_fn: (Required) A function that takes a dictionary of trackables and 

312 a file prefix as the arguments and restores the trackable values. 

313 strict_predicate_restore: If this is `True` (default), then an error will be 

314 raised if the predicate fails during checkpoint restoration. If this is 

315 `True`, checkpoint restoration will skip running the restore function. 

316 This value is generally set to `False` when the predicate does not pass on 

317 the Trackables after being saved/loaded from SavedModel. 

318 

319 Raises: 

320 ValueError: if the package and name are already registered. 

321 """ 

322 if not callable(save_fn): 

323 raise TypeError(f"The save_fn must be callable, got: {type(save_fn)}") 

324 if not callable(restore_fn): 

325 raise TypeError(f"The restore_fn must be callable, got: {type(restore_fn)}") 

326 

327 _saver_registry.register(package, name, predicate, (save_fn, restore_fn, 

328 strict_predicate_restore)) 

329 

330 

331def get_registered_saver_name(trackable): 

332 """Returns the name of the registered saver to use with Trackable.""" 

333 try: 

334 return _saver_registry.get_registered_name(trackable) 

335 except LookupError: 

336 return None 

337 

338 

339def get_save_function(registered_name): 

340 """Returns save function registered to name.""" 

341 return _saver_registry.name_lookup(registered_name)[0] 

342 

343 

344def get_restore_function(registered_name): 

345 """Returns restore function registered to name.""" 

346 return _saver_registry.name_lookup(registered_name)[1] 

347 

348 

349def get_strict_predicate_restore(registered_name): 

350 """Returns if the registered restore can be ignored if the predicate fails.""" 

351 return _saver_registry.name_lookup(registered_name)[2] 

352 

353 

354def validate_restore_function(trackable, registered_name): 

355 """Validates whether the trackable can be restored with the saver. 

356 

357 When using a checkpoint saved with a registered saver, that same saver must 

358 also be also registered when loading. The name of that saver is saved to the 

359 checkpoint and set in the `registered_name` arg. 

360 

361 Args: 

362 trackable: A `Trackable` object. 

363 registered_name: String name of the expected registered saver. This argument 

364 should be set using the name saved in a checkpoint. 

365 

366 Raises: 

367 ValueError if the saver could not be found, or if the predicate associated 

368 with the saver does not pass. 

369 """ 

370 try: 

371 _saver_registry.name_lookup(registered_name) 

372 except LookupError: 

373 raise ValueError( 

374 f"Error when restoring object {trackable} from checkpoint. This " 

375 "object was saved using a registered saver named " 

376 f"'{registered_name}', but this saver cannot be found in the " 

377 "current context.") 

378 if not _saver_registry.get_predicate(registered_name)(trackable): 

379 raise ValueError( 

380 f"Object {trackable} was saved with the registered saver named " 

381 f"'{registered_name}'. However, this saver cannot be used to restore the " 

382 "object because the predicate does not pass.")