Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/auto_control_deps_utils.py: 16%

80 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities for AutomaticControlDependencies.""" 

16 

17from tensorflow.python.framework import dtypes 

18from tensorflow.python.util import object_identity 

19 

20READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs" 

21RESOURCE_READ_OPS = set() 

22 

23 

24COLLECTIVE_MANAGER_IDS = "_collective_manager_ids" 

25 

26 

27def register_read_only_resource_op(op_type): 

28 """Declares that `op_type` does not update its touched resource.""" 

29 RESOURCE_READ_OPS.add(op_type) 

30 

31 

32def get_read_only_resource_input_indices_graph(func_graph): 

33 """Returns sorted list of read-only resource indices in func_graph.inputs.""" 

34 result = [] 

35 # A cache to store the read only resource inputs of an Op. 

36 # Operation -> ObjectIdentitySet of resource handles. 

37 op_read_only_resource_inputs = {} 

38 for input_index, t in enumerate(func_graph.inputs): 

39 if t.dtype != dtypes.resource: 

40 continue 

41 read_only = True 

42 for op in t.consumers(): 

43 if op in op_read_only_resource_inputs: 

44 if t not in op_read_only_resource_inputs[op]: 

45 read_only = False 

46 break 

47 else: 

48 indices = _get_read_only_resource_input_indices_op(op) 

49 op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet( 

50 [op.inputs[i] for i in indices]) 

51 if t not in op_read_only_resource_inputs[op]: 

52 read_only = False 

53 break 

54 if read_only: 

55 result.append(input_index) 

56 return result 

57 

58 

59def _get_read_only_resource_input_indices_op(op): 

60 """Returns sorted list of read-only resource indices in op.inputs.""" 

61 if op.type in RESOURCE_READ_OPS: 

62 return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource] 

63 

64 try: 

65 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 

66 except ValueError: 

67 # Attr was not set. Add all resource inputs to `writes` and return. 

68 return [] 

69 

70 read_only_index = 0 

71 result = [] 

72 for i, t in enumerate(op.inputs): 

73 if read_only_index >= len(read_only_input_indices): 

74 break 

75 if op.inputs[i].dtype != dtypes.resource: 

76 continue 

77 if (read_only_index < len(read_only_input_indices) and 

78 i == read_only_input_indices[read_only_index]): 

79 result.append(i) 

80 read_only_index += 1 

81 

82 return result 

83 

84 

85def get_read_write_resource_inputs(op): 

86 """Returns a tuple of resource reads, writes in op.inputs. 

87 

88 Args: 

89 op: Operation 

90 

91 Returns: 

92 A 2-tuple of ObjectIdentitySets, the first entry containing read-only 

93 resource handles and the second containing read-write resource handles in 

94 `op.inputs`. 

95 """ 

96 reads = object_identity.ObjectIdentitySet() 

97 writes = object_identity.ObjectIdentitySet() 

98 

99 if op.type in RESOURCE_READ_OPS: 

100 # Add all resource inputs to `reads` and return. 

101 reads.update(t for t in op.inputs if t.dtype == dtypes.resource) 

102 return (reads, writes) 

103 

104 try: 

105 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 

106 except ValueError: 

107 # Attr was not set. Add all resource inputs to `writes` and return. 

108 writes.update(t for t in op.inputs if t.dtype == dtypes.resource) 

109 return (reads, writes) 

110 

111 read_only_index = 0 

112 for i, t in enumerate(op.inputs): 

113 if op.inputs[i].dtype != dtypes.resource: 

114 continue 

115 if (read_only_index < len(read_only_input_indices) and 

116 i == read_only_input_indices[read_only_index]): 

117 reads.add(op.inputs[i]) 

118 read_only_index += 1 

119 else: 

120 writes.add(op.inputs[i]) 

121 return (reads, writes) 

122 

123 

124def _op_writes_to_resource(handle, op): 

125 """Returns whether op writes to resource handle. 

126 

127 Args: 

128 handle: Resource handle. Must be an input of `op`. 

129 op: Operation. 

130 

131 Returns: 

132 Returns False if op is a read-only op registered using 

133 `register_read_only_resource_op` or if `handle` is an input at one of 

134 the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True 

135 otherwise. 

136 

137 Raises: 

138 ValueError: if `handle` is not an input of `op`. 

139 """ 

140 if op.type in RESOURCE_READ_OPS: 

141 return False 

142 input_index = _input_index(op, handle) 

143 try: 

144 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR) 

145 except ValueError: 

146 # Attr was not set. Conservatively assume that the resource is written to. 

147 return True 

148 return input_index not in read_only_input_indices 

149 

150 

151def _input_index(op, handle): 

152 """Returns the index of `handle` in `op.inputs`. 

153 

154 Args: 

155 op: Operation. 

156 handle: Resource handle. 

157 

158 Returns: 

159 Index in `op.inputs` receiving the resource `handle`. 

160 

161 Raises: 

162 ValueError: If handle and its replicated input are both not found in 

163 `op.inputs`. 

164 """ 

165 for i, t in enumerate(op.inputs): 

166 if handle is t: 

167 return i 

168 raise ValueError(f"{handle!s} not in list of inputs for op: {op!r}")