Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/lift_to_graph.py: 14%

133 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15# pylint: disable=unidiomatic-typecheck 

16"""Utility to lift subgraphs.""" 

17 

18import collections 

19 

20from tensorflow.python.framework import func_graph 

21from tensorflow.python.framework import ops 

22from tensorflow.python.ops import array_ops 

23from tensorflow.python.ops import op_selector 

24from tensorflow.python.ops import resource_variable_ops 

25from tensorflow.python.util import compat 

26from tensorflow.python.util import object_identity 

27from tensorflow.python.util.tf_export import tf_export 

28 

29 

30UnliftableError = op_selector.UnliftableError 

31 

32 

33def _as_operation(op_or_tensor): 

34 if isinstance(op_or_tensor, ops.Tensor): 

35 return op_or_tensor.op 

36 return op_or_tensor 

37 

38 

39def _constant_inputs(op_or_tensor): 

40 return all(_as_operation(i).type == u"Const" 

41 and not _as_operation(i).control_inputs 

42 for i in op_selector.graph_inputs(_as_operation(op_or_tensor))) 

43 

44 

45# Represents an input to `copied_op` which must be updated once 

46# `old_graph_tensor` has been copied. 

47_InputMutation = collections.namedtuple( 

48 "_InputMutation", 

49 ["copied_op", "input_index", "old_graph_tensor"]) 

50 

51 

52# Represents a control input to `copied_op` which must be added once 

53# `old_graph_op` has been copied. 

54_ControlMutation = collections.namedtuple( 

55 "_ControlMutation", 

56 ["copied_op", "old_graph_op"]) 

57 

58 

59def _copy_non_source(op, graph, op_map, base_graph): 

60 """Copy an op directly to a given graph. 

61 

62 Generally `op`'s inputs should already have been copied. If this is not the 

63 case, for example with v1 while_loops, then `_copy_non_source` inserts 

64 placeholders for the unavailable Tensors and returns a list of required 

65 mutations. 

66 

67 Args: 

68 op: The op to be copied. 

69 graph: The destination graph. 

70 op_map: A dict mapping ops and tensors in the old graph to the new one. 

71 base_graph: The graph we're copying from, for any necessary functions. 

72 Returns: 

73 A tuple of (required_inputs, required_control_inputs): 

74 required_inputs: 

75 A list of `_InputMutation` tuples containing inputs to `copied_op` which 

76 must be updated once `old_graph_tensor` has been copied. 

77 required_control_inputs: 

78 A list of `_ControlMutation` tuples containing control inputs to 

79 `copied_op` which must be added once `old_graph_op` has been copied. 

80 """ 

81 input_mutations = [] 

82 control_mutations = [] 

83 copied_inputs = [] 

84 for input_index, original_input in enumerate(op.inputs): 

85 copied_input = op_map.get(original_input, None) 

86 if copied_input is None: 

87 # An input for this op is missing due to a loop in the graph. We'll insert 

88 # a placeholder for now and return information about the required post-hoc 

89 # mutation. 

90 copied_input = array_ops.placeholder( 

91 name="unused_control_flow_input", 

92 shape=original_input.shape, 

93 dtype=original_input.dtype) 

94 input_mutations.append( 

95 # `copied_op` is filled in below, after we've created it. 

96 _InputMutation(copied_op=None, 

97 input_index=input_index, 

98 old_graph_tensor=original_input)) 

99 copied_inputs.append(copied_input) 

100 

101 copied_control_inputs = [] 

102 for original_control_input in op.control_inputs: 

103 copied_control_input = op_map.get(original_control_input, None) 

104 if copied_control_input is None: 

105 control_mutations.append( 

106 _ControlMutation(copied_op=None, 

107 old_graph_op=original_control_input)) 

108 else: 

109 copied_control_inputs.append(copied_control_input) 

110 

111 # Don't copy over nodes with _tpu_replicate attribute. This attributed is used 

112 # to signal that the op was built inside a tpu_replicate context; if we're 

113 # lifting it to another graph we're similarly lifting it into another context. 

114 with ops.control_dependencies(copied_control_inputs), ops.device(op.device): 

115 # pylint: disable=protected-access 

116 f = base_graph._functions.get(op.type, None) 

117 if f is not None and compat.as_str(f.name) not in graph._functions: 

118 f.add_to_graph(graph) 

119 # pylint: enable=protected-access 

120 

121 # Create a new op in the destination graph if it doesn't exist before. 

122 copied_op = graph.create_op( 

123 op_type=op.type, 

124 inputs=copied_inputs, 

125 dtypes=[x.dtype for x in op.outputs], 

126 attrs={ 

127 key: value for key, value in op.node_def.attr.items() 

128 if not key.startswith("_class") and 

129 not key.startswith("_tpu_replicate") 

130 }, # b/128981532. 

131 name=op.name) 

132 op_map[op] = copied_op 

133 for i, o in enumerate(op.outputs): 

134 op_map[o] = copied_op.outputs[i] 

135 

136 return ([mutation._replace(copied_op=copied_op) 

137 for mutation in input_mutations], 

138 [mutation._replace(copied_op=copied_op) 

139 for mutation in control_mutations]) 

140 

141 

142def _copy_source(s, graph, op_map, handle_captures, inverse_captures, 

143 base_graph): 

144 """Create a source in a graph based on a Tensor from a different graph. 

145 

146 This function creates a placeholder analog of `s` in a graph with the 

147 following behavior: 

148 

149 1) If s is a captured Tensor or Variable and handle_captures is set to True, 

150 simply capture it in the new graph as well. 

151 

152 2) If s is a PlaceholderWithDefault whose default is a constant, preserve 

153 said default in the new graph. 

154 

155 3) When applicable, copy resource variable metadata from `s` to the newly 

156 created placeholder. 

157 

158 Args: 

159 s: The source of interest. 

160 graph: The destination graph. 

161 op_map: A dict mapping ops and tensors in the old graph to the new one. 

162 handle_captures: A boolean indicating whether to re-capture s in the new 

163 graph or simply create a vanilla placeholder. 

164 inverse_captures: A dict mapping s back to the Tensor or Variable that it 

165 captures. 

166 base_graph: The graph being copied from. 

167 """ 

168 if handle_captures and s in inverse_captures: 

169 copied_placeholder = graph.capture(inverse_captures[s], name=s.op.name) 

170 elif s.op.type == "PlaceholderWithDefault" and _constant_inputs(s): 

171 # Copy the default value to the graph. 

172 default_value = s.op.inputs[0] 

173 unavailable_inputs, unavailable_control_inputs = _copy_non_source( 

174 op=default_value.op, graph=graph, op_map=op_map, 

175 base_graph=base_graph) 

176 if unavailable_inputs or unavailable_control_inputs: 

177 raise AssertionError( 

178 "Could not copy source node {} because it has inputs." 

179 .format(default_value)) 

180 

181 with ops.device(s.op.device): 

182 copied_placeholder = array_ops.placeholder_with_default( 

183 input=op_map[default_value], shape=s.shape, name=s.op.name) 

184 else: 

185 with ops.device(s.op.device): 

186 copied_placeholder = array_ops.placeholder( 

187 dtype=s.dtype, shape=s.shape, name=s.op.name) 

188 

189 base_handle = resource_variable_ops.get_resource_handle_data(s) 

190 if base_handle.shape_and_type: 

191 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 

192 copied_placeholder, 

193 base_handle, 

194 graph_mode=True) 

195 

196 op_map[s] = copied_placeholder 

197 # Add an entry for the op of the source tensor so that if there are any nodes 

198 # depending on that op via control dependencies it can work correctly. 

199 op_map[s.op] = copied_placeholder.op 

200 

201 

202@tf_export("__internal__.lift_to_graph", v1=[]) 

203def lift_to_graph(tensors, 

204 graph, 

205 sources=None, 

206 disallowed_placeholders=None, 

207 add_sources=False, 

208 handle_captures=False, 

209 base_graph=None, 

210 op_map=None): 

211 """Copies the tensor and all its inputs recursively to the outer graph. 

212 

213 Args: 

214 tensors: The Tensors to lift. 

215 graph: The graph to lift to. 

216 sources: Optional sequence of nodes to start from. If omitted the whole 

217 subgraph which feeds into `init_tensor` is lifted. 

218 disallowed_placeholders: An optional set of ops which may not appear in the 

219 lifted graph. Defaults to all placeholders. 

220 add_sources: A boolean indicating whether placeholders which are not in 

221 sources should be allowed. 

222 handle_captures: A boolean indicating whether to re-capture s in the new 

223 graph or simply create a vanilla placeholder. 

224 base_graph: The graph from which to lift ops. This will be inferred if not 

225 specified. 

226 op_map: A map contains all the existing nodes that have been lifted to the 

227 destination graph, so they won't be lifted and copied again. 

228 

229 Returns: 

230 A mapping from ops in the current default graph to ops in `graph`. 

231 

232 Raises: 

233 UnliftableError: If a placeholder blocks lifting. 

234 """ 

235 variable_init_tensors = [] 

236 init_tensors = [] 

237 for tensor in tensors: 

238 if isinstance(tensor, resource_variable_ops.ResourceVariable): 

239 variable_init_tensors.append(tensor) 

240 else: 

241 init_tensors.append(tensor) 

242 base_graph = base_graph or init_tensors[0].graph 

243 op_map = op_map or object_identity.ObjectIdentityDictionary() 

244 

245 # Check that the initializer does not depend on any placeholders. 

246 sources = object_identity.ObjectIdentitySet(sources or []) 

247 visited_ops = set(x.op for x in sources) 

248 op_outputs = collections.defaultdict(set) 

249 

250 # First we extract the subgraph between init_tensors and sources. 

251 for init_tensor in init_tensors: 

252 sources.update(op_selector.map_subgraph( 

253 init_tensor=init_tensor, 

254 sources=sources, 

255 disallowed_placeholders=disallowed_placeholders, 

256 visited_ops=visited_ops, 

257 op_outputs=op_outputs, 

258 add_sources=add_sources)) 

259 

260 # Try to topologically sort the nodes we've extracted. Now we know how many of 

261 # their outputs are part of this subgraph. 

262 ops_to_copy = [] 

263 marked_ops = set([]) 

264 ops_to_visit = [_as_operation(t) for t in init_tensors 

265 if not op_outputs[_as_operation(t)]] 

266 unvisited_ops = set(ops_to_visit) 

267 while unvisited_ops: 

268 while ops_to_visit: 

269 op = ops_to_visit.pop() 

270 if op in marked_ops: 

271 continue 

272 marked_ops.add(op) 

273 ops_to_copy.append(op) 

274 for inp in op_selector.graph_inputs(op): 

275 # Don't lift the TPUReplicateMetadata nodes out of the function, because 

276 # it has no registered kernels. 

277 if inp.type == "TPUReplicateMetadata": 

278 continue 

279 unvisited_ops.add(inp) 

280 if (all(x in marked_ops for x in op_outputs[inp]) and 

281 inp not in sources): 

282 ops_to_visit.append(inp) 

283 unvisited_ops.difference_update(marked_ops) 

284 if unvisited_ops: 

285 # `unvisited_ops` should only have elements if the graph has a loop. In 

286 # this case we want to keep copying and there's no topological ordering; 

287 # we'll do ugly post-hoc mutations instead. 

288 ops_to_visit.append(next(iter(unvisited_ops))) 

289 

290 # When the topological sort fails due to loops, it can result in exceptions 

291 # later when copying a node which inputs haven't been copied yet. We can 

292 # improve that pseudo-topological order slightly by putting the ops without 

293 # inputs, such as constants, at the start of the topological order (i.e at 

294 # the end of ops_to_copy). 

295 ops_to_copy.sort(key=(lambda op: len(op_selector.graph_inputs(op)) == 0)) 

296 

297 # When lifting from one FuncGraph to another, we will need to capture the 

298 # relevant tensors as well. 

299 captures = [] 

300 inverse_captures = object_identity.ObjectIdentityDictionary() 

301 internal_captures = [] 

302 if (isinstance(base_graph, func_graph.FuncGraph) and 

303 isinstance(graph, func_graph.FuncGraph)): 

304 captures = base_graph.captures 

305 for external_capture, internal_capture in captures: 

306 inverse_captures[internal_capture] = external_capture 

307 internal_captures = base_graph.internal_captures 

308 

309 # ops_to_copy now holds a reverse topologically sorted list of ops which 

310 # ends in the initializer. We copy those to the outermost graph and 

311 # build the initialization op there. 

312 with graph.as_default(): 

313 for i in variable_init_tensors: 

314 op_map[i] = i 

315 source_ops = set() 

316 # Add the sources in the same order as the original graph. 

317 for s in internal_captures: 

318 if s in sources: 

319 sources.remove(s) 

320 source_ops.add(s.op) 

321 _copy_source( 

322 s=s, 

323 graph=graph, 

324 op_map=op_map, 

325 handle_captures=handle_captures, 

326 inverse_captures=inverse_captures, 

327 base_graph=base_graph) 

328 for s in sources: 

329 source_ops.add(s.op) 

330 _copy_source( 

331 s=s, 

332 graph=graph, 

333 op_map=op_map, 

334 handle_captures=handle_captures, 

335 inverse_captures=inverse_captures, 

336 base_graph=base_graph) 

337 

338 input_mutations = [] 

339 control_mutations = [] 

340 for op in reversed(ops_to_copy): 

341 if op in source_ops or op in op_map: 

342 continue 

343 new_input_mutations, new_control_mutations = _copy_non_source( 

344 op=op, graph=graph, op_map=op_map, base_graph=base_graph) 

345 input_mutations.extend(new_input_mutations) 

346 control_mutations.extend(new_control_mutations) 

347 

348 # Mutate the new graph to insert any loops which existed in the source 

349 # graph due to v1 while_loops. 

350 # 

351 # pylint: disable=protected-access 

352 with graph._mutation_lock(): 

353 for mutation in input_mutations: 

354 mutation.copied_op._update_input( 

355 mutation.input_index, op_map[mutation.old_graph_tensor]) 

356 for mutation in control_mutations: 

357 # Don't lift the TPUReplicateMetadata nodes out of the function, because 

358 # it has no registered kernels. 

359 if mutation.old_graph_op.type == "TPUReplicateMetadata": 

360 continue 

361 mutation.copied_op._add_control_input(op_map[mutation.old_graph_op]) 

362 # pylint: enable=protected-access 

363 

364 return op_map