Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/capture/restore_captures.py: 24%

50 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2023 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""A shim layer for working with functions exported/restored from saved models. 

17 

18This functionality should ultimately be moved into a first-class core API. 

19""" 

20 

21import warnings 

22 

23from tensorflow.python.framework import dtypes 

24from tensorflow.python.framework import ops 

25from tensorflow.python.framework import tensor_util 

26from tensorflow.python.ops import handle_data_util 

27from tensorflow.python.ops import resource_variable_ops 

28from tensorflow.python.ops import variables as variables_lib 

29from tensorflow.python.trackable import asset 

30from tensorflow.python.trackable import resource 

31 

32 

33def get_tensor_from_node(node): 

34 """Resolves a saved model graph node into a tensor to be captured. 

35 

36 Args: 

37 node: a tensor, variable, or resource to be resolved into a capturable 

38 tensor 

39 

40 Returns: 

41 A list of tensors. 

42 Raises: 

43 ValueError: if the node cannot be converted into a tensor. 

44 """ 

45 with ops.init_scope(): 

46 # TODO(b/210144904): Use __tf_tensor__ instead of `is_[...]` checks 

47 if getattr(node, "is_distributed_variable", False): 

48 return node 

49 elif getattr(node, "is_distributed_table", False): 

50 return node 

51 elif getattr(node, "is_sharded_variable", False): 

52 return node 

53 elif resource_variable_ops.is_resource_variable(node): 

54 return node.handle 

55 elif isinstance(node, asset.Asset): 

56 return node.asset_path 

57 elif tensor_util.is_tf_type(node): 

58 return node 

59 elif isinstance(node, resource.CapturableResource): 

60 # Note: this executes restored functions in the CapturableResource. 

61 return node.resource_handle 

62 raise ValueError(f"Cannot convert node {node} to tensor.") 

63 

64 

65def restore_captures(concrete_function, inputs): 

66 """Restore captures for the concrete function. 

67 

68 Used at deserialization time. For functions that are being deserialized, 

69 saved model restores objects that tensors were captured from, but functions 

70 only know about their tensors -- object information is destroyed by tracing. 

71 This additional logic extracts the tensors which the function originally 

72 captured. 

73 

74 Args: 

75 concrete_function: the concrete function for which to restore captures 

76 inputs: a list tensors or other Python objects (such as variables) which 

77 contain tensors that were originally captured by the function 

78 """ 

79 bound_inputs = [get_tensor_from_node(obj) for obj in inputs] 

80 # pylint: disable=g-complex-comprehension 

81 bound_variables = [ 

82 obj 

83 for obj in inputs 

84 if isinstance( 

85 obj, 

86 (variables_lib.Variable, resource_variable_ops.BaseResourceVariable), 

87 ) 

88 ] 

89 # TODO(b/205010575): This is only injecting the captured inputs into the 

90 # concrete function, note that we did not modify the FuncGraph 

91 # itself. 

92 captured_inputs_list = [] 

93 concrete_function.set_variables(bound_variables) 

94 if bound_inputs: 

95 for bound_input, internal_capture in zip( 

96 bound_inputs, concrete_function.inputs[-len(bound_inputs) :] 

97 ): 

98 # Distributed inputs have special logic for capturing, so we call their 

99 # custom restoration methods 

100 if hasattr(bound_input, "__tf_experimental_restore_capture__"): 

101 captured_inputs_list.append( 

102 bound_input.__tf_experimental_restore_capture__( 

103 concrete_function, internal_capture 

104 ) 

105 ) 

106 else: 

107 captured_inputs_list.append(bound_input) 

108 concrete_function.graph.replace_capture(bound_input, internal_capture) 

109 if internal_capture.dtype == dtypes.resource: 

110 if resource_variable_ops.is_resource_variable(bound_input): 

111 try: 

112 handle = bound_input.handle 

113 except ValueError: 

114 # For mirrored variables we'll copy handle data for components 

115 # as they get captured. 

116 pass 

117 else: 

118 handle_data_util.copy_handle_data(handle, internal_capture) 

119 else: 

120 # TODO(b/213451747): Remove need to call copy_handle_data 

121 handle_data_util.copy_handle_data(bound_input, internal_capture) 

122 # Setting "captures" first means "capture" won't create a new 

123 # placeholder for this input. 

124 concrete_function.graph.capture(bound_input) 

125 

126 if any([inp is None for inp in captured_inputs_list]): 

127 warnings.warn( 

128 "Trying to load ShardedVariables using tf.saved_model.load. " 

129 "This won't work if using a tf.distribute.Strategy, and may " 

130 "use excess memory if not using a Strategy. Ignore this " 

131 "warning if using tf.keras.models.load_model." 

132 ) 

133 concrete_function.set_external_captures(captured_inputs_list)