Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gen_functional_ops.py: 9%
644 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Python wrappers around TensorFlow ops.
3This file is MACHINE GENERATED! Do not edit.
4"""
6import collections
8from tensorflow.python import pywrap_tfe as pywrap_tfe
9from tensorflow.python.eager import context as _context
10from tensorflow.python.eager import core as _core
11from tensorflow.python.eager import execute as _execute
12from tensorflow.python.framework import dtypes as _dtypes
13from tensorflow.security.fuzzing.py import annotation_types as _atypes
15from tensorflow.python.framework import op_def_registry as _op_def_registry
16from tensorflow.python.framework import ops as _ops
17from tensorflow.python.framework import op_def_library as _op_def_library
18from tensorflow.python.util.deprecation import deprecated_endpoints
19from tensorflow.python.util import dispatch as _dispatch
20from tensorflow.python.util.tf_export import tf_export
22from typing import TypeVar
24def case(branch_index, input, Tout, branches, output_shapes=[], name=None):
25 r"""An n-way switch statement which calls a single branch function.
27 An n-way switch statement, implementing the following:
28 ```
29 switch (branch_index) {
30 case 0:
31 output = branches[0](input);
32 break;
33 case 1:
34 output = branches[1](input);
35 break;
36 ...
37 case [[nbranches-1]]:
38 default:
39 output = branches[nbranches-1](input);
40 break;
41 }
42 ```
44 Args:
45 branch_index: A `Tensor` of type `int32`.
46 The branch selector, an int32 Tensor.
47 input: A list of `Tensor` objects.
48 A list of input tensors passed to the branch function.
49 Tout: A list of `tf.DTypes`. A list of output types.
50 branches: A list of functions decorated with @Defun that has length `>= 1`.
51 A list of functions each of which takes 'inputs' and returns a list of
52 tensors, whose types are the same as what every other branch returns.
53 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
54 name: A name for the operation (optional).
56 Returns:
57 A list of `Tensor` objects of type `Tout`.
58 """
59 _ctx = _context._context or _context.context()
60 tld = _ctx._thread_local_data
61 if tld.is_eager:
62 try:
63 _result = pywrap_tfe.TFE_Py_FastPathExecute(
64 _ctx, "Case", name, branch_index, input, "Tout", Tout, "branches",
65 branches, "output_shapes", output_shapes)
66 return _result
67 except _core._NotOkStatusException as e:
68 _ops.raise_from_not_ok_status(e, name)
69 except _core._FallbackException:
70 pass
71 try:
72 return case_eager_fallback(
73 branch_index, input, Tout=Tout, branches=branches,
74 output_shapes=output_shapes, name=name, ctx=_ctx)
75 except _core._SymbolicException:
76 pass # Add nodes to the TensorFlow graph.
77 # Add nodes to the TensorFlow graph.
78 if not isinstance(Tout, (list, tuple)):
79 raise TypeError(
80 "Expected list for 'Tout' argument to "
81 "'case' Op, not %r." % Tout)
82 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
83 if not isinstance(branches, (list, tuple)):
84 raise TypeError(
85 "Expected list for 'branches' argument to "
86 "'case' Op, not %r." % branches)
87 if output_shapes is None:
88 output_shapes = []
89 if not isinstance(output_shapes, (list, tuple)):
90 raise TypeError(
91 "Expected list for 'output_shapes' argument to "
92 "'case' Op, not %r." % output_shapes)
93 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
94 _, _, _op, _outputs = _op_def_library._apply_op_helper(
95 "Case", branch_index=branch_index, input=input, Tout=Tout,
96 branches=branches, output_shapes=output_shapes, name=name)
97 _result = _outputs[:]
98 if not _result:
99 return _op
100 if _execute.must_record_gradient():
101 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
102 "branches", _op.get_attr("branches"), "output_shapes",
103 _op.get_attr("output_shapes"))
104 _inputs_flat = _op.inputs
105 _execute.record_gradient(
106 "Case", _inputs_flat, _attrs, _result)
107 return _result
109Case = tf_export("raw_ops.Case")(_ops.to_raw_op(case))
112def case_eager_fallback(branch_index, input, Tout, branches, output_shapes, name, ctx):
113 if not isinstance(Tout, (list, tuple)):
114 raise TypeError(
115 "Expected list for 'Tout' argument to "
116 "'case' Op, not %r." % Tout)
117 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
118 if not isinstance(branches, (list, tuple)):
119 raise TypeError(
120 "Expected list for 'branches' argument to "
121 "'case' Op, not %r." % branches)
122 if output_shapes is None:
123 output_shapes = []
124 if not isinstance(output_shapes, (list, tuple)):
125 raise TypeError(
126 "Expected list for 'output_shapes' argument to "
127 "'case' Op, not %r." % output_shapes)
128 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
129 _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
130 branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
131 _inputs_flat = [branch_index] + list(input)
132 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
133 "output_shapes", output_shapes)
134 _result = _execute.execute(b"Case", len(Tout), inputs=_inputs_flat,
135 attrs=_attrs, ctx=ctx, name=name)
136 if _execute.must_record_gradient():
137 _execute.record_gradient(
138 "Case", _inputs_flat, _attrs, _result)
139 return _result
142def device_index(device_names, name=None):
143 r"""Return the index of device the op runs.
145 Given a list of device names, this operation returns the index of the device
146 this op runs. The length of the list is returned in two cases:
147 (1) Device does not exist in the given device list.
148 (2) It is in XLA compilation.
150 Args:
151 device_names: A list of `strings`.
152 name: A name for the operation (optional).
154 Returns:
155 A `Tensor` of type `int32`.
156 """
157 _ctx = _context._context or _context.context()
158 tld = _ctx._thread_local_data
159 if tld.is_eager:
160 try:
161 _result = pywrap_tfe.TFE_Py_FastPathExecute(
162 _ctx, "DeviceIndex", name, "device_names", device_names)
163 return _result
164 except _core._NotOkStatusException as e:
165 _ops.raise_from_not_ok_status(e, name)
166 except _core._FallbackException:
167 pass
168 try:
169 return device_index_eager_fallback(
170 device_names=device_names, name=name, ctx=_ctx)
171 except _core._SymbolicException:
172 pass # Add nodes to the TensorFlow graph.
173 # Add nodes to the TensorFlow graph.
174 if not isinstance(device_names, (list, tuple)):
175 raise TypeError(
176 "Expected list for 'device_names' argument to "
177 "'device_index' Op, not %r." % device_names)
178 device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
179 _, _, _op, _outputs = _op_def_library._apply_op_helper(
180 "DeviceIndex", device_names=device_names, name=name)
181 _result = _outputs[:]
182 if _execute.must_record_gradient():
183 _attrs = ("device_names", _op.get_attr("device_names"))
184 _inputs_flat = _op.inputs
185 _execute.record_gradient(
186 "DeviceIndex", _inputs_flat, _attrs, _result)
187 _result, = _result
188 return _result
190DeviceIndex = tf_export("raw_ops.DeviceIndex")(_ops.to_raw_op(device_index))
193def device_index_eager_fallback(device_names, name, ctx):
194 if not isinstance(device_names, (list, tuple)):
195 raise TypeError(
196 "Expected list for 'device_names' argument to "
197 "'device_index' Op, not %r." % device_names)
198 device_names = [_execute.make_str(_s, "device_names") for _s in device_names]
199 _inputs_flat = []
200 _attrs = ("device_names", device_names)
201 _result = _execute.execute(b"DeviceIndex", 1, inputs=_inputs_flat,
202 attrs=_attrs, ctx=ctx, name=name)
203 if _execute.must_record_gradient():
204 _execute.record_gradient(
205 "DeviceIndex", _inputs_flat, _attrs, _result)
206 _result, = _result
207 return _result
210def fake_param(dtype, shape, name=None):
211 r""" This op is used as a placeholder in If branch functions. It doesn't provide a
212 valid output when run, so must either be removed (e.g. replaced with a
213 function input) or guaranteed not to be used (e.g. if mirroring an
214 intermediate output needed for the gradient computation of the other branch).
216 Args:
217 dtype: A `tf.DType`. The type of the output.
218 shape: A `tf.TensorShape` or list of `ints`.
219 The purported shape of the output. This is only used for shape inference;
220 the output will not necessarily have this shape. Can be a partial shape.
221 name: A name for the operation (optional).
223 Returns:
224 A `Tensor` of type `dtype`.
225 """
226 _ctx = _context._context or _context.context()
227 tld = _ctx._thread_local_data
228 if tld.is_eager:
229 try:
230 _result = pywrap_tfe.TFE_Py_FastPathExecute(
231 _ctx, "FakeParam", name, "dtype", dtype, "shape", shape)
232 return _result
233 except _core._NotOkStatusException as e:
234 _ops.raise_from_not_ok_status(e, name)
235 except _core._FallbackException:
236 pass
237 try:
238 return fake_param_eager_fallback(
239 dtype=dtype, shape=shape, name=name, ctx=_ctx)
240 except _core._SymbolicException:
241 pass # Add nodes to the TensorFlow graph.
242 # Add nodes to the TensorFlow graph.
243 dtype = _execute.make_type(dtype, "dtype")
244 shape = _execute.make_shape(shape, "shape")
245 _, _, _op, _outputs = _op_def_library._apply_op_helper(
246 "FakeParam", dtype=dtype, shape=shape, name=name)
247 _result = _outputs[:]
248 if _execute.must_record_gradient():
249 _attrs = ("dtype", _op._get_attr_type("dtype"), "shape",
250 _op.get_attr("shape"))
251 _inputs_flat = _op.inputs
252 _execute.record_gradient(
253 "FakeParam", _inputs_flat, _attrs, _result)
254 _result, = _result
255 return _result
257FakeParam = tf_export("raw_ops.FakeParam")(_ops.to_raw_op(fake_param))
260def fake_param_eager_fallback(dtype, shape, name, ctx):
261 dtype = _execute.make_type(dtype, "dtype")
262 shape = _execute.make_shape(shape, "shape")
263 _inputs_flat = []
264 _attrs = ("dtype", dtype, "shape", shape)
265 _result = _execute.execute(b"FakeParam", 1, inputs=_inputs_flat,
266 attrs=_attrs, ctx=ctx, name=name)
267 if _execute.must_record_gradient():
268 _execute.record_gradient(
269 "FakeParam", _inputs_flat, _attrs, _result)
270 _result, = _result
271 return _result
274def _for(start, limit, delta, input, body, name=None):
275 r"""Applies a for loop.
277 ```python
278 output = input;
279 for i in range(start, limit, delta)
280 output = body(i, output);
281 ```
283 Args:
284 start: A `Tensor` of type `int32`. The lower bound. An int32
285 limit: A `Tensor` of type `int32`. The upper bound. An int32
286 delta: A `Tensor` of type `int32`. The increment. An int32
287 input: A list of `Tensor` objects.
288 A list of input tensors whose types are T.
289 body: A function decorated with @Defun.
290 A function that takes a list of tensors (int32, T) and returns another
291 list of tensors (T).
292 name: A name for the operation (optional).
294 Returns:
295 A list of `Tensor` objects. Has the same type as `input`.
296 """
297 _ctx = _context._context or _context.context()
298 tld = _ctx._thread_local_data
299 if tld.is_eager:
300 try:
301 _result = pywrap_tfe.TFE_Py_FastPathExecute(
302 _ctx, "For", name, start, limit, delta, input, "body", body)
303 return _result
304 except _core._NotOkStatusException as e:
305 _ops.raise_from_not_ok_status(e, name)
306 except _core._FallbackException:
307 pass
308 try:
309 return _for_eager_fallback(
310 start, limit, delta, input, body=body, name=name, ctx=_ctx)
311 except _core._SymbolicException:
312 pass # Add nodes to the TensorFlow graph.
313 # Add nodes to the TensorFlow graph.
314 _, _, _op, _outputs = _op_def_library._apply_op_helper(
315 "For", start=start, limit=limit, delta=delta, input=input, body=body,
316 name=name)
317 _result = _outputs[:]
318 if _execute.must_record_gradient():
319 _attrs = ("T", _op.get_attr("T"), "body", _op.get_attr("body"))
320 _inputs_flat = _op.inputs
321 _execute.record_gradient(
322 "For", _inputs_flat, _attrs, _result)
323 return _result
325For = tf_export("raw_ops.For")(_ops.to_raw_op(_for))
328def _for_eager_fallback(start, limit, delta, input, body, name, ctx):
329 _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
330 start = _ops.convert_to_tensor(start, _dtypes.int32)
331 limit = _ops.convert_to_tensor(limit, _dtypes.int32)
332 delta = _ops.convert_to_tensor(delta, _dtypes.int32)
333 _inputs_flat = [start, limit, delta] + list(input)
334 _attrs = ("T", _attr_T, "body", body)
335 _result = _execute.execute(b"For", len(input), inputs=_inputs_flat,
336 attrs=_attrs, ctx=ctx, name=name)
337 if _execute.must_record_gradient():
338 _execute.record_gradient(
339 "For", _inputs_flat, _attrs, _result)
340 return _result
343def _if(cond, input, Tout, then_branch, else_branch, output_shapes=[], name=None):
344 r"""output = cond ? then_branch(input) : else_branch(input)
346 Args:
347 cond: A `Tensor`.
348 A Tensor. If the tensor is a scalar of non-boolean type, the
349 scalar is converted to a boolean according to the
350 following rule: if the scalar is a numerical value, non-zero means
351 `True` and zero means False; if the scalar is a string, non-empty
352 means `True` and empty means `False`. If the tensor is not a scalar,
353 being empty means False and being non-empty means True.
354 input: A list of `Tensor` objects. A list of input tensors.
355 Tout: A list of `tf.DTypes`. A list of output types.
356 then_branch: A function decorated with @Defun.
357 A function that takes 'inputs' and returns a list of tensors, whose
358 types are the same as what else_branch returns.
359 else_branch: A function decorated with @Defun.
360 A function that takes 'inputs' and returns a list of tensors, whose
361 types are the same as what then_branch returns.
362 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
363 name: A name for the operation (optional).
365 Returns:
366 A list of `Tensor` objects of type `Tout`.
367 """
368 _ctx = _context._context or _context.context()
369 tld = _ctx._thread_local_data
370 if tld.is_eager:
371 try:
372 _result = pywrap_tfe.TFE_Py_FastPathExecute(
373 _ctx, "If", name, cond, input, "Tout", Tout, "then_branch",
374 then_branch, "else_branch", else_branch, "output_shapes",
375 output_shapes)
376 return _result
377 except _core._NotOkStatusException as e:
378 _ops.raise_from_not_ok_status(e, name)
379 except _core._FallbackException:
380 pass
381 try:
382 return _if_eager_fallback(
383 cond, input, Tout=Tout, then_branch=then_branch,
384 else_branch=else_branch, output_shapes=output_shapes, name=name,
385 ctx=_ctx)
386 except _core._SymbolicException:
387 pass # Add nodes to the TensorFlow graph.
388 # Add nodes to the TensorFlow graph.
389 if not isinstance(Tout, (list, tuple)):
390 raise TypeError(
391 "Expected list for 'Tout' argument to "
392 "'if' Op, not %r." % Tout)
393 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
394 if output_shapes is None:
395 output_shapes = []
396 if not isinstance(output_shapes, (list, tuple)):
397 raise TypeError(
398 "Expected list for 'output_shapes' argument to "
399 "'if' Op, not %r." % output_shapes)
400 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
401 _, _, _op, _outputs = _op_def_library._apply_op_helper(
402 "If", cond=cond, input=input, Tout=Tout, then_branch=then_branch,
403 else_branch=else_branch, output_shapes=output_shapes, name=name)
404 _result = _outputs[:]
405 if not _result:
406 return _op
407 if _execute.must_record_gradient():
408 _attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
409 _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
410 "then_branch", _op.get_attr("then_branch"), "else_branch",
411 _op.get_attr("else_branch"), "output_shapes",
412 _op.get_attr("output_shapes"))
413 _inputs_flat = _op.inputs
414 _execute.record_gradient(
415 "If", _inputs_flat, _attrs, _result)
416 return _result
418If = tf_export("raw_ops.If")(_ops.to_raw_op(_if))
421def _if_eager_fallback(cond, input, Tout, then_branch, else_branch, output_shapes, name, ctx):
422 if not isinstance(Tout, (list, tuple)):
423 raise TypeError(
424 "Expected list for 'Tout' argument to "
425 "'if' Op, not %r." % Tout)
426 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
427 if output_shapes is None:
428 output_shapes = []
429 if not isinstance(output_shapes, (list, tuple)):
430 raise TypeError(
431 "Expected list for 'output_shapes' argument to "
432 "'if' Op, not %r." % output_shapes)
433 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
434 _attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
435 _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
436 _inputs_flat = [cond] + list(input)
437 _attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
438 "then_branch", then_branch, "else_branch", else_branch, "output_shapes",
439 output_shapes)
440 _result = _execute.execute(b"If", len(Tout), inputs=_inputs_flat,
441 attrs=_attrs, ctx=ctx, name=name)
442 if _execute.must_record_gradient():
443 _execute.record_gradient(
444 "If", _inputs_flat, _attrs, _result)
445 return _result
448def partitioned_call(args, Tout, f, config="", config_proto="", executor_type="", name=None):
449 r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
451 Asynchronously executes a function, potentially across multiple devices but
452 within a single process. The kernel places and partitions a given function's
453 underlying graph, and executes each of the partitioned subgraphs as a function.
455 Args:
456 args: A list of `Tensor` objects. A list of input tensors.
457 Tout: A list of `tf.DTypes`. A list of output types.
458 f: A function decorated with @Defun.
459 A function that takes 'args', a list of tensors, and returns 'output',
460 another list of tensors. Input and output types are specified by 'Tin'
461 and 'Tout'. The function body of f will be placed and partitioned across
462 devices, setting this op apart from the regular Call op.
463 config: An optional `string`. Defaults to `""`.
464 config_proto: An optional `string`. Defaults to `""`.
465 executor_type: An optional `string`. Defaults to `""`.
466 name: A name for the operation (optional).
468 Returns:
469 A list of `Tensor` objects of type `Tout`.
470 """
471 _ctx = _context._context or _context.context()
472 tld = _ctx._thread_local_data
473 if tld.is_eager:
474 try:
475 _result = pywrap_tfe.TFE_Py_FastPathExecute(
476 _ctx, "PartitionedCall", name, args, "Tout", Tout, "f", f, "config",
477 config, "config_proto", config_proto, "executor_type", executor_type)
478 return _result
479 except _core._NotOkStatusException as e:
480 _ops.raise_from_not_ok_status(e, name)
481 except _core._FallbackException:
482 pass
483 try:
484 return partitioned_call_eager_fallback(
485 args, Tout=Tout, f=f, config=config, config_proto=config_proto,
486 executor_type=executor_type, name=name, ctx=_ctx)
487 except _core._SymbolicException:
488 pass # Add nodes to the TensorFlow graph.
489 # Add nodes to the TensorFlow graph.
490 if not isinstance(Tout, (list, tuple)):
491 raise TypeError(
492 "Expected list for 'Tout' argument to "
493 "'partitioned_call' Op, not %r." % Tout)
494 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
495 if config is None:
496 config = ""
497 config = _execute.make_str(config, "config")
498 if config_proto is None:
499 config_proto = ""
500 config_proto = _execute.make_str(config_proto, "config_proto")
501 if executor_type is None:
502 executor_type = ""
503 executor_type = _execute.make_str(executor_type, "executor_type")
504 _, _, _op, _outputs = _op_def_library._apply_op_helper(
505 "PartitionedCall", args=args, Tout=Tout, f=f, config=config,
506 config_proto=config_proto,
507 executor_type=executor_type, name=name)
508 _result = _outputs[:]
509 if _execute.must_record_gradient():
510 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
511 _op.get_attr("f"), "config", _op.get_attr("config"),
512 "config_proto", _op.get_attr("config_proto"), "executor_type",
513 _op.get_attr("executor_type"))
514 _inputs_flat = _op.inputs
515 _execute.record_gradient(
516 "PartitionedCall", _inputs_flat, _attrs, _result)
517 return _result
519PartitionedCall = tf_export("raw_ops.PartitionedCall")(_ops.to_raw_op(partitioned_call))
522def partitioned_call_eager_fallback(args, Tout, f, config, config_proto, executor_type, name, ctx):
523 if not isinstance(Tout, (list, tuple)):
524 raise TypeError(
525 "Expected list for 'Tout' argument to "
526 "'partitioned_call' Op, not %r." % Tout)
527 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
528 if config is None:
529 config = ""
530 config = _execute.make_str(config, "config")
531 if config_proto is None:
532 config_proto = ""
533 config_proto = _execute.make_str(config_proto, "config_proto")
534 if executor_type is None:
535 executor_type = ""
536 executor_type = _execute.make_str(executor_type, "executor_type")
537 _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
538 _inputs_flat = list(args)
539 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
540 "config_proto", config_proto, "executor_type", executor_type)
541 _result = _execute.execute(b"PartitionedCall", len(Tout),
542 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
543 name=name)
544 if _execute.must_record_gradient():
545 _execute.record_gradient(
546 "PartitionedCall", _inputs_flat, _attrs, _result)
547 return _result
550def remote_call(target, args, Tout, f, name=None):
551 r"""Runs function `f` on a remote device indicated by `target`.
553 Args:
554 target: A `Tensor` of type `string`.
555 A fully specified device name where we want to run the function.
556 args: A list of `Tensor` objects. A list of arguments for the function.
557 Tout: A list of `tf.DTypes` that has length `>= 1`.
558 The type list for the return values.
559 f: A function decorated with @Defun. The function to run remotely.
560 name: A name for the operation (optional).
562 Returns:
563 A list of `Tensor` objects of type `Tout`.
564 """
565 _ctx = _context._context or _context.context()
566 tld = _ctx._thread_local_data
567 if tld.is_eager:
568 try:
569 _result = pywrap_tfe.TFE_Py_FastPathExecute(
570 _ctx, "RemoteCall", name, target, args, "Tout", Tout, "f", f)
571 return _result
572 except _core._NotOkStatusException as e:
573 _ops.raise_from_not_ok_status(e, name)
574 except _core._FallbackException:
575 pass
576 try:
577 return remote_call_eager_fallback(
578 target, args, Tout=Tout, f=f, name=name, ctx=_ctx)
579 except _core._SymbolicException:
580 pass # Add nodes to the TensorFlow graph.
581 # Add nodes to the TensorFlow graph.
582 if not isinstance(Tout, (list, tuple)):
583 raise TypeError(
584 "Expected list for 'Tout' argument to "
585 "'remote_call' Op, not %r." % Tout)
586 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
587 _, _, _op, _outputs = _op_def_library._apply_op_helper(
588 "RemoteCall", target=target, args=args, Tout=Tout, f=f, name=name)
589 _result = _outputs[:]
590 if not _result:
591 return _op
592 if _execute.must_record_gradient():
593 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
594 _op.get_attr("f"))
595 _inputs_flat = _op.inputs
596 _execute.record_gradient(
597 "RemoteCall", _inputs_flat, _attrs, _result)
598 return _result
600RemoteCall = tf_export("raw_ops.RemoteCall")(_ops.to_raw_op(remote_call))
603def remote_call_eager_fallback(target, args, Tout, f, name, ctx):
604 if not isinstance(Tout, (list, tuple)):
605 raise TypeError(
606 "Expected list for 'Tout' argument to "
607 "'remote_call' Op, not %r." % Tout)
608 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
609 _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
610 target = _ops.convert_to_tensor(target, _dtypes.string)
611 _inputs_flat = [target] + list(args)
612 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
613 _result = _execute.execute(b"RemoteCall", len(Tout), inputs=_inputs_flat,
614 attrs=_attrs, ctx=ctx, name=name)
615 if _execute.must_record_gradient():
616 _execute.record_gradient(
617 "RemoteCall", _inputs_flat, _attrs, _result)
618 return _result
621def stateful_partitioned_call(args, Tout, f, config="", config_proto="", executor_type="", name=None):
622 r"""returns `f(inputs)`, where `f`'s body is placed and partitioned.
624 Args:
625 args: A list of `Tensor` objects. A list of input tensors.
626 Tout: A list of `tf.DTypes`. A list of output types.
627 f: A function decorated with @Defun.
628 A function that takes 'args', a list of tensors, and returns 'output',
629 another list of tensors. Input and output types are specified by 'Tin'
630 and 'Tout'. The function body of f will be placed and partitioned across
631 devices, setting this op apart from the regular Call op. This op is
632 stateful.
633 config: An optional `string`. Defaults to `""`.
634 config_proto: An optional `string`. Defaults to `""`.
635 executor_type: An optional `string`. Defaults to `""`.
636 name: A name for the operation (optional).
638 Returns:
639 A list of `Tensor` objects of type `Tout`.
640 """
641 _ctx = _context._context or _context.context()
642 tld = _ctx._thread_local_data
643 if tld.is_eager:
644 try:
645 _result = pywrap_tfe.TFE_Py_FastPathExecute(
646 _ctx, "StatefulPartitionedCall", name, args, "Tout", Tout, "f", f,
647 "config", config, "config_proto", config_proto, "executor_type",
648 executor_type)
649 return _result
650 except _core._NotOkStatusException as e:
651 _ops.raise_from_not_ok_status(e, name)
652 except _core._FallbackException:
653 pass
654 try:
655 return stateful_partitioned_call_eager_fallback(
656 args, Tout=Tout, f=f, config=config, config_proto=config_proto,
657 executor_type=executor_type, name=name, ctx=_ctx)
658 except _core._SymbolicException:
659 pass # Add nodes to the TensorFlow graph.
660 # Add nodes to the TensorFlow graph.
661 if not isinstance(Tout, (list, tuple)):
662 raise TypeError(
663 "Expected list for 'Tout' argument to "
664 "'stateful_partitioned_call' Op, not %r." % Tout)
665 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
666 if config is None:
667 config = ""
668 config = _execute.make_str(config, "config")
669 if config_proto is None:
670 config_proto = ""
671 config_proto = _execute.make_str(config_proto, "config_proto")
672 if executor_type is None:
673 executor_type = ""
674 executor_type = _execute.make_str(executor_type, "executor_type")
675 _, _, _op, _outputs = _op_def_library._apply_op_helper(
676 "StatefulPartitionedCall", args=args, Tout=Tout, f=f, config=config,
677 config_proto=config_proto,
678 executor_type=executor_type, name=name)
679 _result = _outputs[:]
680 if not _result:
681 return _op
682 if _execute.must_record_gradient():
683 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
684 _op.get_attr("f"), "config", _op.get_attr("config"),
685 "config_proto", _op.get_attr("config_proto"), "executor_type",
686 _op.get_attr("executor_type"))
687 _inputs_flat = _op.inputs
688 _execute.record_gradient(
689 "StatefulPartitionedCall", _inputs_flat, _attrs, _result)
690 return _result
692StatefulPartitionedCall = tf_export("raw_ops.StatefulPartitionedCall")(_ops.to_raw_op(stateful_partitioned_call))
695def stateful_partitioned_call_eager_fallback(args, Tout, f, config, config_proto, executor_type, name, ctx):
696 if not isinstance(Tout, (list, tuple)):
697 raise TypeError(
698 "Expected list for 'Tout' argument to "
699 "'stateful_partitioned_call' Op, not %r." % Tout)
700 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
701 if config is None:
702 config = ""
703 config = _execute.make_str(config, "config")
704 if config_proto is None:
705 config_proto = ""
706 config_proto = _execute.make_str(config_proto, "config_proto")
707 if executor_type is None:
708 executor_type = ""
709 executor_type = _execute.make_str(executor_type, "executor_type")
710 _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
711 _inputs_flat = list(args)
712 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f, "config", config,
713 "config_proto", config_proto, "executor_type", executor_type)
714 _result = _execute.execute(b"StatefulPartitionedCall", len(Tout),
715 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
716 name=name)
717 if _execute.must_record_gradient():
718 _execute.record_gradient(
719 "StatefulPartitionedCall", _inputs_flat, _attrs, _result)
720 return _result
723def stateless_case(branch_index, input, Tout, branches, output_shapes=[], name=None):
724 r"""An n-way switch statement which calls a single branch function.
726 An n-way switch statement, implementing the following:
727 ```
728 switch (branch_index) {
729 case 0:
730 output = branches[0](input);
731 break;
732 case 1:
733 output = branches[1](input);
734 break;
735 ...
736 case [[nbranches-1]]:
737 default:
738 output = branches[nbranches-1](input);
739 break;
740 }
741 ```
743 This should only be used when the none of branches has stateful ops.
745 Args:
746 branch_index: A `Tensor` of type `int32`.
747 The branch selector, an int32 Tensor.
748 input: A list of `Tensor` objects.
749 A list of input tensors passed to the branch function.
750 Tout: A list of `tf.DTypes`. A list of output types.
751 branches: A list of functions decorated with @Defun that has length `>= 1`.
752 A list of functions each of which takes 'inputs' and returns a list of
753 tensors, whose types are the same as what every other branch returns.
754 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
755 name: A name for the operation (optional).
757 Returns:
758 A list of `Tensor` objects of type `Tout`.
759 """
760 _ctx = _context._context or _context.context()
761 tld = _ctx._thread_local_data
762 if tld.is_eager:
763 try:
764 _result = pywrap_tfe.TFE_Py_FastPathExecute(
765 _ctx, "StatelessCase", name, branch_index, input, "Tout", Tout,
766 "branches", branches, "output_shapes", output_shapes)
767 return _result
768 except _core._NotOkStatusException as e:
769 _ops.raise_from_not_ok_status(e, name)
770 except _core._FallbackException:
771 pass
772 try:
773 return stateless_case_eager_fallback(
774 branch_index, input, Tout=Tout, branches=branches,
775 output_shapes=output_shapes, name=name, ctx=_ctx)
776 except _core._SymbolicException:
777 pass # Add nodes to the TensorFlow graph.
778 # Add nodes to the TensorFlow graph.
779 if not isinstance(Tout, (list, tuple)):
780 raise TypeError(
781 "Expected list for 'Tout' argument to "
782 "'stateless_case' Op, not %r." % Tout)
783 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
784 if not isinstance(branches, (list, tuple)):
785 raise TypeError(
786 "Expected list for 'branches' argument to "
787 "'stateless_case' Op, not %r." % branches)
788 if output_shapes is None:
789 output_shapes = []
790 if not isinstance(output_shapes, (list, tuple)):
791 raise TypeError(
792 "Expected list for 'output_shapes' argument to "
793 "'stateless_case' Op, not %r." % output_shapes)
794 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
795 _, _, _op, _outputs = _op_def_library._apply_op_helper(
796 "StatelessCase", branch_index=branch_index, input=input, Tout=Tout,
797 branches=branches, output_shapes=output_shapes,
798 name=name)
799 _result = _outputs[:]
800 if _execute.must_record_gradient():
801 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
802 "branches", _op.get_attr("branches"), "output_shapes",
803 _op.get_attr("output_shapes"))
804 _inputs_flat = _op.inputs
805 _execute.record_gradient(
806 "StatelessCase", _inputs_flat, _attrs, _result)
807 return _result
809StatelessCase = tf_export("raw_ops.StatelessCase")(_ops.to_raw_op(stateless_case))
812def stateless_case_eager_fallback(branch_index, input, Tout, branches, output_shapes, name, ctx):
813 if not isinstance(Tout, (list, tuple)):
814 raise TypeError(
815 "Expected list for 'Tout' argument to "
816 "'stateless_case' Op, not %r." % Tout)
817 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
818 if not isinstance(branches, (list, tuple)):
819 raise TypeError(
820 "Expected list for 'branches' argument to "
821 "'stateless_case' Op, not %r." % branches)
822 if output_shapes is None:
823 output_shapes = []
824 if not isinstance(output_shapes, (list, tuple)):
825 raise TypeError(
826 "Expected list for 'output_shapes' argument to "
827 "'stateless_case' Op, not %r." % output_shapes)
828 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
829 _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
830 branch_index = _ops.convert_to_tensor(branch_index, _dtypes.int32)
831 _inputs_flat = [branch_index] + list(input)
832 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "branches", branches,
833 "output_shapes", output_shapes)
834 _result = _execute.execute(b"StatelessCase", len(Tout), inputs=_inputs_flat,
835 attrs=_attrs, ctx=ctx, name=name)
836 if _execute.must_record_gradient():
837 _execute.record_gradient(
838 "StatelessCase", _inputs_flat, _attrs, _result)
839 return _result
842def stateless_if(cond, input, Tout, then_branch, else_branch, output_shapes=[], name=None):
843 r"""output = cond ? then_branch(input) : else_branch(input)
845 Args:
846 cond: A `Tensor`.
847 A Tensor. If the tensor is a scalar of non-boolean type, the
848 scalar is converted to a boolean according to the
849 following rule: if the scalar is a numerical value, non-zero means
850 `True` and zero means False; if the scalar is a string, non-empty
851 means `True` and empty means `False`. If the tensor is not a scalar,
852 being empty means False and being non-empty means True.
854 This should only be used when the if then/else body functions do not
855 have stateful ops.
856 input: A list of `Tensor` objects. A list of input tensors.
857 Tout: A list of `tf.DTypes`. A list of output types.
858 then_branch: A function decorated with @Defun.
859 A function that takes 'inputs' and returns a list of tensors, whose
860 types are the same as what else_branch returns.
861 else_branch: A function decorated with @Defun.
862 A function that takes 'inputs' and returns a list of tensors, whose
863 types are the same as what then_branch returns.
864 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
865 name: A name for the operation (optional).
867 Returns:
868 A list of `Tensor` objects of type `Tout`.
869 """
870 _ctx = _context._context or _context.context()
871 tld = _ctx._thread_local_data
872 if tld.is_eager:
873 try:
874 _result = pywrap_tfe.TFE_Py_FastPathExecute(
875 _ctx, "StatelessIf", name, cond, input, "Tout", Tout, "then_branch",
876 then_branch, "else_branch", else_branch, "output_shapes",
877 output_shapes)
878 return _result
879 except _core._NotOkStatusException as e:
880 _ops.raise_from_not_ok_status(e, name)
881 except _core._FallbackException:
882 pass
883 try:
884 return stateless_if_eager_fallback(
885 cond, input, Tout=Tout, then_branch=then_branch,
886 else_branch=else_branch, output_shapes=output_shapes, name=name,
887 ctx=_ctx)
888 except _core._SymbolicException:
889 pass # Add nodes to the TensorFlow graph.
890 # Add nodes to the TensorFlow graph.
891 if not isinstance(Tout, (list, tuple)):
892 raise TypeError(
893 "Expected list for 'Tout' argument to "
894 "'stateless_if' Op, not %r." % Tout)
895 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
896 if output_shapes is None:
897 output_shapes = []
898 if not isinstance(output_shapes, (list, tuple)):
899 raise TypeError(
900 "Expected list for 'output_shapes' argument to "
901 "'stateless_if' Op, not %r." % output_shapes)
902 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
903 _, _, _op, _outputs = _op_def_library._apply_op_helper(
904 "StatelessIf", cond=cond, input=input, Tout=Tout,
905 then_branch=then_branch, else_branch=else_branch,
906 output_shapes=output_shapes, name=name)
907 _result = _outputs[:]
908 if _execute.must_record_gradient():
909 _attrs = ("Tcond", _op._get_attr_type("Tcond"), "Tin",
910 _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"),
911 "then_branch", _op.get_attr("then_branch"), "else_branch",
912 _op.get_attr("else_branch"), "output_shapes",
913 _op.get_attr("output_shapes"))
914 _inputs_flat = _op.inputs
915 _execute.record_gradient(
916 "StatelessIf", _inputs_flat, _attrs, _result)
917 return _result
919StatelessIf = tf_export("raw_ops.StatelessIf")(_ops.to_raw_op(stateless_if))
922def stateless_if_eager_fallback(cond, input, Tout, then_branch, else_branch, output_shapes, name, ctx):
923 if not isinstance(Tout, (list, tuple)):
924 raise TypeError(
925 "Expected list for 'Tout' argument to "
926 "'stateless_if' Op, not %r." % Tout)
927 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
928 if output_shapes is None:
929 output_shapes = []
930 if not isinstance(output_shapes, (list, tuple)):
931 raise TypeError(
932 "Expected list for 'output_shapes' argument to "
933 "'stateless_if' Op, not %r." % output_shapes)
934 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
935 _attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, [])
936 _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
937 _inputs_flat = [cond] + list(input)
938 _attrs = ("Tcond", _attr_Tcond, "Tin", _attr_Tin, "Tout", Tout,
939 "then_branch", then_branch, "else_branch", else_branch, "output_shapes",
940 output_shapes)
941 _result = _execute.execute(b"StatelessIf", len(Tout), inputs=_inputs_flat,
942 attrs=_attrs, ctx=ctx, name=name)
943 if _execute.must_record_gradient():
944 _execute.record_gradient(
945 "StatelessIf", _inputs_flat, _attrs, _result)
946 return _result
949def stateless_while(input, cond, body, output_shapes=[], parallel_iterations=10, name=None):
950 r"""output = input; While (Cond(output)) { output = Body(output) }
952 Args:
953 input: A list of `Tensor` objects.
954 A list of input tensors whose types are T.
955 cond: A function decorated with @Defun.
956 A function takes 'input' and returns a tensor. If the tensor is
957 a scalar of non-boolean, the scalar is converted to a boolean
958 according to the following rule: if the scalar is a numerical
959 value, non-zero means True and zero means False; if the scalar is
960 a string, non-empty means True and empty means False. If the
961 tensor is not a scalar, non-emptiness means True and False
962 otherwise.
964 This should only be used when the while condition and body functions
965 do not have stateful ops.
966 body: A function decorated with @Defun.
967 A function that takes a list of tensors and returns another
968 list of tensors. Both lists have the same types as specified
969 by T.
970 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
971 parallel_iterations: An optional `int`. Defaults to `10`.
972 name: A name for the operation (optional).
974 Returns:
975 A list of `Tensor` objects. Has the same type as `input`.
976 """
977 _ctx = _context._context or _context.context()
978 tld = _ctx._thread_local_data
979 if tld.is_eager:
980 try:
981 _result = pywrap_tfe.TFE_Py_FastPathExecute(
982 _ctx, "StatelessWhile", name, input, "cond", cond, "body", body,
983 "output_shapes", output_shapes, "parallel_iterations",
984 parallel_iterations)
985 return _result
986 except _core._NotOkStatusException as e:
987 _ops.raise_from_not_ok_status(e, name)
988 except _core._FallbackException:
989 pass
990 try:
991 return stateless_while_eager_fallback(
992 input, cond=cond, body=body, output_shapes=output_shapes,
993 parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
994 except _core._SymbolicException:
995 pass # Add nodes to the TensorFlow graph.
996 # Add nodes to the TensorFlow graph.
997 if output_shapes is None:
998 output_shapes = []
999 if not isinstance(output_shapes, (list, tuple)):
1000 raise TypeError(
1001 "Expected list for 'output_shapes' argument to "
1002 "'stateless_while' Op, not %r." % output_shapes)
1003 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
1004 if parallel_iterations is None:
1005 parallel_iterations = 10
1006 parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
1007 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1008 "StatelessWhile", input=input, cond=cond, body=body,
1009 output_shapes=output_shapes,
1010 parallel_iterations=parallel_iterations, name=name)
1011 _result = _outputs[:]
1012 if _execute.must_record_gradient():
1013 _attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
1014 _op.get_attr("body"), "output_shapes",
1015 _op.get_attr("output_shapes"), "parallel_iterations",
1016 _op._get_attr_int("parallel_iterations"))
1017 _inputs_flat = _op.inputs
1018 _execute.record_gradient(
1019 "StatelessWhile", _inputs_flat, _attrs, _result)
1020 return _result
1022StatelessWhile = tf_export("raw_ops.StatelessWhile")(_ops.to_raw_op(stateless_while))
1025def stateless_while_eager_fallback(input, cond, body, output_shapes, parallel_iterations, name, ctx):
1026 if output_shapes is None:
1027 output_shapes = []
1028 if not isinstance(output_shapes, (list, tuple)):
1029 raise TypeError(
1030 "Expected list for 'output_shapes' argument to "
1031 "'stateless_while' Op, not %r." % output_shapes)
1032 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
1033 if parallel_iterations is None:
1034 parallel_iterations = 10
1035 parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
1036 _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
1037 _inputs_flat = list(input)
1038 _attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
1039 output_shapes, "parallel_iterations", parallel_iterations)
1040 _result = _execute.execute(b"StatelessWhile", len(input),
1041 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
1042 name=name)
1043 if _execute.must_record_gradient():
1044 _execute.record_gradient(
1045 "StatelessWhile", _inputs_flat, _attrs, _result)
1046 return _result
1049def symbolic_gradient(input, Tout, f, name=None):
1050 r"""Computes the gradient function for function f via backpropagation.
1052 Args:
1053 input: A list of `Tensor` objects. a list of input tensors of size N + M;
1054 Tout: A list of `tf.DTypes` that has length `>= 1`.
1055 the type list for the input list.
1056 f: A function decorated with @Defun.
1057 The function we want to compute the gradient for.
1059 The function 'f' must be a numerical function which takes N inputs and
1060 produces M outputs. Its gradient function 'g', which is computed by
1061 this SymbolicGradient op is a function taking N + M inputs and
1062 produces N outputs.
1064 I.e. if we have
1065 (y1, y2, ..., y_M) = f(x1, x2, ..., x_N),
1066 then, g is
1067 (dL/dx1, dL/dx2, ..., dL/dx_N) = g(x1, x2, ..., x_N,
1068 dL/dy1, dL/dy2, ..., dL/dy_M),
1070 where L is a scalar-value function of (x1, x2, ..., xN) (e.g., the
1071 loss function). dL/dx_i is the partial derivative of L with respect
1072 to x_i.
1074 (Needs some math expert to say the comment above better.)
1075 name: A name for the operation (optional).
1077 Returns:
1078 A list of `Tensor` objects of type `Tout`.
1079 """
1080 _ctx = _context._context or _context.context()
1081 tld = _ctx._thread_local_data
1082 if tld.is_eager:
1083 try:
1084 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1085 _ctx, "SymbolicGradient", name, input, "Tout", Tout, "f", f)
1086 return _result
1087 except _core._NotOkStatusException as e:
1088 _ops.raise_from_not_ok_status(e, name)
1089 except _core._FallbackException:
1090 pass
1091 try:
1092 return symbolic_gradient_eager_fallback(
1093 input, Tout=Tout, f=f, name=name, ctx=_ctx)
1094 except _core._SymbolicException:
1095 pass # Add nodes to the TensorFlow graph.
1096 # Add nodes to the TensorFlow graph.
1097 if not isinstance(Tout, (list, tuple)):
1098 raise TypeError(
1099 "Expected list for 'Tout' argument to "
1100 "'symbolic_gradient' Op, not %r." % Tout)
1101 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
1102 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1103 "SymbolicGradient", input=input, Tout=Tout, f=f, name=name)
1104 _result = _outputs[:]
1105 if _execute.must_record_gradient():
1106 _attrs = ("Tin", _op.get_attr("Tin"), "Tout", _op.get_attr("Tout"), "f",
1107 _op.get_attr("f"))
1108 _inputs_flat = _op.inputs
1109 _execute.record_gradient(
1110 "SymbolicGradient", _inputs_flat, _attrs, _result)
1111 return _result
1113SymbolicGradient = tf_export("raw_ops.SymbolicGradient")(_ops.to_raw_op(symbolic_gradient))
1116def symbolic_gradient_eager_fallback(input, Tout, f, name, ctx):
1117 if not isinstance(Tout, (list, tuple)):
1118 raise TypeError(
1119 "Expected list for 'Tout' argument to "
1120 "'symbolic_gradient' Op, not %r." % Tout)
1121 Tout = [_execute.make_type(_t, "Tout") for _t in Tout]
1122 _attr_Tin, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
1123 _inputs_flat = list(input)
1124 _attrs = ("Tin", _attr_Tin, "Tout", Tout, "f", f)
1125 _result = _execute.execute(b"SymbolicGradient", len(Tout),
1126 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
1127 name=name)
1128 if _execute.must_record_gradient():
1129 _execute.record_gradient(
1130 "SymbolicGradient", _inputs_flat, _attrs, _result)
1131 return _result
1134def to_bool(input, name=None):
1135 r"""Converts a tensor to a scalar predicate.
1137 Converts a tensor to a scalar predicate with the following rules:
1139 - For 0D tensors, truthiness is determined by comparing against a "zero"
1140 value. For numerical types it is the obvious zero. For strings it is the
1141 empty string.
1143 - For >0D tensors, truthiness is determined by looking at the number of
1144 elements. If has zero elements, then the result is false. Otherwise the
1145 result is true.
1147 This matches the behavior of If and While for determining if a tensor counts
1148 as true/false for a branch condition.
1150 Args:
1151 input: A `Tensor`.
1152 name: A name for the operation (optional).
1154 Returns:
1155 A `Tensor` of type `bool`.
1156 """
1157 _ctx = _context._context or _context.context()
1158 tld = _ctx._thread_local_data
1159 if tld.is_eager:
1160 try:
1161 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1162 _ctx, "ToBool", name, input)
1163 return _result
1164 except _core._NotOkStatusException as e:
1165 _ops.raise_from_not_ok_status(e, name)
1166 except _core._FallbackException:
1167 pass
1168 try:
1169 return to_bool_eager_fallback(
1170 input, name=name, ctx=_ctx)
1171 except _core._SymbolicException:
1172 pass # Add nodes to the TensorFlow graph.
1173 # Add nodes to the TensorFlow graph.
1174 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1175 "ToBool", input=input, name=name)
1176 _result = _outputs[:]
1177 if _execute.must_record_gradient():
1178 _attrs = ("T", _op._get_attr_type("T"))
1179 _inputs_flat = _op.inputs
1180 _execute.record_gradient(
1181 "ToBool", _inputs_flat, _attrs, _result)
1182 _result, = _result
1183 return _result
1185ToBool = tf_export("raw_ops.ToBool")(_ops.to_raw_op(to_bool))
1188def to_bool_eager_fallback(input, name, ctx):
1189 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [])
1190 _inputs_flat = [input]
1191 _attrs = ("T", _attr_T)
1192 _result = _execute.execute(b"ToBool", 1, inputs=_inputs_flat, attrs=_attrs,
1193 ctx=ctx, name=name)
1194 if _execute.must_record_gradient():
1195 _execute.record_gradient(
1196 "ToBool", _inputs_flat, _attrs, _result)
1197 _result, = _result
1198 return _result
1201def _while(input, cond, body, output_shapes=[], parallel_iterations=10, name=None):
1202 r"""output = input; While (Cond(output)) { output = Body(output) }
1204 Args:
1205 input: A list of `Tensor` objects.
1206 A list of input tensors whose types are T.
1207 cond: A function decorated with @Defun.
1208 A function takes 'input' and returns a tensor. If the tensor is
1209 a scalar of non-boolean, the scalar is converted to a boolean
1210 according to the following rule: if the scalar is a numerical
1211 value, non-zero means True and zero means False; if the scalar is
1212 a string, non-empty means True and empty means False. If the
1213 tensor is not a scalar, non-emptiness means True and False
1214 otherwise.
1215 body: A function decorated with @Defun.
1216 A function that takes a list of tensors and returns another
1217 list of tensors. Both lists have the same types as specified
1218 by T.
1219 output_shapes: An optional list of shapes (each a `tf.TensorShape` or list of `ints`). Defaults to `[]`.
1220 parallel_iterations: An optional `int`. Defaults to `10`.
1221 name: A name for the operation (optional).
1223 Returns:
1224 A list of `Tensor` objects. Has the same type as `input`.
1225 """
1226 _ctx = _context._context or _context.context()
1227 tld = _ctx._thread_local_data
1228 if tld.is_eager:
1229 try:
1230 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1231 _ctx, "While", name, input, "cond", cond, "body", body,
1232 "output_shapes", output_shapes, "parallel_iterations",
1233 parallel_iterations)
1234 return _result
1235 except _core._NotOkStatusException as e:
1236 _ops.raise_from_not_ok_status(e, name)
1237 except _core._FallbackException:
1238 pass
1239 try:
1240 return _while_eager_fallback(
1241 input, cond=cond, body=body, output_shapes=output_shapes,
1242 parallel_iterations=parallel_iterations, name=name, ctx=_ctx)
1243 except _core._SymbolicException:
1244 pass # Add nodes to the TensorFlow graph.
1245 # Add nodes to the TensorFlow graph.
1246 if output_shapes is None:
1247 output_shapes = []
1248 if not isinstance(output_shapes, (list, tuple)):
1249 raise TypeError(
1250 "Expected list for 'output_shapes' argument to "
1251 "'while' Op, not %r." % output_shapes)
1252 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
1253 if parallel_iterations is None:
1254 parallel_iterations = 10
1255 parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
1256 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1257 "While", input=input, cond=cond, body=body,
1258 output_shapes=output_shapes,
1259 parallel_iterations=parallel_iterations, name=name)
1260 _result = _outputs[:]
1261 if not _result:
1262 return _op
1263 if _execute.must_record_gradient():
1264 _attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body",
1265 _op.get_attr("body"), "output_shapes",
1266 _op.get_attr("output_shapes"), "parallel_iterations",
1267 _op._get_attr_int("parallel_iterations"))
1268 _inputs_flat = _op.inputs
1269 _execute.record_gradient(
1270 "While", _inputs_flat, _attrs, _result)
1271 return _result
1273While = tf_export("raw_ops.While")(_ops.to_raw_op(_while))
1276def _while_eager_fallback(input, cond, body, output_shapes, parallel_iterations, name, ctx):
1277 if output_shapes is None:
1278 output_shapes = []
1279 if not isinstance(output_shapes, (list, tuple)):
1280 raise TypeError(
1281 "Expected list for 'output_shapes' argument to "
1282 "'while' Op, not %r." % output_shapes)
1283 output_shapes = [_execute.make_shape(_s, "output_shapes") for _s in output_shapes]
1284 if parallel_iterations is None:
1285 parallel_iterations = 10
1286 parallel_iterations = _execute.make_int(parallel_iterations, "parallel_iterations")
1287 _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx)
1288 _inputs_flat = list(input)
1289 _attrs = ("T", _attr_T, "cond", cond, "body", body, "output_shapes",
1290 output_shapes, "parallel_iterations", parallel_iterations)
1291 _result = _execute.execute(b"While", len(input), inputs=_inputs_flat,
1292 attrs=_attrs, ctx=ctx, name=name)
1293 if _execute.must_record_gradient():
1294 _execute.record_gradient(
1295 "While", _inputs_flat, _attrs, _result)
1296 return _result