Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/traverse.py: 31%
26 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to traverse the Dataset dependency structure."""
16import queue
18from tensorflow.python.framework import dtypes
21OP_TYPES_ALLOWLIST = ["DummyIterationCounter"]
22# We allowlist all ops that produce variant tensors as output. This is a bit
23# of overkill but the other dataset _inputs() traversal strategies can't
24# cover the case of function inputs that capture dataset variants.
25TENSOR_TYPES_ALLOWLIST = [dtypes.variant]
28def _traverse(dataset, op_filter_fn):
29 """Traverse a dataset graph, returning nodes matching `op_filter_fn`."""
30 result = []
31 bfs_q = queue.Queue()
32 bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access
33 visited = []
34 while not bfs_q.empty():
35 op = bfs_q.get()
36 visited.append(op)
37 if op_filter_fn(op):
38 result.append(op)
39 for i in op.inputs:
40 input_op = i.op
41 if input_op not in visited:
42 bfs_q.put(input_op)
43 return result
46def obtain_capture_by_value_ops(dataset):
47 """Given an input dataset, finds all allowlisted ops used for construction.
49 Allowlisted ops are stateful ops which are known to be safe to capture by
50 value.
52 Args:
53 dataset: Dataset to find allowlisted stateful ops for.
55 Returns:
56 A list of variant_tensor producing dataset ops used to construct this
57 dataset.
58 """
60 def capture_by_value(op):
61 return (op.outputs[0].dtype in TENSOR_TYPES_ALLOWLIST or
62 op.type in OP_TYPES_ALLOWLIST)
64 return _traverse(dataset, capture_by_value)
67def obtain_all_variant_tensor_ops(dataset):
68 """Given an input dataset, finds all dataset ops used for construction.
70 A series of transformations would have created this dataset with each
71 transformation including zero or more Dataset ops, each producing a dataset
72 variant tensor. This method outputs all of them.
74 Args:
75 dataset: Dataset to find variant tensors for.
77 Returns:
78 A list of variant_tensor producing dataset ops used to construct this
79 dataset.
80 """
81 return _traverse(dataset, lambda op: op.outputs[0].dtype == dtypes.variant)