Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/graph_to_function_def.py: 14%

90 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================= 

15"""Utility to convert a Graph to a FunctionDef.""" 

16 

17import re 

18 

19from tensorflow.core.framework import function_pb2 

20from tensorflow.core.framework import op_def_pb2 

21from tensorflow.python.framework import op_def_registry 

22 

23 

24def _make_argname_from_tensor_name(name): 

25 return re.sub(":0$", "", name).replace(":", "_o") 

26 

27 

28def _tensor_to_argdef(t, name=None, used_names=None): 

29 """Convert tensor t to an argdef, with a specified name or a unique name.""" 

30 arg = op_def_pb2.OpDef.ArgDef() 

31 if name is None: 

32 arg.name = _make_argname_from_tensor_name(t.name) 

33 if used_names is not None: 

34 if arg.name in used_names: 

35 i = 0 

36 while True: 

37 new_name = "%s_U%d" % (arg.name, i) 

38 if new_name not in used_names: 

39 arg.name = new_name 

40 break 

41 i += 1 

42 used_names.add(arg.name) 

43 else: 

44 arg.name = name 

45 arg.type = t.dtype.as_datatype_enum 

46 return arg 

47 

48 

49def _is_in_placeholders(op, func_arg_placeholders): 

50 """Checks whether any output of this op is in func_arg_placeholders.""" 

51 return op.values() and any(x.name in func_arg_placeholders 

52 for x in op.values()) 

53 

54 

55def _get_node_def(op): 

56 return op.node_def # pylint: disable=protected-access 

57 

58 

59def _get_op_def(op): 

60 return op.op_def or op_def_registry.get(op.type) 

61 

62 

63def _create_input_dict(function_graph, 

64 func_arg_placeholders, 

65 initial_value=None): 

66 """Create a mapping from graph tensor names to function tensor names.""" 

67 if initial_value is None: 

68 input_dict = {} 

69 else: 

70 input_dict = dict(initial_value) 

71 for op in function_graph.get_operations(): 

72 if _is_in_placeholders(op, func_arg_placeholders): 

73 input_dict[op.name] = op.name 

74 else: 

75 op_def = _get_op_def(op) 

76 attrs = _get_node_def(op).attr 

77 o = 0 

78 for arg_def in op_def.output_arg: 

79 if arg_def.number_attr: 

80 num = attrs[arg_def.number_attr].i 

81 elif arg_def.type_list_attr: 

82 num = len(attrs[arg_def.type_list_attr].list.type) 

83 else: 

84 num = 1 

85 for i in range(num): 

86 result = "%s:%s:%d" % (op.name, arg_def.name, i) 

87 input_dict[op.values()[o].name] = result 

88 if o == 0: 

89 input_dict[op.name] = result 

90 o += 1 

91 return input_dict 

92 

93 

94def _add_op_node(op, func, input_dict): 

95 """Converts an op to a function def node and add it to `func`.""" 

96 # Add an entry in func.node_def 

97 

98 # Note that extend() makes a copy in this case, see: 

99 # https://developers.google.com/protocol-buffers/docs/reference/python-generated#repeated-message-fields 

100 func.node_def.extend([_get_node_def(op)]) 

101 node_def = func.node_def[-1] 

102 for i in range(len(node_def.input)): 

103 if not node_def.input[i].startswith("^"): 

104 assert node_def.input[i] in input_dict, ("%s missing from %s" % 

105 (node_def.input[i], 

106 input_dict.items())) 

107 node_def.input[i] = input_dict[node_def.input[i]] 

108 # The function is stateful if any of its operations are stateful. 

109 # NOTE(mrry): The "Const" node typically does not have an `OpDef` associated 

110 # with it, so we assume any nodes without an `OpDef` are stateless. 

111 # TODO(skyewm): Remove the `is not None` test after we transition to the C 

112 # API. 

113 if op.op_def is not None and op.op_def.is_stateful: 

114 func.signature.is_stateful = True 

115 

116 

117def graph_to_function_def(graph, operations, inputs, outputs, out_names=None): 

118 """Returns `graph` as a `FunctionDef` protocol buffer. 

119 

120 This method creates a [`FunctionDef`]( 

121 https://www.tensorflow.org/code/tensorflow/core/framework/function.proto) 

122 protocol buffer that contains all the ops in `operations`. The 

123 operations become the body of the function. 

124 

125 The arguments `inputs` and `outputs` will be listed as the inputs 

126 and outputs tensors of the function. They must be lists of 

127 tensors present in the graph. The lists can optionally be empty. 

128 

129 Args: 

130 graph: Graph. 

131 operations: the operations to put in the function. Must be a subset of 

132 the operations in the graph. 

133 inputs: List of tensors. Inputs to the function. 

134 outputs: List of tensors. Outputs of the function. 

135 out_names: Optional list of string names for the outputs. 

136 

137 Returns: 

138 A FunctionDef protocol buffer. 

139 

140 Raises: 

141 ValueError: if out_names is specified and the wrong length. 

142 """ 

143 func = function_pb2.FunctionDef() 

144 func.signature.name = "_" 

145 used_names = set() 

146 func.signature.input_arg.extend( 

147 [_tensor_to_argdef(i, used_names=used_names) for i in inputs]) 

148 # Initializes the input map with all placeholder input tensors. 

149 initial_dict = {} 

150 for o, m in zip(inputs, func.signature.input_arg): 

151 initial_dict[o.name] = m.name 

152 if out_names is None: 

153 used_names = set() 

154 func.signature.output_arg.extend( 

155 [_tensor_to_argdef(o, used_names=used_names) for o in outputs]) 

156 elif len(outputs) != len(out_names): 

157 raise ValueError( 

158 f"out_names must be either empty or equal in size to outputs. " 

159 f"len(out_names) = {len(out_names)} len(outputs) = {len(outputs)}") 

160 elif len(out_names) != len(set(out_names)): 

161 raise ValueError( 

162 f"Must not have duplicates in out_names. Received: {out_names}") 

163 else: 

164 func.signature.output_arg.extend( 

165 [_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)]) 

166 func_arg_placeholders = set(i.name for i in inputs) 

167 input_dict = _create_input_dict(graph, func_arg_placeholders, 

168 initial_value=initial_dict) 

169 

170 for op in operations: 

171 if _is_in_placeholders(op, func_arg_placeholders): 

172 continue 

173 _add_op_node(op, func, input_dict) 

174 

175 if out_names is None: 

176 for index, o in enumerate(outputs): 

177 k = func.signature.output_arg[index].name 

178 func.ret[k] = input_dict[o.name] 

179 else: 

180 for o, n in zip(outputs, out_names): 

181 func.ret[n] = input_dict[o.name] 

182 

183 return func