Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/nested_structure_coder.py: 39%

252 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Module that encodes (decodes) nested structures into (from) protos. 

16 

17The intended use is to serialize everything needed to restore a `Function` that 

18was saved into a SavedModel. This may include concrete function inputs and 

19outputs, signatures, function specs, etc. 

20 

21Example use: 

22# Encode into proto. 

23signature_proto = nested_structure_coder.encode_structure( 

24 function.input_signature) 

25# Decode into a Python object. 

26restored_signature = nested_structure_coder.decode_proto(signature_proto) 

27""" 

28 

29import collections 

30import functools 

31import warnings 

32 

33from tensorflow.core.protobuf import struct_pb2 

34from tensorflow.python.framework import dtypes 

35from tensorflow.python.framework import type_spec_registry 

36from tensorflow.python.types import internal 

37from tensorflow.python.util import compat 

38from tensorflow.python.util import nest 

39from tensorflow.python.util.compat import collections_abc 

40from tensorflow.python.util.tf_export import tf_export 

41 

42 

43class NotEncodableError(Exception): 

44 """Error raised when a coder cannot encode an object.""" 

45 

46 

47def register_codec(x): 

48 """Registers a codec to use for encoding/decoding. 

49 

50 Args: 

51 x: The codec object to register. The object must implement can_encode, 

52 do_encode, can_decode, and do_decode. See the various _*Codec classes for 

53 examples. 

54 """ 

55 _codecs.append(x) 

56 

57 

58def _get_encoders(): 

59 return [(c.can_encode, c.do_encode) for c in _codecs] 

60 

61 

62def _get_decoders(): 

63 return [(c.can_decode, c.do_decode) for c in _codecs] 

64 

65 

66def _map_structure(pyobj, coders): 

67 # Iterate through the codecs in the reverse order they were registered in, 

68 # as the most specific codec should be checked first. 

69 for can, do in reversed(coders): 

70 if can(pyobj): 

71 recursion_fn = functools.partial(_map_structure, coders=coders) 

72 return do(pyobj, recursion_fn) 

73 raise NotEncodableError( 

74 f"No encoder for object {str(pyobj)} of type {type(pyobj)}.") 

75 

76 

77@tf_export("__internal__.saved_model.encode_structure", v1=[]) 

78def encode_structure(nested_structure): 

79 """Encodes nested structures composed of encodable types into a proto. 

80 

81 Args: 

82 nested_structure: Structure to encode. 

83 

84 Returns: 

85 Encoded proto. 

86 

87 Raises: 

88 NotEncodableError: For values for which there are no encoders. 

89 """ 

90 return _map_structure(nested_structure, _get_encoders()) 

91 

92 

93def can_encode(nested_structure): 

94 """Determines whether a nested structure can be encoded into a proto. 

95 

96 Args: 

97 nested_structure: Structure to encode. 

98 

99 Returns: 

100 True if the nested structured can be encoded. 

101 """ 

102 try: 

103 encode_structure(nested_structure) 

104 except NotEncodableError: 

105 return False 

106 return True 

107 

108 

109@tf_export("__internal__.saved_model.decode_proto", v1=[]) 

110def decode_proto(proto): 

111 """Decodes proto representing a nested structure. 

112 

113 Args: 

114 proto: Proto to decode. 

115 

116 Returns: 

117 Decoded structure. 

118 

119 Raises: 

120 NotEncodableError: For values for which there are no encoders. 

121 """ 

122 return _map_structure(proto, _get_decoders()) 

123 

124 

125class _ListCodec: 

126 """Codec for lists.""" 

127 

128 def can_encode(self, pyobj): 

129 return isinstance(pyobj, list) 

130 

131 def do_encode(self, list_value, encode_fn): 

132 encoded_list = struct_pb2.StructuredValue() 

133 encoded_list.list_value.CopyFrom(struct_pb2.ListValue()) 

134 for element in list_value: 

135 encoded_list.list_value.values.add().CopyFrom(encode_fn(element)) 

136 return encoded_list 

137 

138 def can_decode(self, value): 

139 return value.HasField("list_value") 

140 

141 def do_decode(self, value, decode_fn): 

142 return [decode_fn(element) for element in value.list_value.values] 

143 

144 

145def _is_tuple(obj): 

146 return not _is_named_tuple(obj) and isinstance(obj, tuple) 

147 

148 

149def _is_named_tuple(instance): 

150 """Returns True iff `instance` is a `namedtuple`. 

151 

152 Args: 

153 instance: An instance of a Python object. 

154 

155 Returns: 

156 True if `instance` is a `namedtuple`. 

157 """ 

158 if not isinstance(instance, tuple): 

159 return False 

160 return (hasattr(instance, "_fields") and 

161 isinstance(instance._fields, collections_abc.Sequence) and 

162 all(isinstance(f, str) for f in instance._fields)) 

163 

164 

165class _TupleCodec: 

166 """Codec for tuples.""" 

167 

168 def can_encode(self, pyobj): 

169 return _is_tuple(pyobj) 

170 

171 def do_encode(self, tuple_value, encode_fn): 

172 encoded_tuple = struct_pb2.StructuredValue() 

173 encoded_tuple.tuple_value.CopyFrom(struct_pb2.TupleValue()) 

174 for element in tuple_value: 

175 encoded_tuple.tuple_value.values.add().CopyFrom(encode_fn(element)) 

176 return encoded_tuple 

177 

178 def can_decode(self, value): 

179 return value.HasField("tuple_value") 

180 

181 def do_decode(self, value, decode_fn): 

182 return tuple(decode_fn(element) for element in value.tuple_value.values) 

183 

184 

185class _DictCodec: 

186 """Codec for dicts.""" 

187 

188 def can_encode(self, pyobj): 

189 return isinstance(pyobj, dict) 

190 

191 def do_encode(self, dict_value, encode_fn): 

192 encoded_dict = struct_pb2.StructuredValue() 

193 encoded_dict.dict_value.CopyFrom(struct_pb2.DictValue()) 

194 for key, value in dict_value.items(): 

195 encoded_dict.dict_value.fields[key].CopyFrom(encode_fn(value)) 

196 return encoded_dict 

197 

198 def can_decode(self, value): 

199 return value.HasField("dict_value") 

200 

201 def do_decode(self, value, decode_fn): 

202 return {key: decode_fn(val) for key, val in value.dict_value.fields.items()} 

203 

204 

205class _NamedTupleCodec: 

206 """Codec for namedtuples. 

207 

208 Encoding and decoding a namedtuple reconstructs a namedtuple with a different 

209 actual Python type, but with the same `typename` and `fields`. 

210 """ 

211 

212 def can_encode(self, pyobj): 

213 return _is_named_tuple(pyobj) 

214 

215 def do_encode(self, named_tuple_value, encode_fn): 

216 encoded_named_tuple = struct_pb2.StructuredValue() 

217 encoded_named_tuple.named_tuple_value.CopyFrom(struct_pb2.NamedTupleValue()) 

218 encoded_named_tuple.named_tuple_value.name = \ 

219 named_tuple_value.__class__.__name__ 

220 for key in named_tuple_value._fields: 

221 pair = encoded_named_tuple.named_tuple_value.values.add() 

222 pair.key = key 

223 pair.value.CopyFrom(encode_fn(named_tuple_value._asdict()[key])) 

224 return encoded_named_tuple 

225 

226 def can_decode(self, value): 

227 return value.HasField("named_tuple_value") 

228 

229 def do_decode(self, value, decode_fn): 

230 key_value_pairs = value.named_tuple_value.values 

231 items = [(pair.key, decode_fn(pair.value)) for pair in key_value_pairs] 

232 named_tuple_type = collections.namedtuple(value.named_tuple_value.name, 

233 [item[0] for item in items]) 

234 return named_tuple_type(**dict(items)) 

235 

236 

237class _Float64Codec: 

238 """Codec for floats.""" 

239 

240 def can_encode(self, pyobj): 

241 return isinstance(pyobj, float) 

242 

243 def do_encode(self, float64_value, encode_fn): 

244 del encode_fn 

245 value = struct_pb2.StructuredValue() 

246 value.float64_value = float64_value 

247 return value 

248 

249 def can_decode(self, value): 

250 return value.HasField("float64_value") 

251 

252 def do_decode(self, value, decode_fn): 

253 del decode_fn 

254 return value.float64_value 

255 

256 

257class _Int64Codec: 

258 """Codec for Python integers (limited to 64 bit values).""" 

259 

260 def can_encode(self, pyobj): 

261 return not isinstance(pyobj, bool) and isinstance(pyobj, int) 

262 

263 def do_encode(self, int_value, encode_fn): 

264 del encode_fn 

265 value = struct_pb2.StructuredValue() 

266 value.int64_value = int_value 

267 return value 

268 

269 def can_decode(self, value): 

270 return value.HasField("int64_value") 

271 

272 def do_decode(self, value, decode_fn): 

273 del decode_fn 

274 return int(value.int64_value) 

275 

276 

277class _StringCodec: 

278 """Codec for strings. 

279 

280 See StructuredValue.string_value in proto/struct.proto for more detailed 

281 explanation. 

282 """ 

283 

284 def can_encode(self, pyobj): 

285 return isinstance(pyobj, str) 

286 

287 def do_encode(self, string_value, encode_fn): 

288 del encode_fn 

289 value = struct_pb2.StructuredValue() 

290 value.string_value = string_value 

291 return value 

292 

293 def can_decode(self, value): 

294 return value.HasField("string_value") 

295 

296 def do_decode(self, value, decode_fn): 

297 del decode_fn 

298 return compat.as_str(value.string_value) 

299 

300 

301class _NoneCodec: 

302 """Codec for None.""" 

303 

304 def can_encode(self, pyobj): 

305 return pyobj is None 

306 

307 def do_encode(self, none_value, encode_fn): 

308 del encode_fn, none_value 

309 value = struct_pb2.StructuredValue() 

310 value.none_value.CopyFrom(struct_pb2.NoneValue()) 

311 return value 

312 

313 def can_decode(self, value): 

314 return value.HasField("none_value") 

315 

316 def do_decode(self, value, decode_fn): 

317 del decode_fn, value 

318 return None 

319 

320 

321class _BoolCodec: 

322 """Codec for booleans.""" 

323 

324 def can_encode(self, pyobj): 

325 return isinstance(pyobj, bool) 

326 

327 def do_encode(self, bool_value, encode_fn): 

328 del encode_fn 

329 value = struct_pb2.StructuredValue() 

330 value.bool_value = bool_value 

331 return value 

332 

333 def can_decode(self, value): 

334 return value.HasField("bool_value") 

335 

336 def do_decode(self, value, decode_fn): 

337 del decode_fn 

338 return value.bool_value 

339 

340 

341class _TensorTypeCodec: 

342 """Codec for `TensorType`.""" 

343 

344 def can_encode(self, pyobj): 

345 return isinstance(pyobj, dtypes.DType) 

346 

347 def do_encode(self, tensor_dtype_value, encode_fn): 

348 del encode_fn 

349 encoded_tensor_type = struct_pb2.StructuredValue() 

350 encoded_tensor_type.tensor_dtype_value = tensor_dtype_value.as_datatype_enum 

351 return encoded_tensor_type 

352 

353 def can_decode(self, value): 

354 return value.HasField("tensor_dtype_value") 

355 

356 def do_decode(self, value, decode_fn): 

357 del decode_fn 

358 return dtypes.DType(value.tensor_dtype_value) 

359 

360 

361class BuiltInTypeSpecCodec: 

362 """Codec for built-in `TypeSpec` classes. 

363 

364 Built-in TypeSpec's that do not require a custom codec implementation 

365 register themselves by instantiating this class and passing it to 

366 register_codec. 

367 

368 Attributes: 

369 type_spec_class: The built-in TypeSpec class that the 

370 codec is instantiated for. 

371 type_spec_proto_enum: The proto enum value for the built-in TypeSpec class. 

372 """ 

373 

374 _BUILT_IN_TYPE_SPEC_PROTOS = [] 

375 _BUILT_IN_TYPE_SPECS = [] 

376 

377 def __init__(self, type_spec_class, type_spec_proto_enum): 

378 if not issubclass(type_spec_class, internal.TypeSpec): 

379 raise ValueError( 

380 f"The type '{type_spec_class}' does not subclass tf.TypeSpec.") 

381 

382 if type_spec_class in self._BUILT_IN_TYPE_SPECS: 

383 raise ValueError( 

384 f"The type '{type_spec_class}' already has an instantiated codec.") 

385 

386 if type_spec_proto_enum in self._BUILT_IN_TYPE_SPEC_PROTOS: 

387 raise ValueError( 

388 f"The proto value '{type_spec_proto_enum}' is already registered." 

389 ) 

390 

391 if (not isinstance(type_spec_proto_enum, int) 

392 or type_spec_proto_enum <= 0 

393 or type_spec_proto_enum > 10): 

394 raise ValueError(f"The proto value '{type_spec_proto_enum}' is invalid.") 

395 

396 self.type_spec_class = type_spec_class 

397 self.type_spec_proto_enum = type_spec_proto_enum 

398 

399 self._BUILT_IN_TYPE_SPECS.append(type_spec_class) 

400 self._BUILT_IN_TYPE_SPEC_PROTOS.append(type_spec_proto_enum) 

401 

402 def can_encode(self, pyobj): 

403 """Returns true if `pyobj` can be encoded as the built-in TypeSpec.""" 

404 return isinstance(pyobj, self.type_spec_class) 

405 

406 def do_encode(self, type_spec_value, encode_fn): 

407 """Returns an encoded proto for the given built-in TypeSpec.""" 

408 type_state = type_spec_value._serialize() # pylint: disable=protected-access 

409 num_flat_components = len(nest.flatten( 

410 type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access 

411 encoded_type_spec = struct_pb2.StructuredValue() 

412 encoded_type_spec.type_spec_value.CopyFrom( 

413 struct_pb2.TypeSpecProto( 

414 type_spec_class=self.type_spec_proto_enum, 

415 type_state=encode_fn(type_state), 

416 type_spec_class_name=self.type_spec_class.__name__, 

417 num_flat_components=num_flat_components)) 

418 return encoded_type_spec 

419 

420 def can_decode(self, value): 

421 """Returns true if `value` can be decoded into its built-in TypeSpec.""" 

422 if value.HasField("type_spec_value"): 

423 type_spec_class_enum = value.type_spec_value.type_spec_class 

424 return type_spec_class_enum == self.type_spec_proto_enum 

425 return False 

426 

427 def do_decode(self, value, decode_fn): 

428 """Returns the built in `TypeSpec` encoded by the proto `value`.""" 

429 type_spec_proto = value.type_spec_value 

430 # pylint: disable=protected-access 

431 return self.type_spec_class._deserialize( 

432 decode_fn(type_spec_proto.type_state) 

433 ) 

434 

435 

436# TODO(b/238903802): Use TraceType serialization and specific protos. 

437class _TypeSpecCodec: 

438 """Codec for `tf.TypeSpec`.""" 

439 

440 # Mapping from enum value to type (TypeSpec subclass). 

441 # Must leave this for backwards-compatibility until all external usages 

442 # have been removed. 

443 TYPE_SPEC_CLASS_FROM_PROTO = { 

444 } 

445 

446 # Mapping from type (TypeSpec subclass) to enum value. 

447 TYPE_SPEC_CLASS_TO_PROTO = dict( 

448 (cls, enum) for (enum, cls) in TYPE_SPEC_CLASS_FROM_PROTO.items()) 

449 

450 def can_encode(self, pyobj): 

451 """Returns true if `pyobj` can be encoded as a TypeSpec.""" 

452 if type(pyobj) in self.TYPE_SPEC_CLASS_TO_PROTO: # pylint: disable=unidiomatic-typecheck 

453 return True 

454 

455 # Check if it's a registered type. 

456 if isinstance(pyobj, internal.TypeSpec): 

457 try: 

458 type_spec_registry.get_name(type(pyobj)) 

459 return True 

460 except ValueError: 

461 return False 

462 

463 return False 

464 

465 def do_encode(self, type_spec_value, encode_fn): 

466 """Returns an encoded proto for the given `tf.TypeSpec`.""" 

467 type_spec_class = self.TYPE_SPEC_CLASS_TO_PROTO.get(type(type_spec_value)) 

468 type_spec_class_name = type(type_spec_value).__name__ 

469 

470 if type_spec_class is None: 

471 type_spec_class_name = type_spec_registry.get_name(type(type_spec_value)) 

472 type_spec_class = struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC 

473 # Support for saving registered TypeSpecs is currently experimental. 

474 # Issue a warning to indicate the limitations. 

475 warnings.warn("Encoding a StructuredValue with type %s; loading this " 

476 "StructuredValue will require that this type be " 

477 "imported and registered." % type_spec_class_name) 

478 

479 type_state = type_spec_value._serialize() # pylint: disable=protected-access 

480 num_flat_components = len( 

481 nest.flatten(type_spec_value._component_specs, expand_composites=True)) # pylint: disable=protected-access 

482 encoded_type_spec = struct_pb2.StructuredValue() 

483 encoded_type_spec.type_spec_value.CopyFrom( 

484 struct_pb2.TypeSpecProto( 

485 type_spec_class=type_spec_class, 

486 type_state=encode_fn(type_state), 

487 type_spec_class_name=type_spec_class_name, 

488 num_flat_components=num_flat_components)) 

489 return encoded_type_spec 

490 

491 def can_decode(self, value): 

492 """Returns true if `value` can be decoded into a `tf.TypeSpec`.""" 

493 return value.HasField("type_spec_value") 

494 

495 def do_decode(self, value, decode_fn): 

496 """Returns the `tf.TypeSpec` encoded by the proto `value`.""" 

497 type_spec_proto = value.type_spec_value 

498 type_spec_class_enum = type_spec_proto.type_spec_class 

499 class_name = type_spec_proto.type_spec_class_name 

500 

501 if type_spec_class_enum == struct_pb2.TypeSpecProto.REGISTERED_TYPE_SPEC: 

502 try: 

503 type_spec_class = type_spec_registry.lookup(class_name) 

504 except ValueError as e: 

505 raise ValueError( 

506 f"The type '{class_name}' has not been registered. It must be " 

507 "registered before you load this object (typically by importing " 

508 "its module).") from e 

509 else: 

510 if type_spec_class_enum not in self.TYPE_SPEC_CLASS_FROM_PROTO: 

511 raise ValueError( 

512 f"The type '{class_name}' is not supported by this version of " 

513 "TensorFlow. (The object you are loading must have been created " 

514 "with a newer version of TensorFlow.)") 

515 type_spec_class = self.TYPE_SPEC_CLASS_FROM_PROTO[type_spec_class_enum] 

516 

517 # pylint: disable=protected-access 

518 return type_spec_class._deserialize(decode_fn(type_spec_proto.type_state)) 

519 

520 

521_codecs = [ 

522 _ListCodec(), 

523 _TupleCodec(), 

524 _NamedTupleCodec(), 

525 _StringCodec(), 

526 _Float64Codec(), 

527 _NoneCodec(), 

528 _Int64Codec(), 

529 _BoolCodec(), 

530 _DictCodec(), 

531 _TypeSpecCodec(), 

532 _TensorTypeCodec(), 

533]