Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gen_collective_ops.py: 8%

751 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Python wrappers around TensorFlow ops. 

2 

3This file is MACHINE GENERATED! Do not edit. 

4""" 

5 

6import collections 

7 

8from tensorflow.python import pywrap_tfe as pywrap_tfe 

9from tensorflow.python.eager import context as _context 

10from tensorflow.python.eager import core as _core 

11from tensorflow.python.eager import execute as _execute 

12from tensorflow.python.framework import dtypes as _dtypes 

13from tensorflow.security.fuzzing.py import annotation_types as _atypes 

14 

15from tensorflow.python.framework import op_def_registry as _op_def_registry 

16from tensorflow.python.framework import ops as _ops 

17from tensorflow.python.framework import op_def_library as _op_def_library 

18from tensorflow.python.util.deprecation import deprecated_endpoints 

19from tensorflow.python.util import dispatch as _dispatch 

20from tensorflow.python.util.tf_export import tf_export 

21 

22from typing import TypeVar 

23 

24def collective_all_to_all_v2(input, group_size, group_key, instance_key, ordering_token, communication_hint="auto", timeout_seconds=0, name=None): 

25 r"""Mutually exchanges multiple tensors of identical type and shape. 

26 

27 Args: 

28 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

29 group_size: A `Tensor` of type `int32`. 

30 group_key: A `Tensor` of type `int32`. 

31 instance_key: A `Tensor` of type `int32`. 

32 ordering_token: A list of `Tensor` objects with type `resource`. 

33 communication_hint: An optional `string`. Defaults to `"auto"`. 

34 timeout_seconds: An optional `float`. Defaults to `0`. 

35 name: A name for the operation (optional). 

36 

37 Returns: 

38 A `Tensor`. Has the same type as `input`. 

39 """ 

40 _ctx = _context._context or _context.context() 

41 tld = _ctx._thread_local_data 

42 if tld.is_eager: 

43 try: 

44 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

45 _ctx, "CollectiveAllToAllV2", name, input, group_size, group_key, 

46 instance_key, ordering_token, "communication_hint", 

47 communication_hint, "timeout_seconds", timeout_seconds) 

48 return _result 

49 except _core._NotOkStatusException as e: 

50 _ops.raise_from_not_ok_status(e, name) 

51 except _core._FallbackException: 

52 pass 

53 try: 

54 return collective_all_to_all_v2_eager_fallback( 

55 input, group_size, group_key, instance_key, ordering_token, 

56 communication_hint=communication_hint, 

57 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

58 except _core._SymbolicException: 

59 pass # Add nodes to the TensorFlow graph. 

60 # Add nodes to the TensorFlow graph. 

61 if not isinstance(ordering_token, (list, tuple)): 

62 raise TypeError( 

63 "Expected list for 'ordering_token' argument to " 

64 "'collective_all_to_all_v2' Op, not %r." % ordering_token) 

65 _attr_Nordering_token = len(ordering_token) 

66 if communication_hint is None: 

67 communication_hint = "auto" 

68 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

69 if timeout_seconds is None: 

70 timeout_seconds = 0 

71 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

72 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

73 "CollectiveAllToAllV2", input=input, group_size=group_size, 

74 group_key=group_key, 

75 instance_key=instance_key, 

76 ordering_token=ordering_token, 

77 communication_hint=communication_hint, 

78 timeout_seconds=timeout_seconds, name=name) 

79 _result = _outputs[:] 

80 if _execute.must_record_gradient(): 

81 _attrs = ("T", _op._get_attr_type("T"), "communication_hint", 

82 _op.get_attr("communication_hint"), "timeout_seconds", 

83 _op.get_attr("timeout_seconds"), "Nordering_token", 

84 _op._get_attr_int("Nordering_token")) 

85 _inputs_flat = _op.inputs 

86 _execute.record_gradient( 

87 "CollectiveAllToAllV2", _inputs_flat, _attrs, _result) 

88 _result, = _result 

89 return _result 

90 

91CollectiveAllToAllV2 = tf_export("raw_ops.CollectiveAllToAllV2")(_ops.to_raw_op(collective_all_to_all_v2)) 

92 

93 

94def collective_all_to_all_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, communication_hint, timeout_seconds, name, ctx): 

95 if not isinstance(ordering_token, (list, tuple)): 

96 raise TypeError( 

97 "Expected list for 'ordering_token' argument to " 

98 "'collective_all_to_all_v2' Op, not %r." % ordering_token) 

99 _attr_Nordering_token = len(ordering_token) 

100 if communication_hint is None: 

101 communication_hint = "auto" 

102 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

103 if timeout_seconds is None: 

104 timeout_seconds = 0 

105 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

106 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

107 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

108 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

109 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

110 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource) 

111 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token) 

112 _attrs = ("T", _attr_T, "communication_hint", communication_hint, 

113 "timeout_seconds", timeout_seconds, "Nordering_token", 

114 _attr_Nordering_token) 

115 _result = _execute.execute(b"CollectiveAllToAllV2", 1, inputs=_inputs_flat, 

116 attrs=_attrs, ctx=ctx, name=name) 

117 if _execute.must_record_gradient(): 

118 _execute.record_gradient( 

119 "CollectiveAllToAllV2", _inputs_flat, _attrs, _result) 

120 _result, = _result 

121 return _result 

122 

123 

124def collective_all_to_all_v3(input, communicator, group_assignment, timeout_seconds=0, name=None): 

125 r"""Mutually exchanges multiple tensors of identical type and shape. 

126 

127 Args: 

128 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

129 communicator: A `Tensor` of type `resource`. 

130 group_assignment: A `Tensor` of type `int32`. 

131 timeout_seconds: An optional `float`. Defaults to `0`. 

132 name: A name for the operation (optional). 

133 

134 Returns: 

135 A `Tensor`. Has the same type as `input`. 

136 """ 

137 _ctx = _context._context or _context.context() 

138 tld = _ctx._thread_local_data 

139 if tld.is_eager: 

140 try: 

141 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

142 _ctx, "CollectiveAllToAllV3", name, input, communicator, 

143 group_assignment, "timeout_seconds", timeout_seconds) 

144 return _result 

145 except _core._NotOkStatusException as e: 

146 _ops.raise_from_not_ok_status(e, name) 

147 except _core._FallbackException: 

148 pass 

149 try: 

150 return collective_all_to_all_v3_eager_fallback( 

151 input, communicator, group_assignment, 

152 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

153 except _core._SymbolicException: 

154 pass # Add nodes to the TensorFlow graph. 

155 # Add nodes to the TensorFlow graph. 

156 if timeout_seconds is None: 

157 timeout_seconds = 0 

158 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

159 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

160 "CollectiveAllToAllV3", input=input, communicator=communicator, 

161 group_assignment=group_assignment, 

162 timeout_seconds=timeout_seconds, name=name) 

163 _result = _outputs[:] 

164 if _execute.must_record_gradient(): 

165 _attrs = ("T", _op._get_attr_type("T"), "timeout_seconds", 

166 _op.get_attr("timeout_seconds")) 

167 _inputs_flat = _op.inputs 

168 _execute.record_gradient( 

169 "CollectiveAllToAllV3", _inputs_flat, _attrs, _result) 

170 _result, = _result 

171 return _result 

172 

173CollectiveAllToAllV3 = tf_export("raw_ops.CollectiveAllToAllV3")(_ops.to_raw_op(collective_all_to_all_v3)) 

174 

175 

176def collective_all_to_all_v3_eager_fallback(input, communicator, group_assignment, timeout_seconds, name, ctx): 

177 if timeout_seconds is None: 

178 timeout_seconds = 0 

179 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

180 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

181 communicator = _ops.convert_to_tensor(communicator, _dtypes.resource) 

182 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) 

183 _inputs_flat = [input, communicator, group_assignment] 

184 _attrs = ("T", _attr_T, "timeout_seconds", timeout_seconds) 

185 _result = _execute.execute(b"CollectiveAllToAllV3", 1, inputs=_inputs_flat, 

186 attrs=_attrs, ctx=ctx, name=name) 

187 if _execute.must_record_gradient(): 

188 _execute.record_gradient( 

189 "CollectiveAllToAllV3", _inputs_flat, _attrs, _result) 

190 _result, = _result 

191 return _result 

192 

193_CollectiveAssignGroupV2Output = collections.namedtuple( 

194 "CollectiveAssignGroupV2", 

195 ["group_size", "group_key"]) 

196 

197 

198def collective_assign_group_v2(group_assignment, device_index, base_key, name=None): 

199 r"""Assign group keys based on group assignment. 

200 

201 Args: 

202 group_assignment: A `Tensor` of type `int32`. 

203 device_index: A `Tensor` of type `int32`. 

204 base_key: A `Tensor` of type `int32`. 

205 name: A name for the operation (optional). 

206 

207 Returns: 

208 A tuple of `Tensor` objects (group_size, group_key). 

209 

210 group_size: A `Tensor` of type `int32`. 

211 group_key: A `Tensor` of type `int32`. 

212 """ 

213 _ctx = _context._context or _context.context() 

214 tld = _ctx._thread_local_data 

215 if tld.is_eager: 

216 try: 

217 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

218 _ctx, "CollectiveAssignGroupV2", name, group_assignment, device_index, 

219 base_key) 

220 _result = _CollectiveAssignGroupV2Output._make(_result) 

221 return _result 

222 except _core._NotOkStatusException as e: 

223 _ops.raise_from_not_ok_status(e, name) 

224 except _core._FallbackException: 

225 pass 

226 try: 

227 return collective_assign_group_v2_eager_fallback( 

228 group_assignment, device_index, base_key, name=name, ctx=_ctx) 

229 except _core._SymbolicException: 

230 pass # Add nodes to the TensorFlow graph. 

231 # Add nodes to the TensorFlow graph. 

232 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

233 "CollectiveAssignGroupV2", group_assignment=group_assignment, 

234 device_index=device_index, 

235 base_key=base_key, name=name) 

236 _result = _outputs[:] 

237 if _execute.must_record_gradient(): 

238 _attrs = () 

239 _inputs_flat = _op.inputs 

240 _execute.record_gradient( 

241 "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result) 

242 _result = _CollectiveAssignGroupV2Output._make(_result) 

243 return _result 

244 

245CollectiveAssignGroupV2 = tf_export("raw_ops.CollectiveAssignGroupV2")(_ops.to_raw_op(collective_assign_group_v2)) 

246 

247 

248def collective_assign_group_v2_eager_fallback(group_assignment, device_index, base_key, name, ctx): 

249 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) 

250 device_index = _ops.convert_to_tensor(device_index, _dtypes.int32) 

251 base_key = _ops.convert_to_tensor(base_key, _dtypes.int32) 

252 _inputs_flat = [group_assignment, device_index, base_key] 

253 _attrs = None 

254 _result = _execute.execute(b"CollectiveAssignGroupV2", 2, 

255 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

256 name=name) 

257 if _execute.must_record_gradient(): 

258 _execute.record_gradient( 

259 "CollectiveAssignGroupV2", _inputs_flat, _attrs, _result) 

260 _result = _CollectiveAssignGroupV2Output._make(_result) 

261 return _result 

262 

263 

264def collective_bcast_recv(T, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None): 

265 r"""Receives a tensor value broadcast from another device. 

266 

267 Args: 

268 T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`. 

269 group_size: An `int`. 

270 group_key: An `int`. 

271 instance_key: An `int`. 

272 shape: A `tf.TensorShape` or list of `ints`. 

273 communication_hint: An optional `string`. Defaults to `"auto"`. 

274 timeout_seconds: An optional `float`. Defaults to `0`. 

275 name: A name for the operation (optional). 

276 

277 Returns: 

278 A `Tensor` of type `T`. 

279 """ 

280 _ctx = _context._context or _context.context() 

281 tld = _ctx._thread_local_data 

282 if tld.is_eager: 

283 try: 

284 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

285 _ctx, "CollectiveBcastRecv", name, "T", T, "group_size", group_size, 

286 "group_key", group_key, "instance_key", instance_key, "shape", shape, 

287 "communication_hint", communication_hint, "timeout_seconds", 

288 timeout_seconds) 

289 return _result 

290 except _core._NotOkStatusException as e: 

291 _ops.raise_from_not_ok_status(e, name) 

292 except _core._FallbackException: 

293 pass 

294 try: 

295 return collective_bcast_recv_eager_fallback( 

296 T=T, group_size=group_size, group_key=group_key, 

297 instance_key=instance_key, shape=shape, 

298 communication_hint=communication_hint, 

299 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

300 except _core._SymbolicException: 

301 pass # Add nodes to the TensorFlow graph. 

302 # Add nodes to the TensorFlow graph. 

303 T = _execute.make_type(T, "T") 

304 group_size = _execute.make_int(group_size, "group_size") 

305 group_key = _execute.make_int(group_key, "group_key") 

306 instance_key = _execute.make_int(instance_key, "instance_key") 

307 shape = _execute.make_shape(shape, "shape") 

308 if communication_hint is None: 

309 communication_hint = "auto" 

310 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

311 if timeout_seconds is None: 

312 timeout_seconds = 0 

313 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

314 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

315 "CollectiveBcastRecv", T=T, group_size=group_size, 

316 group_key=group_key, instance_key=instance_key, 

317 shape=shape, 

318 communication_hint=communication_hint, 

319 timeout_seconds=timeout_seconds, name=name) 

320 _result = _outputs[:] 

321 if _execute.must_record_gradient(): 

322 _attrs = ("T", _op._get_attr_type("T"), "group_size", 

323 _op._get_attr_int("group_size"), "group_key", 

324 _op._get_attr_int("group_key"), "instance_key", 

325 _op._get_attr_int("instance_key"), "shape", 

326 _op.get_attr("shape"), "communication_hint", 

327 _op.get_attr("communication_hint"), "timeout_seconds", 

328 _op.get_attr("timeout_seconds")) 

329 _inputs_flat = _op.inputs 

330 _execute.record_gradient( 

331 "CollectiveBcastRecv", _inputs_flat, _attrs, _result) 

332 _result, = _result 

333 return _result 

334 

335CollectiveBcastRecv = tf_export("raw_ops.CollectiveBcastRecv")(_ops.to_raw_op(collective_bcast_recv)) 

336 

337 

338def collective_bcast_recv_eager_fallback(T, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx): 

339 T = _execute.make_type(T, "T") 

340 group_size = _execute.make_int(group_size, "group_size") 

341 group_key = _execute.make_int(group_key, "group_key") 

342 instance_key = _execute.make_int(instance_key, "instance_key") 

343 shape = _execute.make_shape(shape, "shape") 

344 if communication_hint is None: 

345 communication_hint = "auto" 

346 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

347 if timeout_seconds is None: 

348 timeout_seconds = 0 

349 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

350 _inputs_flat = [] 

351 _attrs = ("T", T, "group_size", group_size, "group_key", group_key, 

352 "instance_key", instance_key, "shape", shape, "communication_hint", 

353 communication_hint, "timeout_seconds", timeout_seconds) 

354 _result = _execute.execute(b"CollectiveBcastRecv", 1, inputs=_inputs_flat, 

355 attrs=_attrs, ctx=ctx, name=name) 

356 if _execute.must_record_gradient(): 

357 _execute.record_gradient( 

358 "CollectiveBcastRecv", _inputs_flat, _attrs, _result) 

359 _result, = _result 

360 return _result 

361 

362 

363def collective_bcast_recv_v2(group_size, group_key, instance_key, shape, T, communication_hint="auto", timeout_seconds=0, name=None): 

364 r"""Receives a tensor value broadcast from another device. 

365 

366 Args: 

367 group_size: A `Tensor` of type `int32`. 

368 group_key: A `Tensor` of type `int32`. 

369 instance_key: A `Tensor` of type `int32`. 

370 shape: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

371 T: A `tf.DType` from: `tf.bool, tf.float32, tf.half, tf.float64, tf.int32, tf.int64`. 

372 communication_hint: An optional `string`. Defaults to `"auto"`. 

373 timeout_seconds: An optional `float`. Defaults to `0`. 

374 name: A name for the operation (optional). 

375 

376 Returns: 

377 A `Tensor` of type `T`. 

378 """ 

379 _ctx = _context._context or _context.context() 

380 tld = _ctx._thread_local_data 

381 if tld.is_eager: 

382 try: 

383 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

384 _ctx, "CollectiveBcastRecvV2", name, group_size, group_key, 

385 instance_key, shape, "T", T, "communication_hint", communication_hint, 

386 "timeout_seconds", timeout_seconds) 

387 return _result 

388 except _core._NotOkStatusException as e: 

389 _ops.raise_from_not_ok_status(e, name) 

390 except _core._FallbackException: 

391 pass 

392 try: 

393 return collective_bcast_recv_v2_eager_fallback( 

394 group_size, group_key, instance_key, shape, T=T, 

395 communication_hint=communication_hint, 

396 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

397 except _core._SymbolicException: 

398 pass # Add nodes to the TensorFlow graph. 

399 # Add nodes to the TensorFlow graph. 

400 T = _execute.make_type(T, "T") 

401 if communication_hint is None: 

402 communication_hint = "auto" 

403 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

404 if timeout_seconds is None: 

405 timeout_seconds = 0 

406 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

407 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

408 "CollectiveBcastRecvV2", group_size=group_size, group_key=group_key, 

409 instance_key=instance_key, shape=shape, T=T, 

410 communication_hint=communication_hint, 

411 timeout_seconds=timeout_seconds, name=name) 

412 _result = _outputs[:] 

413 if _execute.must_record_gradient(): 

414 _attrs = ("T", _op._get_attr_type("T"), "Tshape", 

415 _op._get_attr_type("Tshape"), "communication_hint", 

416 _op.get_attr("communication_hint"), "timeout_seconds", 

417 _op.get_attr("timeout_seconds")) 

418 _inputs_flat = _op.inputs 

419 _execute.record_gradient( 

420 "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result) 

421 _result, = _result 

422 return _result 

423 

424CollectiveBcastRecvV2 = tf_export("raw_ops.CollectiveBcastRecvV2")(_ops.to_raw_op(collective_bcast_recv_v2)) 

425 

426 

427def collective_bcast_recv_v2_eager_fallback(group_size, group_key, instance_key, shape, T, communication_hint, timeout_seconds, name, ctx): 

428 T = _execute.make_type(T, "T") 

429 if communication_hint is None: 

430 communication_hint = "auto" 

431 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

432 if timeout_seconds is None: 

433 timeout_seconds = 0 

434 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

435 _attr_Tshape, (shape,) = _execute.args_to_matching_eager([shape], ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int32) 

436 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

437 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

438 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

439 _inputs_flat = [group_size, group_key, instance_key, shape] 

440 _attrs = ("T", T, "Tshape", _attr_Tshape, "communication_hint", 

441 communication_hint, "timeout_seconds", timeout_seconds) 

442 _result = _execute.execute(b"CollectiveBcastRecvV2", 1, inputs=_inputs_flat, 

443 attrs=_attrs, ctx=ctx, name=name) 

444 if _execute.must_record_gradient(): 

445 _execute.record_gradient( 

446 "CollectiveBcastRecvV2", _inputs_flat, _attrs, _result) 

447 _result, = _result 

448 return _result 

449 

450 

451def collective_bcast_send(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None): 

452 r"""Broadcasts a tensor value to one or more other devices. 

453 

454 Args: 

455 input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`. 

456 group_size: An `int`. 

457 group_key: An `int`. 

458 instance_key: An `int`. 

459 shape: A `tf.TensorShape` or list of `ints`. 

460 communication_hint: An optional `string`. Defaults to `"auto"`. 

461 timeout_seconds: An optional `float`. Defaults to `0`. 

462 name: A name for the operation (optional). 

463 

464 Returns: 

465 A `Tensor`. Has the same type as `input`. 

466 """ 

467 _ctx = _context._context or _context.context() 

468 tld = _ctx._thread_local_data 

469 if tld.is_eager: 

470 try: 

471 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

472 _ctx, "CollectiveBcastSend", name, input, "group_size", group_size, 

473 "group_key", group_key, "instance_key", instance_key, "shape", shape, 

474 "communication_hint", communication_hint, "timeout_seconds", 

475 timeout_seconds) 

476 return _result 

477 except _core._NotOkStatusException as e: 

478 _ops.raise_from_not_ok_status(e, name) 

479 except _core._FallbackException: 

480 pass 

481 try: 

482 return collective_bcast_send_eager_fallback( 

483 input, group_size=group_size, group_key=group_key, 

484 instance_key=instance_key, shape=shape, 

485 communication_hint=communication_hint, 

486 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

487 except _core._SymbolicException: 

488 pass # Add nodes to the TensorFlow graph. 

489 # Add nodes to the TensorFlow graph. 

490 group_size = _execute.make_int(group_size, "group_size") 

491 group_key = _execute.make_int(group_key, "group_key") 

492 instance_key = _execute.make_int(instance_key, "instance_key") 

493 shape = _execute.make_shape(shape, "shape") 

494 if communication_hint is None: 

495 communication_hint = "auto" 

496 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

497 if timeout_seconds is None: 

498 timeout_seconds = 0 

499 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

500 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

501 "CollectiveBcastSend", input=input, group_size=group_size, 

502 group_key=group_key, instance_key=instance_key, 

503 shape=shape, 

504 communication_hint=communication_hint, 

505 timeout_seconds=timeout_seconds, name=name) 

506 _result = _outputs[:] 

507 if _execute.must_record_gradient(): 

508 _attrs = ("T", _op._get_attr_type("T"), "group_size", 

509 _op._get_attr_int("group_size"), "group_key", 

510 _op._get_attr_int("group_key"), "instance_key", 

511 _op._get_attr_int("instance_key"), "shape", 

512 _op.get_attr("shape"), "communication_hint", 

513 _op.get_attr("communication_hint"), "timeout_seconds", 

514 _op.get_attr("timeout_seconds")) 

515 _inputs_flat = _op.inputs 

516 _execute.record_gradient( 

517 "CollectiveBcastSend", _inputs_flat, _attrs, _result) 

518 _result, = _result 

519 return _result 

520 

521CollectiveBcastSend = tf_export("raw_ops.CollectiveBcastSend")(_ops.to_raw_op(collective_bcast_send)) 

522 

523 

524def collective_bcast_send_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx): 

525 group_size = _execute.make_int(group_size, "group_size") 

526 group_key = _execute.make_int(group_key, "group_key") 

527 instance_key = _execute.make_int(instance_key, "instance_key") 

528 shape = _execute.make_shape(shape, "shape") 

529 if communication_hint is None: 

530 communication_hint = "auto" 

531 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

532 if timeout_seconds is None: 

533 timeout_seconds = 0 

534 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

535 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

536 _inputs_flat = [input] 

537 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key, 

538 "instance_key", instance_key, "shape", shape, "communication_hint", 

539 communication_hint, "timeout_seconds", timeout_seconds) 

540 _result = _execute.execute(b"CollectiveBcastSend", 1, inputs=_inputs_flat, 

541 attrs=_attrs, ctx=ctx, name=name) 

542 if _execute.must_record_gradient(): 

543 _execute.record_gradient( 

544 "CollectiveBcastSend", _inputs_flat, _attrs, _result) 

545 _result, = _result 

546 return _result 

547 

548 

549def collective_bcast_send_v2(input, group_size, group_key, instance_key, communication_hint="auto", timeout_seconds=0, name=None): 

550 r"""Broadcasts a tensor value to one or more other devices. 

551 

552 Args: 

553 input: A `Tensor`. Must be one of the following types: `bool`, `float32`, `half`, `float64`, `int32`, `int64`. 

554 group_size: A `Tensor` of type `int32`. 

555 group_key: A `Tensor` of type `int32`. 

556 instance_key: A `Tensor` of type `int32`. 

557 communication_hint: An optional `string`. Defaults to `"auto"`. 

558 timeout_seconds: An optional `float`. Defaults to `0`. 

559 name: A name for the operation (optional). 

560 

561 Returns: 

562 A `Tensor`. Has the same type as `input`. 

563 """ 

564 _ctx = _context._context or _context.context() 

565 tld = _ctx._thread_local_data 

566 if tld.is_eager: 

567 try: 

568 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

569 _ctx, "CollectiveBcastSendV2", name, input, group_size, group_key, 

570 instance_key, "communication_hint", communication_hint, 

571 "timeout_seconds", timeout_seconds) 

572 return _result 

573 except _core._NotOkStatusException as e: 

574 _ops.raise_from_not_ok_status(e, name) 

575 except _core._FallbackException: 

576 pass 

577 try: 

578 return collective_bcast_send_v2_eager_fallback( 

579 input, group_size, group_key, instance_key, 

580 communication_hint=communication_hint, 

581 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

582 except _core._SymbolicException: 

583 pass # Add nodes to the TensorFlow graph. 

584 # Add nodes to the TensorFlow graph. 

585 if communication_hint is None: 

586 communication_hint = "auto" 

587 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

588 if timeout_seconds is None: 

589 timeout_seconds = 0 

590 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

591 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

592 "CollectiveBcastSendV2", input=input, group_size=group_size, 

593 group_key=group_key, 

594 instance_key=instance_key, 

595 communication_hint=communication_hint, 

596 timeout_seconds=timeout_seconds, name=name) 

597 _result = _outputs[:] 

598 if _execute.must_record_gradient(): 

599 _attrs = ("T", _op._get_attr_type("T"), "communication_hint", 

600 _op.get_attr("communication_hint"), "timeout_seconds", 

601 _op.get_attr("timeout_seconds")) 

602 _inputs_flat = _op.inputs 

603 _execute.record_gradient( 

604 "CollectiveBcastSendV2", _inputs_flat, _attrs, _result) 

605 _result, = _result 

606 return _result 

607 

608CollectiveBcastSendV2 = tf_export("raw_ops.CollectiveBcastSendV2")(_ops.to_raw_op(collective_bcast_send_v2)) 

609 

610 

611def collective_bcast_send_v2_eager_fallback(input, group_size, group_key, instance_key, communication_hint, timeout_seconds, name, ctx): 

612 if communication_hint is None: 

613 communication_hint = "auto" 

614 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

615 if timeout_seconds is None: 

616 timeout_seconds = 0 

617 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

618 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bool, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

619 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

620 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

621 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

622 _inputs_flat = [input, group_size, group_key, instance_key] 

623 _attrs = ("T", _attr_T, "communication_hint", communication_hint, 

624 "timeout_seconds", timeout_seconds) 

625 _result = _execute.execute(b"CollectiveBcastSendV2", 1, inputs=_inputs_flat, 

626 attrs=_attrs, ctx=ctx, name=name) 

627 if _execute.must_record_gradient(): 

628 _execute.record_gradient( 

629 "CollectiveBcastSendV2", _inputs_flat, _attrs, _result) 

630 _result, = _result 

631 return _result 

632 

633 

634def collective_gather(input, group_size, group_key, instance_key, shape, communication_hint="auto", timeout_seconds=0, name=None): 

635 r"""Mutually accumulates multiple tensors of identical type and shape. 

636 

637 Args: 

638 input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`. 

639 group_size: An `int`. 

640 group_key: An `int`. 

641 instance_key: An `int`. 

642 shape: A `tf.TensorShape` or list of `ints`. 

643 communication_hint: An optional `string`. Defaults to `"auto"`. 

644 timeout_seconds: An optional `float`. Defaults to `0`. 

645 name: A name for the operation (optional). 

646 

647 Returns: 

648 A `Tensor`. Has the same type as `input`. 

649 """ 

650 _ctx = _context._context or _context.context() 

651 tld = _ctx._thread_local_data 

652 if tld.is_eager: 

653 try: 

654 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

655 _ctx, "CollectiveGather", name, input, "group_size", group_size, 

656 "group_key", group_key, "instance_key", instance_key, "shape", shape, 

657 "communication_hint", communication_hint, "timeout_seconds", 

658 timeout_seconds) 

659 return _result 

660 except _core._NotOkStatusException as e: 

661 _ops.raise_from_not_ok_status(e, name) 

662 except _core._FallbackException: 

663 pass 

664 try: 

665 return collective_gather_eager_fallback( 

666 input, group_size=group_size, group_key=group_key, 

667 instance_key=instance_key, shape=shape, 

668 communication_hint=communication_hint, 

669 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

670 except _core._SymbolicException: 

671 pass # Add nodes to the TensorFlow graph. 

672 # Add nodes to the TensorFlow graph. 

673 group_size = _execute.make_int(group_size, "group_size") 

674 group_key = _execute.make_int(group_key, "group_key") 

675 instance_key = _execute.make_int(instance_key, "instance_key") 

676 shape = _execute.make_shape(shape, "shape") 

677 if communication_hint is None: 

678 communication_hint = "auto" 

679 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

680 if timeout_seconds is None: 

681 timeout_seconds = 0 

682 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

683 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

684 "CollectiveGather", input=input, group_size=group_size, 

685 group_key=group_key, instance_key=instance_key, 

686 shape=shape, 

687 communication_hint=communication_hint, 

688 timeout_seconds=timeout_seconds, name=name) 

689 _result = _outputs[:] 

690 if _execute.must_record_gradient(): 

691 _attrs = ("T", _op._get_attr_type("T"), "group_size", 

692 _op._get_attr_int("group_size"), "group_key", 

693 _op._get_attr_int("group_key"), "instance_key", 

694 _op._get_attr_int("instance_key"), "shape", 

695 _op.get_attr("shape"), "communication_hint", 

696 _op.get_attr("communication_hint"), "timeout_seconds", 

697 _op.get_attr("timeout_seconds")) 

698 _inputs_flat = _op.inputs 

699 _execute.record_gradient( 

700 "CollectiveGather", _inputs_flat, _attrs, _result) 

701 _result, = _result 

702 return _result 

703 

704CollectiveGather = tf_export("raw_ops.CollectiveGather")(_ops.to_raw_op(collective_gather)) 

705 

706 

707def collective_gather_eager_fallback(input, group_size, group_key, instance_key, shape, communication_hint, timeout_seconds, name, ctx): 

708 group_size = _execute.make_int(group_size, "group_size") 

709 group_key = _execute.make_int(group_key, "group_key") 

710 instance_key = _execute.make_int(instance_key, "instance_key") 

711 shape = _execute.make_shape(shape, "shape") 

712 if communication_hint is None: 

713 communication_hint = "auto" 

714 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

715 if timeout_seconds is None: 

716 timeout_seconds = 0 

717 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

718 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

719 _inputs_flat = [input] 

720 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key, 

721 "instance_key", instance_key, "shape", shape, "communication_hint", 

722 communication_hint, "timeout_seconds", timeout_seconds) 

723 _result = _execute.execute(b"CollectiveGather", 1, inputs=_inputs_flat, 

724 attrs=_attrs, ctx=ctx, name=name) 

725 if _execute.must_record_gradient(): 

726 _execute.record_gradient( 

727 "CollectiveGather", _inputs_flat, _attrs, _result) 

728 _result, = _result 

729 return _result 

730 

731 

732def collective_gather_v2(input, group_size, group_key, instance_key, ordering_token, communication_hint="auto", timeout_seconds=0, name=None): 

733 r"""Mutually accumulates multiple tensors of identical type and shape. 

734 

735 Args: 

736 input: A `Tensor`. Must be one of the following types: `float32`, `half`, `float64`, `int32`, `int64`. 

737 group_size: A `Tensor` of type `int32`. 

738 group_key: A `Tensor` of type `int32`. 

739 instance_key: A `Tensor` of type `int32`. 

740 ordering_token: A list of `Tensor` objects with type `resource`. 

741 communication_hint: An optional `string`. Defaults to `"auto"`. 

742 timeout_seconds: An optional `float`. Defaults to `0`. 

743 name: A name for the operation (optional). 

744 

745 Returns: 

746 A `Tensor`. Has the same type as `input`. 

747 """ 

748 _ctx = _context._context or _context.context() 

749 tld = _ctx._thread_local_data 

750 if tld.is_eager: 

751 try: 

752 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

753 _ctx, "CollectiveGatherV2", name, input, group_size, group_key, 

754 instance_key, ordering_token, "communication_hint", 

755 communication_hint, "timeout_seconds", timeout_seconds) 

756 return _result 

757 except _core._NotOkStatusException as e: 

758 _ops.raise_from_not_ok_status(e, name) 

759 except _core._FallbackException: 

760 pass 

761 try: 

762 return collective_gather_v2_eager_fallback( 

763 input, group_size, group_key, instance_key, ordering_token, 

764 communication_hint=communication_hint, 

765 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

766 except _core._SymbolicException: 

767 pass # Add nodes to the TensorFlow graph. 

768 # Add nodes to the TensorFlow graph. 

769 if not isinstance(ordering_token, (list, tuple)): 

770 raise TypeError( 

771 "Expected list for 'ordering_token' argument to " 

772 "'collective_gather_v2' Op, not %r." % ordering_token) 

773 _attr_Nordering_token = len(ordering_token) 

774 if communication_hint is None: 

775 communication_hint = "auto" 

776 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

777 if timeout_seconds is None: 

778 timeout_seconds = 0 

779 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

780 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

781 "CollectiveGatherV2", input=input, group_size=group_size, 

782 group_key=group_key, instance_key=instance_key, 

783 ordering_token=ordering_token, 

784 communication_hint=communication_hint, 

785 timeout_seconds=timeout_seconds, name=name) 

786 _result = _outputs[:] 

787 if _execute.must_record_gradient(): 

788 _attrs = ("T", _op._get_attr_type("T"), "communication_hint", 

789 _op.get_attr("communication_hint"), "timeout_seconds", 

790 _op.get_attr("timeout_seconds"), "Nordering_token", 

791 _op._get_attr_int("Nordering_token")) 

792 _inputs_flat = _op.inputs 

793 _execute.record_gradient( 

794 "CollectiveGatherV2", _inputs_flat, _attrs, _result) 

795 _result, = _result 

796 return _result 

797 

798CollectiveGatherV2 = tf_export("raw_ops.CollectiveGatherV2")(_ops.to_raw_op(collective_gather_v2)) 

799 

800 

801def collective_gather_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, communication_hint, timeout_seconds, name, ctx): 

802 if not isinstance(ordering_token, (list, tuple)): 

803 raise TypeError( 

804 "Expected list for 'ordering_token' argument to " 

805 "'collective_gather_v2' Op, not %r." % ordering_token) 

806 _attr_Nordering_token = len(ordering_token) 

807 if communication_hint is None: 

808 communication_hint = "auto" 

809 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

810 if timeout_seconds is None: 

811 timeout_seconds = 0 

812 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

813 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

814 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

815 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

816 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

817 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource) 

818 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token) 

819 _attrs = ("T", _attr_T, "communication_hint", communication_hint, 

820 "timeout_seconds", timeout_seconds, "Nordering_token", 

821 _attr_Nordering_token) 

822 _result = _execute.execute(b"CollectiveGatherV2", 1, inputs=_inputs_flat, 

823 attrs=_attrs, ctx=ctx, name=name) 

824 if _execute.must_record_gradient(): 

825 _execute.record_gradient( 

826 "CollectiveGatherV2", _inputs_flat, _attrs, _result) 

827 _result, = _result 

828 return _result 

829 

830 

831def collective_initialize_communicator(group_key, rank, group_size, communication_hint="auto", timeout_seconds=0, name=None): 

832 r"""Initializes a group for collective operations. 

833 

834 Args: 

835 group_key: A `Tensor` of type `int32`. 

836 rank: A `Tensor` of type `int32`. 

837 group_size: A `Tensor` of type `int32`. 

838 communication_hint: An optional `string`. Defaults to `"auto"`. 

839 timeout_seconds: An optional `float`. Defaults to `0`. 

840 name: A name for the operation (optional). 

841 

842 Returns: 

843 A `Tensor` of type `resource`. 

844 """ 

845 _ctx = _context._context or _context.context() 

846 tld = _ctx._thread_local_data 

847 if tld.is_eager: 

848 try: 

849 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

850 _ctx, "CollectiveInitializeCommunicator", name, group_key, rank, 

851 group_size, "communication_hint", communication_hint, 

852 "timeout_seconds", timeout_seconds) 

853 return _result 

854 except _core._NotOkStatusException as e: 

855 _ops.raise_from_not_ok_status(e, name) 

856 except _core._FallbackException: 

857 pass 

858 try: 

859 return collective_initialize_communicator_eager_fallback( 

860 group_key, rank, group_size, communication_hint=communication_hint, 

861 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

862 except _core._SymbolicException: 

863 pass # Add nodes to the TensorFlow graph. 

864 # Add nodes to the TensorFlow graph. 

865 if communication_hint is None: 

866 communication_hint = "auto" 

867 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

868 if timeout_seconds is None: 

869 timeout_seconds = 0 

870 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

871 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

872 "CollectiveInitializeCommunicator", group_key=group_key, rank=rank, 

873 group_size=group_size, 

874 communication_hint=communication_hint, 

875 timeout_seconds=timeout_seconds, 

876 name=name) 

877 _result = _outputs[:] 

878 if _execute.must_record_gradient(): 

879 _attrs = ("communication_hint", _op.get_attr("communication_hint"), 

880 "timeout_seconds", _op.get_attr("timeout_seconds")) 

881 _inputs_flat = _op.inputs 

882 _execute.record_gradient( 

883 "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result) 

884 _result, = _result 

885 return _result 

886 

887CollectiveInitializeCommunicator = tf_export("raw_ops.CollectiveInitializeCommunicator")(_ops.to_raw_op(collective_initialize_communicator)) 

888 

889 

890def collective_initialize_communicator_eager_fallback(group_key, rank, group_size, communication_hint, timeout_seconds, name, ctx): 

891 if communication_hint is None: 

892 communication_hint = "auto" 

893 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

894 if timeout_seconds is None: 

895 timeout_seconds = 0 

896 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

897 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

898 rank = _ops.convert_to_tensor(rank, _dtypes.int32) 

899 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

900 _inputs_flat = [group_key, rank, group_size] 

901 _attrs = ("communication_hint", communication_hint, "timeout_seconds", 

902 timeout_seconds) 

903 _result = _execute.execute(b"CollectiveInitializeCommunicator", 1, 

904 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

905 name=name) 

906 if _execute.must_record_gradient(): 

907 _execute.record_gradient( 

908 "CollectiveInitializeCommunicator", _inputs_flat, _attrs, _result) 

909 _result, = _result 

910 return _result 

911 

912 

913def collective_reduce(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for=[], communication_hint="auto", timeout_seconds=0, name=None): 

914 r"""Mutually reduces multiple tensors of identical type and shape. 

915 

916 Args: 

917 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

918 group_size: An `int`. 

919 group_key: An `int`. 

920 instance_key: An `int`. 

921 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`. 

922 final_op: A `string` from: `"Id", "Div"`. 

923 subdiv_offsets: A list of `ints`. 

924 wait_for: An optional list of `ints`. Defaults to `[]`. 

925 communication_hint: An optional `string`. Defaults to `"auto"`. 

926 timeout_seconds: An optional `float`. Defaults to `0`. 

927 name: A name for the operation (optional). 

928 

929 Returns: 

930 A `Tensor`. Has the same type as `input`. 

931 """ 

932 _ctx = _context._context or _context.context() 

933 tld = _ctx._thread_local_data 

934 if tld.is_eager: 

935 try: 

936 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

937 _ctx, "CollectiveReduce", name, input, "group_size", group_size, 

938 "group_key", group_key, "instance_key", instance_key, "merge_op", 

939 merge_op, "final_op", final_op, "subdiv_offsets", subdiv_offsets, 

940 "wait_for", wait_for, "communication_hint", communication_hint, 

941 "timeout_seconds", timeout_seconds) 

942 return _result 

943 except _core._NotOkStatusException as e: 

944 _ops.raise_from_not_ok_status(e, name) 

945 except _core._FallbackException: 

946 pass 

947 try: 

948 return collective_reduce_eager_fallback( 

949 input, group_size=group_size, group_key=group_key, 

950 instance_key=instance_key, merge_op=merge_op, final_op=final_op, 

951 subdiv_offsets=subdiv_offsets, wait_for=wait_for, 

952 communication_hint=communication_hint, 

953 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

954 except _core._SymbolicException: 

955 pass # Add nodes to the TensorFlow graph. 

956 # Add nodes to the TensorFlow graph. 

957 group_size = _execute.make_int(group_size, "group_size") 

958 group_key = _execute.make_int(group_key, "group_key") 

959 instance_key = _execute.make_int(instance_key, "instance_key") 

960 merge_op = _execute.make_str(merge_op, "merge_op") 

961 final_op = _execute.make_str(final_op, "final_op") 

962 if not isinstance(subdiv_offsets, (list, tuple)): 

963 raise TypeError( 

964 "Expected list for 'subdiv_offsets' argument to " 

965 "'collective_reduce' Op, not %r." % subdiv_offsets) 

966 subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets] 

967 if wait_for is None: 

968 wait_for = [] 

969 if not isinstance(wait_for, (list, tuple)): 

970 raise TypeError( 

971 "Expected list for 'wait_for' argument to " 

972 "'collective_reduce' Op, not %r." % wait_for) 

973 wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for] 

974 if communication_hint is None: 

975 communication_hint = "auto" 

976 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

977 if timeout_seconds is None: 

978 timeout_seconds = 0 

979 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

980 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

981 "CollectiveReduce", input=input, group_size=group_size, 

982 group_key=group_key, instance_key=instance_key, 

983 merge_op=merge_op, final_op=final_op, 

984 subdiv_offsets=subdiv_offsets, wait_for=wait_for, 

985 communication_hint=communication_hint, 

986 timeout_seconds=timeout_seconds, name=name) 

987 _result = _outputs[:] 

988 if _execute.must_record_gradient(): 

989 _attrs = ("T", _op._get_attr_type("T"), "group_size", 

990 _op._get_attr_int("group_size"), "group_key", 

991 _op._get_attr_int("group_key"), "instance_key", 

992 _op._get_attr_int("instance_key"), "merge_op", 

993 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"), 

994 "subdiv_offsets", _op.get_attr("subdiv_offsets"), "wait_for", 

995 _op.get_attr("wait_for"), "communication_hint", 

996 _op.get_attr("communication_hint"), "timeout_seconds", 

997 _op.get_attr("timeout_seconds")) 

998 _inputs_flat = _op.inputs 

999 _execute.record_gradient( 

1000 "CollectiveReduce", _inputs_flat, _attrs, _result) 

1001 _result, = _result 

1002 return _result 

1003 

1004CollectiveReduce = tf_export("raw_ops.CollectiveReduce")(_ops.to_raw_op(collective_reduce)) 

1005 

1006 

1007def collective_reduce_eager_fallback(input, group_size, group_key, instance_key, merge_op, final_op, subdiv_offsets, wait_for, communication_hint, timeout_seconds, name, ctx): 

1008 group_size = _execute.make_int(group_size, "group_size") 

1009 group_key = _execute.make_int(group_key, "group_key") 

1010 instance_key = _execute.make_int(instance_key, "instance_key") 

1011 merge_op = _execute.make_str(merge_op, "merge_op") 

1012 final_op = _execute.make_str(final_op, "final_op") 

1013 if not isinstance(subdiv_offsets, (list, tuple)): 

1014 raise TypeError( 

1015 "Expected list for 'subdiv_offsets' argument to " 

1016 "'collective_reduce' Op, not %r." % subdiv_offsets) 

1017 subdiv_offsets = [_execute.make_int(_i, "subdiv_offsets") for _i in subdiv_offsets] 

1018 if wait_for is None: 

1019 wait_for = [] 

1020 if not isinstance(wait_for, (list, tuple)): 

1021 raise TypeError( 

1022 "Expected list for 'wait_for' argument to " 

1023 "'collective_reduce' Op, not %r." % wait_for) 

1024 wait_for = [_execute.make_int(_i, "wait_for") for _i in wait_for] 

1025 if communication_hint is None: 

1026 communication_hint = "auto" 

1027 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

1028 if timeout_seconds is None: 

1029 timeout_seconds = 0 

1030 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1031 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

1032 _inputs_flat = [input] 

1033 _attrs = ("T", _attr_T, "group_size", group_size, "group_key", group_key, 

1034 "instance_key", instance_key, "merge_op", merge_op, "final_op", final_op, 

1035 "subdiv_offsets", subdiv_offsets, "wait_for", wait_for, 

1036 "communication_hint", communication_hint, "timeout_seconds", 

1037 timeout_seconds) 

1038 _result = _execute.execute(b"CollectiveReduce", 1, inputs=_inputs_flat, 

1039 attrs=_attrs, ctx=ctx, name=name) 

1040 if _execute.must_record_gradient(): 

1041 _execute.record_gradient( 

1042 "CollectiveReduce", _inputs_flat, _attrs, _result) 

1043 _result, = _result 

1044 return _result 

1045 

1046 

1047def collective_reduce_scatter_v2(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint="auto", timeout_seconds=0, max_subdivs_per_device=-1, name=None): 

1048 r"""Mutually reduces multiple tensors of identical type and shape and scatters the result. 

1049 

1050 Args: 

1051 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

1052 group_size: A `Tensor` of type `int32`. 

1053 group_key: A `Tensor` of type `int32`. 

1054 instance_key: A `Tensor` of type `int32`. 

1055 ordering_token: A list of `Tensor` objects with type `resource`. 

1056 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`. 

1057 final_op: A `string` from: `"Id", "Div"`. 

1058 communication_hint: An optional `string`. Defaults to `"auto"`. 

1059 timeout_seconds: An optional `float`. Defaults to `0`. 

1060 max_subdivs_per_device: An optional `int`. Defaults to `-1`. 

1061 name: A name for the operation (optional). 

1062 

1063 Returns: 

1064 A `Tensor`. Has the same type as `input`. 

1065 """ 

1066 _ctx = _context._context or _context.context() 

1067 tld = _ctx._thread_local_data 

1068 if tld.is_eager: 

1069 try: 

1070 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1071 _ctx, "CollectiveReduceScatterV2", name, input, group_size, group_key, 

1072 instance_key, ordering_token, "merge_op", merge_op, "final_op", 

1073 final_op, "communication_hint", communication_hint, "timeout_seconds", 

1074 timeout_seconds, "max_subdivs_per_device", max_subdivs_per_device) 

1075 return _result 

1076 except _core._NotOkStatusException as e: 

1077 _ops.raise_from_not_ok_status(e, name) 

1078 except _core._FallbackException: 

1079 pass 

1080 try: 

1081 return collective_reduce_scatter_v2_eager_fallback( 

1082 input, group_size, group_key, instance_key, ordering_token, 

1083 merge_op=merge_op, final_op=final_op, 

1084 communication_hint=communication_hint, 

1085 timeout_seconds=timeout_seconds, 

1086 max_subdivs_per_device=max_subdivs_per_device, name=name, ctx=_ctx) 

1087 except _core._SymbolicException: 

1088 pass # Add nodes to the TensorFlow graph. 

1089 # Add nodes to the TensorFlow graph. 

1090 if not isinstance(ordering_token, (list, tuple)): 

1091 raise TypeError( 

1092 "Expected list for 'ordering_token' argument to " 

1093 "'collective_reduce_scatter_v2' Op, not %r." % ordering_token) 

1094 _attr_Nordering_token = len(ordering_token) 

1095 merge_op = _execute.make_str(merge_op, "merge_op") 

1096 final_op = _execute.make_str(final_op, "final_op") 

1097 if communication_hint is None: 

1098 communication_hint = "auto" 

1099 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

1100 if timeout_seconds is None: 

1101 timeout_seconds = 0 

1102 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1103 if max_subdivs_per_device is None: 

1104 max_subdivs_per_device = -1 

1105 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device") 

1106 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1107 "CollectiveReduceScatterV2", input=input, group_size=group_size, 

1108 group_key=group_key, 

1109 instance_key=instance_key, 

1110 ordering_token=ordering_token, 

1111 merge_op=merge_op, final_op=final_op, 

1112 communication_hint=communication_hint, 

1113 timeout_seconds=timeout_seconds, 

1114 max_subdivs_per_device=max_subdivs_per_device, 

1115 name=name) 

1116 _result = _outputs[:] 

1117 if _execute.must_record_gradient(): 

1118 _attrs = ("T", _op._get_attr_type("T"), "merge_op", 

1119 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"), 

1120 "communication_hint", _op.get_attr("communication_hint"), 

1121 "timeout_seconds", _op.get_attr("timeout_seconds"), 

1122 "Nordering_token", _op._get_attr_int("Nordering_token"), 

1123 "max_subdivs_per_device", 

1124 _op._get_attr_int("max_subdivs_per_device")) 

1125 _inputs_flat = _op.inputs 

1126 _execute.record_gradient( 

1127 "CollectiveReduceScatterV2", _inputs_flat, _attrs, _result) 

1128 _result, = _result 

1129 return _result 

1130 

1131CollectiveReduceScatterV2 = tf_export("raw_ops.CollectiveReduceScatterV2")(_ops.to_raw_op(collective_reduce_scatter_v2)) 

1132 

1133 

1134def collective_reduce_scatter_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint, timeout_seconds, max_subdivs_per_device, name, ctx): 

1135 if not isinstance(ordering_token, (list, tuple)): 

1136 raise TypeError( 

1137 "Expected list for 'ordering_token' argument to " 

1138 "'collective_reduce_scatter_v2' Op, not %r." % ordering_token) 

1139 _attr_Nordering_token = len(ordering_token) 

1140 merge_op = _execute.make_str(merge_op, "merge_op") 

1141 final_op = _execute.make_str(final_op, "final_op") 

1142 if communication_hint is None: 

1143 communication_hint = "auto" 

1144 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

1145 if timeout_seconds is None: 

1146 timeout_seconds = 0 

1147 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1148 if max_subdivs_per_device is None: 

1149 max_subdivs_per_device = -1 

1150 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device") 

1151 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

1152 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

1153 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

1154 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

1155 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource) 

1156 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token) 

1157 _attrs = ("T", _attr_T, "merge_op", merge_op, "final_op", final_op, 

1158 "communication_hint", communication_hint, "timeout_seconds", 

1159 timeout_seconds, "Nordering_token", _attr_Nordering_token, 

1160 "max_subdivs_per_device", max_subdivs_per_device) 

1161 _result = _execute.execute(b"CollectiveReduceScatterV2", 1, 

1162 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

1163 name=name) 

1164 if _execute.must_record_gradient(): 

1165 _execute.record_gradient( 

1166 "CollectiveReduceScatterV2", _inputs_flat, _attrs, _result) 

1167 _result, = _result 

1168 return _result 

1169 

1170 

1171def collective_reduce_v2(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint="auto", timeout_seconds=0, max_subdivs_per_device=-1, name=None): 

1172 r"""Mutually reduces multiple tensors of identical type and shape. 

1173 

1174 Args: 

1175 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

1176 group_size: A `Tensor` of type `int32`. 

1177 group_key: A `Tensor` of type `int32`. 

1178 instance_key: A `Tensor` of type `int32`. 

1179 ordering_token: A list of `Tensor` objects with type `resource`. 

1180 merge_op: A `string` from: `"Min", "Max", "Mul", "Add"`. 

1181 final_op: A `string` from: `"Id", "Div"`. 

1182 communication_hint: An optional `string`. Defaults to `"auto"`. 

1183 timeout_seconds: An optional `float`. Defaults to `0`. 

1184 max_subdivs_per_device: An optional `int`. Defaults to `-1`. 

1185 name: A name for the operation (optional). 

1186 

1187 Returns: 

1188 A `Tensor`. Has the same type as `input`. 

1189 """ 

1190 _ctx = _context._context or _context.context() 

1191 tld = _ctx._thread_local_data 

1192 if tld.is_eager: 

1193 try: 

1194 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1195 _ctx, "CollectiveReduceV2", name, input, group_size, group_key, 

1196 instance_key, ordering_token, "merge_op", merge_op, "final_op", 

1197 final_op, "communication_hint", communication_hint, "timeout_seconds", 

1198 timeout_seconds, "max_subdivs_per_device", max_subdivs_per_device) 

1199 return _result 

1200 except _core._NotOkStatusException as e: 

1201 _ops.raise_from_not_ok_status(e, name) 

1202 except _core._FallbackException: 

1203 pass 

1204 try: 

1205 return collective_reduce_v2_eager_fallback( 

1206 input, group_size, group_key, instance_key, ordering_token, 

1207 merge_op=merge_op, final_op=final_op, 

1208 communication_hint=communication_hint, 

1209 timeout_seconds=timeout_seconds, 

1210 max_subdivs_per_device=max_subdivs_per_device, name=name, ctx=_ctx) 

1211 except _core._SymbolicException: 

1212 pass # Add nodes to the TensorFlow graph. 

1213 # Add nodes to the TensorFlow graph. 

1214 if not isinstance(ordering_token, (list, tuple)): 

1215 raise TypeError( 

1216 "Expected list for 'ordering_token' argument to " 

1217 "'collective_reduce_v2' Op, not %r." % ordering_token) 

1218 _attr_Nordering_token = len(ordering_token) 

1219 merge_op = _execute.make_str(merge_op, "merge_op") 

1220 final_op = _execute.make_str(final_op, "final_op") 

1221 if communication_hint is None: 

1222 communication_hint = "auto" 

1223 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

1224 if timeout_seconds is None: 

1225 timeout_seconds = 0 

1226 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1227 if max_subdivs_per_device is None: 

1228 max_subdivs_per_device = -1 

1229 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device") 

1230 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1231 "CollectiveReduceV2", input=input, group_size=group_size, 

1232 group_key=group_key, instance_key=instance_key, 

1233 ordering_token=ordering_token, 

1234 merge_op=merge_op, final_op=final_op, 

1235 communication_hint=communication_hint, 

1236 timeout_seconds=timeout_seconds, 

1237 max_subdivs_per_device=max_subdivs_per_device, 

1238 name=name) 

1239 _result = _outputs[:] 

1240 if _execute.must_record_gradient(): 

1241 _attrs = ("T", _op._get_attr_type("T"), "merge_op", 

1242 _op.get_attr("merge_op"), "final_op", _op.get_attr("final_op"), 

1243 "communication_hint", _op.get_attr("communication_hint"), 

1244 "timeout_seconds", _op.get_attr("timeout_seconds"), 

1245 "Nordering_token", _op._get_attr_int("Nordering_token"), 

1246 "max_subdivs_per_device", 

1247 _op._get_attr_int("max_subdivs_per_device")) 

1248 _inputs_flat = _op.inputs 

1249 _execute.record_gradient( 

1250 "CollectiveReduceV2", _inputs_flat, _attrs, _result) 

1251 _result, = _result 

1252 return _result 

1253 

1254CollectiveReduceV2 = tf_export("raw_ops.CollectiveReduceV2")(_ops.to_raw_op(collective_reduce_v2)) 

1255 

1256 

1257def collective_reduce_v2_eager_fallback(input, group_size, group_key, instance_key, ordering_token, merge_op, final_op, communication_hint, timeout_seconds, max_subdivs_per_device, name, ctx): 

1258 if not isinstance(ordering_token, (list, tuple)): 

1259 raise TypeError( 

1260 "Expected list for 'ordering_token' argument to " 

1261 "'collective_reduce_v2' Op, not %r." % ordering_token) 

1262 _attr_Nordering_token = len(ordering_token) 

1263 merge_op = _execute.make_str(merge_op, "merge_op") 

1264 final_op = _execute.make_str(final_op, "final_op") 

1265 if communication_hint is None: 

1266 communication_hint = "auto" 

1267 communication_hint = _execute.make_str(communication_hint, "communication_hint") 

1268 if timeout_seconds is None: 

1269 timeout_seconds = 0 

1270 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1271 if max_subdivs_per_device is None: 

1272 max_subdivs_per_device = -1 

1273 max_subdivs_per_device = _execute.make_int(max_subdivs_per_device, "max_subdivs_per_device") 

1274 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

1275 group_size = _ops.convert_to_tensor(group_size, _dtypes.int32) 

1276 group_key = _ops.convert_to_tensor(group_key, _dtypes.int32) 

1277 instance_key = _ops.convert_to_tensor(instance_key, _dtypes.int32) 

1278 ordering_token = _ops.convert_n_to_tensor(ordering_token, _dtypes.resource) 

1279 _inputs_flat = [input, group_size, group_key, instance_key] + list(ordering_token) 

1280 _attrs = ("T", _attr_T, "merge_op", merge_op, "final_op", final_op, 

1281 "communication_hint", communication_hint, "timeout_seconds", 

1282 timeout_seconds, "Nordering_token", _attr_Nordering_token, 

1283 "max_subdivs_per_device", max_subdivs_per_device) 

1284 _result = _execute.execute(b"CollectiveReduceV2", 1, inputs=_inputs_flat, 

1285 attrs=_attrs, ctx=ctx, name=name) 

1286 if _execute.must_record_gradient(): 

1287 _execute.record_gradient( 

1288 "CollectiveReduceV2", _inputs_flat, _attrs, _result) 

1289 _result, = _result 

1290 return _result 

1291 

1292 

1293def collective_reduce_v3(input, communicator, group_assignment, reduction, timeout_seconds=0, name=None): 

1294 r"""Mutually reduces multiple tensors of identical type and shape. 

1295 

1296 Args: 

1297 input: A `Tensor`. Must be one of the following types: `bfloat16`, `float32`, `half`, `float64`, `int32`, `int64`. 

1298 communicator: A `Tensor` of type `resource`. 

1299 group_assignment: A `Tensor` of type `int32`. 

1300 reduction: A `string` from: `"Min", "Max", "Mul", "Add"`. 

1301 timeout_seconds: An optional `float`. Defaults to `0`. 

1302 name: A name for the operation (optional). 

1303 

1304 Returns: 

1305 A `Tensor`. Has the same type as `input`. 

1306 """ 

1307 _ctx = _context._context or _context.context() 

1308 tld = _ctx._thread_local_data 

1309 if tld.is_eager: 

1310 try: 

1311 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1312 _ctx, "CollectiveReduceV3", name, input, communicator, 

1313 group_assignment, "reduction", reduction, "timeout_seconds", 

1314 timeout_seconds) 

1315 return _result 

1316 except _core._NotOkStatusException as e: 

1317 _ops.raise_from_not_ok_status(e, name) 

1318 except _core._FallbackException: 

1319 pass 

1320 try: 

1321 return collective_reduce_v3_eager_fallback( 

1322 input, communicator, group_assignment, reduction=reduction, 

1323 timeout_seconds=timeout_seconds, name=name, ctx=_ctx) 

1324 except _core._SymbolicException: 

1325 pass # Add nodes to the TensorFlow graph. 

1326 # Add nodes to the TensorFlow graph. 

1327 reduction = _execute.make_str(reduction, "reduction") 

1328 if timeout_seconds is None: 

1329 timeout_seconds = 0 

1330 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1331 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1332 "CollectiveReduceV3", input=input, communicator=communicator, 

1333 group_assignment=group_assignment, 

1334 reduction=reduction, 

1335 timeout_seconds=timeout_seconds, name=name) 

1336 _result = _outputs[:] 

1337 if _execute.must_record_gradient(): 

1338 _attrs = ("T", _op._get_attr_type("T"), "reduction", 

1339 _op.get_attr("reduction"), "timeout_seconds", 

1340 _op.get_attr("timeout_seconds")) 

1341 _inputs_flat = _op.inputs 

1342 _execute.record_gradient( 

1343 "CollectiveReduceV3", _inputs_flat, _attrs, _result) 

1344 _result, = _result 

1345 return _result 

1346 

1347CollectiveReduceV3 = tf_export("raw_ops.CollectiveReduceV3")(_ops.to_raw_op(collective_reduce_v3)) 

1348 

1349 

1350def collective_reduce_v3_eager_fallback(input, communicator, group_assignment, reduction, timeout_seconds, name, ctx): 

1351 reduction = _execute.make_str(reduction, "reduction") 

1352 if timeout_seconds is None: 

1353 timeout_seconds = 0 

1354 timeout_seconds = _execute.make_float(timeout_seconds, "timeout_seconds") 

1355 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.bfloat16, _dtypes.float32, _dtypes.half, _dtypes.float64, _dtypes.int32, _dtypes.int64, ]) 

1356 communicator = _ops.convert_to_tensor(communicator, _dtypes.resource) 

1357 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) 

1358 _inputs_flat = [input, communicator, group_assignment] 

1359 _attrs = ("T", _attr_T, "reduction", reduction, "timeout_seconds", 

1360 timeout_seconds) 

1361 _result = _execute.execute(b"CollectiveReduceV3", 1, inputs=_inputs_flat, 

1362 attrs=_attrs, ctx=ctx, name=name) 

1363 if _execute.must_record_gradient(): 

1364 _execute.record_gradient( 

1365 "CollectiveReduceV3", _inputs_flat, _attrs, _result) 

1366 _result, = _result 

1367 return _result 

1368