Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/core/function/capture/capture_container.py: 32%
177 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""FuncGraph and related functionality."""
17import collections as py_collections
18import functools
19from typing import Any, Callable, Hashable, Mapping, Optional
21from tensorflow.core.function import trace_type
22from tensorflow.python import pywrap_tfe
23from tensorflow.python.framework import dtypes
24from tensorflow.python.types import core
25from tensorflow.python.util import object_identity
28_EAGER_CONST_THRESHOLD = 128
31class MutationAwareDict(py_collections.OrderedDict):
32 """A dict with a mutation flag."""
34 def __init__(self, *args, **kwargs):
35 super().__init__(*args, **kwargs)
36 self._mutated = True
38 def pop(self, key, default=None):
39 self._mutated = True
40 return super().pop(key, default)
42 def __setitem__(self, key, value):
43 self._mutated = True
44 return super().__setitem__(key, value)
46 def __delitem__(self, key):
47 self._mutated = True
48 return super().__delitem__(key)
50 def clear(self):
51 self._mutated = True
52 return super().clear()
54 @property
55 def mutated(self):
56 return self._mutated
58 @mutated.setter
59 def mutated(self, value):
60 self._mutated = value
63class FunctionCaptures(object):
64 """A container for all capture usages within FuncGraph."""
66 def __init__(self):
67 self._by_ref_internal = py_collections.OrderedDict()
68 self._by_ref_external = py_collections.OrderedDict()
69 self._by_ref_tracetype = py_collections.OrderedDict()
70 self._by_val_internal = MutationAwareDict()
71 self._by_val_external = MutationAwareDict()
72 self._by_val_tracetype = py_collections.OrderedDict()
74 # Set of external ops on which the graph has a control dependency
75 self.control = object_identity.ObjectIdentitySet()
77 def clear(self):
78 self._by_ref_internal.clear()
79 self._by_ref_external.clear()
80 self._by_ref_tracetype.clear()
81 self._by_val_internal.clear()
82 self._by_val_external.clear()
84 def capture_by_value(
85 self,
86 graph: Any,
87 tensor: core.Tensor,
88 name: Optional[str] = None
89 ) -> core.Tensor:
90 """Captures `tensor` if it's external to this graph.
92 If `tensor` is from a different graph, returns a placeholder for it.
93 `tensor` and the placeholder will appear in self.captures, and the
94 placeholder will appear in self.inputs. Multiple calls to this method with
95 the same `tensor` argument will return the same placeholder. If `tensor` is
96 from this graph, returns `tensor`.
98 Args:
99 graph: The FuncGraph that captures this tensor.
100 tensor: Tensor. May be from this FuncGraph or a different graph.
101 name: Optional name if a placeholder is created.
103 Returns:
104 Tensor from this FuncGraph.
106 Raises:
107 InaccessibleTensorError: if any tensors are accessed in a manner that
108 bypasses the mechanisms required for the data dependencies to be correctly
109 wired.
110 """
111 if isinstance(tensor, core.Value):
112 if name is None:
113 # A unique (within the program execution) integer.
114 name = str(pywrap_tfe.TFE_Py_UID())
116 # Small EagerTensors are captured with Const ops
117 if (tensor.dtype in dtypes.TF_VALUE_DTYPES and
118 functools.reduce(lambda a, b: a*b, tensor.shape, 1) <=
119 _EAGER_CONST_THRESHOLD):
120 graph_const = self.by_val_internal.get(id(tensor))
121 if graph_const is None:
122 graph_const = tensor._capture_as_const(name) # pylint: disable=protected-access
123 if graph_const is None:
124 # Some eager tensors, e.g. parallel tensors, are not convertible to
125 # a single constant. We'll use a placeholder for this case.
126 graph_const = self._create_placeholder_helper(graph, tensor, name)
127 self.add_or_replace(
128 key=id(tensor),
129 external=tensor,
130 internal=graph_const,
131 is_by_ref=False)
132 graph.inputs.append(graph_const)
133 graph_const._record_tape(tensor) # pylint: disable=protected-access
134 return graph_const
136 # Large EagerTensors and resources are captured with Placeholder ops
137 return self._create_placeholder_helper(graph, tensor, name)
139 if tensor.graph is not graph:
140 graph._validate_in_scope(tensor) # pylint: disable=protected-access
141 if name is None:
142 assert tensor.op is not None, (
143 tensor.__class__,
144 dir(tensor),
145 tensor.__class__.__name__,
146 )
147 name = tensor.op.name
148 # cond/while graphs override _capture_helper() so cannot call
149 # self.create_placeholder_helper() here directly.
150 return graph._capture_helper(tensor, name) # pylint: disable=protected-access
151 return tensor
153 def add_or_replace(
154 self,
155 key: Hashable,
156 external: Any,
157 internal: core.Tensor,
158 tracetype: Any = None,
159 is_by_ref: bool = False) -> None:
160 """Replace a already exsiting capture, otherwise add it."""
161 if is_by_ref:
162 self._by_ref_external[key] = external
163 self._by_ref_internal[key] = internal
164 self._by_ref_tracetype[key] = tracetype
165 else:
166 self._by_val_internal[key] = internal
167 self._by_val_external[key] = external
168 if tracetype is not None:
169 self._by_val_tracetype[key] = tracetype
170 else:
171 self._by_val_tracetype[key] = trace_type.from_value(external)
173 def pop(self,
174 key: Hashable,
175 is_by_ref: bool = False) -> Any:
176 if is_by_ref:
177 return (self._by_ref_external.pop(key, None),
178 self._by_ref_internal.pop(key, None),
179 self._by_ref_tracetype.pop(key, None))
180 else:
181 return (self._by_val_external.pop(key, None),
182 self._by_val_internal.pop(key, None),
183 self._by_val_tracetype.pop(key, None))
185 def reset_captures(self, tensors, placeholders):
186 """Set the captures with the provided list of captures & placeholder."""
187 self._by_val_external = MutationAwareDict()
188 self._by_val_internal = MutationAwareDict()
189 self._by_val_tracetype = MutationAwareDict()
190 for external, internal in zip(tensors, placeholders):
191 key = id(external)
192 self._by_val_external[key] = external
193 self._by_val_internal[key] = internal
194 self._by_val_tracetype[key] = trace_type.from_value(external)
196 # TODO(panzf): make the method public after supporting lam() returns
197 # non-tensor values. Currently, this method is only used by
198 # FuncGraph._experimental_capture_side_input_by_ref(), which contains the
199 # logics for converting non-tensor values to tensor.
200 def _capture_by_ref(self,
201 graph: Any,
202 lam: Callable[[], Any],
203 key: Hashable = None) -> Any:
204 """Used during tracing process to create/retrive by-ref captures.
206 Args:
207 graph: The FuncGraph that captures this tensor.
208 lam: A callable that takes no arguments and returns tensor captures.
209 key: A hashable identifier.
211 Returns:
212 Tensor from this FuncGraph.
213 """
214 # Check if the capture exists in self._by_ref
215 if key is not None and key in self._by_ref_internal:
216 return self._by_ref_internal[key]
217 if key is None:
218 key = len(self._by_ref_internal)
219 while key in self._by_ref_internal:
220 key += 1
222 value_nested = lam()
223 capture_trace_type = trace_type.from_value(value_nested)
224 ctx = trace_type.InternalPlaceholderContext(graph)
225 internal = capture_trace_type.placeholder_value(ctx)
227 def lam_fn():
228 # pytype: disable=attribute-error
229 value = lam()
230 return capture_trace_type._to_tensors(value) # pylint: disable=protected-access
231 # pytype: enable=attribute-error
233 self._by_ref_external[key] = lam_fn
234 self._by_ref_internal[key] = internal
235 self._by_ref_tracetype[key] = capture_trace_type
236 return self._by_ref_internal[key]
238 def merge_by_ref_with(self, other: "FunctionCaptures") -> None:
239 """Add by-ref captures from `other` to `self` if not exist."""
240 assert isinstance(other, FunctionCaptures)
241 for key in other.by_ref_external:
242 if key not in self._by_ref_external:
243 self._by_ref_external[key] = other.by_ref_external[key]
244 self._by_ref_tracetype[key] = other.by_ref_tracetype[key]
246 def get_by_ref_snapshot(self) -> Mapping[Hashable, Any]:
247 """Get a snapshot of current values of by-ref captures."""
248 snapshot = {}
249 for key in self._by_ref_external:
250 func = self._by_ref_external[key]
251 try:
252 value = func()
253 except (AttributeError, RuntimeError):
254 # b/269680071 In case of by-ref captures are unavailable at dispatch
255 # time, use the predefined trace_type instead.
256 value = self._by_ref_tracetype[key]
257 snapshot[key] = value
258 return snapshot
260 def _create_placeholder_helper(
261 self,
262 graph: Any,
263 tensor: core.Tensor,
264 name: str):
265 """A helper function to create capture placeholder."""
266 placeholder = self._by_val_internal.get(id(tensor))
267 if placeholder is None:
268 tracing_ctx = trace_type.InternalTracingContext()
269 spec = trace_type.from_value(tensor, tracing_ctx)
270 spec._name = name # pylint: disable=protected-access
271 if isinstance(tensor, core.Value) and tensor.is_packed:
272 composite_device_name = tensor.device
273 else:
274 composite_device_name = None
275 placeholder_ctx = trace_type.InternalPlaceholderContext(
276 graph,
277 with_none_control_dependencies=True,
278 composite_device_name=composite_device_name)
279 placeholder = spec.placeholder_value(placeholder_ctx)
280 self.add_or_replace(
281 key=id(tensor),
282 external=tensor,
283 internal=placeholder,
284 is_by_ref=False)
285 graph.inputs.append(placeholder)
286 placeholder._record_tape(tensor) # pylint: disable=protected-access
287 return placeholder
289 def _recompute_tuple_cache(self):
290 assert len(self._by_val_internal) == len(self._by_val_external)
291 self._tuple_cache = []
292 for key in self._by_val_internal:
293 assert key in self._by_val_external
294 internal = self._by_val_internal[key]
295 external = self._by_val_external[key]
296 self._tuple_cache.append((external, internal))
298 @property
299 def capture_types(self):
300 return {**self._by_val_tracetype, **self._by_ref_tracetype}
302 @property
303 def by_val_capture_tuples(self):
304 if self._by_val_internal.mutated or self._by_val_external.mutated:
305 self. _recompute_tuple_cache()
306 self._by_val_internal.mutated = False
307 self._by_val_external.mutated = False
308 return self._tuple_cache
310 @property
311 def by_ref_internal(self):
312 return self._by_ref_internal
314 @property
315 def by_ref_external(self):
316 return self._by_ref_external
318 @property
319 def by_ref_tracetype(self):
320 return self._by_ref_tracetype
322 @property
323 def by_val_internal(self):
324 return self._by_val_internal
326 @property
327 def by_val_external(self):
328 return self._by_val_external
330 @property
331 def by_val_tracetype(self):
332 return self._by_val_tracetype