Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gen_collective_ops.py: 8%
751 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Python wrappers around TensorFlow ops.
3This file is MACHINE GENERATED! Do not edit.
4"""
6import collections
8from tensorflow.python import pywrap_tfe as pywrap_tfe
9from tensorflow.python.eager import context as _context
10from tensorflow.python.eager import core as _core
11from tensorflow.python.eager import execute as _execute
12from tensorflow.python.framework import dtypes as _dtypes
13from tensorflow.security.fuzzing.py import annotation_types as _atypes
15from tensorflow.python.framework import op_def_registry as _op_def_registry
16from tensorflow.python.framework import ops as _ops
17from tensorflow.python.framework import op_def_library as _op_def_library
18from tensorflow.python.util.deprecation import deprecated_endpoints
19from tensorflow.python.util import dispatch as _dispatch
20from tensorflow.python.util.tf_export import tf_export
22from typing import TypeVar
24def collective_all_to_all_v2(input, group_size, group_key, instance_key, ordering_token, communication_hint="auto", timeout_seconds=0, name=None):
25 r"""Mutually exchanges multiple tensors of identical type and shape.
27 Args:
28 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
29 group_size: A `Tensor` of type `int32`.
30 group_key: A `Tensor` of type `int32`.
31 instance_key: A `Tensor` of type `int32`.
32 ordering_token: A list of `Tensor` objects with type `resource`.
33 communication_hint: An optional `string`. Defaults to `"auto"`.
34 timeout_seconds: An optional `float`. Defaults to `0`.
35 name: A name for the operation (optional).
37 Returns:
38 A `Tensor`. Has the same type as `input`.
39 """
40 _ctx = _context._context or _context.context()
41 tld = _ctx._thread_local_data
42 if tld.is_eager:
43 try:
44 _result = pywrap_tfe.TFE_Py_FastPathExecute(
45 _ctx, "CollectiveAllToAllV2", name, input, group_size, group_key,
46 instance_key, ordering_token, "communication_hint",
47 communication_hint, "timeout_seconds", timeout_seconds)
48 return _result
49 except _core._NotOkStatusException as e:
50 _ops.raise_from_not_ok_status(e, name)
51 except _core._FallbackException:
52 pass
53 try:
54 return collective_all_to_all_v2_eager_fallback(
55 input, group_size, group_key, instance_key, ordering_token,
56 communication_hint=communication_hint,
57 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
58 except _core._SymbolicException:
59 pass # Add nodes to the TensorFlow graph.
60 # Add nodes to the TensorFlow graph.
61 if not isinstance(ordering_token, (list, tuple)):
62 raise TypeError(
63 "Expected list for 'ordering_token' argument to "
64 "'collective_all_to_all_v2' Op, not %r." % ordering_token)
65 _attr_Nordering_token = len(ordering_token)
66 if communication_hint is None:
67 communication_hint = "auto"
68 communication_hint = _execute.make_str(communication_hint, "communication_hint")
69 if timeout_seconds is None:
70 timeout_seconds = 0
71 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
72 _, _, _op, _outputs = _op_def_library._apply_op_helper(
73 "CollectiveAllToAllV2", input=input, group_size=group_size,
74 group_key=group_key,
75 instance_key=instance_key,
76 ordering_token=ordering_token,
77 communication_hint=communication_hint,
78 timeout_seconds=timeout_seconds, name=name)
79 _result = _outputs[:]
80 if _execute.must_record_gradient():
81 _attrs = ("T", _op._get_attr_type("T"), "communication_hint",
82 _op.get_attr("communication_hint"), "timeout_seconds",
83 _op.get_attr("timeout_seconds"), "Nordering_token",
84 _op._get_attr_int("Nordering_token"))
85 _inputs_flat = _op.inputs
86 _execute.record_gradient(
87 "CollectiveAllToAllV2", _inputs_flat, _attrs, _result)
88 _result, = _result
89 return _result
91CollectiveAllToAllV2 = tf_export("raw_ops.CollectiveAllToAllV2")(_ops.to_raw_op(collective_all_to_all_v2))
94def collective_all_to_all_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, communication_hint, timeout_seconds, name, ctx):
95 if not isinstance(ordering_token, (list, tuple)):
96 raise TypeError(
97 "Expected list for 'ordering_token' argument to "
98 "'collective_all_to_all_v2' Op, not %r." % ordering_token)
99 _attr_Nordering_token = len(ordering_token)
100 if communication_hint is None:
101 communication_hint = "auto"
102 communication_hint = _execute.make_str(communication_hint, "communication_hint")
103 if timeout_seconds is None:
104 timeout_seconds = 0
105 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
106 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
107 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
108 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
109 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
110 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
111 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
112 _attrs = ("T", _attr_T, "communication_hint", communication_hint,
113 "timeout_seconds", timeout_seconds, "Nordering_token",
114 _attr_Nordering_token)
115 _result = _execute.execute(b"CollectiveAllToAllV2", 1, inputs=_inputs_flat,
116 attrs=_attrs, ctx=ctx, name=name)
117 if _execute.must_record_gradient():
118 _execute.record_gradient(
119 "CollectiveAllToAllV2", _inputs_flat, _attrs, _result)
120 _result, = _result
121 return _result
124def collective_all_to_all_v3(input, communicator, group_assignment, timeout_seconds=0, name=None):
125 r"""Mutually exchanges multiple tensors of identical type and shape.
127 Args:
128 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
129 communicator: A `Tensor` of type `resource`.
130 group_assignment: A `Tensor` of type `int32`.
131 timeout_seconds: An optional `float`. Defaults to `0`.
132 name: A name for the operation (optional).
134 Returns:
135 A `Tensor`. Has the same type as `input`.
136 """
137 _ctx = _context._context or _context.context()
138 tld = _ctx._thread_local_data
139 if tld.is_eager:
140 try:
141 _result = pywrap_tfe.TFE_Py_FastPathExecute(
142 _ctx, "CollectiveAllToAllV3", name, input, communicator,
143 group_assignment, "timeout_seconds", timeout_seconds)
144 return _result
145 except _core._NotOkStatusException as e:
146 _ops.raise_from_not_ok_status(e, name)
147 except _core._FallbackException:
148 pass
149 try:
150 return collective_all_to_all_v3_eager_fallback(
151 input, communicator, group_assignment,
152 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
153 except _core._SymbolicException:
154 pass # Add nodes to the TensorFlow graph.
155 # Add nodes to the TensorFlow graph.
156 if timeout_seconds is None:
157 timeout_seconds = 0
158 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
159 _, _, _op, _outputs = _op_def_library._apply_op_helper(
160 "CollectiveAllToAllV3", input=input, communicator=communicator,
161 group_assignment=group_assignment,
162 timeout_seconds=timeout_seconds, name=name)
163 _result = _outputs[:]
164 if _execute.must_record_gradient():
165 _attrs = ("T", _op._get_attr_type("T"), "timeout_seconds",
166 _op.get_attr("timeout_seconds"))
167 _inputs_flat = _op.inputs
168 _execute.record_gradient(
169 "CollectiveAllToAllV3", _inputs_flat, _attrs, _result)
170 _result, = _result
171 return _result
173CollectiveAllToAllV3 = tf_export("raw_ops.CollectiveAllToAllV3")(_ops.to_raw_op(collective_all_to_all_v3))
176def collective_all_to_all_v3_eager_fallback(input, communicator, group_assignment, timeout_seconds, name, ctx):
177 if timeout_seconds is None:
178 timeout_seconds = 0
179 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
180 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
181 communicator = _ops.convert_to_tensor(communicator, _dtypes.resource)
182 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
183 _inputs_flat = [input, communicator, group_assignment]
184 _attrs = ("T", _attr_T, "timeout_seconds", timeout_seconds)
185 _result = _execute.execute(b"CollectiveAllToAllV3", 1, inputs=_inputs_flat,
186 attrs=_attrs, ctx=ctx, name=name)
187 if _execute.must_record_gradient():
188 _execute.record_gradient(
189 "CollectiveAllToAllV3", _inputs_flat, _attrs, _result)
190 _result, = _result
191 return _result
193_CollectiveAssignGroupV2Output = collections.namedtuple(
194 "CollectiveAssignGroupV2",
195 ["group_size", "group_key"])
198def collective_assign_group_v2(group_assignment, device_index, base_key, name=None):
199 r"""Assign group keys based on group assignment.
201 Args:
202 group_assignment: A `Tensor` of type `int32`.
203 device_index: A `Tensor` of type `int32`.
204 base_key: A `Tensor` of type `int32`.
205 name: A name for the operation (optional).
207 Returns:
208 A tuple of `Tensor` objects (group_size, group_key).
210 group_size: A `Tensor` of type `int32`.
211 group_key: A `Tensor` of type `int32`.
212 """
213 _ctx = _context._context or _context.context()
214 tld = _ctx._thread_local_data
215 if tld.is_eager:
216 try:
217 _result = pywrap_tfe.TFE_Py_FastPathExecute(
218 _ctx, "CollectiveAssignGroupV2", name, group_assignment, device_index,
219 base_key)
220 _result = _CollectiveAssignGroupV2Output._make(_result)
221 return _result
222 except _core._NotOkStatusException as e:
223 _ops.raise_from_not_ok_status(e, name)
224 except _core._FallbackException:
225 pass
226 try:
227 return collective_assign_group_v2_eager_fallback(
228 group_assignment, device_index, base_key, name=name, ctx=_ctx)
229 except _core._SymbolicException:
230 pass # Add nodes to the TensorFlow graph.
231 # Add nodes to the TensorFlow graph.
232 _, _, _op, _outputs = _op_def_library._apply_op_helper(
233 "CollectiveAssignGroupV2", group_assignment=group_assignment,
234 device_index=device_index,
235 base_key=base_key, name=name)
236 _result = _outputs[:]
237 if _execute.must_record_gradient():
238 _attrs = ()
239 _inputs_flat = _op.inputs
240 _execute.record_gradient(
241 "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result)
242 _result = _CollectiveAssignGroupV2Output._make(_result)
243 return _result
245CollectiveAssignGroupV2 = tf_export("raw_ops.CollectiveAssignGroupV2")(_ops.to_raw_op(collective_assign_group_v2))
248def collective_assign_group_v2_eager_fallback(group_assignment, device_index, base_key, name, ctx):
249 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
250 device_index = _ops.convert_to_tensor(device_index, _dtypes.int32)
251 base_key = _ops.convert_to_tensor(base_key, _dtypes.int32)
252 _inputs_flat = [group_assignment, device_index, base_key]
253 _attrs = None
254 _result = _execute.execute(b"CollectiveAssignGroupV2", 2,
255 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
256 name=name)
257 if _execute.must_record_gradient():
258 _execute.record_gradient(
259 "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result)
260 _result = _CollectiveAssignGroupV2Output._make(_result)
261 return _result
264def collective_bcast_recv(T, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
265 r"""Receives a tensor value broadcast from another device.
267 Args:
268 T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`.
269 group_size: An `int`.
270 group_key: An `int`.
271 instance_key: An `int`.
272 shape: A `tf.TensorShape` or list of `ints`.
273 communication_hint: An optional `string`. Defaults to `"auto"`.
274 timeout_seconds: An optional `float`. Defaults to `0`.
275 name: A name for the operation (optional).
277 Returns:
278 A `Tensor` of type `T`.
279 """
280 _ctx = _context._context or _context.context()
281 tld = _ctx._thread_local_data
282 if tld.is_eager:
283 try:
284 _result = pywrap_tfe.TFE_Py_FastPathExecute(
285 _ctx, "CollectiveBcastRecv", name, "T", T, "group_size", group_size,
286 "group_key", group_key, "instance_key", instance_key, "shape", shape,
287 "communication_hint", communication_hint, "timeout_seconds",
288 timeout_seconds)
289 return _result
290 except _core._NotOkStatusException as e:
291 _ops.raise_from_not_ok_status(e, name)
292 except _core._FallbackException:
293 pass
294 try:
295 return collective_bcast_recv_eager_fallback(
296 T=T, group_size=group_size, group_key=group_key,
297 instance_key=instance_key, shape=shape,
298 communication_hint=communication_hint,
299 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
300 except _core._SymbolicException:
301 pass # Add nodes to the TensorFlow graph.
302 # Add nodes to the TensorFlow graph.
303 T = _execute.make_type(T, "T")
304 group_size = _execute.make_int(group_size, "group_size")
305 group_key = _execute.make_int(group_key, "group_key")
306 instance_key = _execute.make_int(instance_key, "instance_key")
307 shape = _execute.make_shape(shape, "shape")
308 if communication_hint is None:
309 communication_hint = "auto"
310 communication_hint = _execute.make_str(communication_hint, "communication_hint")
311 if timeout_seconds is None:
312 timeout_seconds = 0
313 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
314 _, _, _op, _outputs = _op_def_library._apply_op_helper(
315 "CollectiveBcastRecv", T=T, group_size=group_size,
316 group_key=group_key, instance_key=instance_key,
317 shape=shape,
318 communication_hint=communication_hint,
319 timeout_seconds=timeout_seconds, name=name)
320 _result = _outputs[:]
321 if _execute.must_record_gradient():
322 _attrs = ("T", _op._get_attr_type("T"), "group_size",
323 _op._get_attr_int("group_size"), "group_key",
324 _op._get_attr_int("group_key"), "instance_key",
325 _op._get_attr_int("instance_key"), "shape",
326 _op.get_attr("shape"), "communication_hint",
327 _op.get_attr("communication_hint"), "timeout_seconds",
328 _op.get_attr("timeout_seconds"))
329 _inputs_flat = _op.inputs
330 _execute.record_gradient(
331 "CollectiveBcastRecv", _inputs_flat, _attrs, _result)
332 _result, = _result
333 return _result
335CollectiveBcastRecv = tf_export("raw_ops.CollectiveBcastRecv")(_ops.to_raw_op(collective_bcast_recv))
338def collective_bcast_recv_eager_fallback(T, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
339 T = _execute.make_type(T, "T")
340 group_size = _execute.make_int(group_size, "group_size")
341 group_key = _execute.make_int(group_key, "group_key")
342 instance_key = _execute.make_int(instance_key, "instance_key")
343 shape = _execute.make_shape(shape, "shape")
344 if communication_hint is None:
345 communication_hint = "auto"
346 communication_hint = _execute.make_str(communication_hint, "communication_hint")
347 if timeout_seconds is None:
348 timeout_seconds = 0
349 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
350 _inputs_flat = []
351 _attrs = ("T", T, "group_size", group_size, "group_key", group_key,
352 "instance_key", instance_key, "shape", shape, "communication_hint",
353 communication_hint, "timeout_seconds", timeout_seconds)
354 _result = _execute.execute(b"CollectiveBcastRecv", 1, inputs=_inputs_flat,
355 attrs=_attrs, ctx=ctx, name=name)
356 if _execute.must_record_gradient():
357 _execute.record_gradient(
358 "CollectiveBcastRecv", _inputs_flat, _attrs, _result)
359 _result, = _result
360 return _result
363def collective_bcast_recv_v2(group_size, group_key, instance_key, shape, T, communication_hint="auto", timeout_seconds=0, name=None):
364 r"""Receives a tensor value broadcast from another device.
366 Args:
367 group_size: A `Tensor` of type `int32`.
368 group_key: A `Tensor` of type `int32`.
369 instance_key: A `Tensor` of type `int32`.
370 shape: A `Tensor`. Must be one of the following types: `int32`, `int64`.
371 T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`.
372 communication_hint: An optional `string`. Defaults to `"auto"`.
373 timeout_seconds: An optional `float`. Defaults to `0`.
374 name: A name for the operation (optional).
376 Returns:
377 A `Tensor` of type `T`.
378 """
379 _ctx = _context._context or _context.context()
380 tld = _ctx._thread_local_data
381 if tld.is_eager:
382 try:
383 _result = pywrap_tfe.TFE_Py_FastPathExecute(
384 _ctx, "CollectiveBcastRecvV2", name, group_size, group_key,
385 instance_key, shape, "T", T, "communication_hint", communication_hint,
386 "timeout_seconds", timeout_seconds)
387 return _result
388 except _core._NotOkStatusException as e:
389 _ops.raise_from_not_ok_status(e, name)
390 except _core._FallbackException:
391 pass
392 try:
393 return collective_bcast_recv_v2_eager_fallback(
394 group_size, group_key, instance_key, shape, T=T,
395 communication_hint=communication_hint,
396 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
397 except _core._SymbolicException:
398 pass # Add nodes to the TensorFlow graph.
399 # Add nodes to the TensorFlow graph.
400 T = _execute.make_type(T, "T")
401 if communication_hint is None:
402 communication_hint = "auto"
403 communication_hint = _execute.make_str(communication_hint, "communication_hint")
404 if timeout_seconds is None:
405 timeout_seconds = 0
406 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
407 _, _, _op, _outputs = _op_def_library._apply_op_helper(
408 "CollectiveBcastRecvV2", group_size=group_size, group_key=group_key,
409 instance_key=instance_key, shape=shape, T=T,
410 communication_hint=communication_hint,
411 timeout_seconds=timeout_seconds, name=name)
412 _result = _outputs[:]
413 if _execute.must_record_gradient():
414 _attrs = ("T", _op._get_attr_type("T"), "Tshape",
415 _op._get_attr_type("Tshape"), "communication_hint",
416 _op.get_attr("communication_hint"), "timeout_seconds",
417 _op.get_attr("timeout_seconds"))
418 _inputs_flat = _op.inputs
419 _execute.record_gradient(
420 "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result)
421 _result, = _result
422 return _result
424CollectiveBcastRecvV2 = tf_export("raw_ops.CollectiveBcastRecvV2")(_ops.to_raw_op(collective_bcast_recv_v2))
427def collective_bcast_recv_v2_eager_fallback(group_size, group_key, instance_key, shape, T, communication_hint, timeout_seconds, name, ctx):
428 T = _execute.make_type(T, "T")
429 if communication_hint is None:
430 communication_hint = "auto"
431 communication_hint = _execute.make_str(communication_hint, "communication_hint")
432 if timeout_seconds is None:
433 timeout_seconds = 0
434 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
435 _attr_Tshape, (shape,) = _execute.args_to_matching_eager([shape], ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int32)
436 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
437 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
438 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
439 _inputs_flat = [group_size, group_key, instance_key, shape]
440 _attrs = ("T", T, "Tshape", _attr_Tshape, "communication_hint",
441 communication_hint, "timeout_seconds", timeout_seconds)
442 _result = _execute.execute(b"CollectiveBcastRecvV2", 1, inputs=_inputs_flat,
443 attrs=_attrs, ctx=ctx, name=name)
444 if _execute.must_record_gradient():
445 _execute.record_gradient(
446 "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result)
447 _result, = _result
448 return _result
451def collective_bcast_send(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
452 r"""Broadcasts a tensor value to one or more other devices.
454 Args:
455 input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`.
456 group_size: An `int`.
457 group_key: An `int`.
458 instance_key: An `int`.
459 shape: A `tf.TensorShape` or list of `ints`.
460 communication_hint: An optional `string`. Defaults to `"auto"`.
461 timeout_seconds: An optional `float`. Defaults to `0`.
462 name: A name for the operation (optional).
464 Returns:
465 A `Tensor`. Has the same type as `input`.
466 """
467 _ctx = _context._context or _context.context()
468 tld = _ctx._thread_local_data
469 if tld.is_eager:
470 try:
471 _result = pywrap_tfe.TFE_Py_FastPathExecute(
472 _ctx, "CollectiveBcastSend", name, input, "group_size", group_size,
473 "group_key", group_key, "instance_key", instance_key, "shape", shape,
474 "communication_hint", communication_hint, "timeout_seconds",
475 timeout_seconds)
476 return _result
477 except _core._NotOkStatusException as e:
478 _ops.raise_from_not_ok_status(e, name)
479 except _core._FallbackException:
480 pass
481 try:
482 return collective_bcast_send_eager_fallback(
483 input, group_size=group_size, group_key=group_key,
484 instance_key=instance_key, shape=shape,
485 communication_hint=communication_hint,
486 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
487 except _core._SymbolicException:
488 pass # Add nodes to the TensorFlow graph.
489 # Add nodes to the TensorFlow graph.
490 group_size = _execute.make_int(group_size, "group_size")
491 group_key = _execute.make_int(group_key, "group_key")
492 instance_key = _execute.make_int(instance_key, "instance_key")
493 shape = _execute.make_shape(shape, "shape")
494 if communication_hint is None:
495 communication_hint = "auto"
496 communication_hint = _execute.make_str(communication_hint, "communication_hint")
497 if timeout_seconds is None:
498 timeout_seconds = 0
499 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
500 _, _, _op, _outputs = _op_def_library._apply_op_helper(
501 "CollectiveBcastSend", input=input, group_size=group_size,
502 group_key=group_key, instance_key=instance_key,
503 shape=shape,
504 communication_hint=communication_hint,
505 timeout_seconds=timeout_seconds, name=name)
506 _result = _outputs[:]
507 if _execute.must_record_gradient():
508 _attrs = ("T", _op._get_attr_type("T"), "group_size",
509 _op._get_attr_int("group_size"), "group_key",
510 _op._get_attr_int("group_key"), "instance_key",
511 _op._get_attr_int("instance_key"), "shape",
512 _op.get_attr("shape"), "communication_hint",
513 _op.get_attr("communication_hint"), "timeout_seconds",
514 _op.get_attr("timeout_seconds"))
515 _inputs_flat = _op.inputs
516 _execute.record_gradient(
517 "CollectiveBcastSend", _inputs_flat, _attrs, _result)
518 _result, = _result
519 return _result
521CollectiveBcastSend = tf_export("raw_ops.CollectiveBcastSend")(_ops.to_raw_op(collective_bcast_send))
524def collective_bcast_send_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
525 group_size = _execute.make_int(group_size, "group_size")
526 group_key = _execute.make_int(group_key, "group_key")
527 instance_key = _execute.make_int(instance_key, "instance_key")
528 shape = _execute.make_shape(shape, "shape")
529 if communication_hint is None:
530 communication_hint = "auto"
531 communication_hint = _execute.make_str(communication_hint, "communication_hint")
532 if timeout_seconds is None:
533 timeout_seconds = 0
534 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
535 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
536 _inputs_flat = [input]
537 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
538 "instance_key", instance_key, "shape", shape, "communication_hint",
539 communication_hint, "timeout_seconds", timeout_seconds)
540 _result = _execute.execute(b"CollectiveBcastSend", 1, inputs=_inputs_flat,
541 attrs=_attrs, ctx=ctx, name=name)
542 if _execute.must_record_gradient():
543 _execute.record_gradient(
544 "CollectiveBcastSend", _inputs_flat, _attrs, _result)
545 _result, = _result
546 return _result
549def collective_bcast_send_v2(input, group_size, group_key, instance_key, communication_hint="auto", timeout_seconds=0, name=None):
550 r"""Broadcasts a tensor value to one or more other devices.
552 Args:
553 input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`.
554 group_size: A `Tensor` of type `int32`.
555 group_key: A `Tensor` of type `int32`.
556 instance_key: A `Tensor` of type `int32`.
557 communication_hint: An optional `string`. Defaults to `"auto"`.
558 timeout_seconds: An optional `float`. Defaults to `0`.
559 name: A name for the operation (optional).
561 Returns:
562 A `Tensor`. Has the same type as `input`.
563 """
564 _ctx = _context._context or _context.context()
565 tld = _ctx._thread_local_data
566 if tld.is_eager:
567 try:
568 _result = pywrap_tfe.TFE_Py_FastPathExecute(
569 _ctx, "CollectiveBcastSendV2", name, input, group_size, group_key,
570 instance_key, "communication_hint", communication_hint,
571 "timeout_seconds", timeout_seconds)
572 return _result
573 except _core._NotOkStatusException as e:
574 _ops.raise_from_not_ok_status(e, name)
575 except _core._FallbackException:
576 pass
577 try:
578 return collective_bcast_send_v2_eager_fallback(
579 input, group_size, group_key, instance_key,
580 communication_hint=communication_hint,
581 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
582 except _core._SymbolicException:
583 pass # Add nodes to the TensorFlow graph.
584 # Add nodes to the TensorFlow graph.
585 if communication_hint is None:
586 communication_hint = "auto"
587 communication_hint = _execute.make_str(communication_hint, "communication_hint")
588 if timeout_seconds is None:
589 timeout_seconds = 0
590 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
591 _, _, _op, _outputs = _op_def_library._apply_op_helper(
592 "CollectiveBcastSendV2", input=input, group_size=group_size,
593 group_key=group_key,
594 instance_key=instance_key,
595 communication_hint=communication_hint,
596 timeout_seconds=timeout_seconds, name=name)
597 _result = _outputs[:]
598 if _execute.must_record_gradient():
599 _attrs = ("T", _op._get_attr_type("T"), "communication_hint",
600 _op.get_attr("communication_hint"), "timeout_seconds",
601 _op.get_attr("timeout_seconds"))
602 _inputs_flat = _op.inputs
603 _execute.record_gradient(
604 "CollectiveBcastSendV2", _inputs_flat, _attrs, _result)
605 _result, = _result
606 return _result
608CollectiveBcastSendV2 = tf_export("raw_ops.CollectiveBcastSendV2")(_ops.to_raw_op(collective_bcast_send_v2))
611def collective_bcast_send_v2_eager_fallback(input, group_size, group_key, instance_key, communication_hint, timeout_seconds, name, ctx):
612 if communication_hint is None:
613 communication_hint = "auto"
614 communication_hint = _execute.make_str(communication_hint, "communication_hint")
615 if timeout_seconds is None:
616 timeout_seconds = 0
617 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
618 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
619 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
620 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
621 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
622 _inputs_flat = [input, group_size, group_key, instance_key]
623 _attrs = ("T", _attr_T, "communication_hint", communication_hint,
624 "timeout_seconds", timeout_seconds)
625 _result = _execute.execute(b"CollectiveBcastSendV2", 1, inputs=_inputs_flat,
626 attrs=_attrs, ctx=ctx, name=name)
627 if _execute.must_record_gradient():
628 _execute.record_gradient(
629 "CollectiveBcastSendV2", _inputs_flat, _attrs, _result)
630 _result, = _result
631 return _result
634def collective_gather(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None):
635 r"""Mutually accumulates multiple tensors of identical type and shape.
637 Args:
638 input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`.
639 group_size: An `int`.
640 group_key: An `int`.
641 instance_key: An `int`.
642 shape: A `tf.TensorShape` or list of `ints`.
643 communication_hint: An optional `string`. Defaults to `"auto"`.
644 timeout_seconds: An optional `float`. Defaults to `0`.
645 name: A name for the operation (optional).
647 Returns:
648 A `Tensor`. Has the same type as `input`.
649 """
650 _ctx = _context._context or _context.context()
651 tld = _ctx._thread_local_data
652 if tld.is_eager:
653 try:
654 _result = pywrap_tfe.TFE_Py_FastPathExecute(
655 _ctx, "CollectiveGather", name, input, "group_size", group_size,
656 "group_key", group_key, "instance_key", instance_key, "shape", shape,
657 "communication_hint", communication_hint, "timeout_seconds",
658 timeout_seconds)
659 return _result
660 except _core._NotOkStatusException as e:
661 _ops.raise_from_not_ok_status(e, name)
662 except _core._FallbackException:
663 pass
664 try:
665 return collective_gather_eager_fallback(
666 input, group_size=group_size, group_key=group_key,
667 instance_key=instance_key, shape=shape,
668 communication_hint=communication_hint,
669 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
670 except _core._SymbolicException:
671 pass # Add nodes to the TensorFlow graph.
672 # Add nodes to the TensorFlow graph.
673 group_size = _execute.make_int(group_size, "group_size")
674 group_key = _execute.make_int(group_key, "group_key")
675 instance_key = _execute.make_int(instance_key, "instance_key")
676 shape = _execute.make_shape(shape, "shape")
677 if communication_hint is None:
678 communication_hint = "auto"
679 communication_hint = _execute.make_str(communication_hint, "communication_hint")
680 if timeout_seconds is None:
681 timeout_seconds = 0
682 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
683 _, _, _op, _outputs = _op_def_library._apply_op_helper(
684 "CollectiveGather", input=input, group_size=group_size,
685 group_key=group_key, instance_key=instance_key,
686 shape=shape,
687 communication_hint=communication_hint,
688 timeout_seconds=timeout_seconds, name=name)
689 _result = _outputs[:]
690 if _execute.must_record_gradient():
691 _attrs = ("T", _op._get_attr_type("T"), "group_size",
692 _op._get_attr_int("group_size"), "group_key",
693 _op._get_attr_int("group_key"), "instance_key",
694 _op._get_attr_int("instance_key"), "shape",
695 _op.get_attr("shape"), "communication_hint",
696 _op.get_attr("communication_hint"), "timeout_seconds",
697 _op.get_attr("timeout_seconds"))
698 _inputs_flat = _op.inputs
699 _execute.record_gradient(
700 "CollectiveGather", _inputs_flat, _attrs, _result)
701 _result, = _result
702 return _result
704CollectiveGather = tf_export("raw_ops.CollectiveGather")(_ops.to_raw_op(collective_gather))
707def collective_gather_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx):
708 group_size = _execute.make_int(group_size, "group_size")
709 group_key = _execute.make_int(group_key, "group_key")
710 instance_key = _execute.make_int(instance_key, "instance_key")
711 shape = _execute.make_shape(shape, "shape")
712 if communication_hint is None:
713 communication_hint = "auto"
714 communication_hint = _execute.make_str(communication_hint, "communication_hint")
715 if timeout_seconds is None:
716 timeout_seconds = 0
717 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
718 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
719 _inputs_flat = [input]
720 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
721 "instance_key", instance_key, "shape", shape, "communication_hint",
722 communication_hint, "timeout_seconds", timeout_seconds)
723 _result = _execute.execute(b"CollectiveGather", 1, inputs=_inputs_flat,
724 attrs=_attrs, ctx=ctx, name=name)
725 if _execute.must_record_gradient():
726 _execute.record_gradient(
727 "CollectiveGather", _inputs_flat, _attrs, _result)
728 _result, = _result
729 return _result
732def collective_gather_v2(input, group_size, group_key, instance_key, ordering_token, communication_hint="auto", timeout_seconds=0, name=None):
733 r"""Mutually accumulates multiple tensors of identical type and shape.
735 Args:
736 input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`.
737 group_size: A `Tensor` of type `int32`.
738 group_key: A `Tensor` of type `int32`.
739 instance_key: A `Tensor` of type `int32`.
740 ordering_token: A list of `Tensor` objects with type `resource`.
741 communication_hint: An optional `string`. Defaults to `"auto"`.
742 timeout_seconds: An optional `float`. Defaults to `0`.
743 name: A name for the operation (optional).
745 Returns:
746 A `Tensor`. Has the same type as `input`.
747 """
748 _ctx = _context._context or _context.context()
749 tld = _ctx._thread_local_data
750 if tld.is_eager:
751 try:
752 _result = pywrap_tfe.TFE_Py_FastPathExecute(
753 _ctx, "CollectiveGatherV2", name, input, group_size, group_key,
754 instance_key, ordering_token, "communication_hint",
755 communication_hint, "timeout_seconds", timeout_seconds)
756 return _result
757 except _core._NotOkStatusException as e:
758 _ops.raise_from_not_ok_status(e, name)
759 except _core._FallbackException:
760 pass
761 try:
762 return collective_gather_v2_eager_fallback(
763 input, group_size, group_key, instance_key, ordering_token,
764 communication_hint=communication_hint,
765 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
766 except _core._SymbolicException:
767 pass # Add nodes to the TensorFlow graph.
768 # Add nodes to the TensorFlow graph.
769 if not isinstance(ordering_token, (list, tuple)):
770 raise TypeError(
771 "Expected list for 'ordering_token' argument to "
772 "'collective_gather_v2' Op, not %r." % ordering_token)
773 _attr_Nordering_token = len(ordering_token)
774 if communication_hint is None:
775 communication_hint = "auto"
776 communication_hint = _execute.make_str(communication_hint, "communication_hint")
777 if timeout_seconds is None:
778 timeout_seconds = 0
779 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
780 _, _, _op, _outputs = _op_def_library._apply_op_helper(
781 "CollectiveGatherV2", input=input, group_size=group_size,
782 group_key=group_key, instance_key=instance_key,
783 ordering_token=ordering_token,
784 communication_hint=communication_hint,
785 timeout_seconds=timeout_seconds, name=name)
786 _result = _outputs[:]
787 if _execute.must_record_gradient():
788 _attrs = ("T", _op._get_attr_type("T"), "communication_hint",
789 _op.get_attr("communication_hint"), "timeout_seconds",
790 _op.get_attr("timeout_seconds"), "Nordering_token",
791 _op._get_attr_int("Nordering_token"))
792 _inputs_flat = _op.inputs
793 _execute.record_gradient(
794 "CollectiveGatherV2", _inputs_flat, _attrs, _result)
795 _result, = _result
796 return _result
798CollectiveGatherV2 = tf_export("raw_ops.CollectiveGatherV2")(_ops.to_raw_op(collective_gather_v2))
801def collective_gather_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, communication_hint, timeout_seconds, name, ctx):
802 if not isinstance(ordering_token, (list, tuple)):
803 raise TypeError(
804 "Expected list for 'ordering_token' argument to "
805 "'collective_gather_v2' Op, not %r." % ordering_token)
806 _attr_Nordering_token = len(ordering_token)
807 if communication_hint is None:
808 communication_hint = "auto"
809 communication_hint = _execute.make_str(communication_hint, "communication_hint")
810 if timeout_seconds is None:
811 timeout_seconds = 0
812 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
813 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
814 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
815 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
816 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
817 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
818 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
819 _attrs = ("T", _attr_T, "communication_hint", communication_hint,
820 "timeout_seconds", timeout_seconds, "Nordering_token",
821 _attr_Nordering_token)
822 _result = _execute.execute(b"CollectiveGatherV2", 1, inputs=_inputs_flat,
823 attrs=_attrs, ctx=ctx, name=name)
824 if _execute.must_record_gradient():
825 _execute.record_gradient(
826 "CollectiveGatherV2", _inputs_flat, _attrs, _result)
827 _result, = _result
828 return _result
831def collective_initialize_communicator(group_key, rank, group_size, communication_hint="auto", timeout_seconds=0, name=None):
832 r"""Initializes a group for collective operations.
834 Args:
835 group_key: A `Tensor` of type `int32`.
836 rank: A `Tensor` of type `int32`.
837 group_size: A `Tensor` of type `int32`.
838 communication_hint: An optional `string`. Defaults to `"auto"`.
839 timeout_seconds: An optional `float`. Defaults to `0`.
840 name: A name for the operation (optional).
842 Returns:
843 A `Tensor` of type `resource`.
844 """
845 _ctx = _context._context or _context.context()
846 tld = _ctx._thread_local_data
847 if tld.is_eager:
848 try:
849 _result = pywrap_tfe.TFE_Py_FastPathExecute(
850 _ctx, "CollectiveInitializeCommunicator", name, group_key, rank,
851 group_size, "communication_hint", communication_hint,
852 "timeout_seconds", timeout_seconds)
853 return _result
854 except _core._NotOkStatusException as e:
855 _ops.raise_from_not_ok_status(e, name)
856 except _core._FallbackException:
857 pass
858 try:
859 return collective_initialize_communicator_eager_fallback(
860 group_key, rank, group_size, communication_hint=communication_hint,
861 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
862 except _core._SymbolicException:
863 pass # Add nodes to the TensorFlow graph.
864 # Add nodes to the TensorFlow graph.
865 if communication_hint is None:
866 communication_hint = "auto"
867 communication_hint = _execute.make_str(communication_hint, "communication_hint")
868 if timeout_seconds is None:
869 timeout_seconds = 0
870 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
871 _, _, _op, _outputs = _op_def_library._apply_op_helper(
872 "CollectiveInitializeCommunicator", group_key=group_key, rank=rank,
873 group_size=group_size,
874 communication_hint=communication_hint,
875 timeout_seconds=timeout_seconds,
876 name=name)
877 _result = _outputs[:]
878 if _execute.must_record_gradient():
879 _attrs = ("communication_hint", _op.get_attr("communication_hint"),
880 "timeout_seconds", _op.get_attr("timeout_seconds"))
881 _inputs_flat = _op.inputs
882 _execute.record_gradient(
883 "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result)
884 _result, = _result
885 return _result
887CollectiveInitializeCommunicator = tf_export("raw_ops.CollectiveInitializeCommunicator")(_ops.to_raw_op(collective_initialize_communicator))
890def collective_initialize_communicator_eager_fallback(group_key, rank, group_size, communication_hint, timeout_seconds, name, ctx):
891 if communication_hint is None:
892 communication_hint = "auto"
893 communication_hint = _execute.make_str(communication_hint, "communication_hint")
894 if timeout_seconds is None:
895 timeout_seconds = 0
896 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
897 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
898 rank = _ops.convert_to_tensor(rank, _dtypes.int32)
899 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
900 _inputs_flat = [group_key, rank, group_size]
901 _attrs = ("communication_hint", communication_hint, "timeout_seconds",
902 timeout_seconds)
903 _result = _execute.execute(b"CollectiveInitializeCommunicator", 1,
904 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
905 name=name)
906 if _execute.must_record_gradient():
907 _execute.record_gradient(
908 "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result)
909 _result, = _result
910 return _result
913def collective_reduce(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for=[], communication_hint="auto", timeout_seconds=0, name=None):
914 r"""Mutually reduces multiple tensors of identical type and shape.
916 Args:
917 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
918 group_size: An `int`.
919 group_key: An `int`.
920 instance_key: An `int`.
921 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`.
922 final_op: A `string` from: `"Id", "Div"`.
923 subdiv_offsets: A list of `ints`.
924 wait_for: An optional list of `ints`. Defaults to `[]`.
925 communication_hint: An optional `string`. Defaults to `"auto"`.
926 timeout_seconds: An optional `float`. Defaults to `0`.
927 name: A name for the operation (optional).
929 Returns:
930 A `Tensor`. Has the same type as `input`.
931 """
932 _ctx = _context._context or _context.context()
933 tld = _ctx._thread_local_data
934 if tld.is_eager:
935 try:
936 _result = pywrap_tfe.TFE_Py_FastPathExecute(
937 _ctx, "CollectiveReduce", name, input, "group_size", group_size,
938 "group_key", group_key, "instance_key", instance_key, "merge_op",
939 merge_op, "final_op", final_op, "subdiv_offsets", subdiv_offsets,
940 "wait_for", wait_for, "communication_hint", communication_hint,
941 "timeout_seconds", timeout_seconds)
942 return _result
943 except _core._NotOkStatusException as e:
944 _ops.raise_from_not_ok_status(e, name)
945 except _core._FallbackException:
946 pass
947 try:
948 return collective_reduce_eager_fallback(
949 input, group_size=group_size, group_key=group_key,
950 instance_key=instance_key, merge_op=merge_op, final_op=final_op,
951 subdiv_offsets=subdiv_offsets, wait_for=wait_for,
952 communication_hint=communication_hint,
953 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
954 except _core._SymbolicException:
955 pass # Add nodes to the TensorFlow graph.
956 # Add nodes to the TensorFlow graph.
957 group_size = _execute.make_int(group_size, "group_size")
958 group_key = _execute.make_int(group_key, "group_key")
959 instance_key = _execute.make_int(instance_key, "instance_key")
960 merge_op = _execute.make_str(merge_op, "merge_op")
961 final_op = _execute.make_str(final_op, "final_op")
962 if not isinstance(subdiv_offsets, (list, tuple)):
963 raise TypeError(
964 "Expected list for 'subdiv_offsets' argument to "
965 "'collective_reduce' Op, not %r." % subdiv_offsets)
966 subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets]
967 if wait_for is None:
968 wait_for = []
969 if not isinstance(wait_for, (list, tuple)):
970 raise TypeError(
971 "Expected list for 'wait_for' argument to "
972 "'collective_reduce' Op, not %r." % wait_for)
973 wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for]
974 if communication_hint is None:
975 communication_hint = "auto"
976 communication_hint = _execute.make_str(communication_hint, "communication_hint")
977 if timeout_seconds is None:
978 timeout_seconds = 0
979 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
980 _, _, _op, _outputs = _op_def_library._apply_op_helper(
981 "CollectiveReduce", input=input, group_size=group_size,
982 group_key=group_key, instance_key=instance_key,
983 merge_op=merge_op, final_op=final_op,
984 subdiv_offsets=subdiv_offsets, wait_for=wait_for,
985 communication_hint=communication_hint,
986 timeout_seconds=timeout_seconds, name=name)
987 _result = _outputs[:]
988 if _execute.must_record_gradient():
989 _attrs = ("T", _op._get_attr_type("T"), "group_size",
990 _op._get_attr_int("group_size"), "group_key",
991 _op._get_attr_int("group_key"), "instance_key",
992 _op._get_attr_int("instance_key"), "merge_op",
993 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"),
994 "subdiv_offsets", _op.get_attr("subdiv_offsets"), "wait_for",
995 _op.get_attr("wait_for"), "communication_hint",
996 _op.get_attr("communication_hint"), "timeout_seconds",
997 _op.get_attr("timeout_seconds"))
998 _inputs_flat = _op.inputs
999 _execute.record_gradient(
1000 "CollectiveReduce", _inputs_flat, _attrs, _result)
1001 _result, = _result
1002 return _result
1004CollectiveReduce = tf_export("raw_ops.CollectiveReduce")(_ops.to_raw_op(collective_reduce))
1007def collective_reduce_eager_fallback(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for, communication_hint, timeout_seconds, name, ctx):
1008 group_size = _execute.make_int(group_size, "group_size")
1009 group_key = _execute.make_int(group_key, "group_key")
1010 instance_key = _execute.make_int(instance_key, "instance_key")
1011 merge_op = _execute.make_str(merge_op, "merge_op")
1012 final_op = _execute.make_str(final_op, "final_op")
1013 if not isinstance(subdiv_offsets, (list, tuple)):
1014 raise TypeError(
1015 "Expected list for 'subdiv_offsets' argument to "
1016 "'collective_reduce' Op, not %r." % subdiv_offsets)
1017 subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets]
1018 if wait_for is None:
1019 wait_for = []
1020 if not isinstance(wait_for, (list, tuple)):
1021 raise TypeError(
1022 "Expected list for 'wait_for' argument to "
1023 "'collective_reduce' Op, not %r." % wait_for)
1024 wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for]
1025 if communication_hint is None:
1026 communication_hint = "auto"
1027 communication_hint = _execute.make_str(communication_hint, "communication_hint")
1028 if timeout_seconds is None:
1029 timeout_seconds = 0
1030 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1031 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
1032 _inputs_flat = [input]
1033 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key,
1034 "instance_key", instance_key, "merge_op", merge_op, "final_op", final_op,
1035 "subdiv_offsets", subdiv_offsets, "wait_for", wait_for,
1036 "communication_hint", communication_hint, "timeout_seconds",
1037 timeout_seconds)
1038 _result = _execute.execute(b"CollectiveReduce", 1, inputs=_inputs_flat,
1039 attrs=_attrs, ctx=ctx, name=name)
1040 if _execute.must_record_gradient():
1041 _execute.record_gradient(
1042 "CollectiveReduce", _inputs_flat, _attrs, _result)
1043 _result, = _result
1044 return _result
1047def collective_reduce_scatter_v2(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint="auto", timeout_seconds=0, max_subdivs_per_device=-1, name=None):
1048 r"""Mutually reduces multiple tensors of identical type and shape and scatters the result.
1050 Args:
1051 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
1052 group_size: A `Tensor` of type `int32`.
1053 group_key: A `Tensor` of type `int32`.
1054 instance_key: A `Tensor` of type `int32`.
1055 ordering_token: A list of `Tensor` objects with type `resource`.
1056 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`.
1057 final_op: A `string` from: `"Id", "Div"`.
1058 communication_hint: An optional `string`. Defaults to `"auto"`.
1059 timeout_seconds: An optional `float`. Defaults to `0`.
1060 max_subdivs_per_device: An optional `int`. Defaults to `-1`.
1061 name: A name for the operation (optional).
1063 Returns:
1064 A `Tensor`. Has the same type as `input`.
1065 """
1066 _ctx = _context._context or _context.context()
1067 tld = _ctx._thread_local_data
1068 if tld.is_eager:
1069 try:
1070 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1071 _ctx, "CollectiveReduceScatterV2", name, input, group_size, group_key,
1072 instance_key, ordering_token, "merge_op", merge_op, "final_op",
1073 final_op, "communication_hint", communication_hint, "timeout_seconds",
1074 timeout_seconds, "max_subdivs_per_device", max_subdivs_per_device)
1075 return _result
1076 except _core._NotOkStatusException as e:
1077 _ops.raise_from_not_ok_status(e, name)
1078 except _core._FallbackException:
1079 pass
1080 try:
1081 return collective_reduce_scatter_v2_eager_fallback(
1082 input, group_size, group_key, instance_key, ordering_token,
1083 merge_op=merge_op, final_op=final_op,
1084 communication_hint=communication_hint,
1085 timeout_seconds=timeout_seconds,
1086 max_subdivs_per_device=max_subdivs_per_device, name=name, ctx=_ctx)
1087 except _core._SymbolicException:
1088 pass # Add nodes to the TensorFlow graph.
1089 # Add nodes to the TensorFlow graph.
1090 if not isinstance(ordering_token, (list, tuple)):
1091 raise TypeError(
1092 "Expected list for 'ordering_token' argument to "
1093 "'collective_reduce_scatter_v2' Op, not %r." % ordering_token)
1094 _attr_Nordering_token = len(ordering_token)
1095 merge_op = _execute.make_str(merge_op, "merge_op")
1096 final_op = _execute.make_str(final_op, "final_op")
1097 if communication_hint is None:
1098 communication_hint = "auto"
1099 communication_hint = _execute.make_str(communication_hint, "communication_hint")
1100 if timeout_seconds is None:
1101 timeout_seconds = 0
1102 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1103 if max_subdivs_per_device is None:
1104 max_subdivs_per_device = -1
1105 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
1106 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1107 "CollectiveReduceScatterV2", input=input, group_size=group_size,
1108 group_key=group_key,
1109 instance_key=instance_key,
1110 ordering_token=ordering_token,
1111 merge_op=merge_op, final_op=final_op,
1112 communication_hint=communication_hint,
1113 timeout_seconds=timeout_seconds,
1114 max_subdivs_per_device=max_subdivs_per_device,
1115 name=name)
1116 _result = _outputs[:]
1117 if _execute.must_record_gradient():
1118 _attrs = ("T", _op._get_attr_type("T"), "merge_op",
1119 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"),
1120 "communication_hint", _op.get_attr("communication_hint"),
1121 "timeout_seconds", _op.get_attr("timeout_seconds"),
1122 "Nordering_token", _op._get_attr_int("Nordering_token"),
1123 "max_subdivs_per_device",
1124 _op._get_attr_int("max_subdivs_per_device"))
1125 _inputs_flat = _op.inputs
1126 _execute.record_gradient(
1127 "CollectiveReduceScatterV2", _inputs_flat, _attrs, _result)
1128 _result, = _result
1129 return _result
1131CollectiveReduceScatterV2 = tf_export("raw_ops.CollectiveReduceScatterV2")(_ops.to_raw_op(collective_reduce_scatter_v2))
1134def collective_reduce_scatter_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint, timeout_seconds, max_subdivs_per_device, name, ctx):
1135 if not isinstance(ordering_token, (list, tuple)):
1136 raise TypeError(
1137 "Expected list for 'ordering_token' argument to "
1138 "'collective_reduce_scatter_v2' Op, not %r." % ordering_token)
1139 _attr_Nordering_token = len(ordering_token)
1140 merge_op = _execute.make_str(merge_op, "merge_op")
1141 final_op = _execute.make_str(final_op, "final_op")
1142 if communication_hint is None:
1143 communication_hint = "auto"
1144 communication_hint = _execute.make_str(communication_hint, "communication_hint")
1145 if timeout_seconds is None:
1146 timeout_seconds = 0
1147 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1148 if max_subdivs_per_device is None:
1149 max_subdivs_per_device = -1
1150 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
1151 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
1152 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
1153 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
1154 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
1155 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
1156 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
1157 _attrs = ("T", _attr_T, "merge_op", merge_op, "final_op", final_op,
1158 "communication_hint", communication_hint, "timeout_seconds",
1159 timeout_seconds, "Nordering_token", _attr_Nordering_token,
1160 "max_subdivs_per_device", max_subdivs_per_device)
1161 _result = _execute.execute(b"CollectiveReduceScatterV2", 1,
1162 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
1163 name=name)
1164 if _execute.must_record_gradient():
1165 _execute.record_gradient(
1166 "CollectiveReduceScatterV2", _inputs_flat, _attrs, _result)
1167 _result, = _result
1168 return _result
1171def collective_reduce_v2(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint="auto", timeout_seconds=0, max_subdivs_per_device=-1, name=None):
1172 r"""Mutually reduces multiple tensors of identical type and shape.
1174 Args:
1175 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
1176 group_size: A `Tensor` of type `int32`.
1177 group_key: A `Tensor` of type `int32`.
1178 instance_key: A `Tensor` of type `int32`.
1179 ordering_token: A list of `Tensor` objects with type `resource`.
1180 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`.
1181 final_op: A `string` from: `"Id", "Div"`.
1182 communication_hint: An optional `string`. Defaults to `"auto"`.
1183 timeout_seconds: An optional `float`. Defaults to `0`.
1184 max_subdivs_per_device: An optional `int`. Defaults to `-1`.
1185 name: A name for the operation (optional).
1187 Returns:
1188 A `Tensor`. Has the same type as `input`.
1189 """
1190 _ctx = _context._context or _context.context()
1191 tld = _ctx._thread_local_data
1192 if tld.is_eager:
1193 try:
1194 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1195 _ctx, "CollectiveReduceV2", name, input, group_size, group_key,
1196 instance_key, ordering_token, "merge_op", merge_op, "final_op",
1197 final_op, "communication_hint", communication_hint, "timeout_seconds",
1198 timeout_seconds, "max_subdivs_per_device", max_subdivs_per_device)
1199 return _result
1200 except _core._NotOkStatusException as e:
1201 _ops.raise_from_not_ok_status(e, name)
1202 except _core._FallbackException:
1203 pass
1204 try:
1205 return collective_reduce_v2_eager_fallback(
1206 input, group_size, group_key, instance_key, ordering_token,
1207 merge_op=merge_op, final_op=final_op,
1208 communication_hint=communication_hint,
1209 timeout_seconds=timeout_seconds,
1210 max_subdivs_per_device=max_subdivs_per_device, name=name, ctx=_ctx)
1211 except _core._SymbolicException:
1212 pass # Add nodes to the TensorFlow graph.
1213 # Add nodes to the TensorFlow graph.
1214 if not isinstance(ordering_token, (list, tuple)):
1215 raise TypeError(
1216 "Expected list for 'ordering_token' argument to "
1217 "'collective_reduce_v2' Op, not %r." % ordering_token)
1218 _attr_Nordering_token = len(ordering_token)
1219 merge_op = _execute.make_str(merge_op, "merge_op")
1220 final_op = _execute.make_str(final_op, "final_op")
1221 if communication_hint is None:
1222 communication_hint = "auto"
1223 communication_hint = _execute.make_str(communication_hint, "communication_hint")
1224 if timeout_seconds is None:
1225 timeout_seconds = 0
1226 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1227 if max_subdivs_per_device is None:
1228 max_subdivs_per_device = -1
1229 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
1230 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1231 "CollectiveReduceV2", input=input, group_size=group_size,
1232 group_key=group_key, instance_key=instance_key,
1233 ordering_token=ordering_token,
1234 merge_op=merge_op, final_op=final_op,
1235 communication_hint=communication_hint,
1236 timeout_seconds=timeout_seconds,
1237 max_subdivs_per_device=max_subdivs_per_device,
1238 name=name)
1239 _result = _outputs[:]
1240 if _execute.must_record_gradient():
1241 _attrs = ("T", _op._get_attr_type("T"), "merge_op",
1242 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"),
1243 "communication_hint", _op.get_attr("communication_hint"),
1244 "timeout_seconds", _op.get_attr("timeout_seconds"),
1245 "Nordering_token", _op._get_attr_int("Nordering_token"),
1246 "max_subdivs_per_device",
1247 _op._get_attr_int("max_subdivs_per_device"))
1248 _inputs_flat = _op.inputs
1249 _execute.record_gradient(
1250 "CollectiveReduceV2", _inputs_flat, _attrs, _result)
1251 _result, = _result
1252 return _result
1254CollectiveReduceV2 = tf_export("raw_ops.CollectiveReduceV2")(_ops.to_raw_op(collective_reduce_v2))
1257def collective_reduce_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint, timeout_seconds, max_subdivs_per_device, name, ctx):
1258 if not isinstance(ordering_token, (list, tuple)):
1259 raise TypeError(
1260 "Expected list for 'ordering_token' argument to "
1261 "'collective_reduce_v2' Op, not %r." % ordering_token)
1262 _attr_Nordering_token = len(ordering_token)
1263 merge_op = _execute.make_str(merge_op, "merge_op")
1264 final_op = _execute.make_str(final_op, "final_op")
1265 if communication_hint is None:
1266 communication_hint = "auto"
1267 communication_hint = _execute.make_str(communication_hint, "communication_hint")
1268 if timeout_seconds is None:
1269 timeout_seconds = 0
1270 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1271 if max_subdivs_per_device is None:
1272 max_subdivs_per_device = -1
1273 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device")
1274 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
1275 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32)
1276 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32)
1277 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32)
1278 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource)
1279 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token)
1280 _attrs = ("T", _attr_T, "merge_op", merge_op, "final_op", final_op,
1281 "communication_hint", communication_hint, "timeout_seconds",
1282 timeout_seconds, "Nordering_token", _attr_Nordering_token,
1283 "max_subdivs_per_device", max_subdivs_per_device)
1284 _result = _execute.execute(b"CollectiveReduceV2", 1, inputs=_inputs_flat,
1285 attrs=_attrs, ctx=ctx, name=name)
1286 if _execute.must_record_gradient():
1287 _execute.record_gradient(
1288 "CollectiveReduceV2", _inputs_flat, _attrs, _result)
1289 _result, = _result
1290 return _result
1293def collective_reduce_v3(input, communicator, group_assignment, reduction, timeout_seconds=0, name=None):
1294 r"""Mutually reduces multiple tensors of identical type and shape.
1296 Args:
1297 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`.
1298 communicator: A `Tensor` of type `resource`.
1299 group_assignment: A `Tensor` of type `int32`.
1300 reduction: A `string` from: `"Min", "Max", "Mul", "Add"`.
1301 timeout_seconds: An optional `float`. Defaults to `0`.
1302 name: A name for the operation (optional).
1304 Returns:
1305 A `Tensor`. Has the same type as `input`.
1306 """
1307 _ctx = _context._context or _context.context()
1308 tld = _ctx._thread_local_data
1309 if tld.is_eager:
1310 try:
1311 _result = pywrap_tfe.TFE_Py_FastPathExecute(
1312 _ctx, "CollectiveReduceV3", name, input, communicator,
1313 group_assignment, "reduction", reduction, "timeout_seconds",
1314 timeout_seconds)
1315 return _result
1316 except _core._NotOkStatusException as e:
1317 _ops.raise_from_not_ok_status(e, name)
1318 except _core._FallbackException:
1319 pass
1320 try:
1321 return collective_reduce_v3_eager_fallback(
1322 input, communicator, group_assignment, reduction=reduction,
1323 timeout_seconds=timeout_seconds, name=name, ctx=_ctx)
1324 except _core._SymbolicException:
1325 pass # Add nodes to the TensorFlow graph.
1326 # Add nodes to the TensorFlow graph.
1327 reduction = _execute.make_str(reduction, "reduction")
1328 if timeout_seconds is None:
1329 timeout_seconds = 0
1330 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1331 _, _, _op, _outputs = _op_def_library._apply_op_helper(
1332 "CollectiveReduceV3", input=input, communicator=communicator,
1333 group_assignment=group_assignment,
1334 reduction=reduction,
1335 timeout_seconds=timeout_seconds, name=name)
1336 _result = _outputs[:]
1337 if _execute.must_record_gradient():
1338 _attrs = ("T", _op._get_attr_type("T"), "reduction",
1339 _op.get_attr("reduction"), "timeout_seconds",
1340 _op.get_attr("timeout_seconds"))
1341 _inputs_flat = _op.inputs
1342 _execute.record_gradient(
1343 "CollectiveReduceV3", _inputs_flat, _attrs, _result)
1344 _result, = _result
1345 return _result
1347CollectiveReduceV3 = tf_export("raw_ops.CollectiveReduceV3")(_ops.to_raw_op(collective_reduce_v3))
1350def collective_reduce_v3_eager_fallback(input, communicator, group_assignment, reduction, timeout_seconds, name, ctx):
1351 reduction = _execute.make_str(reduction, "reduction")
1352 if timeout_seconds is None:
1353 timeout_seconds = 0
1354 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds")
1355 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ])
1356 communicator = _ops.convert_to_tensor(communicator, _dtypes.resource)
1357 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32)
1358 _inputs_flat = [input, communicator, group_assignment]
1359 _attrs = ("T", _attr_T, "reduction", reduction, "timeout_seconds",
1360 timeout_seconds)
1361 _result = _execute.execute(b"CollectiveReduceV3", 1, inputs=_inputs_flat,
1362 attrs=_attrs, ctx=ctx, name=name)
1363 if _execute.must_record_gradient():
1364 _execute.record_gradient(
1365 "CollectiveReduceV3", _inputs_flat, _attrs, _result)
1366 _result, = _result
1367 return _result