Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/saved_model/tracing_utils.py: 21%

29 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2022 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tracing utilities used by SavedModel.""" 

16 

17from tensorflow.python.checkpoint import saveable_compat 

18from tensorflow.python.checkpoint import tensor_callable 

19from tensorflow.python.eager import def_function 

20from tensorflow.python.eager import function as defun 

21 

22 

23def trace_save_and_restore(obj): 

24 """Traces `Trackable` serialize- and restore-from-tensors functions. 

25 

26 Args: 

27 obj: A `Trackable` object. 

28 

29 Returns: 

30 A concrete Function. 

31 """ 

32 legacy_name = saveable_compat.get_saveable_name(obj) 

33 

34 obj_save_fn = obj._serialize_to_tensors # pylint: disable=protected-access 

35 obj_restore_fn = obj._restore_from_tensors # pylint: disable=protected-access 

36 

37 if isinstance(obj_save_fn, defun.ConcreteFunction): 

38 concrete_save = obj_save_fn 

39 else: 

40 @def_function.function 

41 def save_fn(): 

42 tensor_dict = obj_save_fn() 

43 if any(isinstance(v, tensor_callable.Callable) 

44 for v in tensor_dict.values()): 

45 raise NotImplementedError( 

46 f"Unable to export SavedModel with object of type {type(obj)} " 

47 "because it returns a Callable in `_serialize_to_tensors`. " 

48 "If you need this functionality please file a feature request.") 

49 

50 if legacy_name: 

51 # If there is a legacy decorator, append the name to the keys. 

52 return {f"{legacy_name}{key}": value 

53 for key, value in tensor_dict.items()} 

54 return tensor_dict 

55 

56 concrete_save = save_fn.get_concrete_function() 

57 

58 if isinstance(obj_restore_fn, defun.ConcreteFunction): 

59 concrete_restore = obj_restore_fn 

60 else: 

61 @def_function.function 

62 def restore_fn(restored_tensors): 

63 if legacy_name: 

64 # Do the opposite operation of save_fn() 

65 restored_tensors = {key[len(legacy_name):]: value 

66 for key, value in restored_tensors.items()} 

67 obj_restore_fn(restored_tensors) 

68 

69 concrete_restore = restore_fn.get_concrete_function( 

70 concrete_save.structured_outputs) 

71 

72 return concrete_save, concrete_restore