Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/eager/execute.py: 4%
156 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
« prev ^ index » next coverage.py v7.3.2, created at 2023-10-05 06:32 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions called by the generated code to execute an eager-mode op."""
17from google.protobuf import text_format
18from tensorflow.core.framework import tensor_pb2
19from tensorflow.python import pywrap_tfe
20from tensorflow.python.eager import core
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import tensor_conversion_registry
23from tensorflow.python.framework import tensor_shape
24from tensorflow.python.types import core as core_types
25from tensorflow.python.util import compat
28def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
29 """Execute a TensorFlow operation.
31 Args:
32 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
33 execute.
34 num_outputs: The number of outputs of the operation to fetch. (Explicitly
35 provided instead of being inferred for performance reasons).
36 inputs: A list of inputs to the operation. Each entry should be a Tensor, or
37 a value which can be passed to the Tensor constructor to create one.
38 attrs: A tuple with alternating string attr names and attr values for this
39 operation.
40 ctx: The value of context.context().
41 name: Customized name for the operation.
43 Returns:
44 List of output Tensor objects. The list is empty if there are no outputs
46 Raises:
47 An exception on error.
48 """
49 device_name = ctx.device_name
50 # pylint: disable=protected-access
51 try:
52 ctx.ensure_initialized()
53 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
54 inputs, attrs, num_outputs)
55 except core._NotOkStatusException as e:
56 if name is not None:
57 e.message += " name: " + name
58 raise core._status_to_exception(e) from None
59 except TypeError as e:
60 keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
61 if keras_symbolic_tensors:
62 raise core._SymbolicException(
63 "Inputs to eager execution function cannot be Keras symbolic "
64 "tensors, but found {}".format(keras_symbolic_tensors))
65 raise e
66 # pylint: enable=protected-access
67 return tensors
70def execute_with_cancellation(op_name,
71 num_outputs,
72 inputs,
73 attrs,
74 ctx,
75 cancellation_manager,
76 name=None):
77 """Execute a TensorFlow operation.
79 Args:
80 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
81 execute.
82 num_outputs: The number of outputs of the operation to fetch. (Explicitly
83 provided instead of being inferred for performance reasons).
84 inputs: A list of inputs to the operation. Each entry should be a Tensor, or
85 a value which can be passed to the Tensor constructor to create one.
86 attrs: A tuple with alternating string attr names and attr values for this
87 operation.
88 ctx: The value of context.context().
89 cancellation_manager: a `CancellationManager` object that can be used to
90 cancel the operation.
91 name: Customized name for the operation.
93 Returns:
94 List of output Tensor objects. The list is empty if there are no outputs
96 Raises:
97 An exception on error.
98 """
99 device_name = ctx.device_name
100 # pylint: disable=protected-access
101 try:
102 ctx.ensure_initialized()
103 tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
104 op_name, inputs, attrs,
105 cancellation_manager._impl,
106 num_outputs)
107 except core._NotOkStatusException as e:
108 if name is not None:
109 e.message += " name: " + name
110 raise core._status_to_exception(e) from None
111 except TypeError as e:
112 keras_symbolic_tensors = [x for x in inputs if _is_keras_symbolic_tensor(x)]
113 if keras_symbolic_tensors:
114 raise core._SymbolicException(
115 "Inputs to eager execution function cannot be Keras symbolic "
116 "tensors, but found {}".format(keras_symbolic_tensors))
117 raise e
118 # pylint: enable=protected-access
119 return tensors
122def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None):
123 """Monkey-patch to execute to enable execution callbacks."""
124 tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
125 for callback in ctx.op_callbacks:
126 callback(op_name, tuple(inputs), attrs, tensors, name)
128 return tensors
131execute = quick_execute
134def must_record_gradient():
135 """Import backprop if you want gradients recorded."""
136 return False
139def record_gradient(unused_op_name, unused_inputs, unused_attrs,
140 unused_outputs):
141 """Import backprop if you want gradients recorded."""
142 pass
145def make_float(v, arg_name):
146 if not isinstance(v, compat.real_types):
147 raise TypeError("Expected float for argument '%s' not %s." %
148 (arg_name, repr(v)))
149 return float(v)
152def make_int(v, arg_name):
153 if isinstance(v, str):
154 raise TypeError("Expected int for argument '%s' not %s." %
155 (arg_name, repr(v)))
156 try:
157 return int(v)
158 except (ValueError, TypeError):
159 raise TypeError("Expected int for argument '%s' not %s." %
160 (arg_name, repr(v)))
163def make_str(v, arg_name):
164 if not isinstance(v, compat.bytes_or_text_types):
165 raise TypeError("Expected string for argument '%s' not %s." %
166 (arg_name, repr(v)))
167 return compat.as_bytes(v) # Convert unicode strings to bytes.
170def make_bool(v, arg_name):
171 if not isinstance(v, bool):
172 raise TypeError("Expected bool for argument '%s' not %s." %
173 (arg_name, repr(v)))
174 return v
177def make_type(v, arg_name):
178 try:
179 v = dtypes.as_dtype(v).base_dtype
180 except TypeError:
181 raise TypeError("Expected DataType for argument '%s' not %s." %
182 (arg_name, repr(v)))
183 i = v.as_datatype_enum
184 return i
187def make_shape(v, arg_name):
188 """Convert v into a list."""
189 # Args:
190 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
191 # arg_name: String, for error messages.
193 # Returns:
194 # None if the rank is unknown, otherwise a list of ints (or Nones in the
195 # position where the dimension is unknown).
196 try:
197 shape = tensor_shape.as_shape(v)
198 except TypeError as e:
199 raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))
200 except ValueError as e:
201 raise ValueError("Error converting %s to a TensorShape: %s." %
202 (arg_name, e))
203 if shape.ndims is None:
204 return None
205 else:
206 return shape.as_list()
209def make_tensor(v, arg_name):
210 """Ensure v is a TensorProto."""
211 if isinstance(v, tensor_pb2.TensorProto):
212 return v
213 elif isinstance(v, str):
214 pb = tensor_pb2.TensorProto()
215 text_format.Merge(v, pb)
216 return pb
217 raise TypeError(
218 "Don't know how to convert %s to a TensorProto for argument '%s'." %
219 (repr(v), arg_name))
222def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
223 """Convert sequence `l` to eager same-type Tensors."""
224 del ctx # Unused
225 if (not l) and (default_dtype is not None):
226 return default_dtype, [] # List is empty; assume default dtype.
227 for x in l:
228 if not isinstance(x, core_types.Value):
229 break
230 else: # note: intentional for-else
231 return l[0]._datatype_enum(), l # pylint: disable=protected-access
233 # Is some input already a Tensor with a dtype?
234 dtype = None
235 for t in l:
236 if isinstance(t, core_types.Value):
237 dtype = t.dtype
238 break
240 if dtype is None:
241 # Infer a dtype based on the first value, and use that dtype for the
242 # remaining values.
244 ret = []
245 for t in l:
246 tensor = None
247 # First see if we can get a valid dtype with the default conversion
248 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
249 # not list allowed dtypes, in which case we should skip this.
250 if dtype is None and allowed_dtypes:
251 tensor = tensor_conversion_registry.convert(t)
252 # If we did not match an allowed dtype, try again with the default
253 # dtype. This could be because we have an empty tensor and thus we
254 # picked the wrong type.
255 if tensor.dtype not in allowed_dtypes:
256 tensor = None
258 if tensor is None:
259 tensor = tensor_conversion_registry.convert(
260 t, dtype, preferred_dtype=default_dtype
261 )
263 ret.append(tensor)
264 if dtype is None:
265 dtype = tensor.dtype
266 else:
267 ret = [tensor_conversion_registry.convert(t, dtype) for t in l]
269 # TODO(slebedev): consider removing this as it leaks a Keras concept.
270 # pylint: disable=protected-access
271 keras_symbolic_tensors = [x for x in ret if _is_keras_symbolic_tensor(x)]
272 if keras_symbolic_tensors:
273 raise core._SymbolicException(
274 "Using symbolic output of a Keras layer during eager execution "
275 "{}".format(keras_symbolic_tensors))
276 # pylint: enable=protected-access
277 return dtype.as_datatype_enum, ret
280def convert_to_mixed_eager_tensors(values, ctx):
281 del ctx # Unused
282 v = [tensor_conversion_registry.convert(t) for t in values]
283 types = [t._datatype_enum() for t in v] # pylint: disable=protected-access
284 return types, v
287def args_to_mixed_eager_tensors(lists, ctx):
288 """Converts a list of same-length lists of values to eager tensors."""
289 del ctx # Unused
290 assert len(lists) > 1
292 # Generate an error if len(lists[i]) is not the same for all i.
293 lists_ret = [[]]
294 for l in lists[1:]:
295 if len(l) != len(lists[0]):
296 raise ValueError(
297 "Expected list arguments to be the same length: %d != %d (%r vs. %r)."
298 % (len(lists[0]), len(l), lists[0], l))
299 lists_ret.append([])
301 # Convert the first element of each list first, then the second element, etc.
302 types = []
303 for i in range(len(lists[0])):
304 dtype = None
305 # If any list has a Tensor, use that dtype
306 for l in lists:
307 if isinstance(l[i], core_types.Value):
308 dtype = l[i].dtype
309 break
310 if dtype is None:
311 # Convert the first one and use its dtype.
312 lists_ret[0].append(tensor_conversion_registry.convert(lists[0][i]))
313 dtype = lists_ret[0][i].dtype
314 for j in range(1, len(lists)):
315 lists_ret[j].append(
316 tensor_conversion_registry.convert(lists[j][i], dtype=dtype)
317 )
318 else:
319 # Convert everything to the found dtype.
320 for j in range(len(lists)):
321 lists_ret[j].append(
322 tensor_conversion_registry.convert(lists[j][i], dtype=dtype)
323 )
324 types.append(dtype.as_datatype_enum)
325 return types, lists_ret
328def _is_keras_symbolic_tensor(x):
329 return hasattr(x, "graph") and getattr(x.graph, "name", None) == "keras_graph"