Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/util/traverse.py: 31%

26 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Helpers to traverse the Dataset dependency structure.""" 

16import queue 

17 

18from tensorflow.python.framework import dtypes 

19 

20 

21OP_TYPES_ALLOWLIST = ["DummyIterationCounter"] 

22# We allowlist all ops that produce variant tensors as output. This is a bit 

23# of overkill but the other dataset _inputs() traversal strategies can't 

24# cover the case of function inputs that capture dataset variants. 

25TENSOR_TYPES_ALLOWLIST = [dtypes.variant] 

26 

27 

28def _traverse(dataset, op_filter_fn): 

29 """Traverse a dataset graph, returning nodes matching `op_filter_fn`.""" 

30 result = [] 

31 bfs_q = queue.Queue() 

32 bfs_q.put(dataset._variant_tensor.op) # pylint: disable=protected-access 

33 visited = [] 

34 while not bfs_q.empty(): 

35 op = bfs_q.get() 

36 visited.append(op) 

37 if op_filter_fn(op): 

38 result.append(op) 

39 for i in op.inputs: 

40 input_op = i.op 

41 if input_op not in visited: 

42 bfs_q.put(input_op) 

43 return result 

44 

45 

46def obtain_capture_by_value_ops(dataset): 

47 """Given an input dataset, finds all allowlisted ops used for construction. 

48 

49 Allowlisted ops are stateful ops which are known to be safe to capture by 

50 value. 

51 

52 Args: 

53 dataset: Dataset to find allowlisted stateful ops for. 

54 

55 Returns: 

56 A list of variant_tensor producing dataset ops used to construct this 

57 dataset. 

58 """ 

59 

60 def capture_by_value(op): 

61 return (op.outputs[0].dtype in TENSOR_TYPES_ALLOWLIST or 

62 op.type in OP_TYPES_ALLOWLIST) 

63 

64 return _traverse(dataset, capture_by_value) 

65 

66 

67def obtain_all_variant_tensor_ops(dataset): 

68 """Given an input dataset, finds all dataset ops used for construction. 

69 

70 A series of transformations would have created this dataset with each 

71 transformation including zero or more Dataset ops, each producing a dataset 

72 variant tensor. This method outputs all of them. 

73 

74 Args: 

75 dataset: Dataset to find variant tensors for. 

76 

77 Returns: 

78 A list of variant_tensor producing dataset ops used to construct this 

79 dataset. 

80 """ 

81 return _traverse(dataset, lambda op: op.outputs[0].dtype == dtypes.variant)