Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/framework/auto_control_deps_utils.py: 16%
80 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities for AutomaticControlDependencies."""
17from tensorflow.python.framework import dtypes
18from tensorflow.python.util import object_identity
20READ_ONLY_RESOURCE_INPUTS_ATTR = "_read_only_resource_inputs"
21RESOURCE_READ_OPS = set()
24COLLECTIVE_MANAGER_IDS = "_collective_manager_ids"
27def register_read_only_resource_op(op_type):
28 """Declares that `op_type` does not update its touched resource."""
29 RESOURCE_READ_OPS.add(op_type)
32def get_read_only_resource_input_indices_graph(func_graph):
33 """Returns sorted list of read-only resource indices in func_graph.inputs."""
34 result = []
35 # A cache to store the read only resource inputs of an Op.
36 # Operation -> ObjectIdentitySet of resource handles.
37 op_read_only_resource_inputs = {}
38 for input_index, t in enumerate(func_graph.inputs):
39 if t.dtype != dtypes.resource:
40 continue
41 read_only = True
42 for op in t.consumers():
43 if op in op_read_only_resource_inputs:
44 if t not in op_read_only_resource_inputs[op]:
45 read_only = False
46 break
47 else:
48 indices = _get_read_only_resource_input_indices_op(op)
49 op_read_only_resource_inputs[op] = object_identity.ObjectIdentitySet(
50 [op.inputs[i] for i in indices])
51 if t not in op_read_only_resource_inputs[op]:
52 read_only = False
53 break
54 if read_only:
55 result.append(input_index)
56 return result
59def _get_read_only_resource_input_indices_op(op):
60 """Returns sorted list of read-only resource indices in op.inputs."""
61 if op.type in RESOURCE_READ_OPS:
62 return [i for i, t in enumerate(op.inputs) if t.dtype == dtypes.resource]
64 try:
65 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
66 except ValueError:
67 # Attr was not set. Add all resource inputs to `writes` and return.
68 return []
70 read_only_index = 0
71 result = []
72 for i, t in enumerate(op.inputs):
73 if read_only_index >= len(read_only_input_indices):
74 break
75 if op.inputs[i].dtype != dtypes.resource:
76 continue
77 if (read_only_index < len(read_only_input_indices) and
78 i == read_only_input_indices[read_only_index]):
79 result.append(i)
80 read_only_index += 1
82 return result
85def get_read_write_resource_inputs(op):
86 """Returns a tuple of resource reads, writes in op.inputs.
88 Args:
89 op: Operation
91 Returns:
92 A 2-tuple of ObjectIdentitySets, the first entry containing read-only
93 resource handles and the second containing read-write resource handles in
94 `op.inputs`.
95 """
96 reads = object_identity.ObjectIdentitySet()
97 writes = object_identity.ObjectIdentitySet()
99 if op.type in RESOURCE_READ_OPS:
100 # Add all resource inputs to `reads` and return.
101 reads.update(t for t in op.inputs if t.dtype == dtypes.resource)
102 return (reads, writes)
104 try:
105 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
106 except ValueError:
107 # Attr was not set. Add all resource inputs to `writes` and return.
108 writes.update(t for t in op.inputs if t.dtype == dtypes.resource)
109 return (reads, writes)
111 read_only_index = 0
112 for i, t in enumerate(op.inputs):
113 if op.inputs[i].dtype != dtypes.resource:
114 continue
115 if (read_only_index < len(read_only_input_indices) and
116 i == read_only_input_indices[read_only_index]):
117 reads.add(op.inputs[i])
118 read_only_index += 1
119 else:
120 writes.add(op.inputs[i])
121 return (reads, writes)
124def _op_writes_to_resource(handle, op):
125 """Returns whether op writes to resource handle.
127 Args:
128 handle: Resource handle. Must be an input of `op`.
129 op: Operation.
131 Returns:
132 Returns False if op is a read-only op registered using
133 `register_read_only_resource_op` or if `handle` is an input at one of
134 the indices in the `READ_ONLY_RESOURCE_INPUTS_ATTR` attr of the op, True
135 otherwise.
137 Raises:
138 ValueError: if `handle` is not an input of `op`.
139 """
140 if op.type in RESOURCE_READ_OPS:
141 return False
142 input_index = _input_index(op, handle)
143 try:
144 read_only_input_indices = op.get_attr(READ_ONLY_RESOURCE_INPUTS_ATTR)
145 except ValueError:
146 # Attr was not set. Conservatively assume that the resource is written to.
147 return True
148 return input_index not in read_only_input_indices
151def _input_index(op, handle):
152 """Returns the index of `handle` in `op.inputs`.
154 Args:
155 op: Operation.
156 handle: Resource handle.
158 Returns:
159 Index in `op.inputs` receiving the resource `handle`.
161 Raises:
162 ValueError: If handle and its replicated input are both not found in
163 `op.inputs`.
164 """
165 for i, t in enumerate(op.inputs):
166 if handle is t:
167 return i
168 raise ValueError(f"{handle!s} not in list of inputs for op: {op!r}")