Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/op_selector.py: 11%

194 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Tools for selecting ops in a graph.""" 

16 

17from tensorflow.python.framework import ops 

18from tensorflow.python.util import object_identity 

19 

20 

21def is_differentiable(op): 

22 try: 

23 return ops._gradient_registry.lookup(op.op_def.name) is not None # pylint: disable=protected-access 

24 except LookupError: 

25 return False 

26 

27 

28def is_iterable(obj): 

29 """Return true if the object is iterable.""" 

30 if isinstance(obj, ops.Tensor): 

31 return False 

32 try: 

33 _ = iter(obj) 

34 except Exception: # pylint: disable=broad-except 

35 return False 

36 return True 

37 

38 

39def concatenate_unique(la, lb): 

40 """Add all the elements of `lb` to `la` if they are not there already. 

41 

42 The elements added to `la` maintain ordering with respect to `lb`. 

43 

44 Args: 

45 la: List of Python objects. 

46 lb: List of Python objects. 

47 Returns: 

48 `la`: The list `la` with missing elements from `lb`. 

49 """ 

50 la_set = set(la) 

51 for l in lb: 

52 if l not in la_set: 

53 la.append(l) 

54 la_set.add(l) 

55 return la 

56 

57 

58def get_tensors(graph): 

59 """get all the tensors which are input or output of an op in the graph. 

60 

61 Args: 

62 graph: a `tf.Graph`. 

63 Returns: 

64 A list of `tf.Tensor`. 

65 Raises: 

66 TypeError: if graph is not a `tf.Graph`. 

67 """ 

68 if not isinstance(graph, ops.Graph): 

69 raise TypeError("Expected a graph, got: {}".format(type(graph))) 

70 ts = [] 

71 for op in graph.get_operations(): 

72 ts += op.outputs 

73 return ts 

74 

75 

76def get_unique_graph(tops, check_types=None, none_if_empty=False): 

77 """Return the unique graph used by the all the elements in tops. 

78 

79 Args: 

80 tops: iterable of elements to check (usually a list of tf.Operation and/or 

81 tf.Tensor). Or a tf.Graph. 

82 check_types: check that the element in tops are of given type(s). If None, 

83 the types (tf.Operation, tf.Tensor) are used. 

84 none_if_empty: don't raise an error if tops is an empty list, just return 

85 None. 

86 Returns: 

87 The unique graph used by all the tops. 

88 Raises: 

89 TypeError: if tops is not a iterable of tf.Operation. 

90 ValueError: if the graph is not unique. 

91 """ 

92 if isinstance(tops, ops.Graph): 

93 return tops 

94 if not is_iterable(tops): 

95 raise TypeError("{} is not iterable".format(type(tops))) 

96 if check_types is None: 

97 check_types = (ops.Operation, ops.Tensor) 

98 elif not is_iterable(check_types): 

99 check_types = (check_types,) 

100 g = None 

101 for op in tops: 

102 if not isinstance(op, check_types): 

103 raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( 

104 t) for t in check_types]), type(op))) 

105 if g is None: 

106 g = op.graph 

107 elif g._graph_key != op.graph._graph_key: # pylint: disable=protected-access 

108 raise ValueError("Operation {} does not belong to given graph".format(op)) 

109 if g is None and not none_if_empty: 

110 raise ValueError("Can't find the unique graph of an empty list") 

111 return g 

112 

113 

114def check_graphs(*args): 

115 """Check that all the element in args belong to the same graph. 

116 

117 Args: 

118 *args: a list of object with a obj.graph property. 

119 Raises: 

120 ValueError: if all the elements do not belong to the same graph. 

121 """ 

122 graph = None 

123 for i, sgv in enumerate(args): 

124 if graph is None and sgv.graph is not None: 

125 graph = sgv.graph 

126 elif sgv.graph is not None and sgv.graph is not graph: 

127 raise ValueError(f"args[{i}] does not belong to the same graph as " 

128 "other arguments.") 

129 

130 

131def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): 

132 """Convert ts to a list of `tf.Tensor`. 

133 

134 Args: 

135 ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. 

136 check_graph: if `True` check if all the tensors belong to the same graph. 

137 allow_graph: if `False` a `tf.Graph` cannot be converted. 

138 ignore_ops: if `True`, silently ignore `tf.Operation`. 

139 Returns: 

140 A newly created list of `tf.Tensor`. 

141 Raises: 

142 TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, 

143 if `check_graph` is `True`, if all the ops do not belong to the same graph. 

144 """ 

145 if isinstance(ts, ops.Graph): 

146 if allow_graph: 

147 return get_tensors(ts) 

148 else: 

149 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 

150 else: 

151 if not is_iterable(ts): 

152 ts = [ts] 

153 if not ts: 

154 return [] 

155 if check_graph: 

156 check_types = None if ignore_ops else ops.Tensor 

157 get_unique_graph(ts, check_types=check_types) 

158 return [t for t in ts if isinstance(t, ops.Tensor)] 

159 

160 

161def get_generating_ops(ts): 

162 """Return all the generating ops of the tensors in `ts`. 

163 

164 Args: 

165 ts: a list of `tf.Tensor` 

166 Returns: 

167 A list of all the generating `tf.Operation` of the tensors in `ts`. 

168 Raises: 

169 TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. 

170 """ 

171 ts = make_list_of_t(ts, allow_graph=False) 

172 return [t.op for t in ts] 

173 

174 

175def get_consuming_ops(ts): 

176 """Return all the consuming ops of the tensors in ts. 

177 

178 Args: 

179 ts: a list of `tf.Tensor` 

180 Returns: 

181 A list of all the consuming `tf.Operation` of the tensors in `ts`. 

182 Raises: 

183 TypeError: if ts cannot be converted to a list of `tf.Tensor`. 

184 """ 

185 ts = make_list_of_t(ts, allow_graph=False) 

186 tops = [] 

187 for t in ts: 

188 for op in t.consumers(): 

189 if op not in tops: 

190 tops.append(op) 

191 return tops 

192 

193 

194def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False): 

195 """Convert ops to a list of `tf.Operation`. 

196 

197 Args: 

198 tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single 

199 operation. 

200 check_graph: if `True` check if all the operations belong to the same graph. 

201 allow_graph: if `False` a `tf.Graph` cannot be converted. 

202 ignore_ts: if True, silently ignore `tf.Tensor`. 

203 Returns: 

204 A newly created list of `tf.Operation`. 

205 Raises: 

206 TypeError: if tops cannot be converted to a list of `tf.Operation` or, 

207 if `check_graph` is `True`, if all the ops do not belong to the 

208 same graph. 

209 """ 

210 if isinstance(tops, ops.Graph): 

211 if allow_graph: 

212 return tops.get_operations() 

213 else: 

214 raise TypeError("allow_graph is False: cannot convert a tf.Graph.") 

215 else: 

216 if not is_iterable(tops): 

217 tops = [tops] 

218 if not tops: 

219 return [] 

220 if check_graph: 

221 check_types = None if ignore_ts else ops.Operation 

222 get_unique_graph(tops, check_types=check_types) 

223 return [op for op in tops if isinstance(op, ops.Operation)] 

224 

225 

226def _get_inputs(op, only_differentiable): 

227 op_inputs = op.inputs 

228 if only_differentiable: 

229 return op_inputs if is_differentiable(op) else [] 

230 else: 

231 return op_inputs 

232 

233 

234def get_backward_walk_ops(seed_ops, 

235 inclusive=True, 

236 within_ops=None, 

237 within_ops_fn=None, 

238 stop_at_ts=(), 

239 control_inputs=False, 

240 only_differentiable=False): 

241 """Do a backward graph walk and return all the visited ops. 

242 

243 Args: 

244 seed_ops: an iterable of operations from which the backward graph 

245 walk starts. If a list of tensors is given instead, the seed_ops are set 

246 to be the generators of those tensors. 

247 inclusive: if True the given seed_ops are also part of the resulting set. 

248 within_ops: an iterable of `tf.Operation` within which the search is 

249 restricted. If `within_ops` is `None`, the search is performed within 

250 the whole graph. 

251 within_ops_fn: if provided, a function on ops that should return True iff 

252 the op is within the graph traversal. This can be used along within_ops, 

253 in which case an op is within if it is also in within_ops. 

254 stop_at_ts: an iterable of tensors at which the graph walk stops. 

255 control_inputs: if True, control inputs will be used while moving backward. 

256 only_differentiable: if True, only traverse ops which are differentiable. 

257 This includes natively differentiable ops, or ops with custom gradients. 

258 Returns: 

259 A Python set of all the `tf.Operation` behind `seed_ops`. 

260 Raises: 

261 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 

262 `tf.Operation`. 

263 """ 

264 control_inputs = control_inputs and (not only_differentiable) 

265 

266 if not is_iterable(seed_ops): 

267 seed_ops = [seed_ops] 

268 

269 try: 

270 first_seed_op = next(iter(seed_ops)) 

271 except StopIteration: 

272 # Empty iterable. 

273 return [] 

274 

275 if isinstance(first_seed_op, ops.Tensor): 

276 ts = make_list_of_t(seed_ops, allow_graph=False) 

277 seed_ops = get_generating_ops(ts) 

278 else: 

279 seed_ops = make_list_of_op(seed_ops, allow_graph=False) 

280 

281 stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts)) 

282 seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops)) 

283 if within_ops: 

284 within_ops = make_list_of_op(within_ops, allow_graph=False) 

285 within_ops = object_identity.ObjectIdentitySet(within_ops) 

286 seed_ops &= within_ops 

287 

288 def is_within(op): 

289 return (within_ops is None or op in within_ops) and ( 

290 within_ops_fn is None or within_ops_fn(op)) 

291 

292 result = list(seed_ops) 

293 wave = set(seed_ops) 

294 while wave: 

295 new_wave = set() 

296 for op in wave: 

297 for new_t in _get_inputs(op, only_differentiable=only_differentiable): 

298 if new_t in stop_at_ts: 

299 continue 

300 if new_t.op not in result and is_within(new_t.op): 

301 new_wave.add(new_t.op) 

302 if control_inputs: 

303 for new_op in op.control_inputs: 

304 if new_op not in result and is_within(new_op): 

305 new_wave.add(new_op) 

306 concatenate_unique(result, new_wave) 

307 wave = new_wave 

308 if not inclusive: 

309 result = [op for op in result if op not in seed_ops] 

310 return result 

311 

312 

313class UnliftableError(Exception): 

314 """Raised if a Tensor cannot be lifted from the graph.""" 

315 

316 # Prevent autograph from rewriting this error. 

317 ag_pass_through = True 

318 

319 

320def _as_operation(op_or_tensor): 

321 if isinstance(op_or_tensor, ops.Tensor): 

322 return op_or_tensor.op 

323 return op_or_tensor 

324 

325 

326def graph_inputs(op): 

327 return [x.op for x in op.inputs] + list(op.control_inputs) 

328 

329 

330def show_path(from_op, tensors, sources): 

331 """Find one path from `from_op` to any of `tensors`, ignoring `sources`. 

332 

333 Args: 

334 from_op: A `tf.Operation`. 

335 tensors: A `tf.Operation`, a `tf.Tensor`, or a list thereof. 

336 sources: A list of `tf.Tensor`. 

337 

338 Returns: 

339 A python string containing the path, or "??" if none is found. 

340 """ 

341 if isinstance(from_op, ops.Tensor): 

342 from_op = from_op.op 

343 

344 if not isinstance(tensors, list): 

345 tensors = [tensors] 

346 

347 final_ops = [_as_operation(tensor) for tensor in tensors] 

348 

349 visited_ops = set(x.op for x in sources) 

350 ops_to_visit = list(final_ops) 

351 some_op_output = {} 

352 while ops_to_visit: 

353 op = ops_to_visit.pop() 

354 if op in visited_ops: 

355 continue 

356 visited_ops.add(op) 

357 if op == from_op: 

358 path_op = op 

359 path = [path_op] 

360 while path_op not in final_ops: 

361 path_op = some_op_output[path_op] 

362 path.append(path_op) 

363 return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path)) 

364 else: 

365 for inp in graph_inputs(op): 

366 if inp not in visited_ops and inp not in sources: 

367 some_op_output[inp] = op 

368 ops_to_visit.append(inp) 

369 return "??" 

370 

371 

372# TODO(jmenick) - there is considerable duplication of functionality between 

373# this function and get_backward_walk_ops(). Need to deduplicate. 

374def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops, 

375 op_outputs, add_sources): 

376 """Walk a Graph and capture the subgraph between init_tensor and sources. 

377 

378 Note: This function mutates visited_ops and op_outputs. 

379 

380 Args: 

381 init_tensor: A Tensor or Operation where the subgraph terminates. 

382 sources: A set of Tensors where subgraph extraction should stop. 

383 disallowed_placeholders: An optional set of ops which may not appear in the 

384 lifted graph. Defaults to all placeholders. 

385 visited_ops: A set of operations which were visited in a prior pass. 

386 op_outputs: A defaultdict containing the outputs of an op which are to be 

387 copied into the new subgraph. 

388 add_sources: A boolean indicating whether placeholders which are not in 

389 sources should be allowed. 

390 

391 Returns: 

392 The set of placeholders upon which init_tensor depends and are not in 

393 sources. 

394 

395 Raises: 

396 UnliftableError: if init_tensor depends on a placeholder which is not in 

397 sources and add_sources is False. 

398 """ 

399 ops_to_visit = [_as_operation(init_tensor)] 

400 extra_sources = object_identity.ObjectIdentitySet() 

401 while ops_to_visit: 

402 op = ops_to_visit.pop() 

403 if op in visited_ops: 

404 continue 

405 visited_ops.add(op) 

406 

407 should_raise = False 

408 if disallowed_placeholders is not None and op in disallowed_placeholders: 

409 should_raise = True 

410 elif op.type == "Placeholder": 

411 if disallowed_placeholders is None and not add_sources: 

412 should_raise = True 

413 extra_sources.update(op.outputs) 

414 

415 if should_raise: 

416 raise UnliftableError( 

417 "Unable to lift tensor %s because it depends transitively on " 

418 "placeholder %s via at least one path, e.g.: %s" % 

419 (repr(init_tensor), repr(op), show_path(op, init_tensor, sources))) 

420 for inp in graph_inputs(op): 

421 op_outputs[inp].add(op) 

422 if inp not in visited_ops and inp not in (sources or extra_sources): 

423 ops_to_visit.append(inp) 

424 

425 return extra_sources