Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/compiler/tf2xla/ops/gen_xla_ops.py: 12%

2496 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1"""Python wrappers around TensorFlow ops. 

2 

3This file is MACHINE GENERATED! Do not edit. 

4""" 

5 

6import collections 

7 

8from tensorflow.python import pywrap_tfe as pywrap_tfe 

9from tensorflow.python.eager import context as _context 

10from tensorflow.python.eager import core as _core 

11from tensorflow.python.eager import execute as _execute 

12from tensorflow.python.framework import dtypes as _dtypes 

13from tensorflow.security.fuzzing.py import annotation_types as _atypes 

14 

15from tensorflow.python.framework import op_def_registry as _op_def_registry 

16from tensorflow.python.framework import ops as _ops 

17from tensorflow.python.framework import op_def_library as _op_def_library 

18from tensorflow.python.util.deprecation import deprecated_endpoints 

19from tensorflow.python.util import dispatch as _dispatch 

20from tensorflow.python.util.tf_export import tf_export 

21 

22from typing import TypeVar 

23 

24@_dispatch.add_fallback_dispatch_list 

25@_dispatch.add_type_based_api_dispatcher 

26@tf_export('xla_all_reduce') 

27def xla_all_reduce(input, group_assignment, reduce_op, mode, name=None): 

28 r"""Wraps the XLA AllReduce operator 

29 

30 documented at https://www.tensorflow.org/xla/operation_semantics#allreduce. 

31 

32 Args: 

33 input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `int32`, `uint32`. 

34 Array or a non-empty tuple of arrays to reduce across replicas. 

35 group_assignment: A `Tensor` of type `int32`. 

36 Groups between which the reductions are performed. 

37 reduce_op: A `string` from: `"Min", "Max", "Mul", "Add", "Mean"`. 

38 Reduction computation. 

39 mode: A `string` from: `"CrossReplica", "CrossReplicaAndPartition"`. 

40 group mode. 

41 CrossReplica: group_assignment contains replica_id. Each group contains the 

42 replicas for the current partition. 

43 CrossReplicaAndPartition: group_assignment contains replica_id. Each group 

44 contains the replicas for all partitions. 

45 name: A name for the operation (optional). 

46 

47 Returns: 

48 A `Tensor`. Has the same type as `input`. 

49 """ 

50 _ctx = _context._context or _context.context() 

51 tld = _ctx._thread_local_data 

52 if tld.is_eager: 

53 try: 

54 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

55 _ctx, "XlaAllReduce", name, input, group_assignment, "reduce_op", 

56 reduce_op, "mode", mode) 

57 return _result 

58 except _core._NotOkStatusException as e: 

59 _ops.raise_from_not_ok_status(e, name) 

60 except _core._FallbackException: 

61 pass 

62 try: 

63 _result = _dispatcher_for_xla_all_reduce( 

64 (input, group_assignment, reduce_op, mode, name,), None) 

65 if _result is not NotImplemented: 

66 return _result 

67 return xla_all_reduce_eager_fallback( 

68 input, group_assignment, reduce_op=reduce_op, mode=mode, name=name, 

69 ctx=_ctx) 

70 except _core._SymbolicException: 

71 pass # Add nodes to the TensorFlow graph. 

72 except (TypeError, ValueError): 

73 _result = _dispatch.dispatch( 

74 xla_all_reduce, (), dict(input=input, 

75 group_assignment=group_assignment, 

76 reduce_op=reduce_op, mode=mode, 

77 name=name) 

78 ) 

79 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

80 return _result 

81 raise 

82 else: 

83 _result = _dispatcher_for_xla_all_reduce( 

84 (input, group_assignment, reduce_op, mode, name,), None) 

85 if _result is not NotImplemented: 

86 return _result 

87 # Add nodes to the TensorFlow graph. 

88 reduce_op = _execute.make_str(reduce_op, "reduce_op") 

89 mode = _execute.make_str(mode, "mode") 

90 try: 

91 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

92 "XlaAllReduce", input=input, group_assignment=group_assignment, 

93 reduce_op=reduce_op, mode=mode, name=name) 

94 except (TypeError, ValueError): 

95 _result = _dispatch.dispatch( 

96 xla_all_reduce, (), dict(input=input, 

97 group_assignment=group_assignment, 

98 reduce_op=reduce_op, mode=mode, name=name) 

99 ) 

100 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

101 return _result 

102 raise 

103 _result = _outputs[:] 

104 if _execute.must_record_gradient(): 

105 _attrs = ("T", _op._get_attr_type("T"), "reduce_op", 

106 _op.get_attr("reduce_op"), "mode", _op.get_attr("mode")) 

107 _inputs_flat = _op.inputs 

108 _execute.record_gradient( 

109 "XlaAllReduce", _inputs_flat, _attrs, _result) 

110 _result, = _result 

111 return _result 

112 

113XlaAllReduce = tf_export("raw_ops.XlaAllReduce")(_ops.to_raw_op(xla_all_reduce)) 

114_dispatcher_for_xla_all_reduce = xla_all_reduce._tf_type_based_dispatcher.Dispatch 

115 

116 

117def xla_all_reduce_eager_fallback(input, group_assignment, reduce_op, mode, name, ctx): 

118 reduce_op = _execute.make_str(reduce_op, "reduce_op") 

119 mode = _execute.make_str(mode, "mode") 

120 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.half, _dtypes.bfloat16, _dtypes.float32, _dtypes.int32, _dtypes.uint32, ]) 

121 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) 

122 _inputs_flat = [input, group_assignment] 

123 _attrs = ("T", _attr_T, "reduce_op", reduce_op, "mode", mode) 

124 _result = _execute.execute(b"XlaAllReduce", 1, inputs=_inputs_flat, 

125 attrs=_attrs, ctx=ctx, name=name) 

126 if _execute.must_record_gradient(): 

127 _execute.record_gradient( 

128 "XlaAllReduce", _inputs_flat, _attrs, _result) 

129 _result, = _result 

130 return _result 

131 

132_XlaBroadcastHelperOutput = collections.namedtuple( 

133 "XlaBroadcastHelper", 

134 ["lhs_output", "rhs_output"]) 

135 

136 

137@_dispatch.add_fallback_dispatch_list 

138@_dispatch.add_type_based_api_dispatcher 

139@tf_export('xla_broadcast_helper') 

140def xla_broadcast_helper(lhs, rhs, broadcast_dims, name=None): 

141 r"""Helper operator for performing XLA-style broadcasts 

142 

143 Broadcasts `lhs` and `rhs` to the same rank, by adding size 1 dimensions to 

144 whichever of `lhs` and `rhs` has the lower rank, using XLA's broadcasting rules 

145 for binary operators. 

146 

147 Args: 

148 lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

149 the LHS input tensor 

150 rhs: A `Tensor`. Must have the same type as `lhs`. the RHS input tensor 

151 broadcast_dims: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

152 an XLA-style broadcast dimension specification 

153 name: A name for the operation (optional). 

154 

155 Returns: 

156 A tuple of `Tensor` objects (lhs_output, rhs_output). 

157 

158 lhs_output: A `Tensor`. Has the same type as `lhs`. the broadcasted LHS tensor 

159 rhs_output: A `Tensor`. Has the same type as `lhs`. the broadcasted RHS tensor 

160 """ 

161 _ctx = _context._context or _context.context() 

162 tld = _ctx._thread_local_data 

163 if tld.is_eager: 

164 try: 

165 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

166 _ctx, "XlaBroadcastHelper", name, lhs, rhs, broadcast_dims) 

167 _result = _XlaBroadcastHelperOutput._make(_result) 

168 return _result 

169 except _core._NotOkStatusException as e: 

170 _ops.raise_from_not_ok_status(e, name) 

171 except _core._FallbackException: 

172 pass 

173 try: 

174 _result = _dispatcher_for_xla_broadcast_helper( 

175 (lhs, rhs, broadcast_dims, name,), None) 

176 if _result is not NotImplemented: 

177 return _result 

178 return xla_broadcast_helper_eager_fallback( 

179 lhs, rhs, broadcast_dims, name=name, ctx=_ctx) 

180 except _core._SymbolicException: 

181 pass # Add nodes to the TensorFlow graph. 

182 except (TypeError, ValueError): 

183 _result = _dispatch.dispatch( 

184 xla_broadcast_helper, (), dict(lhs=lhs, rhs=rhs, 

185 broadcast_dims=broadcast_dims, 

186 name=name) 

187 ) 

188 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

189 return _result 

190 raise 

191 else: 

192 _result = _dispatcher_for_xla_broadcast_helper( 

193 (lhs, rhs, broadcast_dims, name,), None) 

194 if _result is not NotImplemented: 

195 return _result 

196 # Add nodes to the TensorFlow graph. 

197 try: 

198 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

199 "XlaBroadcastHelper", lhs=lhs, rhs=rhs, broadcast_dims=broadcast_dims, 

200 name=name) 

201 except (TypeError, ValueError): 

202 _result = _dispatch.dispatch( 

203 xla_broadcast_helper, (), dict(lhs=lhs, rhs=rhs, 

204 broadcast_dims=broadcast_dims, 

205 name=name) 

206 ) 

207 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

208 return _result 

209 raise 

210 _result = _outputs[:] 

211 if _execute.must_record_gradient(): 

212 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

213 _op._get_attr_type("Tindices")) 

214 _inputs_flat = _op.inputs 

215 _execute.record_gradient( 

216 "XlaBroadcastHelper", _inputs_flat, _attrs, _result) 

217 _result = _XlaBroadcastHelperOutput._make(_result) 

218 return _result 

219 

220XlaBroadcastHelper = tf_export("raw_ops.XlaBroadcastHelper")(_ops.to_raw_op(xla_broadcast_helper)) 

221_dispatcher_for_xla_broadcast_helper = xla_broadcast_helper._tf_type_based_dispatcher.Dispatch 

222 

223 

224def xla_broadcast_helper_eager_fallback(lhs, rhs, broadcast_dims, name, ctx): 

225 _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

226 (lhs, rhs) = _inputs_T 

227 _attr_Tindices, (broadcast_dims,) = _execute.args_to_matching_eager([broadcast_dims], ctx, [_dtypes.int32, _dtypes.int64, ]) 

228 _inputs_flat = [lhs, rhs, broadcast_dims] 

229 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) 

230 _result = _execute.execute(b"XlaBroadcastHelper", 2, inputs=_inputs_flat, 

231 attrs=_attrs, ctx=ctx, name=name) 

232 if _execute.must_record_gradient(): 

233 _execute.record_gradient( 

234 "XlaBroadcastHelper", _inputs_flat, _attrs, _result) 

235 _result = _XlaBroadcastHelperOutput._make(_result) 

236 return _result 

237 

238 

239@_dispatch.add_fallback_dispatch_list 

240@_dispatch.add_type_based_api_dispatcher 

241@tf_export('xla_call_module') 

242def xla_call_module(args, version, module, Sout, Tout, dim_args_spec=[], platforms=[], name=None): 

243 r"""Invokes a StableHLO module. 

244 

245 This op is used with JAX native serialization in a TensorFlow context with 

246 stability guarantees. 

247 

248 Args: 

249 args: A list of `Tensor` objects. 

250 A list of `Tensor` with possibly different types to be passed as arguments 

251 to the `module`. These are the actual arguments and do not include the 

252 platform argument (see `platforms`) nor the dimension arguments (see 

253 `dim_args_spec`). 

254 version: An `int`. 

255 Tracks changes the semantics of the op, to support backwards 

256 compatibility. Minimum supported version is 2. From 

257 version 2, the op carries a StableHLO text or bytecode `module`. From 

258 version 3, the op also supports the `platforms` attribute. From version 4, 

259 the op carries a StableHLO module with compatibility guarantees. 

260 module: A `string`. 

261 A serialized computation, a text or bytecode representation of 

262 an mlir.Module. The return type must be a tuple if and only if the `Sout` is 

263 a list with 0 or more than 1 elements. The length of `Tout` and 

264 `Sout` must match. This op always returns a tuple of results, even if the 

265 module returns a single result. 

266 Sout: A list of shapes (each a `tf.TensorShape` or list of `ints`). 

267 List of output tensor shapes. 

268 Tout: A list of `tf.DTypes`. List of output tensor data types. 

269 dim_args_spec: An optional list of `strings`. Defaults to `[]`. 

270 in presence of dynamic shapes, this is the specification for the 

271 dimension arguments. In absence of dynamic shapes this list is empty. The 

272 `module` takes one 0-dimensional integer tensor dimension argument for each 

273 element of `dim_spec_args`. The dimension arguments come after the platform 

274 index argument and before the actual arguments. Each specification is a 

275 string of the form "<arg_idx>.<axis_idx>" that specifies that the value of 

276 the corresponding dimension argument must be "args[arg_idx].shape[axis_idx]", 

277 where "args" are the actual array arguments. 

278 platforms: An optional list of `strings`. Defaults to `[]`. 

279 the list of platforms supported by `module`. If the list is empty, 

280 the `module` is platform independent or there should be no platform checking 

281 or preprocessing. The list can contain the strings "CPU", "CUDA", "ROCM", 

282 or "TPU". 

283 If the list is not empty then it is an error to compile this op for a 

284 platform that does not appear in the list. If the list contains more than 

285 one platform, then the `module` takes one additional 0-dimensional 

286 integer-tensor parameter in the first position, encoding the index in 

287 `platforms` of the current compilation platform. 

288 name: A name for the operation (optional). 

289 

290 Returns: 

291 A list of `Tensor` objects of type `Tout`. 

292 """ 

293 _ctx = _context._context or _context.context() 

294 tld = _ctx._thread_local_data 

295 if tld.is_eager: 

296 try: 

297 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

298 _ctx, "XlaCallModule", name, args, "version", version, "module", 

299 module, "Sout", Sout, "Tout", Tout, "dim_args_spec", dim_args_spec, 

300 "platforms", platforms) 

301 return _result 

302 except _core._NotOkStatusException as e: 

303 _ops.raise_from_not_ok_status(e, name) 

304 except _core._FallbackException: 

305 pass 

306 try: 

307 _result = _dispatcher_for_xla_call_module( 

308 (args, version, module, Sout, Tout, dim_args_spec, platforms, 

309 name,), None) 

310 if _result is not NotImplemented: 

311 return _result 

312 return xla_call_module_eager_fallback( 

313 args, version=version, module=module, Sout=Sout, Tout=Tout, 

314 dim_args_spec=dim_args_spec, platforms=platforms, name=name, 

315 ctx=_ctx) 

316 except _core._SymbolicException: 

317 pass # Add nodes to the TensorFlow graph. 

318 except (TypeError, ValueError): 

319 _result = _dispatch.dispatch( 

320 xla_call_module, (), dict(args=args, version=version, 

321 module=module, Sout=Sout, Tout=Tout, 

322 dim_args_spec=dim_args_spec, 

323 platforms=platforms, name=name) 

324 ) 

325 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

326 return _result 

327 raise 

328 else: 

329 _result = _dispatcher_for_xla_call_module( 

330 (args, version, module, Sout, Tout, dim_args_spec, platforms, name,), 

331 None) 

332 if _result is not NotImplemented: 

333 return _result 

334 # Add nodes to the TensorFlow graph. 

335 version = _execute.make_int(version, "version") 

336 module = _execute.make_str(module, "module") 

337 if not isinstance(Sout, (list, tuple)): 

338 raise TypeError( 

339 "Expected list for 'Sout' argument to " 

340 "'xla_call_module' Op, not %r." % Sout) 

341 Sout = [_execute.make_shape(_s, "Sout") for _s in Sout] 

342 if not isinstance(Tout, (list, tuple)): 

343 raise TypeError( 

344 "Expected list for 'Tout' argument to " 

345 "'xla_call_module' Op, not %r." % Tout) 

346 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

347 if dim_args_spec is None: 

348 dim_args_spec = [] 

349 if not isinstance(dim_args_spec, (list, tuple)): 

350 raise TypeError( 

351 "Expected list for 'dim_args_spec' argument to " 

352 "'xla_call_module' Op, not %r." % dim_args_spec) 

353 dim_args_spec = [_execute.make_str(_s, "dim_args_spec") for _s in dim_args_spec] 

354 if platforms is None: 

355 platforms = [] 

356 if not isinstance(platforms, (list, tuple)): 

357 raise TypeError( 

358 "Expected list for 'platforms' argument to " 

359 "'xla_call_module' Op, not %r." % platforms) 

360 platforms = [_execute.make_str(_s, "platforms") for _s in platforms] 

361 try: 

362 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

363 "XlaCallModule", args=args, version=version, module=module, Sout=Sout, 

364 Tout=Tout, dim_args_spec=dim_args_spec, 

365 platforms=platforms, name=name) 

366 except (TypeError, ValueError): 

367 _result = _dispatch.dispatch( 

368 xla_call_module, (), dict(args=args, version=version, module=module, 

369 Sout=Sout, Tout=Tout, 

370 dim_args_spec=dim_args_spec, 

371 platforms=platforms, name=name) 

372 ) 

373 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

374 return _result 

375 raise 

376 _result = _outputs[:] 

377 if _execute.must_record_gradient(): 

378 _attrs = ("version", _op._get_attr_int("version"), "module", 

379 _op.get_attr("module"), "Sout", _op.get_attr("Sout"), "Tout", 

380 _op.get_attr("Tout"), "Tin", _op.get_attr("Tin"), 

381 "dim_args_spec", _op.get_attr("dim_args_spec"), "platforms", 

382 _op.get_attr("platforms")) 

383 _inputs_flat = _op.inputs 

384 _execute.record_gradient( 

385 "XlaCallModule", _inputs_flat, _attrs, _result) 

386 return _result 

387 

388XlaCallModule = tf_export("raw_ops.XlaCallModule")(_ops.to_raw_op(xla_call_module)) 

389_dispatcher_for_xla_call_module = xla_call_module._tf_type_based_dispatcher.Dispatch 

390 

391 

392def xla_call_module_eager_fallback(args, version, module, Sout, Tout, dim_args_spec, platforms, name, ctx): 

393 version = _execute.make_int(version, "version") 

394 module = _execute.make_str(module, "module") 

395 if not isinstance(Sout, (list, tuple)): 

396 raise TypeError( 

397 "Expected list for 'Sout' argument to " 

398 "'xla_call_module' Op, not %r." % Sout) 

399 Sout = [_execute.make_shape(_s, "Sout") for _s in Sout] 

400 if not isinstance(Tout, (list, tuple)): 

401 raise TypeError( 

402 "Expected list for 'Tout' argument to " 

403 "'xla_call_module' Op, not %r." % Tout) 

404 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

405 if dim_args_spec is None: 

406 dim_args_spec = [] 

407 if not isinstance(dim_args_spec, (list, tuple)): 

408 raise TypeError( 

409 "Expected list for 'dim_args_spec' argument to " 

410 "'xla_call_module' Op, not %r." % dim_args_spec) 

411 dim_args_spec = [_execute.make_str(_s, "dim_args_spec") for _s in dim_args_spec] 

412 if platforms is None: 

413 platforms = [] 

414 if not isinstance(platforms, (list, tuple)): 

415 raise TypeError( 

416 "Expected list for 'platforms' argument to " 

417 "'xla_call_module' Op, not %r." % platforms) 

418 platforms = [_execute.make_str(_s, "platforms") for _s in platforms] 

419 _attr_Tin, args = _execute.convert_to_mixed_eager_tensors(args, ctx) 

420 _inputs_flat = list(args) 

421 _attrs = ("version", version, "module", module, "Sout", Sout, "Tout", Tout, 

422 "Tin", _attr_Tin, "dim_args_spec", dim_args_spec, "platforms", platforms) 

423 _result = _execute.execute(b"XlaCallModule", len(Tout), inputs=_inputs_flat, 

424 attrs=_attrs, ctx=ctx, name=name) 

425 if _execute.must_record_gradient(): 

426 _execute.record_gradient( 

427 "XlaCallModule", _inputs_flat, _attrs, _result) 

428 return _result 

429 

430 

431@_dispatch.add_fallback_dispatch_list 

432@_dispatch.add_type_based_api_dispatcher 

433@tf_export('xla_conv') 

434def xla_conv(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count, dimension_numbers, precision_config, name=None): 

435 r"""Wraps the XLA ConvGeneralDilated operator, documented at 

436 

437 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 

438 . 

439 

440 Args: 

441 lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

442 the input tensor 

443 rhs: A `Tensor`. Must have the same type as `lhs`. the kernel tensor 

444 window_strides: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

445 the inter-window strides 

446 padding: A `Tensor`. Must have the same type as `window_strides`. 

447 the padding to apply at the start and end of each input dimensions 

448 lhs_dilation: A `Tensor`. Must have the same type as `window_strides`. 

449 dilation to apply between input elements 

450 rhs_dilation: A `Tensor`. Must have the same type as `window_strides`. 

451 dilation to apply between kernel elements 

452 feature_group_count: A `Tensor`. Must have the same type as `window_strides`. 

453 number of feature groups for grouped convolution. 

454 dimension_numbers: A `string`. 

455 a serialized xla::ConvolutionDimensionNumbers proto. 

456 precision_config: A `string`. a serialized xla::PrecisionConfig proto. 

457 name: A name for the operation (optional). 

458 

459 Returns: 

460 A `Tensor`. Has the same type as `lhs`. 

461 """ 

462 _ctx = _context._context or _context.context() 

463 tld = _ctx._thread_local_data 

464 if tld.is_eager: 

465 try: 

466 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

467 _ctx, "XlaConv", name, lhs, rhs, window_strides, padding, 

468 lhs_dilation, rhs_dilation, feature_group_count, "dimension_numbers", 

469 dimension_numbers, "precision_config", precision_config) 

470 return _result 

471 except _core._NotOkStatusException as e: 

472 _ops.raise_from_not_ok_status(e, name) 

473 except _core._FallbackException: 

474 pass 

475 try: 

476 _result = _dispatcher_for_xla_conv( 

477 (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

478 feature_group_count, dimension_numbers, precision_config, name,), 

479 None) 

480 if _result is not NotImplemented: 

481 return _result 

482 return xla_conv_eager_fallback( 

483 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

484 feature_group_count, dimension_numbers=dimension_numbers, 

485 precision_config=precision_config, name=name, ctx=_ctx) 

486 except _core._SymbolicException: 

487 pass # Add nodes to the TensorFlow graph. 

488 except (TypeError, ValueError): 

489 _result = _dispatch.dispatch( 

490 xla_conv, (), dict(lhs=lhs, rhs=rhs, 

491 window_strides=window_strides, padding=padding, 

492 lhs_dilation=lhs_dilation, 

493 rhs_dilation=rhs_dilation, 

494 feature_group_count=feature_group_count, 

495 dimension_numbers=dimension_numbers, 

496 precision_config=precision_config, name=name) 

497 ) 

498 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

499 return _result 

500 raise 

501 else: 

502 _result = _dispatcher_for_xla_conv( 

503 (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

504 feature_group_count, dimension_numbers, precision_config, name,), None) 

505 if _result is not NotImplemented: 

506 return _result 

507 # Add nodes to the TensorFlow graph. 

508 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

509 precision_config = _execute.make_str(precision_config, "precision_config") 

510 try: 

511 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

512 "XlaConv", lhs=lhs, rhs=rhs, window_strides=window_strides, 

513 padding=padding, lhs_dilation=lhs_dilation, 

514 rhs_dilation=rhs_dilation, 

515 feature_group_count=feature_group_count, 

516 dimension_numbers=dimension_numbers, 

517 precision_config=precision_config, name=name) 

518 except (TypeError, ValueError): 

519 _result = _dispatch.dispatch( 

520 xla_conv, (), dict(lhs=lhs, rhs=rhs, window_strides=window_strides, 

521 padding=padding, lhs_dilation=lhs_dilation, 

522 rhs_dilation=rhs_dilation, 

523 feature_group_count=feature_group_count, 

524 dimension_numbers=dimension_numbers, 

525 precision_config=precision_config, name=name) 

526 ) 

527 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

528 return _result 

529 raise 

530 _result = _outputs[:] 

531 if _execute.must_record_gradient(): 

532 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

533 _op._get_attr_type("Tindices"), "dimension_numbers", 

534 _op.get_attr("dimension_numbers"), "precision_config", 

535 _op.get_attr("precision_config")) 

536 _inputs_flat = _op.inputs 

537 _execute.record_gradient( 

538 "XlaConv", _inputs_flat, _attrs, _result) 

539 _result, = _result 

540 return _result 

541 

542XlaConv = tf_export("raw_ops.XlaConv")(_ops.to_raw_op(xla_conv)) 

543_dispatcher_for_xla_conv = xla_conv._tf_type_based_dispatcher.Dispatch 

544 

545 

546def xla_conv_eager_fallback(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count, dimension_numbers, precision_config, name, ctx): 

547 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

548 precision_config = _execute.make_str(precision_config, "precision_config") 

549 _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

550 (lhs, rhs) = _inputs_T 

551 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count], ctx, [_dtypes.int32, _dtypes.int64, ]) 

552 (window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count) = _inputs_Tindices 

553 _inputs_flat = [lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count] 

554 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "dimension_numbers", 

555 dimension_numbers, "precision_config", precision_config) 

556 _result = _execute.execute(b"XlaConv", 1, inputs=_inputs_flat, attrs=_attrs, 

557 ctx=ctx, name=name) 

558 if _execute.must_record_gradient(): 

559 _execute.record_gradient( 

560 "XlaConv", _inputs_flat, _attrs, _result) 

561 _result, = _result 

562 return _result 

563 

564 

565@_dispatch.add_fallback_dispatch_list 

566@_dispatch.add_type_based_api_dispatcher 

567@tf_export('xla_conv_v2') 

568def xla_conv_v2(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count, dimension_numbers, precision_config, preferred_element_type, batch_group_count=1, name=None): 

569 r"""Wraps the XLA ConvGeneralDilated operator, documented at 

570 

571 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 

572 . 

573 

574 Args: 

575 lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

576 input tensor 

577 rhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

578 kernel tensor 

579 window_strides: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

580 inter-window strides 

581 padding: A `Tensor`. Must have the same type as `window_strides`. 

582 padding to apply at the start and end of each input dimensions 

583 lhs_dilation: A `Tensor`. Must have the same type as `window_strides`. 

584 dilation to apply between input elements 

585 rhs_dilation: A `Tensor`. Must have the same type as `window_strides`. 

586 dilation to apply between kernel elements 

587 feature_group_count: A `Tensor`. Must have the same type as `window_strides`. 

588 number of feature groups for grouped convolution. 

589 dimension_numbers: A `string`. 

590 serialized xla::ConvolutionDimensionNumbers proto. 

591 precision_config: A `string`. serialized xla::PrecisionConfig proto. 

592 preferred_element_type: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64, tf.qint8, tf.quint8, tf.qint32, tf.bfloat16, tf.qint16, tf.quint16, tf.uint16, tf.complex128, tf.half, tf.uint32, tf.uint64`. 

593 type of the tensor. 

594 batch_group_count: An optional `int`. Defaults to `1`. 

595 number of batch groups or grouped filters. 

596 name: A name for the operation (optional). 

597 

598 Returns: 

599 A `Tensor` of type `preferred_element_type`. 

600 """ 

601 _ctx = _context._context or _context.context() 

602 tld = _ctx._thread_local_data 

603 if tld.is_eager: 

604 try: 

605 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

606 _ctx, "XlaConvV2", name, lhs, rhs, window_strides, padding, 

607 lhs_dilation, rhs_dilation, feature_group_count, "dimension_numbers", 

608 dimension_numbers, "precision_config", precision_config, 

609 "preferred_element_type", preferred_element_type, "batch_group_count", 

610 batch_group_count) 

611 return _result 

612 except _core._NotOkStatusException as e: 

613 _ops.raise_from_not_ok_status(e, name) 

614 except _core._FallbackException: 

615 pass 

616 try: 

617 _result = _dispatcher_for_xla_conv_v2( 

618 (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

619 feature_group_count, dimension_numbers, precision_config, 

620 preferred_element_type, batch_group_count, name,), None) 

621 if _result is not NotImplemented: 

622 return _result 

623 return xla_conv_v2_eager_fallback( 

624 lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

625 feature_group_count, dimension_numbers=dimension_numbers, 

626 precision_config=precision_config, 

627 preferred_element_type=preferred_element_type, 

628 batch_group_count=batch_group_count, name=name, ctx=_ctx) 

629 except _core._SymbolicException: 

630 pass # Add nodes to the TensorFlow graph. 

631 except (TypeError, ValueError): 

632 _result = _dispatch.dispatch( 

633 xla_conv_v2, (), dict(lhs=lhs, rhs=rhs, 

634 window_strides=window_strides, 

635 padding=padding, lhs_dilation=lhs_dilation, 

636 rhs_dilation=rhs_dilation, 

637 feature_group_count=feature_group_count, 

638 dimension_numbers=dimension_numbers, 

639 precision_config=precision_config, 

640 preferred_element_type=preferred_element_type, 

641 batch_group_count=batch_group_count, 

642 name=name) 

643 ) 

644 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

645 return _result 

646 raise 

647 else: 

648 _result = _dispatcher_for_xla_conv_v2( 

649 (lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, 

650 feature_group_count, dimension_numbers, precision_config, 

651 preferred_element_type, batch_group_count, name,), None) 

652 if _result is not NotImplemented: 

653 return _result 

654 # Add nodes to the TensorFlow graph. 

655 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

656 precision_config = _execute.make_str(precision_config, "precision_config") 

657 preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") 

658 if batch_group_count is None: 

659 batch_group_count = 1 

660 batch_group_count = _execute.make_int(batch_group_count, "batch_group_count") 

661 try: 

662 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

663 "XlaConvV2", lhs=lhs, rhs=rhs, window_strides=window_strides, 

664 padding=padding, lhs_dilation=lhs_dilation, 

665 rhs_dilation=rhs_dilation, 

666 feature_group_count=feature_group_count, 

667 dimension_numbers=dimension_numbers, 

668 precision_config=precision_config, 

669 preferred_element_type=preferred_element_type, 

670 batch_group_count=batch_group_count, name=name) 

671 except (TypeError, ValueError): 

672 _result = _dispatch.dispatch( 

673 xla_conv_v2, (), dict(lhs=lhs, rhs=rhs, 

674 window_strides=window_strides, 

675 padding=padding, lhs_dilation=lhs_dilation, 

676 rhs_dilation=rhs_dilation, 

677 feature_group_count=feature_group_count, 

678 dimension_numbers=dimension_numbers, 

679 precision_config=precision_config, 

680 preferred_element_type=preferred_element_type, 

681 batch_group_count=batch_group_count, 

682 name=name) 

683 ) 

684 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

685 return _result 

686 raise 

687 _result = _outputs[:] 

688 if _execute.must_record_gradient(): 

689 _attrs = ("LhsT", _op._get_attr_type("LhsT"), "RhsT", 

690 _op._get_attr_type("RhsT"), "Tindices", 

691 _op._get_attr_type("Tindices"), "dimension_numbers", 

692 _op.get_attr("dimension_numbers"), "precision_config", 

693 _op.get_attr("precision_config"), "preferred_element_type", 

694 _op._get_attr_type("preferred_element_type"), 

695 "batch_group_count", _op._get_attr_int("batch_group_count")) 

696 _inputs_flat = _op.inputs 

697 _execute.record_gradient( 

698 "XlaConvV2", _inputs_flat, _attrs, _result) 

699 _result, = _result 

700 return _result 

701 

702XlaConvV2 = tf_export("raw_ops.XlaConvV2")(_ops.to_raw_op(xla_conv_v2)) 

703_dispatcher_for_xla_conv_v2 = xla_conv_v2._tf_type_based_dispatcher.Dispatch 

704 

705 

706def xla_conv_v2_eager_fallback(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count, dimension_numbers, precision_config, preferred_element_type, batch_group_count, name, ctx): 

707 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

708 precision_config = _execute.make_str(precision_config, "precision_config") 

709 preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") 

710 if batch_group_count is None: 

711 batch_group_count = 1 

712 batch_group_count = _execute.make_int(batch_group_count, "batch_group_count") 

713 _attr_LhsT, (lhs,) = _execute.args_to_matching_eager([lhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

714 _attr_RhsT, (rhs,) = _execute.args_to_matching_eager([rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

715 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count], ctx, [_dtypes.int32, _dtypes.int64, ]) 

716 (window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count) = _inputs_Tindices 

717 _inputs_flat = [lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation, feature_group_count] 

718 _attrs = ("LhsT", _attr_LhsT, "RhsT", _attr_RhsT, "Tindices", 

719 _attr_Tindices, "dimension_numbers", dimension_numbers, "precision_config", 

720 precision_config, "preferred_element_type", preferred_element_type, 

721 "batch_group_count", batch_group_count) 

722 _result = _execute.execute(b"XlaConvV2", 1, inputs=_inputs_flat, 

723 attrs=_attrs, ctx=ctx, name=name) 

724 if _execute.must_record_gradient(): 

725 _execute.record_gradient( 

726 "XlaConvV2", _inputs_flat, _attrs, _result) 

727 _result, = _result 

728 return _result 

729 

730 

731@_dispatch.add_fallback_dispatch_list 

732@_dispatch.add_type_based_api_dispatcher 

733@tf_export('xla_custom_call') 

734def xla_custom_call(args, target_name, backend_config, dtype, shape, name=None): 

735 r"""Wraps the XLA CustomCall operator 

736 

737 documented at https://www.tensorflow.org/xla/operation_semantics#customcall. 

738 

739 Args: 

740 args: A list of `Tensor` objects. 

741 A list of `Tensor` with possibly different types. 

742 target_name: A `string`. 

743 Name of the function. A call instruction will be emitted which 

744 targets this symbol name. 

745 backend_config: A `string`. 

746 String, used to encode serialized metadata to the backend. 

747 dtype: A `tf.DType`. Output tensor data type. 

748 shape: A `tf.TensorShape` or list of `ints`. Output tensor shape. 

749 name: A name for the operation (optional). 

750 

751 Returns: 

752 A `Tensor` of type `dtype`. 

753 """ 

754 _ctx = _context._context or _context.context() 

755 tld = _ctx._thread_local_data 

756 if tld.is_eager: 

757 try: 

758 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

759 _ctx, "XlaCustomCall", name, args, "target_name", target_name, 

760 "backend_config", backend_config, "dtype", dtype, "shape", shape) 

761 return _result 

762 except _core._NotOkStatusException as e: 

763 _ops.raise_from_not_ok_status(e, name) 

764 except _core._FallbackException: 

765 pass 

766 try: 

767 _result = _dispatcher_for_xla_custom_call( 

768 (args, target_name, backend_config, dtype, shape, name,), None) 

769 if _result is not NotImplemented: 

770 return _result 

771 return xla_custom_call_eager_fallback( 

772 args, target_name=target_name, backend_config=backend_config, 

773 dtype=dtype, shape=shape, name=name, ctx=_ctx) 

774 except _core._SymbolicException: 

775 pass # Add nodes to the TensorFlow graph. 

776 except (TypeError, ValueError): 

777 _result = _dispatch.dispatch( 

778 xla_custom_call, (), dict(args=args, target_name=target_name, 

779 backend_config=backend_config, 

780 dtype=dtype, shape=shape, name=name) 

781 ) 

782 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

783 return _result 

784 raise 

785 else: 

786 _result = _dispatcher_for_xla_custom_call( 

787 (args, target_name, backend_config, dtype, shape, name,), None) 

788 if _result is not NotImplemented: 

789 return _result 

790 # Add nodes to the TensorFlow graph. 

791 target_name = _execute.make_str(target_name, "target_name") 

792 backend_config = _execute.make_str(backend_config, "backend_config") 

793 dtype = _execute.make_type(dtype, "dtype") 

794 shape = _execute.make_shape(shape, "shape") 

795 try: 

796 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

797 "XlaCustomCall", args=args, target_name=target_name, 

798 backend_config=backend_config, dtype=dtype, 

799 shape=shape, name=name) 

800 except (TypeError, ValueError): 

801 _result = _dispatch.dispatch( 

802 xla_custom_call, (), dict(args=args, target_name=target_name, 

803 backend_config=backend_config, 

804 dtype=dtype, shape=shape, name=name) 

805 ) 

806 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

807 return _result 

808 raise 

809 _result = _outputs[:] 

810 if _execute.must_record_gradient(): 

811 _attrs = ("target_name", _op.get_attr("target_name"), "backend_config", 

812 _op.get_attr("backend_config"), "T", _op.get_attr("T"), "dtype", 

813 _op._get_attr_type("dtype"), "shape", _op.get_attr("shape")) 

814 _inputs_flat = _op.inputs 

815 _execute.record_gradient( 

816 "XlaCustomCall", _inputs_flat, _attrs, _result) 

817 _result, = _result 

818 return _result 

819 

820XlaCustomCall = tf_export("raw_ops.XlaCustomCall")(_ops.to_raw_op(xla_custom_call)) 

821_dispatcher_for_xla_custom_call = xla_custom_call._tf_type_based_dispatcher.Dispatch 

822 

823 

824def xla_custom_call_eager_fallback(args, target_name, backend_config, dtype, shape, name, ctx): 

825 target_name = _execute.make_str(target_name, "target_name") 

826 backend_config = _execute.make_str(backend_config, "backend_config") 

827 dtype = _execute.make_type(dtype, "dtype") 

828 shape = _execute.make_shape(shape, "shape") 

829 _attr_T, args = _execute.convert_to_mixed_eager_tensors(args, ctx) 

830 _inputs_flat = list(args) 

831 _attrs = ("target_name", target_name, "backend_config", backend_config, "T", 

832 _attr_T, "dtype", dtype, "shape", shape) 

833 _result = _execute.execute(b"XlaCustomCall", 1, inputs=_inputs_flat, 

834 attrs=_attrs, ctx=ctx, name=name) 

835 if _execute.must_record_gradient(): 

836 _execute.record_gradient( 

837 "XlaCustomCall", _inputs_flat, _attrs, _result) 

838 _result, = _result 

839 return _result 

840 

841 

842@_dispatch.add_fallback_dispatch_list 

843@_dispatch.add_type_based_api_dispatcher 

844@tf_export('xla_custom_call_v2') 

845def xla_custom_call_v2(operands, call_target_name, backend_config, has_side_effect, result_dtypes, result_shapes, name=None): 

846 r"""Emits an HLO `CustomCall` operation with multiple outputs. 

847 

848 As opposed to `XlaCustomCall`, this operation supports multiple outputs. 

849 

850 See `CustomCall` specification at 

851 https://tensorflow.org/xla/operation_semantics#customcall, 

852 and `mhlo.custom_call` specification at 

853 https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop. 

854 

855 Args: 

856 operands: A list of `Tensor` objects. 

857 A sequence of tensors with possibly different types. 

858 call_target_name: A `string`. 

859 Name of the user function. The function signature must conform 

860 to version 3 of the API, see `API_VERSION_STATUS_RETURNING_UNIFIED`. All 

861 operands and results assumed to be in the default layout. 

862 backend_config: A `string`. 

863 A string that encodes a metadata for the backend. 

864 has_side_effect: A `bool`. 

865 Indicates whether the custom call has side effects. 

866 result_dtypes: A list of `tf.DTypes`. Types of all results. 

867 result_shapes: A list of shapes (each a `tf.TensorShape` or list of `ints`). 

868 Shapes of all results. 

869 name: A name for the operation (optional). 

870 

871 Returns: 

872 A list of `Tensor` objects of type `result_dtypes`. 

873 """ 

874 _ctx = _context._context or _context.context() 

875 tld = _ctx._thread_local_data 

876 if tld.is_eager: 

877 try: 

878 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

879 _ctx, "XlaCustomCallV2", name, operands, "call_target_name", 

880 call_target_name, "backend_config", backend_config, "has_side_effect", 

881 has_side_effect, "result_dtypes", result_dtypes, "result_shapes", 

882 result_shapes) 

883 return _result 

884 except _core._NotOkStatusException as e: 

885 _ops.raise_from_not_ok_status(e, name) 

886 except _core._FallbackException: 

887 pass 

888 try: 

889 _result = _dispatcher_for_xla_custom_call_v2( 

890 (operands, call_target_name, backend_config, has_side_effect, 

891 result_dtypes, result_shapes, name,), None) 

892 if _result is not NotImplemented: 

893 return _result 

894 return xla_custom_call_v2_eager_fallback( 

895 operands, call_target_name=call_target_name, 

896 backend_config=backend_config, has_side_effect=has_side_effect, 

897 result_dtypes=result_dtypes, result_shapes=result_shapes, name=name, 

898 ctx=_ctx) 

899 except _core._SymbolicException: 

900 pass # Add nodes to the TensorFlow graph. 

901 except (TypeError, ValueError): 

902 _result = _dispatch.dispatch( 

903 xla_custom_call_v2, (), dict(operands=operands, 

904 call_target_name=call_target_name, 

905 backend_config=backend_config, 

906 has_side_effect=has_side_effect, 

907 result_dtypes=result_dtypes, 

908 result_shapes=result_shapes, 

909 name=name) 

910 ) 

911 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

912 return _result 

913 raise 

914 else: 

915 _result = _dispatcher_for_xla_custom_call_v2( 

916 (operands, call_target_name, backend_config, has_side_effect, 

917 result_dtypes, result_shapes, name,), None) 

918 if _result is not NotImplemented: 

919 return _result 

920 # Add nodes to the TensorFlow graph. 

921 call_target_name = _execute.make_str(call_target_name, "call_target_name") 

922 backend_config = _execute.make_str(backend_config, "backend_config") 

923 has_side_effect = _execute.make_bool(has_side_effect, "has_side_effect") 

924 if not isinstance(result_dtypes, (list, tuple)): 

925 raise TypeError( 

926 "Expected list for 'result_dtypes' argument to " 

927 "'xla_custom_call_v2' Op, not %r." % result_dtypes) 

928 result_dtypes = [_execute.make_type(_t, "result_dtypes") for _t in result_dtypes] 

929 if not isinstance(result_shapes, (list, tuple)): 

930 raise TypeError( 

931 "Expected list for 'result_shapes' argument to " 

932 "'xla_custom_call_v2' Op, not %r." % result_shapes) 

933 result_shapes = [_execute.make_shape(_s, "result_shapes") for _s in result_shapes] 

934 try: 

935 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

936 "XlaCustomCallV2", operands=operands, 

937 call_target_name=call_target_name, 

938 backend_config=backend_config, 

939 has_side_effect=has_side_effect, 

940 result_dtypes=result_dtypes, 

941 result_shapes=result_shapes, name=name) 

942 except (TypeError, ValueError): 

943 _result = _dispatch.dispatch( 

944 xla_custom_call_v2, (), dict(operands=operands, 

945 call_target_name=call_target_name, 

946 backend_config=backend_config, 

947 has_side_effect=has_side_effect, 

948 result_dtypes=result_dtypes, 

949 result_shapes=result_shapes, name=name) 

950 ) 

951 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

952 return _result 

953 raise 

954 _result = _outputs[:] 

955 if _execute.must_record_gradient(): 

956 _attrs = ("call_target_name", _op.get_attr("call_target_name"), 

957 "backend_config", _op.get_attr("backend_config"), 

958 "has_side_effect", _op._get_attr_bool("has_side_effect"), 

959 "operand_dtypes", _op.get_attr("operand_dtypes"), 

960 "result_dtypes", _op.get_attr("result_dtypes"), "result_shapes", 

961 _op.get_attr("result_shapes")) 

962 _inputs_flat = _op.inputs 

963 _execute.record_gradient( 

964 "XlaCustomCallV2", _inputs_flat, _attrs, _result) 

965 return _result 

966 

967XlaCustomCallV2 = tf_export("raw_ops.XlaCustomCallV2")(_ops.to_raw_op(xla_custom_call_v2)) 

968_dispatcher_for_xla_custom_call_v2 = xla_custom_call_v2._tf_type_based_dispatcher.Dispatch 

969 

970 

971def xla_custom_call_v2_eager_fallback(operands, call_target_name, backend_config, has_side_effect, result_dtypes, result_shapes, name, ctx): 

972 call_target_name = _execute.make_str(call_target_name, "call_target_name") 

973 backend_config = _execute.make_str(backend_config, "backend_config") 

974 has_side_effect = _execute.make_bool(has_side_effect, "has_side_effect") 

975 if not isinstance(result_dtypes, (list, tuple)): 

976 raise TypeError( 

977 "Expected list for 'result_dtypes' argument to " 

978 "'xla_custom_call_v2' Op, not %r." % result_dtypes) 

979 result_dtypes = [_execute.make_type(_t, "result_dtypes") for _t in result_dtypes] 

980 if not isinstance(result_shapes, (list, tuple)): 

981 raise TypeError( 

982 "Expected list for 'result_shapes' argument to " 

983 "'xla_custom_call_v2' Op, not %r." % result_shapes) 

984 result_shapes = [_execute.make_shape(_s, "result_shapes") for _s in result_shapes] 

985 _attr_operand_dtypes, operands = _execute.convert_to_mixed_eager_tensors(operands, ctx) 

986 _inputs_flat = list(operands) 

987 _attrs = ("call_target_name", call_target_name, "backend_config", 

988 backend_config, "has_side_effect", has_side_effect, "operand_dtypes", 

989 _attr_operand_dtypes, "result_dtypes", result_dtypes, "result_shapes", 

990 result_shapes) 

991 _result = _execute.execute(b"XlaCustomCallV2", len(result_dtypes), 

992 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

993 name=name) 

994 if _execute.must_record_gradient(): 

995 _execute.record_gradient( 

996 "XlaCustomCallV2", _inputs_flat, _attrs, _result) 

997 return _result 

998 

999 

1000@_dispatch.add_fallback_dispatch_list 

1001@_dispatch.add_type_based_api_dispatcher 

1002@tf_export('xla_dequantize') 

1003def xla_dequantize(input, min_range, max_range, mode, transpose_output, name=None): 

1004 r"""Takes the packed uint32 input and unpacks the input to uint8 to do 

1005 

1006 Dequantization on device. 

1007 

1008 Args: 

1009 input: A `Tensor` of type `uint32`. 

1010 Input tensors whose types is uint32, shape is [d0, ..., dn]. 

1011 min_range: A `float`. 

1012 The minimum scalar value possibly produced for the input. 

1013 max_range: A `float`. 

1014 The maximum scalar value possibly produced for the input. 

1015 mode: A `string`. 

1016 String to determine the dequantize mode in {"MIN_COMBINED", "MIN_FIRST", "SCALED"}. 

1017 transpose_output: A `bool`. 

1018 Boolean to determine if output is transposed. transpose_output 

1019 is faster when input is large and rank of input is higher than 1. 

1020 name: A name for the operation (optional). 

1021 

1022 Returns: 

1023 A `Tensor` of type `bfloat16`. 

1024 Output tensors whose types is bfloat16. If transpose_output is true, 

1025 output shape is [dn * 4, dn-1, ..., d1, d0]. If transpose_output 

1026 is false, output shape is [d0,..., dn * 4]. 

1027 """ 

1028 _ctx = _context._context or _context.context() 

1029 tld = _ctx._thread_local_data 

1030 if tld.is_eager: 

1031 try: 

1032 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1033 _ctx, "XlaDequantize", name, input, "min_range", min_range, 

1034 "max_range", max_range, "mode", mode, "transpose_output", 

1035 transpose_output) 

1036 return _result 

1037 except _core._NotOkStatusException as e: 

1038 _ops.raise_from_not_ok_status(e, name) 

1039 except _core._FallbackException: 

1040 pass 

1041 try: 

1042 _result = _dispatcher_for_xla_dequantize( 

1043 (input, min_range, max_range, mode, transpose_output, name,), None) 

1044 if _result is not NotImplemented: 

1045 return _result 

1046 return xla_dequantize_eager_fallback( 

1047 input, min_range=min_range, max_range=max_range, mode=mode, 

1048 transpose_output=transpose_output, name=name, ctx=_ctx) 

1049 except _core._SymbolicException: 

1050 pass # Add nodes to the TensorFlow graph. 

1051 except (TypeError, ValueError): 

1052 _result = _dispatch.dispatch( 

1053 xla_dequantize, (), dict(input=input, min_range=min_range, 

1054 max_range=max_range, mode=mode, 

1055 transpose_output=transpose_output, 

1056 name=name) 

1057 ) 

1058 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1059 return _result 

1060 raise 

1061 else: 

1062 _result = _dispatcher_for_xla_dequantize( 

1063 (input, min_range, max_range, mode, transpose_output, name,), None) 

1064 if _result is not NotImplemented: 

1065 return _result 

1066 # Add nodes to the TensorFlow graph. 

1067 min_range = _execute.make_float(min_range, "min_range") 

1068 max_range = _execute.make_float(max_range, "max_range") 

1069 mode = _execute.make_str(mode, "mode") 

1070 transpose_output = _execute.make_bool(transpose_output, "transpose_output") 

1071 try: 

1072 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1073 "XlaDequantize", input=input, min_range=min_range, 

1074 max_range=max_range, mode=mode, 

1075 transpose_output=transpose_output, name=name) 

1076 except (TypeError, ValueError): 

1077 _result = _dispatch.dispatch( 

1078 xla_dequantize, (), dict(input=input, min_range=min_range, 

1079 max_range=max_range, mode=mode, 

1080 transpose_output=transpose_output, 

1081 name=name) 

1082 ) 

1083 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1084 return _result 

1085 raise 

1086 _result = _outputs[:] 

1087 if _execute.must_record_gradient(): 

1088 _attrs = ("min_range", _op.get_attr("min_range"), "max_range", 

1089 _op.get_attr("max_range"), "mode", _op.get_attr("mode"), 

1090 "transpose_output", _op._get_attr_bool("transpose_output")) 

1091 _inputs_flat = _op.inputs 

1092 _execute.record_gradient( 

1093 "XlaDequantize", _inputs_flat, _attrs, _result) 

1094 _result, = _result 

1095 return _result 

1096 

1097XlaDequantize = tf_export("raw_ops.XlaDequantize")(_ops.to_raw_op(xla_dequantize)) 

1098_dispatcher_for_xla_dequantize = xla_dequantize._tf_type_based_dispatcher.Dispatch 

1099 

1100 

1101def xla_dequantize_eager_fallback(input, min_range, max_range, mode, transpose_output, name, ctx): 

1102 min_range = _execute.make_float(min_range, "min_range") 

1103 max_range = _execute.make_float(max_range, "max_range") 

1104 mode = _execute.make_str(mode, "mode") 

1105 transpose_output = _execute.make_bool(transpose_output, "transpose_output") 

1106 input = _ops.convert_to_tensor(input, _dtypes.uint32) 

1107 _inputs_flat = [input] 

1108 _attrs = ("min_range", min_range, "max_range", max_range, "mode", mode, 

1109 "transpose_output", transpose_output) 

1110 _result = _execute.execute(b"XlaDequantize", 1, inputs=_inputs_flat, 

1111 attrs=_attrs, ctx=ctx, name=name) 

1112 if _execute.must_record_gradient(): 

1113 _execute.record_gradient( 

1114 "XlaDequantize", _inputs_flat, _attrs, _result) 

1115 _result, = _result 

1116 return _result 

1117 

1118 

1119@_dispatch.add_fallback_dispatch_list 

1120@_dispatch.add_type_based_api_dispatcher 

1121@tf_export('xla_dot') 

1122def xla_dot(lhs, rhs, dimension_numbers, precision_config, name=None): 

1123 r"""Wraps the XLA DotGeneral operator, documented at 

1124 

1125 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral 

1126 . 

1127 

1128 Args: 

1129 lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

1130 the LHS tensor 

1131 rhs: A `Tensor`. Must have the same type as `lhs`. the RHS tensor 

1132 dimension_numbers: A `string`. 

1133 a serialized xla::DotDimensionNumbers proto. 

1134 precision_config: A `string`. a serialized xla::PrecisionConfig proto. 

1135 name: A name for the operation (optional). 

1136 

1137 Returns: 

1138 A `Tensor`. Has the same type as `lhs`. 

1139 """ 

1140 _ctx = _context._context or _context.context() 

1141 tld = _ctx._thread_local_data 

1142 if tld.is_eager: 

1143 try: 

1144 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1145 _ctx, "XlaDot", name, lhs, rhs, "dimension_numbers", 

1146 dimension_numbers, "precision_config", precision_config) 

1147 return _result 

1148 except _core._NotOkStatusException as e: 

1149 _ops.raise_from_not_ok_status(e, name) 

1150 except _core._FallbackException: 

1151 pass 

1152 try: 

1153 _result = _dispatcher_for_xla_dot( 

1154 (lhs, rhs, dimension_numbers, precision_config, name,), None) 

1155 if _result is not NotImplemented: 

1156 return _result 

1157 return xla_dot_eager_fallback( 

1158 lhs, rhs, dimension_numbers=dimension_numbers, 

1159 precision_config=precision_config, name=name, ctx=_ctx) 

1160 except _core._SymbolicException: 

1161 pass # Add nodes to the TensorFlow graph. 

1162 except (TypeError, ValueError): 

1163 _result = _dispatch.dispatch( 

1164 xla_dot, (), dict(lhs=lhs, rhs=rhs, 

1165 dimension_numbers=dimension_numbers, 

1166 precision_config=precision_config, name=name) 

1167 ) 

1168 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1169 return _result 

1170 raise 

1171 else: 

1172 _result = _dispatcher_for_xla_dot( 

1173 (lhs, rhs, dimension_numbers, precision_config, name,), None) 

1174 if _result is not NotImplemented: 

1175 return _result 

1176 # Add nodes to the TensorFlow graph. 

1177 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1178 precision_config = _execute.make_str(precision_config, "precision_config") 

1179 try: 

1180 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1181 "XlaDot", lhs=lhs, rhs=rhs, dimension_numbers=dimension_numbers, 

1182 precision_config=precision_config, name=name) 

1183 except (TypeError, ValueError): 

1184 _result = _dispatch.dispatch( 

1185 xla_dot, (), dict(lhs=lhs, rhs=rhs, 

1186 dimension_numbers=dimension_numbers, 

1187 precision_config=precision_config, name=name) 

1188 ) 

1189 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1190 return _result 

1191 raise 

1192 _result = _outputs[:] 

1193 if _execute.must_record_gradient(): 

1194 _attrs = ("T", _op._get_attr_type("T"), "dimension_numbers", 

1195 _op.get_attr("dimension_numbers"), "precision_config", 

1196 _op.get_attr("precision_config")) 

1197 _inputs_flat = _op.inputs 

1198 _execute.record_gradient( 

1199 "XlaDot", _inputs_flat, _attrs, _result) 

1200 _result, = _result 

1201 return _result 

1202 

1203XlaDot = tf_export("raw_ops.XlaDot")(_ops.to_raw_op(xla_dot)) 

1204_dispatcher_for_xla_dot = xla_dot._tf_type_based_dispatcher.Dispatch 

1205 

1206 

1207def xla_dot_eager_fallback(lhs, rhs, dimension_numbers, precision_config, name, ctx): 

1208 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1209 precision_config = _execute.make_str(precision_config, "precision_config") 

1210 _attr_T, _inputs_T = _execute.args_to_matching_eager([lhs, rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

1211 (lhs, rhs) = _inputs_T 

1212 _inputs_flat = [lhs, rhs] 

1213 _attrs = ("T", _attr_T, "dimension_numbers", dimension_numbers, 

1214 "precision_config", precision_config) 

1215 _result = _execute.execute(b"XlaDot", 1, inputs=_inputs_flat, attrs=_attrs, 

1216 ctx=ctx, name=name) 

1217 if _execute.must_record_gradient(): 

1218 _execute.record_gradient( 

1219 "XlaDot", _inputs_flat, _attrs, _result) 

1220 _result, = _result 

1221 return _result 

1222 

1223 

1224@_dispatch.add_fallback_dispatch_list 

1225@_dispatch.add_type_based_api_dispatcher 

1226@tf_export('xla_dot_v2') 

1227def xla_dot_v2(lhs, rhs, dimension_numbers, precision_config, preferred_element_type, name=None): 

1228 r"""Wraps the XLA DotGeneral operator, documented at 

1229 

1230 https://www.tensorflow.org/performance/xla/operation_semantics#dotgeneral 

1231 . 

1232 

1233 Args: 

1234 lhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

1235 the LHS tensor 

1236 rhs: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

1237 the RHS tensor 

1238 dimension_numbers: A `string`. 

1239 a serialized xla::DotDimensionNumbers proto. 

1240 precision_config: A `string`. a serialized xla::PrecisionConfig proto. 

1241 preferred_element_type: A `tf.DType` from: `tf.float32, tf.float64, tf.int32, tf.uint8, tf.int16, tf.int8, tf.complex64, tf.int64, tf.qint8, tf.quint8, tf.qint32, tf.bfloat16, tf.qint16, tf.quint16, tf.uint16, tf.complex128, tf.half, tf.uint32, tf.uint64`. 

1242 The type of the tensor. 

1243 name: A name for the operation (optional). 

1244 

1245 Returns: 

1246 A `Tensor` of type `preferred_element_type`. 

1247 """ 

1248 _ctx = _context._context or _context.context() 

1249 tld = _ctx._thread_local_data 

1250 if tld.is_eager: 

1251 try: 

1252 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1253 _ctx, "XlaDotV2", name, lhs, rhs, "dimension_numbers", 

1254 dimension_numbers, "precision_config", precision_config, 

1255 "preferred_element_type", preferred_element_type) 

1256 return _result 

1257 except _core._NotOkStatusException as e: 

1258 _ops.raise_from_not_ok_status(e, name) 

1259 except _core._FallbackException: 

1260 pass 

1261 try: 

1262 _result = _dispatcher_for_xla_dot_v2( 

1263 (lhs, rhs, dimension_numbers, precision_config, 

1264 preferred_element_type, name,), None) 

1265 if _result is not NotImplemented: 

1266 return _result 

1267 return xla_dot_v2_eager_fallback( 

1268 lhs, rhs, dimension_numbers=dimension_numbers, 

1269 precision_config=precision_config, 

1270 preferred_element_type=preferred_element_type, name=name, ctx=_ctx) 

1271 except _core._SymbolicException: 

1272 pass # Add nodes to the TensorFlow graph. 

1273 except (TypeError, ValueError): 

1274 _result = _dispatch.dispatch( 

1275 xla_dot_v2, (), dict(lhs=lhs, rhs=rhs, 

1276 dimension_numbers=dimension_numbers, 

1277 precision_config=precision_config, 

1278 preferred_element_type=preferred_element_type, 

1279 name=name) 

1280 ) 

1281 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1282 return _result 

1283 raise 

1284 else: 

1285 _result = _dispatcher_for_xla_dot_v2( 

1286 (lhs, rhs, dimension_numbers, precision_config, 

1287 preferred_element_type, name,), None) 

1288 if _result is not NotImplemented: 

1289 return _result 

1290 # Add nodes to the TensorFlow graph. 

1291 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1292 precision_config = _execute.make_str(precision_config, "precision_config") 

1293 preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") 

1294 try: 

1295 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1296 "XlaDotV2", lhs=lhs, rhs=rhs, dimension_numbers=dimension_numbers, 

1297 precision_config=precision_config, 

1298 preferred_element_type=preferred_element_type, name=name) 

1299 except (TypeError, ValueError): 

1300 _result = _dispatch.dispatch( 

1301 xla_dot_v2, (), dict(lhs=lhs, rhs=rhs, 

1302 dimension_numbers=dimension_numbers, 

1303 precision_config=precision_config, 

1304 preferred_element_type=preferred_element_type, 

1305 name=name) 

1306 ) 

1307 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1308 return _result 

1309 raise 

1310 _result = _outputs[:] 

1311 if _execute.must_record_gradient(): 

1312 _attrs = ("LhsT", _op._get_attr_type("LhsT"), "RhsT", 

1313 _op._get_attr_type("RhsT"), "dimension_numbers", 

1314 _op.get_attr("dimension_numbers"), "precision_config", 

1315 _op.get_attr("precision_config"), "preferred_element_type", 

1316 _op._get_attr_type("preferred_element_type")) 

1317 _inputs_flat = _op.inputs 

1318 _execute.record_gradient( 

1319 "XlaDotV2", _inputs_flat, _attrs, _result) 

1320 _result, = _result 

1321 return _result 

1322 

1323XlaDotV2 = tf_export("raw_ops.XlaDotV2")(_ops.to_raw_op(xla_dot_v2)) 

1324_dispatcher_for_xla_dot_v2 = xla_dot_v2._tf_type_based_dispatcher.Dispatch 

1325 

1326 

1327def xla_dot_v2_eager_fallback(lhs, rhs, dimension_numbers, precision_config, preferred_element_type, name, ctx): 

1328 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1329 precision_config = _execute.make_str(precision_config, "precision_config") 

1330 preferred_element_type = _execute.make_type(preferred_element_type, "preferred_element_type") 

1331 _attr_LhsT, (lhs,) = _execute.args_to_matching_eager([lhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

1332 _attr_RhsT, (rhs,) = _execute.args_to_matching_eager([rhs], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

1333 _inputs_flat = [lhs, rhs] 

1334 _attrs = ("LhsT", _attr_LhsT, "RhsT", _attr_RhsT, "dimension_numbers", 

1335 dimension_numbers, "precision_config", precision_config, 

1336 "preferred_element_type", preferred_element_type) 

1337 _result = _execute.execute(b"XlaDotV2", 1, inputs=_inputs_flat, 

1338 attrs=_attrs, ctx=ctx, name=name) 

1339 if _execute.must_record_gradient(): 

1340 _execute.record_gradient( 

1341 "XlaDotV2", _inputs_flat, _attrs, _result) 

1342 _result, = _result 

1343 return _result 

1344 

1345 

1346@_dispatch.add_fallback_dispatch_list 

1347@_dispatch.add_type_based_api_dispatcher 

1348@tf_export('xla_dynamic_slice') 

1349def xla_dynamic_slice(input, start_indices, size_indices, name=None): 

1350 r"""Wraps the XLA DynamicSlice operator, documented at 

1351 

1352 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicslice 

1353 . 

1354 

1355 DynamicSlice extracts a sub-array from the input array at dynamic 

1356 start_indices. The size of the slice in each dimension is passed in 

1357 size_indices, which specify the end point of exclusive slice intervals in each 

1358 dimension -- [start, start + size). The shape of start_indices must have rank 1, 

1359 with dimension size equal to the rank of operand. 

1360 

1361 Args: 

1362 input: A `Tensor`. A `Tensor` of type T. 

1363 start_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

1364 List of N integers containing the slice size for each 

1365 dimension. Each value must be strictly greater than zero, and start + size 

1366 must be less than or equal to the size of the dimension to avoid 

1367 implementation defined behavior. 

1368 size_indices: A `Tensor`. Must have the same type as `start_indices`. 

1369 name: A name for the operation (optional). 

1370 

1371 Returns: 

1372 A `Tensor`. Has the same type as `input`. 

1373 """ 

1374 _ctx = _context._context or _context.context() 

1375 tld = _ctx._thread_local_data 

1376 if tld.is_eager: 

1377 try: 

1378 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1379 _ctx, "XlaDynamicSlice", name, input, start_indices, size_indices) 

1380 return _result 

1381 except _core._NotOkStatusException as e: 

1382 _ops.raise_from_not_ok_status(e, name) 

1383 except _core._FallbackException: 

1384 pass 

1385 try: 

1386 _result = _dispatcher_for_xla_dynamic_slice( 

1387 (input, start_indices, size_indices, name,), None) 

1388 if _result is not NotImplemented: 

1389 return _result 

1390 return xla_dynamic_slice_eager_fallback( 

1391 input, start_indices, size_indices, name=name, ctx=_ctx) 

1392 except _core._SymbolicException: 

1393 pass # Add nodes to the TensorFlow graph. 

1394 except (TypeError, ValueError): 

1395 _result = _dispatch.dispatch( 

1396 xla_dynamic_slice, (), dict(input=input, 

1397 start_indices=start_indices, 

1398 size_indices=size_indices, name=name) 

1399 ) 

1400 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1401 return _result 

1402 raise 

1403 else: 

1404 _result = _dispatcher_for_xla_dynamic_slice( 

1405 (input, start_indices, size_indices, name,), None) 

1406 if _result is not NotImplemented: 

1407 return _result 

1408 # Add nodes to the TensorFlow graph. 

1409 try: 

1410 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1411 "XlaDynamicSlice", input=input, start_indices=start_indices, 

1412 size_indices=size_indices, name=name) 

1413 except (TypeError, ValueError): 

1414 _result = _dispatch.dispatch( 

1415 xla_dynamic_slice, (), dict(input=input, 

1416 start_indices=start_indices, 

1417 size_indices=size_indices, name=name) 

1418 ) 

1419 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1420 return _result 

1421 raise 

1422 _result = _outputs[:] 

1423 if _execute.must_record_gradient(): 

1424 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

1425 _op._get_attr_type("Tindices")) 

1426 _inputs_flat = _op.inputs 

1427 _execute.record_gradient( 

1428 "XlaDynamicSlice", _inputs_flat, _attrs, _result) 

1429 _result, = _result 

1430 return _result 

1431 

1432XlaDynamicSlice = tf_export("raw_ops.XlaDynamicSlice")(_ops.to_raw_op(xla_dynamic_slice)) 

1433_dispatcher_for_xla_dynamic_slice = xla_dynamic_slice._tf_type_based_dispatcher.Dispatch 

1434 

1435 

1436def xla_dynamic_slice_eager_fallback(input, start_indices, size_indices, name, ctx): 

1437 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

1438 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([start_indices, size_indices], ctx, [_dtypes.int32, _dtypes.int64, ]) 

1439 (start_indices, size_indices) = _inputs_Tindices 

1440 _inputs_flat = [input, start_indices, size_indices] 

1441 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) 

1442 _result = _execute.execute(b"XlaDynamicSlice", 1, inputs=_inputs_flat, 

1443 attrs=_attrs, ctx=ctx, name=name) 

1444 if _execute.must_record_gradient(): 

1445 _execute.record_gradient( 

1446 "XlaDynamicSlice", _inputs_flat, _attrs, _result) 

1447 _result, = _result 

1448 return _result 

1449 

1450 

1451@_dispatch.add_fallback_dispatch_list 

1452@_dispatch.add_type_based_api_dispatcher 

1453@tf_export('xla_dynamic_update_slice') 

1454def xla_dynamic_update_slice(input, update, indices, name=None): 

1455 r"""Wraps the XLA DynamicUpdateSlice operator, documented at 

1456 

1457 https://www.tensorflow.org/performance/xla/operation_semantics#dynamicupdateslice 

1458 . 

1459 

1460 XlaDynamicUpdateSlice generates a result which is the value of the `input` 

1461 operand, with a slice update overwritten at `indices`. The shape of `update` 

1462 determines the shape of the sub-array of the result which is updated. The shape 

1463 of indices must be rank == 1, with dimension size equal to the rank of `input`. 

1464 

1465 Handling of out-of-bounds slice indices is implementation-defined. 

1466 

1467 Args: 

1468 input: A `Tensor`. A `Tensor` of type T. 

1469 update: A `Tensor`. Must have the same type as `input`. 

1470 A `Tensor` of type T. Same rank as `input`. 

1471 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

1472 A vector of indices into `input`. Must have length equal to the rank of 

1473 `input`. 

1474 name: A name for the operation (optional). 

1475 

1476 Returns: 

1477 A `Tensor`. Has the same type as `input`. A `Tensor` of type T. 

1478 """ 

1479 _ctx = _context._context or _context.context() 

1480 tld = _ctx._thread_local_data 

1481 if tld.is_eager: 

1482 try: 

1483 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1484 _ctx, "XlaDynamicUpdateSlice", name, input, update, indices) 

1485 return _result 

1486 except _core._NotOkStatusException as e: 

1487 _ops.raise_from_not_ok_status(e, name) 

1488 except _core._FallbackException: 

1489 pass 

1490 try: 

1491 _result = _dispatcher_for_xla_dynamic_update_slice( 

1492 (input, update, indices, name,), None) 

1493 if _result is not NotImplemented: 

1494 return _result 

1495 return xla_dynamic_update_slice_eager_fallback( 

1496 input, update, indices, name=name, ctx=_ctx) 

1497 except _core._SymbolicException: 

1498 pass # Add nodes to the TensorFlow graph. 

1499 except (TypeError, ValueError): 

1500 _result = _dispatch.dispatch( 

1501 xla_dynamic_update_slice, (), dict(input=input, update=update, 

1502 indices=indices, name=name) 

1503 ) 

1504 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1505 return _result 

1506 raise 

1507 else: 

1508 _result = _dispatcher_for_xla_dynamic_update_slice( 

1509 (input, update, indices, name,), None) 

1510 if _result is not NotImplemented: 

1511 return _result 

1512 # Add nodes to the TensorFlow graph. 

1513 try: 

1514 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1515 "XlaDynamicUpdateSlice", input=input, update=update, indices=indices, 

1516 name=name) 

1517 except (TypeError, ValueError): 

1518 _result = _dispatch.dispatch( 

1519 xla_dynamic_update_slice, (), dict(input=input, update=update, 

1520 indices=indices, name=name) 

1521 ) 

1522 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1523 return _result 

1524 raise 

1525 _result = _outputs[:] 

1526 if _execute.must_record_gradient(): 

1527 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

1528 _op._get_attr_type("Tindices")) 

1529 _inputs_flat = _op.inputs 

1530 _execute.record_gradient( 

1531 "XlaDynamicUpdateSlice", _inputs_flat, _attrs, _result) 

1532 _result, = _result 

1533 return _result 

1534 

1535XlaDynamicUpdateSlice = tf_export("raw_ops.XlaDynamicUpdateSlice")(_ops.to_raw_op(xla_dynamic_update_slice)) 

1536_dispatcher_for_xla_dynamic_update_slice = xla_dynamic_update_slice._tf_type_based_dispatcher.Dispatch 

1537 

1538 

1539def xla_dynamic_update_slice_eager_fallback(input, update, indices, name, ctx): 

1540 _attr_T, _inputs_T = _execute.args_to_matching_eager([input, update], ctx, []) 

1541 (input, update) = _inputs_T 

1542 _attr_Tindices, (indices,) = _execute.args_to_matching_eager([indices], ctx, [_dtypes.int32, _dtypes.int64, ]) 

1543 _inputs_flat = [input, update, indices] 

1544 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) 

1545 _result = _execute.execute(b"XlaDynamicUpdateSlice", 1, inputs=_inputs_flat, 

1546 attrs=_attrs, ctx=ctx, name=name) 

1547 if _execute.must_record_gradient(): 

1548 _execute.record_gradient( 

1549 "XlaDynamicUpdateSlice", _inputs_flat, _attrs, _result) 

1550 _result, = _result 

1551 return _result 

1552 

1553 

1554@_dispatch.add_fallback_dispatch_list 

1555@_dispatch.add_type_based_api_dispatcher 

1556@tf_export('xla_einsum') 

1557def xla_einsum(a, b, equation, name=None): 

1558 r"""An op which supports basic einsum op with 2 inputs and 1 output. 

1559 

1560 This op has better TPU performance since it doesn't have explicitly reshape and 

1561 transpose operations as tf.einsum does. 

1562 

1563 Args: 

1564 a: A `Tensor`. Must be one of the following types: `complex64`, `bfloat16`, `float32`. 

1565 b: A `Tensor`. Must have the same type as `a`. 

1566 equation: A `string`. 

1567 name: A name for the operation (optional). 

1568 

1569 Returns: 

1570 A `Tensor`. Has the same type as `a`. 

1571 """ 

1572 _ctx = _context._context or _context.context() 

1573 tld = _ctx._thread_local_data 

1574 if tld.is_eager: 

1575 try: 

1576 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1577 _ctx, "XlaEinsum", name, a, b, "equation", equation) 

1578 return _result 

1579 except _core._NotOkStatusException as e: 

1580 _ops.raise_from_not_ok_status(e, name) 

1581 except _core._FallbackException: 

1582 pass 

1583 try: 

1584 _result = _dispatcher_for_xla_einsum( 

1585 (a, b, equation, name,), None) 

1586 if _result is not NotImplemented: 

1587 return _result 

1588 return xla_einsum_eager_fallback( 

1589 a, b, equation=equation, name=name, ctx=_ctx) 

1590 except _core._SymbolicException: 

1591 pass # Add nodes to the TensorFlow graph. 

1592 except (TypeError, ValueError): 

1593 _result = _dispatch.dispatch( 

1594 xla_einsum, (), dict(a=a, b=b, equation=equation, name=name) 

1595 ) 

1596 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1597 return _result 

1598 raise 

1599 else: 

1600 _result = _dispatcher_for_xla_einsum( 

1601 (a, b, equation, name,), None) 

1602 if _result is not NotImplemented: 

1603 return _result 

1604 # Add nodes to the TensorFlow graph. 

1605 equation = _execute.make_str(equation, "equation") 

1606 try: 

1607 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1608 "XlaEinsum", a=a, b=b, equation=equation, name=name) 

1609 except (TypeError, ValueError): 

1610 _result = _dispatch.dispatch( 

1611 xla_einsum, (), dict(a=a, b=b, equation=equation, name=name) 

1612 ) 

1613 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1614 return _result 

1615 raise 

1616 _result = _outputs[:] 

1617 if _execute.must_record_gradient(): 

1618 _attrs = ("equation", _op.get_attr("equation"), "T", 

1619 _op._get_attr_type("T")) 

1620 _inputs_flat = _op.inputs 

1621 _execute.record_gradient( 

1622 "XlaEinsum", _inputs_flat, _attrs, _result) 

1623 _result, = _result 

1624 return _result 

1625 

1626XlaEinsum = tf_export("raw_ops.XlaEinsum")(_ops.to_raw_op(xla_einsum)) 

1627_dispatcher_for_xla_einsum = xla_einsum._tf_type_based_dispatcher.Dispatch 

1628 

1629 

1630def xla_einsum_eager_fallback(a, b, equation, name, ctx): 

1631 equation = _execute.make_str(equation, "equation") 

1632 _attr_T, _inputs_T = _execute.args_to_matching_eager([a, b], ctx, [_dtypes.complex64, _dtypes.bfloat16, _dtypes.float32, ]) 

1633 (a, b) = _inputs_T 

1634 _inputs_flat = [a, b] 

1635 _attrs = ("equation", equation, "T", _attr_T) 

1636 _result = _execute.execute(b"XlaEinsum", 1, inputs=_inputs_flat, 

1637 attrs=_attrs, ctx=ctx, name=name) 

1638 if _execute.must_record_gradient(): 

1639 _execute.record_gradient( 

1640 "XlaEinsum", _inputs_flat, _attrs, _result) 

1641 _result, = _result 

1642 return _result 

1643 

1644 

1645@_dispatch.add_fallback_dispatch_list 

1646@_dispatch.add_type_based_api_dispatcher 

1647@tf_export('xla_gather') 

1648def xla_gather(operand, start_indices, slice_sizes, dimension_numbers, indices_are_sorted, name=None): 

1649 r"""Wraps the XLA Gather operator documented at 

1650 

1651 https://www.tensorflow.org/xla/operation_semantics#gather 

1652 

1653 Args: 

1654 operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. 

1655 The array we're gathering from. 

1656 start_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

1657 Array containing the starting indices of the slices we gather. 

1658 slice_sizes: A `Tensor`. Must have the same type as `start_indices`. 

1659 slice_sizes[i] is the bounds for the slice on dimension i. 

1660 dimension_numbers: A `string`. 

1661 A serialized xla::GatherDimensionNumbers proto. 

1662 indices_are_sorted: A `bool`. 

1663 Boolean indicating if the indices are sorted. 

1664 name: A name for the operation (optional). 

1665 

1666 Returns: 

1667 A `Tensor`. Has the same type as `operand`. 

1668 """ 

1669 _ctx = _context._context or _context.context() 

1670 tld = _ctx._thread_local_data 

1671 if tld.is_eager: 

1672 try: 

1673 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1674 _ctx, "XlaGather", name, operand, start_indices, slice_sizes, 

1675 "dimension_numbers", dimension_numbers, "indices_are_sorted", 

1676 indices_are_sorted) 

1677 return _result 

1678 except _core._NotOkStatusException as e: 

1679 _ops.raise_from_not_ok_status(e, name) 

1680 except _core._FallbackException: 

1681 pass 

1682 try: 

1683 _result = _dispatcher_for_xla_gather( 

1684 (operand, start_indices, slice_sizes, dimension_numbers, 

1685 indices_are_sorted, name,), None) 

1686 if _result is not NotImplemented: 

1687 return _result 

1688 return xla_gather_eager_fallback( 

1689 operand, start_indices, slice_sizes, 

1690 dimension_numbers=dimension_numbers, 

1691 indices_are_sorted=indices_are_sorted, name=name, ctx=_ctx) 

1692 except _core._SymbolicException: 

1693 pass # Add nodes to the TensorFlow graph. 

1694 except (TypeError, ValueError): 

1695 _result = _dispatch.dispatch( 

1696 xla_gather, (), dict(operand=operand, start_indices=start_indices, 

1697 slice_sizes=slice_sizes, 

1698 dimension_numbers=dimension_numbers, 

1699 indices_are_sorted=indices_are_sorted, 

1700 name=name) 

1701 ) 

1702 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1703 return _result 

1704 raise 

1705 else: 

1706 _result = _dispatcher_for_xla_gather( 

1707 (operand, start_indices, slice_sizes, dimension_numbers, 

1708 indices_are_sorted, name,), None) 

1709 if _result is not NotImplemented: 

1710 return _result 

1711 # Add nodes to the TensorFlow graph. 

1712 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1713 indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") 

1714 try: 

1715 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1716 "XlaGather", operand=operand, start_indices=start_indices, 

1717 slice_sizes=slice_sizes, 

1718 dimension_numbers=dimension_numbers, 

1719 indices_are_sorted=indices_are_sorted, name=name) 

1720 except (TypeError, ValueError): 

1721 _result = _dispatch.dispatch( 

1722 xla_gather, (), dict(operand=operand, start_indices=start_indices, 

1723 slice_sizes=slice_sizes, 

1724 dimension_numbers=dimension_numbers, 

1725 indices_are_sorted=indices_are_sorted, 

1726 name=name) 

1727 ) 

1728 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1729 return _result 

1730 raise 

1731 _result = _outputs[:] 

1732 if _execute.must_record_gradient(): 

1733 _attrs = ("dimension_numbers", _op.get_attr("dimension_numbers"), 

1734 "indices_are_sorted", _op._get_attr_bool("indices_are_sorted"), 

1735 "T", _op._get_attr_type("T"), "Tindices", 

1736 _op._get_attr_type("Tindices")) 

1737 _inputs_flat = _op.inputs 

1738 _execute.record_gradient( 

1739 "XlaGather", _inputs_flat, _attrs, _result) 

1740 _result, = _result 

1741 return _result 

1742 

1743XlaGather = tf_export("raw_ops.XlaGather")(_ops.to_raw_op(xla_gather)) 

1744_dispatcher_for_xla_gather = xla_gather._tf_type_based_dispatcher.Dispatch 

1745 

1746 

1747def xla_gather_eager_fallback(operand, start_indices, slice_sizes, dimension_numbers, indices_are_sorted, name, ctx): 

1748 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

1749 indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") 

1750 _attr_T, (operand,) = _execute.args_to_matching_eager([operand], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) 

1751 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([start_indices, slice_sizes], ctx, [_dtypes.int32, _dtypes.int64, ]) 

1752 (start_indices, slice_sizes) = _inputs_Tindices 

1753 _inputs_flat = [operand, start_indices, slice_sizes] 

1754 _attrs = ("dimension_numbers", dimension_numbers, "indices_are_sorted", 

1755 indices_are_sorted, "T", _attr_T, "Tindices", _attr_Tindices) 

1756 _result = _execute.execute(b"XlaGather", 1, inputs=_inputs_flat, 

1757 attrs=_attrs, ctx=ctx, name=name) 

1758 if _execute.must_record_gradient(): 

1759 _execute.record_gradient( 

1760 "XlaGather", _inputs_flat, _attrs, _result) 

1761 _result, = _result 

1762 return _result 

1763 

1764 

1765@_dispatch.add_fallback_dispatch_list 

1766@_dispatch.add_type_based_api_dispatcher 

1767@tf_export('xla_if') 

1768def xla_if(cond, inputs, then_branch, else_branch, Tout, name=None): 

1769 r"""output = cond ? then_branch(inputs) : else_branch(inputs). 

1770 

1771 Args: 

1772 cond: A `Tensor`. A boolean scalar. 

1773 inputs: A list of `Tensor` objects. A list of input tensors. 

1774 then_branch: A function decorated with @Defun. 

1775 A function takes 'inputs' and returns a list of tensors, 

1776 whose types are the same as what else_branch returns. 

1777 else_branch: A function decorated with @Defun. 

1778 A function takes 'inputs' and returns a list of tensors. 

1779 whose types are the same as what then_branch returns. 

1780 Tout: A list of `tf.DTypes`. 

1781 name: A name for the operation (optional). 

1782 

1783 Returns: 

1784 A list of `Tensor` objects of type `Tout`. 

1785 A list of tensors returned by either then_branch(inputs) or 

1786 else_branch(inputs). The input shapes of the then_branch and 

1787 else_branch must match. 

1788 """ 

1789 _ctx = _context._context or _context.context() 

1790 tld = _ctx._thread_local_data 

1791 if tld.is_eager: 

1792 try: 

1793 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1794 _ctx, "XlaIf", name, cond, inputs, "then_branch", then_branch, 

1795 "else_branch", else_branch, "Tout", Tout) 

1796 return _result 

1797 except _core._NotOkStatusException as e: 

1798 _ops.raise_from_not_ok_status(e, name) 

1799 except _core._FallbackException: 

1800 pass 

1801 try: 

1802 _result = _dispatcher_for_xla_if( 

1803 (cond, inputs, then_branch, else_branch, Tout, name,), None) 

1804 if _result is not NotImplemented: 

1805 return _result 

1806 return xla_if_eager_fallback( 

1807 cond, inputs, then_branch=then_branch, else_branch=else_branch, 

1808 Tout=Tout, name=name, ctx=_ctx) 

1809 except _core._SymbolicException: 

1810 pass # Add nodes to the TensorFlow graph. 

1811 except (TypeError, ValueError): 

1812 _result = _dispatch.dispatch( 

1813 xla_if, (), dict(cond=cond, inputs=inputs, 

1814 then_branch=then_branch, else_branch=else_branch, 

1815 Tout=Tout, name=name) 

1816 ) 

1817 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1818 return _result 

1819 raise 

1820 else: 

1821 _result = _dispatcher_for_xla_if( 

1822 (cond, inputs, then_branch, else_branch, Tout, name,), None) 

1823 if _result is not NotImplemented: 

1824 return _result 

1825 # Add nodes to the TensorFlow graph. 

1826 if not isinstance(Tout, (list, tuple)): 

1827 raise TypeError( 

1828 "Expected list for 'Tout' argument to " 

1829 "'xla_if' Op, not %r." % Tout) 

1830 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

1831 try: 

1832 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1833 "XlaIf", cond=cond, inputs=inputs, then_branch=then_branch, 

1834 else_branch=else_branch, Tout=Tout, name=name) 

1835 except (TypeError, ValueError): 

1836 _result = _dispatch.dispatch( 

1837 xla_if, (), dict(cond=cond, inputs=inputs, then_branch=then_branch, 

1838 else_branch=else_branch, Tout=Tout, name=name) 

1839 ) 

1840 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1841 return _result 

1842 raise 

1843 _result = _outputs[:] 

1844 if not _result: 

1845 return _op 

1846 if _execute.must_record_gradient(): 

1847 _attrs = ("Tcond", _op._get_attr_type("Tcond"), "then_branch", 

1848 _op.get_attr("then_branch"), "else_branch", 

1849 _op.get_attr("else_branch"), "Tin", _op.get_attr("Tin"), "Tout", 

1850 _op.get_attr("Tout")) 

1851 _inputs_flat = _op.inputs 

1852 _execute.record_gradient( 

1853 "XlaIf", _inputs_flat, _attrs, _result) 

1854 return _result 

1855 

1856XlaIf = tf_export("raw_ops.XlaIf")(_ops.to_raw_op(xla_if)) 

1857_dispatcher_for_xla_if = xla_if._tf_type_based_dispatcher.Dispatch 

1858 

1859 

1860def xla_if_eager_fallback(cond, inputs, then_branch, else_branch, Tout, name, ctx): 

1861 if not isinstance(Tout, (list, tuple)): 

1862 raise TypeError( 

1863 "Expected list for 'Tout' argument to " 

1864 "'xla_if' Op, not %r." % Tout) 

1865 Tout = [_execute.make_type(_t, "Tout") for _t in Tout] 

1866 _attr_Tcond, (cond,) = _execute.args_to_matching_eager([cond], ctx, []) 

1867 _attr_Tin, inputs = _execute.convert_to_mixed_eager_tensors(inputs, ctx) 

1868 _inputs_flat = [cond] + list(inputs) 

1869 _attrs = ("Tcond", _attr_Tcond, "then_branch", then_branch, "else_branch", 

1870 else_branch, "Tin", _attr_Tin, "Tout", Tout) 

1871 _result = _execute.execute(b"XlaIf", len(Tout), inputs=_inputs_flat, 

1872 attrs=_attrs, ctx=ctx, name=name) 

1873 if _execute.must_record_gradient(): 

1874 _execute.record_gradient( 

1875 "XlaIf", _inputs_flat, _attrs, _result) 

1876 return _result 

1877 

1878_XlaKeyValueSortOutput = collections.namedtuple( 

1879 "XlaKeyValueSort", 

1880 ["sorted_keys", "sorted_values"]) 

1881 

1882 

1883@_dispatch.add_fallback_dispatch_list 

1884@_dispatch.add_type_based_api_dispatcher 

1885@tf_export('xla_key_value_sort') 

1886def xla_key_value_sort(keys, values, name=None): 

1887 r"""Wraps the XLA Sort operator, documented at 

1888 

1889 https://www.tensorflow.org/performance/xla/operation_semantics#sort 

1890 . 

1891 

1892 Sorts a tensor. Currently only sorts in ascending order are supported. 

1893 

1894 Args: 

1895 keys: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `int64`, `bfloat16`, `uint16`, `half`, `uint32`, `uint64`. 

1896 A `Tensor` of type K. 

1897 values: A `Tensor`. A `Tensor` of type V. 

1898 name: A name for the operation (optional). 

1899 

1900 Returns: 

1901 A tuple of `Tensor` objects (sorted_keys, sorted_values). 

1902 

1903 sorted_keys: A `Tensor`. Has the same type as `keys`. A `Tensor` of type K. 

1904 sorted_values: A `Tensor`. Has the same type as `values`. A `Tensor` of type V. 

1905 """ 

1906 _ctx = _context._context or _context.context() 

1907 tld = _ctx._thread_local_data 

1908 if tld.is_eager: 

1909 try: 

1910 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1911 _ctx, "XlaKeyValueSort", name, keys, values) 

1912 _result = _XlaKeyValueSortOutput._make(_result) 

1913 return _result 

1914 except _core._NotOkStatusException as e: 

1915 _ops.raise_from_not_ok_status(e, name) 

1916 except _core._FallbackException: 

1917 pass 

1918 try: 

1919 _result = _dispatcher_for_xla_key_value_sort( 

1920 (keys, values, name,), None) 

1921 if _result is not NotImplemented: 

1922 return _result 

1923 return xla_key_value_sort_eager_fallback( 

1924 keys, values, name=name, ctx=_ctx) 

1925 except _core._SymbolicException: 

1926 pass # Add nodes to the TensorFlow graph. 

1927 except (TypeError, ValueError): 

1928 _result = _dispatch.dispatch( 

1929 xla_key_value_sort, (), dict(keys=keys, values=values, name=name) 

1930 ) 

1931 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1932 return _result 

1933 raise 

1934 else: 

1935 _result = _dispatcher_for_xla_key_value_sort( 

1936 (keys, values, name,), None) 

1937 if _result is not NotImplemented: 

1938 return _result 

1939 # Add nodes to the TensorFlow graph. 

1940 try: 

1941 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

1942 "XlaKeyValueSort", keys=keys, values=values, name=name) 

1943 except (TypeError, ValueError): 

1944 _result = _dispatch.dispatch( 

1945 xla_key_value_sort, (), dict(keys=keys, values=values, name=name) 

1946 ) 

1947 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

1948 return _result 

1949 raise 

1950 _result = _outputs[:] 

1951 if _execute.must_record_gradient(): 

1952 _attrs = ("K", _op._get_attr_type("K"), "V", _op._get_attr_type("V")) 

1953 _inputs_flat = _op.inputs 

1954 _execute.record_gradient( 

1955 "XlaKeyValueSort", _inputs_flat, _attrs, _result) 

1956 _result = _XlaKeyValueSortOutput._make(_result) 

1957 return _result 

1958 

1959XlaKeyValueSort = tf_export("raw_ops.XlaKeyValueSort")(_ops.to_raw_op(xla_key_value_sort)) 

1960_dispatcher_for_xla_key_value_sort = xla_key_value_sort._tf_type_based_dispatcher.Dispatch 

1961 

1962 

1963def xla_key_value_sort_eager_fallback(keys, values, name, ctx): 

1964 _attr_K, (keys,) = _execute.args_to_matching_eager([keys], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.int64, _dtypes.bfloat16, _dtypes.uint16, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

1965 _attr_V, (values,) = _execute.args_to_matching_eager([values], ctx, []) 

1966 _inputs_flat = [keys, values] 

1967 _attrs = ("K", _attr_K, "V", _attr_V) 

1968 _result = _execute.execute(b"XlaKeyValueSort", 2, inputs=_inputs_flat, 

1969 attrs=_attrs, ctx=ctx, name=name) 

1970 if _execute.must_record_gradient(): 

1971 _execute.record_gradient( 

1972 "XlaKeyValueSort", _inputs_flat, _attrs, _result) 

1973 _result = _XlaKeyValueSortOutput._make(_result) 

1974 return _result 

1975 

1976 

1977@_dispatch.add_fallback_dispatch_list 

1978@_dispatch.add_type_based_api_dispatcher 

1979@tf_export('xla_optimization_barrier') 

1980def xla_optimization_barrier(input, name=None): 

1981 r"""Wraps the XLA OptimizationBarrier operator. 

1982 

1983 Documented at https://www.tensorflow.org/xla/operation_semantics#optimizationbarrier. 

1984 

1985 Args: 

1986 input: A list of `Tensor` objects. A Tuple of Arrays of any type. 

1987 name: A name for the operation (optional). 

1988 

1989 Returns: 

1990 A list of `Tensor` objects. Has the same type as `input`. 

1991 """ 

1992 _ctx = _context._context or _context.context() 

1993 tld = _ctx._thread_local_data 

1994 if tld.is_eager: 

1995 try: 

1996 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

1997 _ctx, "XlaOptimizationBarrier", name, input) 

1998 return _result 

1999 except _core._NotOkStatusException as e: 

2000 _ops.raise_from_not_ok_status(e, name) 

2001 except _core._FallbackException: 

2002 pass 

2003 try: 

2004 _result = _dispatcher_for_xla_optimization_barrier( 

2005 (input, name,), None) 

2006 if _result is not NotImplemented: 

2007 return _result 

2008 return xla_optimization_barrier_eager_fallback( 

2009 input, name=name, ctx=_ctx) 

2010 except _core._SymbolicException: 

2011 pass # Add nodes to the TensorFlow graph. 

2012 except (TypeError, ValueError): 

2013 _result = _dispatch.dispatch( 

2014 xla_optimization_barrier, (), dict(input=input, name=name) 

2015 ) 

2016 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2017 return _result 

2018 raise 

2019 else: 

2020 _result = _dispatcher_for_xla_optimization_barrier( 

2021 (input, name,), None) 

2022 if _result is not NotImplemented: 

2023 return _result 

2024 # Add nodes to the TensorFlow graph. 

2025 try: 

2026 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2027 "XlaOptimizationBarrier", input=input, name=name) 

2028 except (TypeError, ValueError): 

2029 _result = _dispatch.dispatch( 

2030 xla_optimization_barrier, (), dict(input=input, name=name) 

2031 ) 

2032 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2033 return _result 

2034 raise 

2035 _result = _outputs[:] 

2036 if _execute.must_record_gradient(): 

2037 _attrs = ("T", _op.get_attr("T")) 

2038 _inputs_flat = _op.inputs 

2039 _execute.record_gradient( 

2040 "XlaOptimizationBarrier", _inputs_flat, _attrs, _result) 

2041 return _result 

2042 

2043XlaOptimizationBarrier = tf_export("raw_ops.XlaOptimizationBarrier")(_ops.to_raw_op(xla_optimization_barrier)) 

2044_dispatcher_for_xla_optimization_barrier = xla_optimization_barrier._tf_type_based_dispatcher.Dispatch 

2045 

2046 

2047def xla_optimization_barrier_eager_fallback(input, name, ctx): 

2048 _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx) 

2049 _inputs_flat = list(input) 

2050 _attrs = ("T", _attr_T) 

2051 _result = _execute.execute(b"XlaOptimizationBarrier", len(input), 

2052 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

2053 name=name) 

2054 if _execute.must_record_gradient(): 

2055 _execute.record_gradient( 

2056 "XlaOptimizationBarrier", _inputs_flat, _attrs, _result) 

2057 return _result 

2058 

2059 

2060@_dispatch.add_fallback_dispatch_list 

2061@_dispatch.add_type_based_api_dispatcher 

2062@tf_export('xla_pad') 

2063def xla_pad(input, padding_value, padding_low, padding_high, padding_interior, name=None): 

2064 r"""Wraps the XLA Pad operator, documented at 

2065 

2066 https://www.tensorflow.org/performance/xla/operation_semantics#pad 

2067 . 

2068 

2069 Args: 

2070 input: A `Tensor`. A `Tensor` of type T. 

2071 padding_value: A `Tensor`. Must have the same type as `input`. 

2072 A scalar `Tensor` of type T. 

2073 padding_low: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

2074 the padding to apply at the start of each input dimensions. Must 

2075 be a compile-time constant 1D tensor of length equal to rank of input. 

2076 padding_high: A `Tensor`. Must have the same type as `padding_low`. 

2077 the padding to apply at the end of each input dimension. Must 

2078 be a compile-time constant 1D tensor of length equal to rank of input. 

2079 padding_interior: A `Tensor`. Must have the same type as `padding_low`. 

2080 the padding to apply between each input element. Must 

2081 be a compile-time constant 1D tensor of length equal to rank of input, 

2082 containing only non-negative values. 

2083 name: A name for the operation (optional). 

2084 

2085 Returns: 

2086 A `Tensor`. Has the same type as `input`. A `Tensor` of type T. 

2087 """ 

2088 _ctx = _context._context or _context.context() 

2089 tld = _ctx._thread_local_data 

2090 if tld.is_eager: 

2091 try: 

2092 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2093 _ctx, "XlaPad", name, input, padding_value, padding_low, padding_high, 

2094 padding_interior) 

2095 return _result 

2096 except _core._NotOkStatusException as e: 

2097 _ops.raise_from_not_ok_status(e, name) 

2098 except _core._FallbackException: 

2099 pass 

2100 try: 

2101 _result = _dispatcher_for_xla_pad( 

2102 (input, padding_value, padding_low, padding_high, padding_interior, 

2103 name,), None) 

2104 if _result is not NotImplemented: 

2105 return _result 

2106 return xla_pad_eager_fallback( 

2107 input, padding_value, padding_low, padding_high, padding_interior, 

2108 name=name, ctx=_ctx) 

2109 except _core._SymbolicException: 

2110 pass # Add nodes to the TensorFlow graph. 

2111 except (TypeError, ValueError): 

2112 _result = _dispatch.dispatch( 

2113 xla_pad, (), dict(input=input, padding_value=padding_value, 

2114 padding_low=padding_low, 

2115 padding_high=padding_high, 

2116 padding_interior=padding_interior, name=name) 

2117 ) 

2118 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2119 return _result 

2120 raise 

2121 else: 

2122 _result = _dispatcher_for_xla_pad( 

2123 (input, padding_value, padding_low, padding_high, padding_interior, 

2124 name,), None) 

2125 if _result is not NotImplemented: 

2126 return _result 

2127 # Add nodes to the TensorFlow graph. 

2128 try: 

2129 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2130 "XlaPad", input=input, padding_value=padding_value, 

2131 padding_low=padding_low, padding_high=padding_high, 

2132 padding_interior=padding_interior, name=name) 

2133 except (TypeError, ValueError): 

2134 _result = _dispatch.dispatch( 

2135 xla_pad, (), dict(input=input, padding_value=padding_value, 

2136 padding_low=padding_low, 

2137 padding_high=padding_high, 

2138 padding_interior=padding_interior, name=name) 

2139 ) 

2140 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2141 return _result 

2142 raise 

2143 _result = _outputs[:] 

2144 if _execute.must_record_gradient(): 

2145 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

2146 _op._get_attr_type("Tindices")) 

2147 _inputs_flat = _op.inputs 

2148 _execute.record_gradient( 

2149 "XlaPad", _inputs_flat, _attrs, _result) 

2150 _result, = _result 

2151 return _result 

2152 

2153XlaPad = tf_export("raw_ops.XlaPad")(_ops.to_raw_op(xla_pad)) 

2154_dispatcher_for_xla_pad = xla_pad._tf_type_based_dispatcher.Dispatch 

2155 

2156 

2157def xla_pad_eager_fallback(input, padding_value, padding_low, padding_high, padding_interior, name, ctx): 

2158 _attr_T, _inputs_T = _execute.args_to_matching_eager([input, padding_value], ctx, []) 

2159 (input, padding_value) = _inputs_T 

2160 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([padding_low, padding_high, padding_interior], ctx, [_dtypes.int32, _dtypes.int64, ]) 

2161 (padding_low, padding_high, padding_interior) = _inputs_Tindices 

2162 _inputs_flat = [input, padding_value, padding_low, padding_high, padding_interior] 

2163 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices) 

2164 _result = _execute.execute(b"XlaPad", 1, inputs=_inputs_flat, attrs=_attrs, 

2165 ctx=ctx, name=name) 

2166 if _execute.must_record_gradient(): 

2167 _execute.record_gradient( 

2168 "XlaPad", _inputs_flat, _attrs, _result) 

2169 _result, = _result 

2170 return _result 

2171 

2172 

2173@_dispatch.add_fallback_dispatch_list 

2174@_dispatch.add_type_based_api_dispatcher 

2175@tf_export('xla_recv') 

2176def xla_recv(dtype, tensor_name, shape, name=None): 

2177 r"""Receives the named tensor from another XLA computation. Wraps the XLA Recv 

2178 

2179 operator documented at 

2180 https://www.tensorflow.org/performance/xla/operation_semantics#recv . 

2181 

2182 Args: 

2183 dtype: A `tf.DType`. The type of the tensor. 

2184 tensor_name: A `string`. A string key that identifies the channel. 

2185 shape: A `tf.TensorShape` or list of `ints`. The shape of the tensor. 

2186 name: A name for the operation (optional). 

2187 

2188 Returns: 

2189 A `Tensor` of type `dtype`. The tensor to receive. 

2190 """ 

2191 _ctx = _context._context or _context.context() 

2192 tld = _ctx._thread_local_data 

2193 if tld.is_eager: 

2194 try: 

2195 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2196 _ctx, "XlaRecv", name, "dtype", dtype, "tensor_name", tensor_name, 

2197 "shape", shape) 

2198 return _result 

2199 except _core._NotOkStatusException as e: 

2200 _ops.raise_from_not_ok_status(e, name) 

2201 except _core._FallbackException: 

2202 pass 

2203 try: 

2204 _result = _dispatcher_for_xla_recv( 

2205 (dtype, tensor_name, shape, name,), None) 

2206 if _result is not NotImplemented: 

2207 return _result 

2208 return xla_recv_eager_fallback( 

2209 dtype=dtype, tensor_name=tensor_name, shape=shape, name=name, 

2210 ctx=_ctx) 

2211 except _core._SymbolicException: 

2212 pass # Add nodes to the TensorFlow graph. 

2213 except (TypeError, ValueError): 

2214 _result = _dispatch.dispatch( 

2215 xla_recv, (), dict(dtype=dtype, tensor_name=tensor_name, 

2216 shape=shape, name=name) 

2217 ) 

2218 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2219 return _result 

2220 raise 

2221 else: 

2222 _result = _dispatcher_for_xla_recv( 

2223 (dtype, tensor_name, shape, name,), None) 

2224 if _result is not NotImplemented: 

2225 return _result 

2226 # Add nodes to the TensorFlow graph. 

2227 dtype = _execute.make_type(dtype, "dtype") 

2228 tensor_name = _execute.make_str(tensor_name, "tensor_name") 

2229 shape = _execute.make_shape(shape, "shape") 

2230 try: 

2231 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2232 "XlaRecv", dtype=dtype, tensor_name=tensor_name, shape=shape, 

2233 name=name) 

2234 except (TypeError, ValueError): 

2235 _result = _dispatch.dispatch( 

2236 xla_recv, (), dict(dtype=dtype, tensor_name=tensor_name, 

2237 shape=shape, name=name) 

2238 ) 

2239 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2240 return _result 

2241 raise 

2242 _result = _outputs[:] 

2243 if _execute.must_record_gradient(): 

2244 _attrs = ("dtype", _op._get_attr_type("dtype"), "tensor_name", 

2245 _op.get_attr("tensor_name"), "shape", _op.get_attr("shape")) 

2246 _inputs_flat = _op.inputs 

2247 _execute.record_gradient( 

2248 "XlaRecv", _inputs_flat, _attrs, _result) 

2249 _result, = _result 

2250 return _result 

2251 

2252XlaRecv = tf_export("raw_ops.XlaRecv")(_ops.to_raw_op(xla_recv)) 

2253_dispatcher_for_xla_recv = xla_recv._tf_type_based_dispatcher.Dispatch 

2254 

2255 

2256def xla_recv_eager_fallback(dtype, tensor_name, shape, name, ctx): 

2257 dtype = _execute.make_type(dtype, "dtype") 

2258 tensor_name = _execute.make_str(tensor_name, "tensor_name") 

2259 shape = _execute.make_shape(shape, "shape") 

2260 _inputs_flat = [] 

2261 _attrs = ("dtype", dtype, "tensor_name", tensor_name, "shape", shape) 

2262 _result = _execute.execute(b"XlaRecv", 1, inputs=_inputs_flat, attrs=_attrs, 

2263 ctx=ctx, name=name) 

2264 if _execute.must_record_gradient(): 

2265 _execute.record_gradient( 

2266 "XlaRecv", _inputs_flat, _attrs, _result) 

2267 _result, = _result 

2268 return _result 

2269 

2270 

2271@_dispatch.add_fallback_dispatch_list 

2272@_dispatch.add_type_based_api_dispatcher 

2273@tf_export('xla_reduce') 

2274def xla_reduce(input, init_value, dimensions_to_reduce, reducer, name=None): 

2275 r"""Wraps the XLA Reduce operator, documented at 

2276 

2277 https://www.tensorflow.org/performance/xla/operation_semantics#reduce . 

2278 

2279 Args: 

2280 input: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. 

2281 the input tensor 

2282 init_value: A `Tensor`. Must have the same type as `input`. 

2283 a scalar representing the initial value for the reduction 

2284 dimensions_to_reduce: A list of `ints`. 

2285 dimension numbers over which to reduce 

2286 reducer: A function decorated with @Defun. a reducer function to apply 

2287 name: A name for the operation (optional). 

2288 

2289 Returns: 

2290 A `Tensor`. Has the same type as `input`. 

2291 """ 

2292 _ctx = _context._context or _context.context() 

2293 tld = _ctx._thread_local_data 

2294 if tld.is_eager: 

2295 try: 

2296 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2297 _ctx, "XlaReduce", name, input, init_value, "dimensions_to_reduce", 

2298 dimensions_to_reduce, "reducer", reducer) 

2299 return _result 

2300 except _core._NotOkStatusException as e: 

2301 _ops.raise_from_not_ok_status(e, name) 

2302 except _core._FallbackException: 

2303 pass 

2304 try: 

2305 _result = _dispatcher_for_xla_reduce( 

2306 (input, init_value, dimensions_to_reduce, reducer, name,), None) 

2307 if _result is not NotImplemented: 

2308 return _result 

2309 return xla_reduce_eager_fallback( 

2310 input, init_value, dimensions_to_reduce=dimensions_to_reduce, 

2311 reducer=reducer, name=name, ctx=_ctx) 

2312 except _core._SymbolicException: 

2313 pass # Add nodes to the TensorFlow graph. 

2314 except (TypeError, ValueError): 

2315 _result = _dispatch.dispatch( 

2316 xla_reduce, (), dict(input=input, init_value=init_value, 

2317 dimensions_to_reduce=dimensions_to_reduce, 

2318 reducer=reducer, name=name) 

2319 ) 

2320 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2321 return _result 

2322 raise 

2323 else: 

2324 _result = _dispatcher_for_xla_reduce( 

2325 (input, init_value, dimensions_to_reduce, reducer, name,), None) 

2326 if _result is not NotImplemented: 

2327 return _result 

2328 # Add nodes to the TensorFlow graph. 

2329 if not isinstance(dimensions_to_reduce, (list, tuple)): 

2330 raise TypeError( 

2331 "Expected list for 'dimensions_to_reduce' argument to " 

2332 "'xla_reduce' Op, not %r." % dimensions_to_reduce) 

2333 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

2334 try: 

2335 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2336 "XlaReduce", input=input, init_value=init_value, 

2337 dimensions_to_reduce=dimensions_to_reduce, 

2338 reducer=reducer, name=name) 

2339 except (TypeError, ValueError): 

2340 _result = _dispatch.dispatch( 

2341 xla_reduce, (), dict(input=input, init_value=init_value, 

2342 dimensions_to_reduce=dimensions_to_reduce, 

2343 reducer=reducer, name=name) 

2344 ) 

2345 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2346 return _result 

2347 raise 

2348 _result = _outputs[:] 

2349 if _execute.must_record_gradient(): 

2350 _attrs = ("T", _op._get_attr_type("T"), "dimensions_to_reduce", 

2351 _op.get_attr("dimensions_to_reduce"), "reducer", 

2352 _op.get_attr("reducer")) 

2353 _inputs_flat = _op.inputs 

2354 _execute.record_gradient( 

2355 "XlaReduce", _inputs_flat, _attrs, _result) 

2356 _result, = _result 

2357 return _result 

2358 

2359XlaReduce = tf_export("raw_ops.XlaReduce")(_ops.to_raw_op(xla_reduce)) 

2360_dispatcher_for_xla_reduce = xla_reduce._tf_type_based_dispatcher.Dispatch 

2361 

2362 

2363def xla_reduce_eager_fallback(input, init_value, dimensions_to_reduce, reducer, name, ctx): 

2364 if not isinstance(dimensions_to_reduce, (list, tuple)): 

2365 raise TypeError( 

2366 "Expected list for 'dimensions_to_reduce' argument to " 

2367 "'xla_reduce' Op, not %r." % dimensions_to_reduce) 

2368 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

2369 _attr_T, _inputs_T = _execute.args_to_matching_eager([input, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) 

2370 (input, init_value) = _inputs_T 

2371 _inputs_flat = [input, init_value] 

2372 _attrs = ("T", _attr_T, "dimensions_to_reduce", dimensions_to_reduce, 

2373 "reducer", reducer) 

2374 _result = _execute.execute(b"XlaReduce", 1, inputs=_inputs_flat, 

2375 attrs=_attrs, ctx=ctx, name=name) 

2376 if _execute.must_record_gradient(): 

2377 _execute.record_gradient( 

2378 "XlaReduce", _inputs_flat, _attrs, _result) 

2379 _result, = _result 

2380 return _result 

2381 

2382 

2383@_dispatch.add_fallback_dispatch_list 

2384@_dispatch.add_type_based_api_dispatcher 

2385@tf_export('xla_reduce_precision') 

2386def xla_reduce_precision(operand, exponent_bits, mantissa_bits, name=None): 

2387 r"""Wraps the XLA ReducePrecision operator 

2388 

2389 documented at https://www.tensorflow.org/xla/operation_semantics#reduceprecision. 

2390 

2391 Args: 

2392 operand: A `Tensor`. Must be one of the following types: `bfloat16`, `half`, `float32`, `float64`. 

2393 array of floating-point type. 

2394 exponent_bits: An `int`. number of exponent bits in lower-precision format 

2395 mantissa_bits: An `int`. number of mantissa bits in lower-precision format 

2396 name: A name for the operation (optional). 

2397 

2398 Returns: 

2399 A `Tensor`. Has the same type as `operand`. 

2400 """ 

2401 _ctx = _context._context or _context.context() 

2402 tld = _ctx._thread_local_data 

2403 if tld.is_eager: 

2404 try: 

2405 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2406 _ctx, "XlaReducePrecision", name, operand, "exponent_bits", 

2407 exponent_bits, "mantissa_bits", mantissa_bits) 

2408 return _result 

2409 except _core._NotOkStatusException as e: 

2410 _ops.raise_from_not_ok_status(e, name) 

2411 except _core._FallbackException: 

2412 pass 

2413 try: 

2414 _result = _dispatcher_for_xla_reduce_precision( 

2415 (operand, exponent_bits, mantissa_bits, name,), None) 

2416 if _result is not NotImplemented: 

2417 return _result 

2418 return xla_reduce_precision_eager_fallback( 

2419 operand, exponent_bits=exponent_bits, mantissa_bits=mantissa_bits, 

2420 name=name, ctx=_ctx) 

2421 except _core._SymbolicException: 

2422 pass # Add nodes to the TensorFlow graph. 

2423 except (TypeError, ValueError): 

2424 _result = _dispatch.dispatch( 

2425 xla_reduce_precision, (), dict(operand=operand, 

2426 exponent_bits=exponent_bits, 

2427 mantissa_bits=mantissa_bits, 

2428 name=name) 

2429 ) 

2430 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2431 return _result 

2432 raise 

2433 else: 

2434 _result = _dispatcher_for_xla_reduce_precision( 

2435 (operand, exponent_bits, mantissa_bits, name,), None) 

2436 if _result is not NotImplemented: 

2437 return _result 

2438 # Add nodes to the TensorFlow graph. 

2439 exponent_bits = _execute.make_int(exponent_bits, "exponent_bits") 

2440 mantissa_bits = _execute.make_int(mantissa_bits, "mantissa_bits") 

2441 try: 

2442 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2443 "XlaReducePrecision", operand=operand, exponent_bits=exponent_bits, 

2444 mantissa_bits=mantissa_bits, name=name) 

2445 except (TypeError, ValueError): 

2446 _result = _dispatch.dispatch( 

2447 xla_reduce_precision, (), dict(operand=operand, 

2448 exponent_bits=exponent_bits, 

2449 mantissa_bits=mantissa_bits, 

2450 name=name) 

2451 ) 

2452 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2453 return _result 

2454 raise 

2455 _result = _outputs[:] 

2456 if _execute.must_record_gradient(): 

2457 _attrs = ("T", _op._get_attr_type("T"), "exponent_bits", 

2458 _op._get_attr_int("exponent_bits"), "mantissa_bits", 

2459 _op._get_attr_int("mantissa_bits")) 

2460 _inputs_flat = _op.inputs 

2461 _execute.record_gradient( 

2462 "XlaReducePrecision", _inputs_flat, _attrs, _result) 

2463 _result, = _result 

2464 return _result 

2465 

2466XlaReducePrecision = tf_export("raw_ops.XlaReducePrecision")(_ops.to_raw_op(xla_reduce_precision)) 

2467_dispatcher_for_xla_reduce_precision = xla_reduce_precision._tf_type_based_dispatcher.Dispatch 

2468 

2469 

2470def xla_reduce_precision_eager_fallback(operand, exponent_bits, mantissa_bits, name, ctx): 

2471 exponent_bits = _execute.make_int(exponent_bits, "exponent_bits") 

2472 mantissa_bits = _execute.make_int(mantissa_bits, "mantissa_bits") 

2473 _attr_T, (operand,) = _execute.args_to_matching_eager([operand], ctx, [_dtypes.bfloat16, _dtypes.half, _dtypes.float32, _dtypes.float64, ]) 

2474 _inputs_flat = [operand] 

2475 _attrs = ("T", _attr_T, "exponent_bits", exponent_bits, "mantissa_bits", 

2476 mantissa_bits) 

2477 _result = _execute.execute(b"XlaReducePrecision", 1, inputs=_inputs_flat, 

2478 attrs=_attrs, ctx=ctx, name=name) 

2479 if _execute.must_record_gradient(): 

2480 _execute.record_gradient( 

2481 "XlaReducePrecision", _inputs_flat, _attrs, _result) 

2482 _result, = _result 

2483 return _result 

2484 

2485 

2486@_dispatch.add_fallback_dispatch_list 

2487@_dispatch.add_type_based_api_dispatcher 

2488@tf_export('xla_reduce_scatter') 

2489def xla_reduce_scatter(input, group_assignment, scatter_dimension, reduce_op, name=None): 

2490 r"""Wraps the XLA ReduceScatter operator 

2491 

2492 documented at https://www.tensorflow.org/xla/operation_semantics#reducescatter. 

2493 

2494 Args: 

2495 input: A `Tensor`. Must be one of the following types: `half`, `bfloat16`, `float32`, `int32`, `uint32`. 

2496 Array or a non-empty tuple of arrays to reduce across replicas. 

2497 group_assignment: A `Tensor` of type `int32`. 

2498 Groups between which the reductions are performed. 

2499 scatter_dimension: A `Tensor` of type `int32`. Dimension to scatter. 

2500 reduce_op: A `string` from: `"Min", "Max", "Mul", "Add", "Mean"`. 

2501 Reduction computation. 

2502 name: A name for the operation (optional). 

2503 

2504 Returns: 

2505 A `Tensor`. Has the same type as `input`. 

2506 """ 

2507 _ctx = _context._context or _context.context() 

2508 tld = _ctx._thread_local_data 

2509 if tld.is_eager: 

2510 try: 

2511 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2512 _ctx, "XlaReduceScatter", name, input, group_assignment, 

2513 scatter_dimension, "reduce_op", reduce_op) 

2514 return _result 

2515 except _core._NotOkStatusException as e: 

2516 _ops.raise_from_not_ok_status(e, name) 

2517 except _core._FallbackException: 

2518 pass 

2519 try: 

2520 _result = _dispatcher_for_xla_reduce_scatter( 

2521 (input, group_assignment, scatter_dimension, reduce_op, name,), None) 

2522 if _result is not NotImplemented: 

2523 return _result 

2524 return xla_reduce_scatter_eager_fallback( 

2525 input, group_assignment, scatter_dimension, reduce_op=reduce_op, 

2526 name=name, ctx=_ctx) 

2527 except _core._SymbolicException: 

2528 pass # Add nodes to the TensorFlow graph. 

2529 except (TypeError, ValueError): 

2530 _result = _dispatch.dispatch( 

2531 xla_reduce_scatter, (), dict(input=input, 

2532 group_assignment=group_assignment, 

2533 scatter_dimension=scatter_dimension, 

2534 reduce_op=reduce_op, name=name) 

2535 ) 

2536 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2537 return _result 

2538 raise 

2539 else: 

2540 _result = _dispatcher_for_xla_reduce_scatter( 

2541 (input, group_assignment, scatter_dimension, reduce_op, name,), None) 

2542 if _result is not NotImplemented: 

2543 return _result 

2544 # Add nodes to the TensorFlow graph. 

2545 reduce_op = _execute.make_str(reduce_op, "reduce_op") 

2546 try: 

2547 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2548 "XlaReduceScatter", input=input, group_assignment=group_assignment, 

2549 scatter_dimension=scatter_dimension, 

2550 reduce_op=reduce_op, name=name) 

2551 except (TypeError, ValueError): 

2552 _result = _dispatch.dispatch( 

2553 xla_reduce_scatter, (), dict(input=input, 

2554 group_assignment=group_assignment, 

2555 scatter_dimension=scatter_dimension, 

2556 reduce_op=reduce_op, name=name) 

2557 ) 

2558 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2559 return _result 

2560 raise 

2561 _result = _outputs[:] 

2562 if _execute.must_record_gradient(): 

2563 _attrs = ("T", _op._get_attr_type("T"), "reduce_op", 

2564 _op.get_attr("reduce_op")) 

2565 _inputs_flat = _op.inputs 

2566 _execute.record_gradient( 

2567 "XlaReduceScatter", _inputs_flat, _attrs, _result) 

2568 _result, = _result 

2569 return _result 

2570 

2571XlaReduceScatter = tf_export("raw_ops.XlaReduceScatter")(_ops.to_raw_op(xla_reduce_scatter)) 

2572_dispatcher_for_xla_reduce_scatter = xla_reduce_scatter._tf_type_based_dispatcher.Dispatch 

2573 

2574 

2575def xla_reduce_scatter_eager_fallback(input, group_assignment, scatter_dimension, reduce_op, name, ctx): 

2576 reduce_op = _execute.make_str(reduce_op, "reduce_op") 

2577 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [_dtypes.half, _dtypes.bfloat16, _dtypes.float32, _dtypes.int32, _dtypes.uint32, ]) 

2578 group_assignment = _ops.convert_to_tensor(group_assignment, _dtypes.int32) 

2579 scatter_dimension = _ops.convert_to_tensor(scatter_dimension, _dtypes.int32) 

2580 _inputs_flat = [input, group_assignment, scatter_dimension] 

2581 _attrs = ("T", _attr_T, "reduce_op", reduce_op) 

2582 _result = _execute.execute(b"XlaReduceScatter", 1, inputs=_inputs_flat, 

2583 attrs=_attrs, ctx=ctx, name=name) 

2584 if _execute.must_record_gradient(): 

2585 _execute.record_gradient( 

2586 "XlaReduceScatter", _inputs_flat, _attrs, _result) 

2587 _result, = _result 

2588 return _result 

2589 

2590 

2591@_dispatch.add_fallback_dispatch_list 

2592@_dispatch.add_type_based_api_dispatcher 

2593@tf_export('xla_reduce_window') 

2594def xla_reduce_window(input, init_value, window_dimensions, window_strides, base_dilations, window_dilations, padding, computation, name=None): 

2595 r"""Wraps the XLA ReduceWindow operator, documented at 

2596 

2597 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 

2598 

2599 Args: 

2600 input: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. 

2601 the input tensor 

2602 init_value: A `Tensor`. Must have the same type as `input`. 

2603 a scalar representing the initial value for the reduction 

2604 window_dimensions: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

2605 the shape of the window 

2606 window_strides: A `Tensor`. Must have the same type as `window_dimensions`. 

2607 the inter-window strides 

2608 base_dilations: A `Tensor`. Must have the same type as `window_dimensions`. 

2609 window_dilations: A `Tensor`. Must have the same type as `window_dimensions`. 

2610 padding: A `Tensor`. Must have the same type as `window_dimensions`. 

2611 the padding to apply at the start and end of each input dimensions 

2612 computation: A function decorated with @Defun. a reducer function to apply 

2613 name: A name for the operation (optional). 

2614 

2615 Returns: 

2616 A `Tensor`. Has the same type as `input`. 

2617 """ 

2618 _ctx = _context._context or _context.context() 

2619 tld = _ctx._thread_local_data 

2620 if tld.is_eager: 

2621 try: 

2622 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2623 _ctx, "XlaReduceWindow", name, input, init_value, window_dimensions, 

2624 window_strides, base_dilations, window_dilations, padding, 

2625 "computation", computation) 

2626 return _result 

2627 except _core._NotOkStatusException as e: 

2628 _ops.raise_from_not_ok_status(e, name) 

2629 except _core._FallbackException: 

2630 pass 

2631 try: 

2632 _result = _dispatcher_for_xla_reduce_window( 

2633 (input, init_value, window_dimensions, window_strides, 

2634 base_dilations, window_dilations, padding, computation, name,), None) 

2635 if _result is not NotImplemented: 

2636 return _result 

2637 return xla_reduce_window_eager_fallback( 

2638 input, init_value, window_dimensions, window_strides, 

2639 base_dilations, window_dilations, padding, computation=computation, 

2640 name=name, ctx=_ctx) 

2641 except _core._SymbolicException: 

2642 pass # Add nodes to the TensorFlow graph. 

2643 except (TypeError, ValueError): 

2644 _result = _dispatch.dispatch( 

2645 xla_reduce_window, (), dict(input=input, init_value=init_value, 

2646 window_dimensions=window_dimensions, 

2647 window_strides=window_strides, 

2648 base_dilations=base_dilations, 

2649 window_dilations=window_dilations, 

2650 padding=padding, 

2651 computation=computation, name=name) 

2652 ) 

2653 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2654 return _result 

2655 raise 

2656 else: 

2657 _result = _dispatcher_for_xla_reduce_window( 

2658 (input, init_value, window_dimensions, window_strides, base_dilations, 

2659 window_dilations, padding, computation, name,), None) 

2660 if _result is not NotImplemented: 

2661 return _result 

2662 # Add nodes to the TensorFlow graph. 

2663 try: 

2664 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2665 "XlaReduceWindow", input=input, init_value=init_value, 

2666 window_dimensions=window_dimensions, 

2667 window_strides=window_strides, 

2668 base_dilations=base_dilations, 

2669 window_dilations=window_dilations, padding=padding, 

2670 computation=computation, name=name) 

2671 except (TypeError, ValueError): 

2672 _result = _dispatch.dispatch( 

2673 xla_reduce_window, (), dict(input=input, init_value=init_value, 

2674 window_dimensions=window_dimensions, 

2675 window_strides=window_strides, 

2676 base_dilations=base_dilations, 

2677 window_dilations=window_dilations, 

2678 padding=padding, 

2679 computation=computation, name=name) 

2680 ) 

2681 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2682 return _result 

2683 raise 

2684 _result = _outputs[:] 

2685 if _execute.must_record_gradient(): 

2686 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

2687 _op._get_attr_type("Tindices"), "computation", 

2688 _op.get_attr("computation")) 

2689 _inputs_flat = _op.inputs 

2690 _execute.record_gradient( 

2691 "XlaReduceWindow", _inputs_flat, _attrs, _result) 

2692 _result, = _result 

2693 return _result 

2694 

2695XlaReduceWindow = tf_export("raw_ops.XlaReduceWindow")(_ops.to_raw_op(xla_reduce_window)) 

2696_dispatcher_for_xla_reduce_window = xla_reduce_window._tf_type_based_dispatcher.Dispatch 

2697 

2698 

2699def xla_reduce_window_eager_fallback(input, init_value, window_dimensions, window_strides, base_dilations, window_dilations, padding, computation, name, ctx): 

2700 _attr_T, _inputs_T = _execute.args_to_matching_eager([input, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) 

2701 (input, init_value) = _inputs_T 

2702 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_dimensions, window_strides, base_dilations, window_dilations, padding], ctx, [_dtypes.int32, _dtypes.int64, ]) 

2703 (window_dimensions, window_strides, base_dilations, window_dilations, padding) = _inputs_Tindices 

2704 _inputs_flat = [input, init_value, window_dimensions, window_strides, base_dilations, window_dilations, padding] 

2705 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "computation", 

2706 computation) 

2707 _result = _execute.execute(b"XlaReduceWindow", 1, inputs=_inputs_flat, 

2708 attrs=_attrs, ctx=ctx, name=name) 

2709 if _execute.must_record_gradient(): 

2710 _execute.record_gradient( 

2711 "XlaReduceWindow", _inputs_flat, _attrs, _result) 

2712 _result, = _result 

2713 return _result 

2714 

2715 

2716@_dispatch.add_fallback_dispatch_list 

2717@_dispatch.add_type_based_api_dispatcher 

2718@tf_export('xla_remove_dynamic_dimension_size') 

2719def xla_remove_dynamic_dimension_size(input, dim_index, name=None): 

2720 r"""Inverse of XlaSetDynamicDimensionSize. 

2721 

2722 Make an xla bounded dynamic dimension into a static dimension. The bound of the 

2723 size of dimension `dim_index` becomes the static dimension size. 

2724 

2725 Args: 

2726 input: A `Tensor`. 

2727 dim_index: A `Tensor` of type `int32`. 

2728 name: A name for the operation (optional). 

2729 

2730 Returns: 

2731 A `Tensor`. Has the same type as `input`. 

2732 """ 

2733 _ctx = _context._context or _context.context() 

2734 tld = _ctx._thread_local_data 

2735 if tld.is_eager: 

2736 try: 

2737 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2738 _ctx, "XlaRemoveDynamicDimensionSize", name, input, dim_index) 

2739 return _result 

2740 except _core._NotOkStatusException as e: 

2741 _ops.raise_from_not_ok_status(e, name) 

2742 except _core._FallbackException: 

2743 pass 

2744 try: 

2745 _result = _dispatcher_for_xla_remove_dynamic_dimension_size( 

2746 (input, dim_index, name,), None) 

2747 if _result is not NotImplemented: 

2748 return _result 

2749 return xla_remove_dynamic_dimension_size_eager_fallback( 

2750 input, dim_index, name=name, ctx=_ctx) 

2751 except _core._SymbolicException: 

2752 pass # Add nodes to the TensorFlow graph. 

2753 except (TypeError, ValueError): 

2754 _result = _dispatch.dispatch( 

2755 xla_remove_dynamic_dimension_size, (), dict(input=input, 

2756 dim_index=dim_index, 

2757 name=name) 

2758 ) 

2759 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2760 return _result 

2761 raise 

2762 else: 

2763 _result = _dispatcher_for_xla_remove_dynamic_dimension_size( 

2764 (input, dim_index, name,), None) 

2765 if _result is not NotImplemented: 

2766 return _result 

2767 # Add nodes to the TensorFlow graph. 

2768 try: 

2769 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2770 "XlaRemoveDynamicDimensionSize", input=input, dim_index=dim_index, 

2771 name=name) 

2772 except (TypeError, ValueError): 

2773 _result = _dispatch.dispatch( 

2774 xla_remove_dynamic_dimension_size, (), dict(input=input, 

2775 dim_index=dim_index, 

2776 name=name) 

2777 ) 

2778 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2779 return _result 

2780 raise 

2781 _result = _outputs[:] 

2782 if _execute.must_record_gradient(): 

2783 _attrs = ("T", _op._get_attr_type("T")) 

2784 _inputs_flat = _op.inputs 

2785 _execute.record_gradient( 

2786 "XlaRemoveDynamicDimensionSize", _inputs_flat, _attrs, _result) 

2787 _result, = _result 

2788 return _result 

2789 

2790XlaRemoveDynamicDimensionSize = tf_export("raw_ops.XlaRemoveDynamicDimensionSize")(_ops.to_raw_op(xla_remove_dynamic_dimension_size)) 

2791_dispatcher_for_xla_remove_dynamic_dimension_size = xla_remove_dynamic_dimension_size._tf_type_based_dispatcher.Dispatch 

2792 

2793 

2794def xla_remove_dynamic_dimension_size_eager_fallback(input, dim_index, name, ctx): 

2795 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

2796 dim_index = _ops.convert_to_tensor(dim_index, _dtypes.int32) 

2797 _inputs_flat = [input, dim_index] 

2798 _attrs = ("T", _attr_T) 

2799 _result = _execute.execute(b"XlaRemoveDynamicDimensionSize", 1, 

2800 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

2801 name=name) 

2802 if _execute.must_record_gradient(): 

2803 _execute.record_gradient( 

2804 "XlaRemoveDynamicDimensionSize", _inputs_flat, _attrs, _result) 

2805 _result, = _result 

2806 return _result 

2807 

2808 

2809@_dispatch.add_fallback_dispatch_list 

2810@_dispatch.add_type_based_api_dispatcher 

2811@tf_export('xla_replica_id') 

2812def xla_replica_id(name=None): 

2813 r"""Replica ID. 

2814 

2815 Args: 

2816 name: A name for the operation (optional). 

2817 

2818 Returns: 

2819 A `Tensor` of type `int32`. 

2820 """ 

2821 _ctx = _context._context or _context.context() 

2822 tld = _ctx._thread_local_data 

2823 if tld.is_eager: 

2824 try: 

2825 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2826 _ctx, "XlaReplicaId", name) 

2827 return _result 

2828 except _core._NotOkStatusException as e: 

2829 _ops.raise_from_not_ok_status(e, name) 

2830 except _core._FallbackException: 

2831 pass 

2832 try: 

2833 _result = _dispatcher_for_xla_replica_id( 

2834 (name,), None) 

2835 if _result is not NotImplemented: 

2836 return _result 

2837 return xla_replica_id_eager_fallback( 

2838 name=name, ctx=_ctx) 

2839 except _core._SymbolicException: 

2840 pass # Add nodes to the TensorFlow graph. 

2841 except (TypeError, ValueError): 

2842 _result = _dispatch.dispatch( 

2843 xla_replica_id, (), dict(name=name) 

2844 ) 

2845 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2846 return _result 

2847 raise 

2848 else: 

2849 _result = _dispatcher_for_xla_replica_id( 

2850 (name,), None) 

2851 if _result is not NotImplemented: 

2852 return _result 

2853 # Add nodes to the TensorFlow graph. 

2854 try: 

2855 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2856 "XlaReplicaId", name=name) 

2857 except (TypeError, ValueError): 

2858 _result = _dispatch.dispatch( 

2859 xla_replica_id, (), dict(name=name) 

2860 ) 

2861 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2862 return _result 

2863 raise 

2864 _result = _outputs[:] 

2865 if _execute.must_record_gradient(): 

2866 _attrs = () 

2867 _inputs_flat = _op.inputs 

2868 _execute.record_gradient( 

2869 "XlaReplicaId", _inputs_flat, _attrs, _result) 

2870 _result, = _result 

2871 return _result 

2872 

2873XlaReplicaId = tf_export("raw_ops.XlaReplicaId")(_ops.to_raw_op(xla_replica_id)) 

2874_dispatcher_for_xla_replica_id = xla_replica_id._tf_type_based_dispatcher.Dispatch 

2875 

2876 

2877def xla_replica_id_eager_fallback(name, ctx): 

2878 _inputs_flat = [] 

2879 _attrs = None 

2880 _result = _execute.execute(b"XlaReplicaId", 1, inputs=_inputs_flat, 

2881 attrs=_attrs, ctx=ctx, name=name) 

2882 if _execute.must_record_gradient(): 

2883 _execute.record_gradient( 

2884 "XlaReplicaId", _inputs_flat, _attrs, _result) 

2885 _result, = _result 

2886 return _result 

2887 

2888_XlaRngBitGeneratorOutput = collections.namedtuple( 

2889 "XlaRngBitGenerator", 

2890 ["output_key", "output"]) 

2891 

2892 

2893@_dispatch.add_fallback_dispatch_list 

2894@_dispatch.add_type_based_api_dispatcher 

2895@tf_export('xla_rng_bit_generator') 

2896def xla_rng_bit_generator(algorithm, initial_state, shape, dtype=_dtypes.uint64, name=None): 

2897 r"""Stateless PRNG bit generator. 

2898 

2899 Wraps the XLA RngBitGenerator operator, documented at 

2900 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. 

2901 

2902 Args: 

2903 algorithm: A `Tensor` of type `int32`. The PRNG algorithm to use, one of 

2904 tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. 

2905 initial_state: A `Tensor` of type `uint64`. 

2906 Initial state for the PRNG algorithm. For THREEFRY, it should be 

2907 a u64[2] and for PHILOX a u64[3]. 

2908 shape: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

2909 The output shape of the generated data. 

2910 dtype: An optional `tf.DType` from: `tf.int32, tf.int64, tf.uint32, tf.uint64`. Defaults to `tf.uint64`. 

2911 The type of the tensor. 

2912 name: A name for the operation (optional). 

2913 

2914 Returns: 

2915 A tuple of `Tensor` objects (output_key, output). 

2916 

2917 output_key: A `Tensor` of type `uint64`. 

2918 output: A `Tensor` of type `dtype`. 

2919 """ 

2920 _ctx = _context._context or _context.context() 

2921 tld = _ctx._thread_local_data 

2922 if tld.is_eager: 

2923 try: 

2924 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

2925 _ctx, "XlaRngBitGenerator", name, algorithm, initial_state, shape, 

2926 "dtype", dtype) 

2927 _result = _XlaRngBitGeneratorOutput._make(_result) 

2928 return _result 

2929 except _core._NotOkStatusException as e: 

2930 _ops.raise_from_not_ok_status(e, name) 

2931 except _core._FallbackException: 

2932 pass 

2933 try: 

2934 _result = _dispatcher_for_xla_rng_bit_generator( 

2935 (algorithm, initial_state, shape, dtype, name,), None) 

2936 if _result is not NotImplemented: 

2937 return _result 

2938 return xla_rng_bit_generator_eager_fallback( 

2939 algorithm, initial_state, shape, dtype=dtype, name=name, ctx=_ctx) 

2940 except _core._SymbolicException: 

2941 pass # Add nodes to the TensorFlow graph. 

2942 except (TypeError, ValueError): 

2943 _result = _dispatch.dispatch( 

2944 xla_rng_bit_generator, (), dict(algorithm=algorithm, 

2945 initial_state=initial_state, 

2946 shape=shape, dtype=dtype, 

2947 name=name) 

2948 ) 

2949 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2950 return _result 

2951 raise 

2952 else: 

2953 _result = _dispatcher_for_xla_rng_bit_generator( 

2954 (algorithm, initial_state, shape, dtype, name,), None) 

2955 if _result is not NotImplemented: 

2956 return _result 

2957 # Add nodes to the TensorFlow graph. 

2958 if dtype is None: 

2959 dtype = _dtypes.uint64 

2960 dtype = _execute.make_type(dtype, "dtype") 

2961 try: 

2962 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

2963 "XlaRngBitGenerator", algorithm=algorithm, 

2964 initial_state=initial_state, shape=shape, 

2965 dtype=dtype, name=name) 

2966 except (TypeError, ValueError): 

2967 _result = _dispatch.dispatch( 

2968 xla_rng_bit_generator, (), dict(algorithm=algorithm, 

2969 initial_state=initial_state, 

2970 shape=shape, dtype=dtype, name=name) 

2971 ) 

2972 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

2973 return _result 

2974 raise 

2975 _result = _outputs[:] 

2976 if _execute.must_record_gradient(): 

2977 _attrs = ("dtype", _op._get_attr_type("dtype"), "Tshape", 

2978 _op._get_attr_type("Tshape")) 

2979 _inputs_flat = _op.inputs 

2980 _execute.record_gradient( 

2981 "XlaRngBitGenerator", _inputs_flat, _attrs, _result) 

2982 _result = _XlaRngBitGeneratorOutput._make(_result) 

2983 return _result 

2984 

2985XlaRngBitGenerator = tf_export("raw_ops.XlaRngBitGenerator")(_ops.to_raw_op(xla_rng_bit_generator)) 

2986_dispatcher_for_xla_rng_bit_generator = xla_rng_bit_generator._tf_type_based_dispatcher.Dispatch 

2987 

2988 

2989def xla_rng_bit_generator_eager_fallback(algorithm, initial_state, shape, dtype, name, ctx): 

2990 if dtype is None: 

2991 dtype = _dtypes.uint64 

2992 dtype = _execute.make_type(dtype, "dtype") 

2993 _attr_Tshape, (shape,) = _execute.args_to_matching_eager([shape], ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int32) 

2994 algorithm = _ops.convert_to_tensor(algorithm, _dtypes.int32) 

2995 initial_state = _ops.convert_to_tensor(initial_state, _dtypes.uint64) 

2996 _inputs_flat = [algorithm, initial_state, shape] 

2997 _attrs = ("dtype", dtype, "Tshape", _attr_Tshape) 

2998 _result = _execute.execute(b"XlaRngBitGenerator", 2, inputs=_inputs_flat, 

2999 attrs=_attrs, ctx=ctx, name=name) 

3000 if _execute.must_record_gradient(): 

3001 _execute.record_gradient( 

3002 "XlaRngBitGenerator", _inputs_flat, _attrs, _result) 

3003 _result = _XlaRngBitGeneratorOutput._make(_result) 

3004 return _result 

3005 

3006 

3007@_dispatch.add_fallback_dispatch_list 

3008@_dispatch.add_type_based_api_dispatcher 

3009@tf_export('xla_scatter') 

3010def xla_scatter(operand, scatter_indices, updates, update_computation, dimension_numbers, indices_are_sorted, name=None): 

3011 r"""Wraps the XLA Scatter operator documented at 

3012 

3013 https://www.tensorflow.org/xla/operation_semantics#scatter. 

3014 

3015 Args: 

3016 operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. 

3017 Array to be scattered into. 

3018 scatter_indices: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

3019 Array containing the starting indices of the slices that must 

3020 be scattered to. 

3021 updates: A `Tensor`. Must have the same type as `operand`. 

3022 Array containing the values that must be used for scattering. 

3023 update_computation: A function decorated with @Defun. 

3024 Computation to be used for combining the existing values in 

3025 the input array and the updates during scatter. 

3026 dimension_numbers: A `string`. 

3027 A serialized xla::ScatterDimensionNumbers proto. 

3028 indices_are_sorted: A `bool`. 

3029 Boolean indicating if the indices are sorted. 

3030 name: A name for the operation (optional). 

3031 

3032 Returns: 

3033 A `Tensor`. Has the same type as `operand`. 

3034 """ 

3035 _ctx = _context._context or _context.context() 

3036 tld = _ctx._thread_local_data 

3037 if tld.is_eager: 

3038 try: 

3039 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3040 _ctx, "XlaScatter", name, operand, scatter_indices, updates, 

3041 "update_computation", update_computation, "dimension_numbers", 

3042 dimension_numbers, "indices_are_sorted", indices_are_sorted) 

3043 return _result 

3044 except _core._NotOkStatusException as e: 

3045 _ops.raise_from_not_ok_status(e, name) 

3046 except _core._FallbackException: 

3047 pass 

3048 try: 

3049 _result = _dispatcher_for_xla_scatter( 

3050 (operand, scatter_indices, updates, update_computation, 

3051 dimension_numbers, indices_are_sorted, name,), None) 

3052 if _result is not NotImplemented: 

3053 return _result 

3054 return xla_scatter_eager_fallback( 

3055 operand, scatter_indices, updates, 

3056 update_computation=update_computation, 

3057 dimension_numbers=dimension_numbers, 

3058 indices_are_sorted=indices_are_sorted, name=name, ctx=_ctx) 

3059 except _core._SymbolicException: 

3060 pass # Add nodes to the TensorFlow graph. 

3061 except (TypeError, ValueError): 

3062 _result = _dispatch.dispatch( 

3063 xla_scatter, (), dict(operand=operand, 

3064 scatter_indices=scatter_indices, 

3065 updates=updates, 

3066 update_computation=update_computation, 

3067 dimension_numbers=dimension_numbers, 

3068 indices_are_sorted=indices_are_sorted, 

3069 name=name) 

3070 ) 

3071 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3072 return _result 

3073 raise 

3074 else: 

3075 _result = _dispatcher_for_xla_scatter( 

3076 (operand, scatter_indices, updates, update_computation, 

3077 dimension_numbers, indices_are_sorted, name,), None) 

3078 if _result is not NotImplemented: 

3079 return _result 

3080 # Add nodes to the TensorFlow graph. 

3081 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

3082 indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") 

3083 try: 

3084 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3085 "XlaScatter", operand=operand, scatter_indices=scatter_indices, 

3086 updates=updates, update_computation=update_computation, 

3087 dimension_numbers=dimension_numbers, 

3088 indices_are_sorted=indices_are_sorted, name=name) 

3089 except (TypeError, ValueError): 

3090 _result = _dispatch.dispatch( 

3091 xla_scatter, (), dict(operand=operand, 

3092 scatter_indices=scatter_indices, 

3093 updates=updates, 

3094 update_computation=update_computation, 

3095 dimension_numbers=dimension_numbers, 

3096 indices_are_sorted=indices_are_sorted, 

3097 name=name) 

3098 ) 

3099 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3100 return _result 

3101 raise 

3102 _result = _outputs[:] 

3103 if _execute.must_record_gradient(): 

3104 _attrs = ("update_computation", _op.get_attr("update_computation"), 

3105 "dimension_numbers", _op.get_attr("dimension_numbers"), 

3106 "indices_are_sorted", _op._get_attr_bool("indices_are_sorted"), 

3107 "T", _op._get_attr_type("T"), "Tindices", 

3108 _op._get_attr_type("Tindices")) 

3109 _inputs_flat = _op.inputs 

3110 _execute.record_gradient( 

3111 "XlaScatter", _inputs_flat, _attrs, _result) 

3112 _result, = _result 

3113 return _result 

3114 

3115XlaScatter = tf_export("raw_ops.XlaScatter")(_ops.to_raw_op(xla_scatter)) 

3116_dispatcher_for_xla_scatter = xla_scatter._tf_type_based_dispatcher.Dispatch 

3117 

3118 

3119def xla_scatter_eager_fallback(operand, scatter_indices, updates, update_computation, dimension_numbers, indices_are_sorted, name, ctx): 

3120 dimension_numbers = _execute.make_str(dimension_numbers, "dimension_numbers") 

3121 indices_are_sorted = _execute.make_bool(indices_are_sorted, "indices_are_sorted") 

3122 _attr_T, _inputs_T = _execute.args_to_matching_eager([operand, updates], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) 

3123 (operand, updates) = _inputs_T 

3124 _attr_Tindices, (scatter_indices,) = _execute.args_to_matching_eager([scatter_indices], ctx, [_dtypes.int32, _dtypes.int64, ]) 

3125 _inputs_flat = [operand, scatter_indices, updates] 

3126 _attrs = ("update_computation", update_computation, "dimension_numbers", 

3127 dimension_numbers, "indices_are_sorted", indices_are_sorted, "T", _attr_T, 

3128 "Tindices", _attr_Tindices) 

3129 _result = _execute.execute(b"XlaScatter", 1, inputs=_inputs_flat, 

3130 attrs=_attrs, ctx=ctx, name=name) 

3131 if _execute.must_record_gradient(): 

3132 _execute.record_gradient( 

3133 "XlaScatter", _inputs_flat, _attrs, _result) 

3134 _result, = _result 

3135 return _result 

3136 

3137 

3138@_dispatch.add_fallback_dispatch_list 

3139@_dispatch.add_type_based_api_dispatcher 

3140@tf_export('xla_select_and_scatter') 

3141def xla_select_and_scatter(operand, window_dimensions, window_strides, padding, source, init_value, select, scatter, name=None): 

3142 r"""Wraps the XLA SelectAndScatter operator, documented at 

3143 

3144 https://www.tensorflow.org/performance/xla/operation_semantics#selectandscatter 

3145 . 

3146 

3147 Args: 

3148 operand: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

3149 the input tensor 

3150 window_dimensions: A `Tensor`. Must be one of the following types: `int32`, `int64`. 

3151 the shape of the window 

3152 window_strides: A `Tensor`. Must have the same type as `window_dimensions`. 

3153 the inter-window strides 

3154 padding: A `Tensor`. Must have the same type as `window_dimensions`. 

3155 the padding to apply at the start and end of each input dimensions 

3156 source: A `Tensor`. Must have the same type as `operand`. 

3157 a tensor of values to scatter 

3158 init_value: A `Tensor`. Must have the same type as `operand`. 

3159 a scalar representing the initial value for the output tensor 

3160 select: A function decorated with @Defun. a selection function to apply 

3161 scatter: A function decorated with @Defun. a scatter function to apply 

3162 name: A name for the operation (optional). 

3163 

3164 Returns: 

3165 A `Tensor`. Has the same type as `operand`. 

3166 """ 

3167 _ctx = _context._context or _context.context() 

3168 tld = _ctx._thread_local_data 

3169 if tld.is_eager: 

3170 try: 

3171 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3172 _ctx, "XlaSelectAndScatter", name, operand, window_dimensions, 

3173 window_strides, padding, source, init_value, "select", select, 

3174 "scatter", scatter) 

3175 return _result 

3176 except _core._NotOkStatusException as e: 

3177 _ops.raise_from_not_ok_status(e, name) 

3178 except _core._FallbackException: 

3179 pass 

3180 try: 

3181 _result = _dispatcher_for_xla_select_and_scatter( 

3182 (operand, window_dimensions, window_strides, padding, source, 

3183 init_value, select, scatter, name,), None) 

3184 if _result is not NotImplemented: 

3185 return _result 

3186 return xla_select_and_scatter_eager_fallback( 

3187 operand, window_dimensions, window_strides, padding, source, 

3188 init_value, select=select, scatter=scatter, name=name, ctx=_ctx) 

3189 except _core._SymbolicException: 

3190 pass # Add nodes to the TensorFlow graph. 

3191 except (TypeError, ValueError): 

3192 _result = _dispatch.dispatch( 

3193 xla_select_and_scatter, (), dict(operand=operand, 

3194 window_dimensions=window_dimensions, 

3195 window_strides=window_strides, 

3196 padding=padding, source=source, 

3197 init_value=init_value, 

3198 select=select, scatter=scatter, 

3199 name=name) 

3200 ) 

3201 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3202 return _result 

3203 raise 

3204 else: 

3205 _result = _dispatcher_for_xla_select_and_scatter( 

3206 (operand, window_dimensions, window_strides, padding, source, 

3207 init_value, select, scatter, name,), None) 

3208 if _result is not NotImplemented: 

3209 return _result 

3210 # Add nodes to the TensorFlow graph. 

3211 try: 

3212 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3213 "XlaSelectAndScatter", operand=operand, 

3214 window_dimensions=window_dimensions, 

3215 window_strides=window_strides, padding=padding, 

3216 source=source, init_value=init_value, 

3217 select=select, scatter=scatter, name=name) 

3218 except (TypeError, ValueError): 

3219 _result = _dispatch.dispatch( 

3220 xla_select_and_scatter, (), dict(operand=operand, 

3221 window_dimensions=window_dimensions, 

3222 window_strides=window_strides, 

3223 padding=padding, source=source, 

3224 init_value=init_value, 

3225 select=select, scatter=scatter, 

3226 name=name) 

3227 ) 

3228 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3229 return _result 

3230 raise 

3231 _result = _outputs[:] 

3232 if _execute.must_record_gradient(): 

3233 _attrs = ("T", _op._get_attr_type("T"), "Tindices", 

3234 _op._get_attr_type("Tindices"), "select", 

3235 _op.get_attr("select"), "scatter", _op.get_attr("scatter")) 

3236 _inputs_flat = _op.inputs 

3237 _execute.record_gradient( 

3238 "XlaSelectAndScatter", _inputs_flat, _attrs, _result) 

3239 _result, = _result 

3240 return _result 

3241 

3242XlaSelectAndScatter = tf_export("raw_ops.XlaSelectAndScatter")(_ops.to_raw_op(xla_select_and_scatter)) 

3243_dispatcher_for_xla_select_and_scatter = xla_select_and_scatter._tf_type_based_dispatcher.Dispatch 

3244 

3245 

3246def xla_select_and_scatter_eager_fallback(operand, window_dimensions, window_strides, padding, source, init_value, select, scatter, name, ctx): 

3247 _attr_T, _inputs_T = _execute.args_to_matching_eager([operand, source, init_value], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

3248 (operand, source, init_value) = _inputs_T 

3249 _attr_Tindices, _inputs_Tindices = _execute.args_to_matching_eager([window_dimensions, window_strides, padding], ctx, [_dtypes.int32, _dtypes.int64, ]) 

3250 (window_dimensions, window_strides, padding) = _inputs_Tindices 

3251 _inputs_flat = [operand, window_dimensions, window_strides, padding, source, init_value] 

3252 _attrs = ("T", _attr_T, "Tindices", _attr_Tindices, "select", select, 

3253 "scatter", scatter) 

3254 _result = _execute.execute(b"XlaSelectAndScatter", 1, inputs=_inputs_flat, 

3255 attrs=_attrs, ctx=ctx, name=name) 

3256 if _execute.must_record_gradient(): 

3257 _execute.record_gradient( 

3258 "XlaSelectAndScatter", _inputs_flat, _attrs, _result) 

3259 _result, = _result 

3260 return _result 

3261 

3262_XlaSelfAdjointEigOutput = collections.namedtuple( 

3263 "XlaSelfAdjointEig", 

3264 ["w", "v"]) 

3265 

3266 

3267@_dispatch.add_fallback_dispatch_list 

3268@_dispatch.add_type_based_api_dispatcher 

3269@tf_export('xla_self_adjoint_eig') 

3270def xla_self_adjoint_eig(a, lower, max_iter, epsilon, name=None): 

3271 r"""Computes the eigen decomposition of a batch of self-adjoint matrices 

3272 

3273 (Note: Only real inputs are supported). 

3274 

3275 Computes the eigenvalues and eigenvectors of the innermost N-by-N matrices in 

3276 tensor such that tensor[...,:,:] * v[..., :,i] = e[..., i] * v[...,:,i], for 

3277 i=0...N-1. 

3278 

3279 Args: 

3280 a: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

3281 the input tensor. 

3282 lower: A `bool`. 

3283 a boolean specifies whether the calculation is done with the lower 

3284 triangular part or the upper triangular part. 

3285 max_iter: An `int`. 

3286 maximum number of sweep update, i.e., the whole lower triangular 

3287 part or upper triangular part based on parameter lower. Heuristically, it has 

3288 been argued that approximately logN sweeps are needed in practice (Ref: Golub & 

3289 van Loan "Matrix Computation"). 

3290 epsilon: A `float`. the tolerance ratio. 

3291 name: A name for the operation (optional). 

3292 

3293 Returns: 

3294 A tuple of `Tensor` objects (w, v). 

3295 

3296 w: A `Tensor`. Has the same type as `a`. The eigenvalues in ascending order, each repeated according to its 

3297 multiplicity. 

3298 v: A `Tensor`. Has the same type as `a`. The column v[..., :, i] is the normalized eigenvector corresponding to the 

3299 eigenvalue w[..., i]. 

3300 """ 

3301 _ctx = _context._context or _context.context() 

3302 tld = _ctx._thread_local_data 

3303 if tld.is_eager: 

3304 try: 

3305 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3306 _ctx, "XlaSelfAdjointEig", name, a, "lower", lower, "max_iter", 

3307 max_iter, "epsilon", epsilon) 

3308 _result = _XlaSelfAdjointEigOutput._make(_result) 

3309 return _result 

3310 except _core._NotOkStatusException as e: 

3311 _ops.raise_from_not_ok_status(e, name) 

3312 except _core._FallbackException: 

3313 pass 

3314 try: 

3315 _result = _dispatcher_for_xla_self_adjoint_eig( 

3316 (a, lower, max_iter, epsilon, name,), None) 

3317 if _result is not NotImplemented: 

3318 return _result 

3319 return xla_self_adjoint_eig_eager_fallback( 

3320 a, lower=lower, max_iter=max_iter, epsilon=epsilon, name=name, 

3321 ctx=_ctx) 

3322 except _core._SymbolicException: 

3323 pass # Add nodes to the TensorFlow graph. 

3324 except (TypeError, ValueError): 

3325 _result = _dispatch.dispatch( 

3326 xla_self_adjoint_eig, (), dict(a=a, lower=lower, 

3327 max_iter=max_iter, epsilon=epsilon, 

3328 name=name) 

3329 ) 

3330 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3331 return _result 

3332 raise 

3333 else: 

3334 _result = _dispatcher_for_xla_self_adjoint_eig( 

3335 (a, lower, max_iter, epsilon, name,), None) 

3336 if _result is not NotImplemented: 

3337 return _result 

3338 # Add nodes to the TensorFlow graph. 

3339 lower = _execute.make_bool(lower, "lower") 

3340 max_iter = _execute.make_int(max_iter, "max_iter") 

3341 epsilon = _execute.make_float(epsilon, "epsilon") 

3342 try: 

3343 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3344 "XlaSelfAdjointEig", a=a, lower=lower, max_iter=max_iter, 

3345 epsilon=epsilon, name=name) 

3346 except (TypeError, ValueError): 

3347 _result = _dispatch.dispatch( 

3348 xla_self_adjoint_eig, (), dict(a=a, lower=lower, max_iter=max_iter, 

3349 epsilon=epsilon, name=name) 

3350 ) 

3351 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3352 return _result 

3353 raise 

3354 _result = _outputs[:] 

3355 if _execute.must_record_gradient(): 

3356 _attrs = ("lower", _op._get_attr_bool("lower"), "max_iter", 

3357 _op._get_attr_int("max_iter"), "epsilon", 

3358 _op.get_attr("epsilon"), "T", _op._get_attr_type("T")) 

3359 _inputs_flat = _op.inputs 

3360 _execute.record_gradient( 

3361 "XlaSelfAdjointEig", _inputs_flat, _attrs, _result) 

3362 _result = _XlaSelfAdjointEigOutput._make(_result) 

3363 return _result 

3364 

3365XlaSelfAdjointEig = tf_export("raw_ops.XlaSelfAdjointEig")(_ops.to_raw_op(xla_self_adjoint_eig)) 

3366_dispatcher_for_xla_self_adjoint_eig = xla_self_adjoint_eig._tf_type_based_dispatcher.Dispatch 

3367 

3368 

3369def xla_self_adjoint_eig_eager_fallback(a, lower, max_iter, epsilon, name, ctx): 

3370 lower = _execute.make_bool(lower, "lower") 

3371 max_iter = _execute.make_int(max_iter, "max_iter") 

3372 epsilon = _execute.make_float(epsilon, "epsilon") 

3373 _attr_T, (a,) = _execute.args_to_matching_eager([a], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

3374 _inputs_flat = [a] 

3375 _attrs = ("lower", lower, "max_iter", max_iter, "epsilon", epsilon, "T", 

3376 _attr_T) 

3377 _result = _execute.execute(b"XlaSelfAdjointEig", 2, inputs=_inputs_flat, 

3378 attrs=_attrs, ctx=ctx, name=name) 

3379 if _execute.must_record_gradient(): 

3380 _execute.record_gradient( 

3381 "XlaSelfAdjointEig", _inputs_flat, _attrs, _result) 

3382 _result = _XlaSelfAdjointEigOutput._make(_result) 

3383 return _result 

3384 

3385 

3386@_dispatch.add_fallback_dispatch_list 

3387@_dispatch.add_type_based_api_dispatcher 

3388@tf_export('xla_send') 

3389def xla_send(tensor, tensor_name, name=None): 

3390 r"""Sends the named tensor to another XLA computation. Wraps the XLA Send operator 

3391 

3392 documented at 

3393 https://www.tensorflow.org/performance/xla/operation_semantics#send . 

3394 

3395 Args: 

3396 tensor: A `Tensor`. The tensor to send. 

3397 tensor_name: A `string`. A string key that identifies the channel. 

3398 name: A name for the operation (optional). 

3399 

3400 Returns: 

3401 The created Operation. 

3402 """ 

3403 _ctx = _context._context or _context.context() 

3404 tld = _ctx._thread_local_data 

3405 if tld.is_eager: 

3406 try: 

3407 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3408 _ctx, "XlaSend", name, tensor, "tensor_name", tensor_name) 

3409 return _result 

3410 except _core._NotOkStatusException as e: 

3411 _ops.raise_from_not_ok_status(e, name) 

3412 except _core._FallbackException: 

3413 pass 

3414 try: 

3415 _result = _dispatcher_for_xla_send( 

3416 (tensor, tensor_name, name,), None) 

3417 if _result is not NotImplemented: 

3418 return _result 

3419 return xla_send_eager_fallback( 

3420 tensor, tensor_name=tensor_name, name=name, ctx=_ctx) 

3421 except _core._SymbolicException: 

3422 pass # Add nodes to the TensorFlow graph. 

3423 except (TypeError, ValueError): 

3424 _result = _dispatch.dispatch( 

3425 xla_send, (), dict(tensor=tensor, tensor_name=tensor_name, 

3426 name=name) 

3427 ) 

3428 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3429 return _result 

3430 raise 

3431 else: 

3432 _result = _dispatcher_for_xla_send( 

3433 (tensor, tensor_name, name,), None) 

3434 if _result is not NotImplemented: 

3435 return _result 

3436 # Add nodes to the TensorFlow graph. 

3437 tensor_name = _execute.make_str(tensor_name, "tensor_name") 

3438 try: 

3439 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3440 "XlaSend", tensor=tensor, tensor_name=tensor_name, name=name) 

3441 except (TypeError, ValueError): 

3442 _result = _dispatch.dispatch( 

3443 xla_send, (), dict(tensor=tensor, tensor_name=tensor_name, 

3444 name=name) 

3445 ) 

3446 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3447 return _result 

3448 raise 

3449 return _op 

3450XlaSend = tf_export("raw_ops.XlaSend")(_ops.to_raw_op(xla_send)) 

3451_dispatcher_for_xla_send = xla_send._tf_type_based_dispatcher.Dispatch 

3452 

3453 

3454def xla_send_eager_fallback(tensor, tensor_name, name, ctx): 

3455 tensor_name = _execute.make_str(tensor_name, "tensor_name") 

3456 _attr_T, (tensor,) = _execute.args_to_matching_eager([tensor], ctx, []) 

3457 _inputs_flat = [tensor] 

3458 _attrs = ("T", _attr_T, "tensor_name", tensor_name) 

3459 _result = _execute.execute(b"XlaSend", 0, inputs=_inputs_flat, attrs=_attrs, 

3460 ctx=ctx, name=name) 

3461 _result = None 

3462 return _result 

3463 

3464 

3465@_dispatch.add_fallback_dispatch_list 

3466@_dispatch.add_type_based_api_dispatcher 

3467@tf_export('xla_set_bound') 

3468def xla_set_bound(input, bound, name=None): 

3469 r"""Set a bound for the given input value as a hint to Xla compiler, 

3470 

3471 returns the same value. 

3472 

3473 Args: 

3474 input: A `Tensor` of type `int32`. 

3475 bound: A `Tensor` of type `int32`. 

3476 name: A name for the operation (optional). 

3477 

3478 Returns: 

3479 A `Tensor` of type `int32`. 

3480 """ 

3481 _ctx = _context._context or _context.context() 

3482 tld = _ctx._thread_local_data 

3483 if tld.is_eager: 

3484 try: 

3485 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3486 _ctx, "XlaSetBound", name, input, bound) 

3487 return _result 

3488 except _core._NotOkStatusException as e: 

3489 _ops.raise_from_not_ok_status(e, name) 

3490 except _core._FallbackException: 

3491 pass 

3492 try: 

3493 _result = _dispatcher_for_xla_set_bound( 

3494 (input, bound, name,), None) 

3495 if _result is not NotImplemented: 

3496 return _result 

3497 return xla_set_bound_eager_fallback( 

3498 input, bound, name=name, ctx=_ctx) 

3499 except _core._SymbolicException: 

3500 pass # Add nodes to the TensorFlow graph. 

3501 except (TypeError, ValueError): 

3502 _result = _dispatch.dispatch( 

3503 xla_set_bound, (), dict(input=input, bound=bound, name=name) 

3504 ) 

3505 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3506 return _result 

3507 raise 

3508 else: 

3509 _result = _dispatcher_for_xla_set_bound( 

3510 (input, bound, name,), None) 

3511 if _result is not NotImplemented: 

3512 return _result 

3513 # Add nodes to the TensorFlow graph. 

3514 try: 

3515 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3516 "XlaSetBound", input=input, bound=bound, name=name) 

3517 except (TypeError, ValueError): 

3518 _result = _dispatch.dispatch( 

3519 xla_set_bound, (), dict(input=input, bound=bound, name=name) 

3520 ) 

3521 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3522 return _result 

3523 raise 

3524 _result = _outputs[:] 

3525 if _execute.must_record_gradient(): 

3526 _attrs = () 

3527 _inputs_flat = _op.inputs 

3528 _execute.record_gradient( 

3529 "XlaSetBound", _inputs_flat, _attrs, _result) 

3530 _result, = _result 

3531 return _result 

3532 

3533XlaSetBound = tf_export("raw_ops.XlaSetBound")(_ops.to_raw_op(xla_set_bound)) 

3534_dispatcher_for_xla_set_bound = xla_set_bound._tf_type_based_dispatcher.Dispatch 

3535 

3536 

3537def xla_set_bound_eager_fallback(input, bound, name, ctx): 

3538 input = _ops.convert_to_tensor(input, _dtypes.int32) 

3539 bound = _ops.convert_to_tensor(bound, _dtypes.int32) 

3540 _inputs_flat = [input, bound] 

3541 _attrs = None 

3542 _result = _execute.execute(b"XlaSetBound", 1, inputs=_inputs_flat, 

3543 attrs=_attrs, ctx=ctx, name=name) 

3544 if _execute.must_record_gradient(): 

3545 _execute.record_gradient( 

3546 "XlaSetBound", _inputs_flat, _attrs, _result) 

3547 _result, = _result 

3548 return _result 

3549 

3550 

3551@_dispatch.add_fallback_dispatch_list 

3552@_dispatch.add_type_based_api_dispatcher 

3553@tf_export('xla_set_dynamic_dimension_size') 

3554def xla_set_dynamic_dimension_size(input, dim_index, size, name=None): 

3555 r"""Make a static dimension into a xla bounded dynamic dimension. 

3556 

3557 The current static dimension size will become the bound and the second 

3558 operand becomes the dynamic size of the dimension. 

3559 

3560 Args: 

3561 input: A `Tensor`. 

3562 dim_index: A `Tensor` of type `int32`. 

3563 size: A `Tensor` of type `int32`. 

3564 name: A name for the operation (optional). 

3565 

3566 Returns: 

3567 A `Tensor`. Has the same type as `input`. 

3568 """ 

3569 _ctx = _context._context or _context.context() 

3570 tld = _ctx._thread_local_data 

3571 if tld.is_eager: 

3572 try: 

3573 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3574 _ctx, "XlaSetDynamicDimensionSize", name, input, dim_index, size) 

3575 return _result 

3576 except _core._NotOkStatusException as e: 

3577 _ops.raise_from_not_ok_status(e, name) 

3578 except _core._FallbackException: 

3579 pass 

3580 try: 

3581 _result = _dispatcher_for_xla_set_dynamic_dimension_size( 

3582 (input, dim_index, size, name,), None) 

3583 if _result is not NotImplemented: 

3584 return _result 

3585 return xla_set_dynamic_dimension_size_eager_fallback( 

3586 input, dim_index, size, name=name, ctx=_ctx) 

3587 except _core._SymbolicException: 

3588 pass # Add nodes to the TensorFlow graph. 

3589 except (TypeError, ValueError): 

3590 _result = _dispatch.dispatch( 

3591 xla_set_dynamic_dimension_size, (), dict(input=input, 

3592 dim_index=dim_index, 

3593 size=size, name=name) 

3594 ) 

3595 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3596 return _result 

3597 raise 

3598 else: 

3599 _result = _dispatcher_for_xla_set_dynamic_dimension_size( 

3600 (input, dim_index, size, name,), None) 

3601 if _result is not NotImplemented: 

3602 return _result 

3603 # Add nodes to the TensorFlow graph. 

3604 try: 

3605 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3606 "XlaSetDynamicDimensionSize", input=input, dim_index=dim_index, 

3607 size=size, name=name) 

3608 except (TypeError, ValueError): 

3609 _result = _dispatch.dispatch( 

3610 xla_set_dynamic_dimension_size, (), dict(input=input, 

3611 dim_index=dim_index, 

3612 size=size, name=name) 

3613 ) 

3614 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3615 return _result 

3616 raise 

3617 _result = _outputs[:] 

3618 if _execute.must_record_gradient(): 

3619 _attrs = ("T", _op._get_attr_type("T")) 

3620 _inputs_flat = _op.inputs 

3621 _execute.record_gradient( 

3622 "XlaSetDynamicDimensionSize", _inputs_flat, _attrs, _result) 

3623 _result, = _result 

3624 return _result 

3625 

3626XlaSetDynamicDimensionSize = tf_export("raw_ops.XlaSetDynamicDimensionSize")(_ops.to_raw_op(xla_set_dynamic_dimension_size)) 

3627_dispatcher_for_xla_set_dynamic_dimension_size = xla_set_dynamic_dimension_size._tf_type_based_dispatcher.Dispatch 

3628 

3629 

3630def xla_set_dynamic_dimension_size_eager_fallback(input, dim_index, size, name, ctx): 

3631 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

3632 dim_index = _ops.convert_to_tensor(dim_index, _dtypes.int32) 

3633 size = _ops.convert_to_tensor(size, _dtypes.int32) 

3634 _inputs_flat = [input, dim_index, size] 

3635 _attrs = ("T", _attr_T) 

3636 _result = _execute.execute(b"XlaSetDynamicDimensionSize", 1, 

3637 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

3638 name=name) 

3639 if _execute.must_record_gradient(): 

3640 _execute.record_gradient( 

3641 "XlaSetDynamicDimensionSize", _inputs_flat, _attrs, _result) 

3642 _result, = _result 

3643 return _result 

3644 

3645 

3646@_dispatch.add_fallback_dispatch_list 

3647@_dispatch.add_type_based_api_dispatcher 

3648@tf_export('xla_sharding') 

3649def xla_sharding(input, sharding="", unspecified_dims=[], name=None): 

3650 r"""An op which shards the input based on the given sharding attribute. It can 

3651 

3652 selectively annotate a subset of tensor dimensions by skipping unspecified_dims, 

3653 and the sharding annotation should be replicated in those dims. 

3654 

3655 Args: 

3656 input: A `Tensor`. 

3657 sharding: An optional `string`. Defaults to `""`. 

3658 unspecified_dims: An optional list of `ints`. Defaults to `[]`. 

3659 name: A name for the operation (optional). 

3660 

3661 Returns: 

3662 A `Tensor`. Has the same type as `input`. 

3663 """ 

3664 _ctx = _context._context or _context.context() 

3665 tld = _ctx._thread_local_data 

3666 if tld.is_eager: 

3667 try: 

3668 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3669 _ctx, "XlaSharding", name, input, "sharding", sharding, 

3670 "unspecified_dims", unspecified_dims) 

3671 return _result 

3672 except _core._NotOkStatusException as e: 

3673 _ops.raise_from_not_ok_status(e, name) 

3674 except _core._FallbackException: 

3675 pass 

3676 try: 

3677 _result = _dispatcher_for_xla_sharding( 

3678 (input, sharding, unspecified_dims, name,), None) 

3679 if _result is not NotImplemented: 

3680 return _result 

3681 return xla_sharding_eager_fallback( 

3682 input, sharding=sharding, unspecified_dims=unspecified_dims, 

3683 name=name, ctx=_ctx) 

3684 except _core._SymbolicException: 

3685 pass # Add nodes to the TensorFlow graph. 

3686 except (TypeError, ValueError): 

3687 _result = _dispatch.dispatch( 

3688 xla_sharding, (), dict(input=input, sharding=sharding, 

3689 unspecified_dims=unspecified_dims, 

3690 name=name) 

3691 ) 

3692 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3693 return _result 

3694 raise 

3695 else: 

3696 _result = _dispatcher_for_xla_sharding( 

3697 (input, sharding, unspecified_dims, name,), None) 

3698 if _result is not NotImplemented: 

3699 return _result 

3700 # Add nodes to the TensorFlow graph. 

3701 if sharding is None: 

3702 sharding = "" 

3703 sharding = _execute.make_str(sharding, "sharding") 

3704 if unspecified_dims is None: 

3705 unspecified_dims = [] 

3706 if not isinstance(unspecified_dims, (list, tuple)): 

3707 raise TypeError( 

3708 "Expected list for 'unspecified_dims' argument to " 

3709 "'xla_sharding' Op, not %r." % unspecified_dims) 

3710 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

3711 try: 

3712 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3713 "XlaSharding", input=input, sharding=sharding, 

3714 unspecified_dims=unspecified_dims, name=name) 

3715 except (TypeError, ValueError): 

3716 _result = _dispatch.dispatch( 

3717 xla_sharding, (), dict(input=input, sharding=sharding, 

3718 unspecified_dims=unspecified_dims, name=name) 

3719 ) 

3720 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3721 return _result 

3722 raise 

3723 _result = _outputs[:] 

3724 if _execute.must_record_gradient(): 

3725 _attrs = ("T", _op._get_attr_type("T"), "sharding", 

3726 _op.get_attr("sharding"), "unspecified_dims", 

3727 _op.get_attr("unspecified_dims")) 

3728 _inputs_flat = _op.inputs 

3729 _execute.record_gradient( 

3730 "XlaSharding", _inputs_flat, _attrs, _result) 

3731 _result, = _result 

3732 return _result 

3733 

3734XlaSharding = tf_export("raw_ops.XlaSharding")(_ops.to_raw_op(xla_sharding)) 

3735_dispatcher_for_xla_sharding = xla_sharding._tf_type_based_dispatcher.Dispatch 

3736 

3737 

3738def xla_sharding_eager_fallback(input, sharding, unspecified_dims, name, ctx): 

3739 if sharding is None: 

3740 sharding = "" 

3741 sharding = _execute.make_str(sharding, "sharding") 

3742 if unspecified_dims is None: 

3743 unspecified_dims = [] 

3744 if not isinstance(unspecified_dims, (list, tuple)): 

3745 raise TypeError( 

3746 "Expected list for 'unspecified_dims' argument to " 

3747 "'xla_sharding' Op, not %r." % unspecified_dims) 

3748 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

3749 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

3750 _inputs_flat = [input] 

3751 _attrs = ("T", _attr_T, "sharding", sharding, "unspecified_dims", 

3752 unspecified_dims) 

3753 _result = _execute.execute(b"XlaSharding", 1, inputs=_inputs_flat, 

3754 attrs=_attrs, ctx=ctx, name=name) 

3755 if _execute.must_record_gradient(): 

3756 _execute.record_gradient( 

3757 "XlaSharding", _inputs_flat, _attrs, _result) 

3758 _result, = _result 

3759 return _result 

3760 

3761 

3762@_dispatch.add_fallback_dispatch_list 

3763@_dispatch.add_type_based_api_dispatcher 

3764@tf_export('xla_sort') 

3765def xla_sort(input, name=None): 

3766 r"""Wraps the XLA Sort operator, documented at 

3767 

3768 https://www.tensorflow.org/performance/xla/operation_semantics#sort 

3769 . 

3770 

3771 Sorts a tensor. Currently only sorts in ascending order are supported. 

3772 

3773 Args: 

3774 input: A `Tensor`. A `Tensor` of type T. 

3775 name: A name for the operation (optional). 

3776 

3777 Returns: 

3778 A `Tensor`. Has the same type as `input`. A `Tensor` of type T. 

3779 """ 

3780 _ctx = _context._context or _context.context() 

3781 tld = _ctx._thread_local_data 

3782 if tld.is_eager: 

3783 try: 

3784 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3785 _ctx, "XlaSort", name, input) 

3786 return _result 

3787 except _core._NotOkStatusException as e: 

3788 _ops.raise_from_not_ok_status(e, name) 

3789 except _core._FallbackException: 

3790 pass 

3791 try: 

3792 _result = _dispatcher_for_xla_sort( 

3793 (input, name,), None) 

3794 if _result is not NotImplemented: 

3795 return _result 

3796 return xla_sort_eager_fallback( 

3797 input, name=name, ctx=_ctx) 

3798 except _core._SymbolicException: 

3799 pass # Add nodes to the TensorFlow graph. 

3800 except (TypeError, ValueError): 

3801 _result = _dispatch.dispatch( 

3802 xla_sort, (), dict(input=input, name=name) 

3803 ) 

3804 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3805 return _result 

3806 raise 

3807 else: 

3808 _result = _dispatcher_for_xla_sort( 

3809 (input, name,), None) 

3810 if _result is not NotImplemented: 

3811 return _result 

3812 # Add nodes to the TensorFlow graph. 

3813 try: 

3814 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3815 "XlaSort", input=input, name=name) 

3816 except (TypeError, ValueError): 

3817 _result = _dispatch.dispatch( 

3818 xla_sort, (), dict(input=input, name=name) 

3819 ) 

3820 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3821 return _result 

3822 raise 

3823 _result = _outputs[:] 

3824 if _execute.must_record_gradient(): 

3825 _attrs = ("T", _op._get_attr_type("T")) 

3826 _inputs_flat = _op.inputs 

3827 _execute.record_gradient( 

3828 "XlaSort", _inputs_flat, _attrs, _result) 

3829 _result, = _result 

3830 return _result 

3831 

3832XlaSort = tf_export("raw_ops.XlaSort")(_ops.to_raw_op(xla_sort)) 

3833_dispatcher_for_xla_sort = xla_sort._tf_type_based_dispatcher.Dispatch 

3834 

3835 

3836def xla_sort_eager_fallback(input, name, ctx): 

3837 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

3838 _inputs_flat = [input] 

3839 _attrs = ("T", _attr_T) 

3840 _result = _execute.execute(b"XlaSort", 1, inputs=_inputs_flat, attrs=_attrs, 

3841 ctx=ctx, name=name) 

3842 if _execute.must_record_gradient(): 

3843 _execute.record_gradient( 

3844 "XlaSort", _inputs_flat, _attrs, _result) 

3845 _result, = _result 

3846 return _result 

3847 

3848 

3849@_dispatch.add_fallback_dispatch_list 

3850@_dispatch.add_type_based_api_dispatcher 

3851@tf_export('xla_spmd_full_to_shard_shape') 

3852def xla_spmd_full_to_shard_shape(input, manual_sharding, dim=-1, unspecified_dims=[], name=None): 

3853 r"""An op used by XLA SPMD partitioner to switch from automatic partitioning to 

3854 

3855 manual partitioning. It annotates the input (full-shape, to be automatically 

3856 partitioned) with the same sharding used by manual partitioning, and outputs a 

3857 shard-shaped tensor to be consumed by later manually-partitioned ops. If the 

3858 shape is not evenly partitionable, the padding region will be masked with 0s. 

3859 The conversion can happen partially in subgroups, by specifying the dim 

3860 attribute, where only that dim will be converted. 

3861 

3862 Args: 

3863 input: A `Tensor`. 

3864 manual_sharding: A `string`. 

3865 dim: An optional `int`. Defaults to `-1`. 

3866 unspecified_dims: An optional list of `ints`. Defaults to `[]`. 

3867 name: A name for the operation (optional). 

3868 

3869 Returns: 

3870 A `Tensor`. Has the same type as `input`. 

3871 """ 

3872 _ctx = _context._context or _context.context() 

3873 tld = _ctx._thread_local_data 

3874 if tld.is_eager: 

3875 try: 

3876 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

3877 _ctx, "XlaSpmdFullToShardShape", name, input, "manual_sharding", 

3878 manual_sharding, "dim", dim, "unspecified_dims", unspecified_dims) 

3879 return _result 

3880 except _core._NotOkStatusException as e: 

3881 _ops.raise_from_not_ok_status(e, name) 

3882 except _core._FallbackException: 

3883 pass 

3884 try: 

3885 _result = _dispatcher_for_xla_spmd_full_to_shard_shape( 

3886 (input, manual_sharding, dim, unspecified_dims, name,), None) 

3887 if _result is not NotImplemented: 

3888 return _result 

3889 return xla_spmd_full_to_shard_shape_eager_fallback( 

3890 input, manual_sharding=manual_sharding, dim=dim, 

3891 unspecified_dims=unspecified_dims, name=name, ctx=_ctx) 

3892 except _core._SymbolicException: 

3893 pass # Add nodes to the TensorFlow graph. 

3894 except (TypeError, ValueError): 

3895 _result = _dispatch.dispatch( 

3896 xla_spmd_full_to_shard_shape, (), dict(input=input, 

3897 manual_sharding=manual_sharding, 

3898 dim=dim, 

3899 unspecified_dims=unspecified_dims, 

3900 name=name) 

3901 ) 

3902 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3903 return _result 

3904 raise 

3905 else: 

3906 _result = _dispatcher_for_xla_spmd_full_to_shard_shape( 

3907 (input, manual_sharding, dim, unspecified_dims, name,), None) 

3908 if _result is not NotImplemented: 

3909 return _result 

3910 # Add nodes to the TensorFlow graph. 

3911 manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") 

3912 if dim is None: 

3913 dim = -1 

3914 dim = _execute.make_int(dim, "dim") 

3915 if unspecified_dims is None: 

3916 unspecified_dims = [] 

3917 if not isinstance(unspecified_dims, (list, tuple)): 

3918 raise TypeError( 

3919 "Expected list for 'unspecified_dims' argument to " 

3920 "'xla_spmd_full_to_shard_shape' Op, not %r." % unspecified_dims) 

3921 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

3922 try: 

3923 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

3924 "XlaSpmdFullToShardShape", input=input, 

3925 manual_sharding=manual_sharding, dim=dim, 

3926 unspecified_dims=unspecified_dims, 

3927 name=name) 

3928 except (TypeError, ValueError): 

3929 _result = _dispatch.dispatch( 

3930 xla_spmd_full_to_shard_shape, (), dict(input=input, 

3931 manual_sharding=manual_sharding, 

3932 dim=dim, 

3933 unspecified_dims=unspecified_dims, 

3934 name=name) 

3935 ) 

3936 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

3937 return _result 

3938 raise 

3939 _result = _outputs[:] 

3940 if _execute.must_record_gradient(): 

3941 _attrs = ("T", _op._get_attr_type("T"), "manual_sharding", 

3942 _op.get_attr("manual_sharding"), "dim", 

3943 _op._get_attr_int("dim"), "unspecified_dims", 

3944 _op.get_attr("unspecified_dims")) 

3945 _inputs_flat = _op.inputs 

3946 _execute.record_gradient( 

3947 "XlaSpmdFullToShardShape", _inputs_flat, _attrs, _result) 

3948 _result, = _result 

3949 return _result 

3950 

3951XlaSpmdFullToShardShape = tf_export("raw_ops.XlaSpmdFullToShardShape")(_ops.to_raw_op(xla_spmd_full_to_shard_shape)) 

3952_dispatcher_for_xla_spmd_full_to_shard_shape = xla_spmd_full_to_shard_shape._tf_type_based_dispatcher.Dispatch 

3953 

3954 

3955def xla_spmd_full_to_shard_shape_eager_fallback(input, manual_sharding, dim, unspecified_dims, name, ctx): 

3956 manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") 

3957 if dim is None: 

3958 dim = -1 

3959 dim = _execute.make_int(dim, "dim") 

3960 if unspecified_dims is None: 

3961 unspecified_dims = [] 

3962 if not isinstance(unspecified_dims, (list, tuple)): 

3963 raise TypeError( 

3964 "Expected list for 'unspecified_dims' argument to " 

3965 "'xla_spmd_full_to_shard_shape' Op, not %r." % unspecified_dims) 

3966 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

3967 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

3968 _inputs_flat = [input] 

3969 _attrs = ("T", _attr_T, "manual_sharding", manual_sharding, "dim", dim, 

3970 "unspecified_dims", unspecified_dims) 

3971 _result = _execute.execute(b"XlaSpmdFullToShardShape", 1, 

3972 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

3973 name=name) 

3974 if _execute.must_record_gradient(): 

3975 _execute.record_gradient( 

3976 "XlaSpmdFullToShardShape", _inputs_flat, _attrs, _result) 

3977 _result, = _result 

3978 return _result 

3979 

3980 

3981@_dispatch.add_fallback_dispatch_list 

3982@_dispatch.add_type_based_api_dispatcher 

3983@tf_export('xla_spmd_shard_to_full_shape') 

3984def xla_spmd_shard_to_full_shape(input, manual_sharding, full_shape, dim=-1, unspecified_dims=[], name=None): 

3985 r"""An op used by XLA SPMD partitioner to switch from manual partitioning to 

3986 

3987 automatic partitioning. It converts the shard-shaped, manually partitioned input 

3988 into full-shaped tensor to be partitioned automatically with the same sharding 

3989 used by manual partitioning. The conversion can happen partially in subgroups, 

3990 by specifying the dim attribute, where only that dim will be converted. 

3991 

3992 Args: 

3993 input: A `Tensor`. 

3994 manual_sharding: A `string`. 

3995 full_shape: A `tf.TensorShape` or list of `ints`. 

3996 dim: An optional `int`. Defaults to `-1`. 

3997 unspecified_dims: An optional list of `ints`. Defaults to `[]`. 

3998 name: A name for the operation (optional). 

3999 

4000 Returns: 

4001 A `Tensor`. Has the same type as `input`. 

4002 """ 

4003 _ctx = _context._context or _context.context() 

4004 tld = _ctx._thread_local_data 

4005 if tld.is_eager: 

4006 try: 

4007 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4008 _ctx, "XlaSpmdShardToFullShape", name, input, "manual_sharding", 

4009 manual_sharding, "full_shape", full_shape, "dim", dim, 

4010 "unspecified_dims", unspecified_dims) 

4011 return _result 

4012 except _core._NotOkStatusException as e: 

4013 _ops.raise_from_not_ok_status(e, name) 

4014 except _core._FallbackException: 

4015 pass 

4016 try: 

4017 _result = _dispatcher_for_xla_spmd_shard_to_full_shape( 

4018 (input, manual_sharding, full_shape, dim, unspecified_dims, name,), 

4019 None) 

4020 if _result is not NotImplemented: 

4021 return _result 

4022 return xla_spmd_shard_to_full_shape_eager_fallback( 

4023 input, manual_sharding=manual_sharding, full_shape=full_shape, 

4024 dim=dim, unspecified_dims=unspecified_dims, name=name, ctx=_ctx) 

4025 except _core._SymbolicException: 

4026 pass # Add nodes to the TensorFlow graph. 

4027 except (TypeError, ValueError): 

4028 _result = _dispatch.dispatch( 

4029 xla_spmd_shard_to_full_shape, (), dict(input=input, 

4030 manual_sharding=manual_sharding, 

4031 full_shape=full_shape, 

4032 dim=dim, 

4033 unspecified_dims=unspecified_dims, 

4034 name=name) 

4035 ) 

4036 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4037 return _result 

4038 raise 

4039 else: 

4040 _result = _dispatcher_for_xla_spmd_shard_to_full_shape( 

4041 (input, manual_sharding, full_shape, dim, unspecified_dims, name,), 

4042 None) 

4043 if _result is not NotImplemented: 

4044 return _result 

4045 # Add nodes to the TensorFlow graph. 

4046 manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") 

4047 full_shape = _execute.make_shape(full_shape, "full_shape") 

4048 if dim is None: 

4049 dim = -1 

4050 dim = _execute.make_int(dim, "dim") 

4051 if unspecified_dims is None: 

4052 unspecified_dims = [] 

4053 if not isinstance(unspecified_dims, (list, tuple)): 

4054 raise TypeError( 

4055 "Expected list for 'unspecified_dims' argument to " 

4056 "'xla_spmd_shard_to_full_shape' Op, not %r." % unspecified_dims) 

4057 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

4058 try: 

4059 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4060 "XlaSpmdShardToFullShape", input=input, 

4061 manual_sharding=manual_sharding, 

4062 full_shape=full_shape, dim=dim, 

4063 unspecified_dims=unspecified_dims, 

4064 name=name) 

4065 except (TypeError, ValueError): 

4066 _result = _dispatch.dispatch( 

4067 xla_spmd_shard_to_full_shape, (), dict(input=input, 

4068 manual_sharding=manual_sharding, 

4069 full_shape=full_shape, 

4070 dim=dim, 

4071 unspecified_dims=unspecified_dims, 

4072 name=name) 

4073 ) 

4074 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4075 return _result 

4076 raise 

4077 _result = _outputs[:] 

4078 if _execute.must_record_gradient(): 

4079 _attrs = ("T", _op._get_attr_type("T"), "manual_sharding", 

4080 _op.get_attr("manual_sharding"), "full_shape", 

4081 _op.get_attr("full_shape"), "dim", _op._get_attr_int("dim"), 

4082 "unspecified_dims", _op.get_attr("unspecified_dims")) 

4083 _inputs_flat = _op.inputs 

4084 _execute.record_gradient( 

4085 "XlaSpmdShardToFullShape", _inputs_flat, _attrs, _result) 

4086 _result, = _result 

4087 return _result 

4088 

4089XlaSpmdShardToFullShape = tf_export("raw_ops.XlaSpmdShardToFullShape")(_ops.to_raw_op(xla_spmd_shard_to_full_shape)) 

4090_dispatcher_for_xla_spmd_shard_to_full_shape = xla_spmd_shard_to_full_shape._tf_type_based_dispatcher.Dispatch 

4091 

4092 

4093def xla_spmd_shard_to_full_shape_eager_fallback(input, manual_sharding, full_shape, dim, unspecified_dims, name, ctx): 

4094 manual_sharding = _execute.make_str(manual_sharding, "manual_sharding") 

4095 full_shape = _execute.make_shape(full_shape, "full_shape") 

4096 if dim is None: 

4097 dim = -1 

4098 dim = _execute.make_int(dim, "dim") 

4099 if unspecified_dims is None: 

4100 unspecified_dims = [] 

4101 if not isinstance(unspecified_dims, (list, tuple)): 

4102 raise TypeError( 

4103 "Expected list for 'unspecified_dims' argument to " 

4104 "'xla_spmd_shard_to_full_shape' Op, not %r." % unspecified_dims) 

4105 unspecified_dims = [_execute.make_int(_i, "unspecified_dims") for _i in unspecified_dims] 

4106 _attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, []) 

4107 _inputs_flat = [input] 

4108 _attrs = ("T", _attr_T, "manual_sharding", manual_sharding, "full_shape", 

4109 full_shape, "dim", dim, "unspecified_dims", unspecified_dims) 

4110 _result = _execute.execute(b"XlaSpmdShardToFullShape", 1, 

4111 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

4112 name=name) 

4113 if _execute.must_record_gradient(): 

4114 _execute.record_gradient( 

4115 "XlaSpmdShardToFullShape", _inputs_flat, _attrs, _result) 

4116 _result, = _result 

4117 return _result 

4118 

4119_XlaSvdOutput = collections.namedtuple( 

4120 "XlaSvd", 

4121 ["s", "u", "v"]) 

4122 

4123 

4124@_dispatch.add_fallback_dispatch_list 

4125@_dispatch.add_type_based_api_dispatcher 

4126@tf_export('xla_svd') 

4127def xla_svd(a, max_iter, epsilon, precision_config, name=None): 

4128 r"""Computes the eigen decomposition of a batch of self-adjoint matrices 

4129 

4130 (Note: Only real inputs are supported). 

4131 

4132 Computes the eigenvalues and eigenvectors of the innermost M-by-N matrices in 

4133 tensor such that tensor[...,:,:] = u[..., :, :] * Diag(s[..., :]) * Transpose(v[...,:,:]). 

4134 

4135 Args: 

4136 a: A `Tensor`. Must be one of the following types: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`. 

4137 the input tensor. 

4138 max_iter: An `int`. 

4139 maximum number of sweep update, i.e., the whole lower triangular 

4140 part or upper triangular part based on parameter lower. Heuristically, it has 

4141 been argued that approximately log(min (M, N)) sweeps are needed in practice 

4142 (Ref: Golub & van Loan "Matrix Computation"). 

4143 epsilon: A `float`. the tolerance ratio. 

4144 precision_config: A `string`. a serialized xla::PrecisionConfig proto. 

4145 name: A name for the operation (optional). 

4146 

4147 Returns: 

4148 A tuple of `Tensor` objects (s, u, v). 

4149 

4150 s: A `Tensor`. Has the same type as `a`. Singular values. The values are sorted in reverse order of magnitude, so 

4151 s[..., 0] is the largest value, s[..., 1] is the second largest, etc. 

4152 u: A `Tensor`. Has the same type as `a`. Left singular vectors. 

4153 v: A `Tensor`. Has the same type as `a`. Right singular vectors. 

4154 """ 

4155 _ctx = _context._context or _context.context() 

4156 tld = _ctx._thread_local_data 

4157 if tld.is_eager: 

4158 try: 

4159 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4160 _ctx, "XlaSvd", name, a, "max_iter", max_iter, "epsilon", epsilon, 

4161 "precision_config", precision_config) 

4162 _result = _XlaSvdOutput._make(_result) 

4163 return _result 

4164 except _core._NotOkStatusException as e: 

4165 _ops.raise_from_not_ok_status(e, name) 

4166 except _core._FallbackException: 

4167 pass 

4168 try: 

4169 _result = _dispatcher_for_xla_svd( 

4170 (a, max_iter, epsilon, precision_config, name,), None) 

4171 if _result is not NotImplemented: 

4172 return _result 

4173 return xla_svd_eager_fallback( 

4174 a, max_iter=max_iter, epsilon=epsilon, 

4175 precision_config=precision_config, name=name, ctx=_ctx) 

4176 except _core._SymbolicException: 

4177 pass # Add nodes to the TensorFlow graph. 

4178 except (TypeError, ValueError): 

4179 _result = _dispatch.dispatch( 

4180 xla_svd, (), dict(a=a, max_iter=max_iter, epsilon=epsilon, 

4181 precision_config=precision_config, name=name) 

4182 ) 

4183 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4184 return _result 

4185 raise 

4186 else: 

4187 _result = _dispatcher_for_xla_svd( 

4188 (a, max_iter, epsilon, precision_config, name,), None) 

4189 if _result is not NotImplemented: 

4190 return _result 

4191 # Add nodes to the TensorFlow graph. 

4192 max_iter = _execute.make_int(max_iter, "max_iter") 

4193 epsilon = _execute.make_float(epsilon, "epsilon") 

4194 precision_config = _execute.make_str(precision_config, "precision_config") 

4195 try: 

4196 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4197 "XlaSvd", a=a, max_iter=max_iter, epsilon=epsilon, 

4198 precision_config=precision_config, name=name) 

4199 except (TypeError, ValueError): 

4200 _result = _dispatch.dispatch( 

4201 xla_svd, (), dict(a=a, max_iter=max_iter, epsilon=epsilon, 

4202 precision_config=precision_config, name=name) 

4203 ) 

4204 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4205 return _result 

4206 raise 

4207 _result = _outputs[:] 

4208 if _execute.must_record_gradient(): 

4209 _attrs = ("max_iter", _op._get_attr_int("max_iter"), "epsilon", 

4210 _op.get_attr("epsilon"), "precision_config", 

4211 _op.get_attr("precision_config"), "T", _op._get_attr_type("T")) 

4212 _inputs_flat = _op.inputs 

4213 _execute.record_gradient( 

4214 "XlaSvd", _inputs_flat, _attrs, _result) 

4215 _result = _XlaSvdOutput._make(_result) 

4216 return _result 

4217 

4218XlaSvd = tf_export("raw_ops.XlaSvd")(_ops.to_raw_op(xla_svd)) 

4219_dispatcher_for_xla_svd = xla_svd._tf_type_based_dispatcher.Dispatch 

4220 

4221 

4222def xla_svd_eager_fallback(a, max_iter, epsilon, precision_config, name, ctx): 

4223 max_iter = _execute.make_int(max_iter, "max_iter") 

4224 epsilon = _execute.make_float(epsilon, "epsilon") 

4225 precision_config = _execute.make_str(precision_config, "precision_config") 

4226 _attr_T, (a,) = _execute.args_to_matching_eager([a], ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, ]) 

4227 _inputs_flat = [a] 

4228 _attrs = ("max_iter", max_iter, "epsilon", epsilon, "precision_config", 

4229 precision_config, "T", _attr_T) 

4230 _result = _execute.execute(b"XlaSvd", 3, inputs=_inputs_flat, attrs=_attrs, 

4231 ctx=ctx, name=name) 

4232 if _execute.must_record_gradient(): 

4233 _execute.record_gradient( 

4234 "XlaSvd", _inputs_flat, _attrs, _result) 

4235 _result = _XlaSvdOutput._make(_result) 

4236 return _result 

4237 

4238 

4239@_dispatch.add_fallback_dispatch_list 

4240@_dispatch.add_type_based_api_dispatcher 

4241@tf_export('xla_variadic_reduce') 

4242def xla_variadic_reduce(input, init_value, dimensions_to_reduce, reducer, name=None): 

4243 r"""Wraps the variadic XLA Reduce operator. 

4244 

4245 Semantics are documented at 

4246 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. 

4247 

4248 This version is limited to operands of the same dtype. 

4249 XlaVariadicReduceV2 is a version that supports heterogeneous operands. 

4250 

4251 Args: 

4252 input: A list of at least 1 `Tensor` objects with the same type in: `float32`, `float64`, `int32`, `uint8`, `int16`, `int8`, `complex64`, `int64`, `qint8`, `quint8`, `qint32`, `bfloat16`, `qint16`, `quint16`, `uint16`, `complex128`, `half`, `uint32`, `uint64`, `bool`. 

4253 the input tensor(s) 

4254 init_value: A list with the same length as `input` of `Tensor` objects with the same type as `input`. 

4255 scalar initial value(s) for the reduction 

4256 dimensions_to_reduce: A list of `ints`. 

4257 dimension numbers over which to reduce 

4258 reducer: A function decorated with @Defun. a reducer function to apply 

4259 name: A name for the operation (optional). 

4260 

4261 Returns: 

4262 A list with the same length as `input` of `Tensor` objects with the same type as `input`. 

4263 """ 

4264 _ctx = _context._context or _context.context() 

4265 tld = _ctx._thread_local_data 

4266 if tld.is_eager: 

4267 try: 

4268 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4269 _ctx, "XlaVariadicReduce", name, input, init_value, 

4270 "dimensions_to_reduce", dimensions_to_reduce, "reducer", reducer) 

4271 return _result 

4272 except _core._NotOkStatusException as e: 

4273 _ops.raise_from_not_ok_status(e, name) 

4274 except _core._FallbackException: 

4275 pass 

4276 try: 

4277 _result = _dispatcher_for_xla_variadic_reduce( 

4278 (input, init_value, dimensions_to_reduce, reducer, name,), None) 

4279 if _result is not NotImplemented: 

4280 return _result 

4281 return xla_variadic_reduce_eager_fallback( 

4282 input, init_value, dimensions_to_reduce=dimensions_to_reduce, 

4283 reducer=reducer, name=name, ctx=_ctx) 

4284 except _core._SymbolicException: 

4285 pass # Add nodes to the TensorFlow graph. 

4286 except (TypeError, ValueError): 

4287 _result = _dispatch.dispatch( 

4288 xla_variadic_reduce, (), dict(input=input, init_value=init_value, 

4289 dimensions_to_reduce=dimensions_to_reduce, 

4290 reducer=reducer, name=name) 

4291 ) 

4292 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4293 return _result 

4294 raise 

4295 else: 

4296 _result = _dispatcher_for_xla_variadic_reduce( 

4297 (input, init_value, dimensions_to_reduce, reducer, name,), None) 

4298 if _result is not NotImplemented: 

4299 return _result 

4300 # Add nodes to the TensorFlow graph. 

4301 if not isinstance(input, (list, tuple)): 

4302 raise TypeError( 

4303 "Expected list for 'input' argument to " 

4304 "'xla_variadic_reduce' Op, not %r." % input) 

4305 _attr_N = len(input) 

4306 if not isinstance(init_value, (list, tuple)): 

4307 raise TypeError( 

4308 "Expected list for 'init_value' argument to " 

4309 "'xla_variadic_reduce' Op, not %r." % init_value) 

4310 if len(init_value) != _attr_N: 

4311 raise ValueError( 

4312 "List argument 'init_value' to 'xla_variadic_reduce' Op with length %d " 

4313 "must match length %d of argument 'input'." % 

4314 (len(init_value), _attr_N)) 

4315 if not isinstance(dimensions_to_reduce, (list, tuple)): 

4316 raise TypeError( 

4317 "Expected list for 'dimensions_to_reduce' argument to " 

4318 "'xla_variadic_reduce' Op, not %r." % dimensions_to_reduce) 

4319 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

4320 try: 

4321 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4322 "XlaVariadicReduce", input=input, init_value=init_value, 

4323 dimensions_to_reduce=dimensions_to_reduce, 

4324 reducer=reducer, name=name) 

4325 except (TypeError, ValueError): 

4326 _result = _dispatch.dispatch( 

4327 xla_variadic_reduce, (), dict(input=input, init_value=init_value, 

4328 dimensions_to_reduce=dimensions_to_reduce, 

4329 reducer=reducer, name=name) 

4330 ) 

4331 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4332 return _result 

4333 raise 

4334 _result = _outputs[:] 

4335 if _execute.must_record_gradient(): 

4336 _attrs = ("N", _op._get_attr_int("N"), "T", _op._get_attr_type("T"), 

4337 "dimensions_to_reduce", _op.get_attr("dimensions_to_reduce"), 

4338 "reducer", _op.get_attr("reducer")) 

4339 _inputs_flat = _op.inputs 

4340 _execute.record_gradient( 

4341 "XlaVariadicReduce", _inputs_flat, _attrs, _result) 

4342 return _result 

4343 

4344XlaVariadicReduce = tf_export("raw_ops.XlaVariadicReduce")(_ops.to_raw_op(xla_variadic_reduce)) 

4345_dispatcher_for_xla_variadic_reduce = xla_variadic_reduce._tf_type_based_dispatcher.Dispatch 

4346 

4347 

4348def xla_variadic_reduce_eager_fallback(input, init_value, dimensions_to_reduce, reducer, name, ctx): 

4349 if not isinstance(input, (list, tuple)): 

4350 raise TypeError( 

4351 "Expected list for 'input' argument to " 

4352 "'xla_variadic_reduce' Op, not %r." % input) 

4353 _attr_N = len(input) 

4354 if not isinstance(init_value, (list, tuple)): 

4355 raise TypeError( 

4356 "Expected list for 'init_value' argument to " 

4357 "'xla_variadic_reduce' Op, not %r." % init_value) 

4358 if len(init_value) != _attr_N: 

4359 raise ValueError( 

4360 "List argument 'init_value' to 'xla_variadic_reduce' Op with length %d " 

4361 "must match length %d of argument 'input'." % 

4362 (len(init_value), _attr_N)) 

4363 if not isinstance(dimensions_to_reduce, (list, tuple)): 

4364 raise TypeError( 

4365 "Expected list for 'dimensions_to_reduce' argument to " 

4366 "'xla_variadic_reduce' Op, not %r." % dimensions_to_reduce) 

4367 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

4368 _attr_T, _inputs_T = _execute.args_to_matching_eager(list(input) + list(init_value), ctx, [_dtypes.float32, _dtypes.float64, _dtypes.int32, _dtypes.uint8, _dtypes.int16, _dtypes.int8, _dtypes.complex64, _dtypes.int64, _dtypes.qint8, _dtypes.quint8, _dtypes.qint32, _dtypes.bfloat16, _dtypes.qint16, _dtypes.quint16, _dtypes.uint16, _dtypes.complex128, _dtypes.half, _dtypes.uint32, _dtypes.uint64, _dtypes.bool, ]) 

4369 _inputs_T = [_inputs_T[:_attr_N]] + _inputs_T[_attr_N:] 

4370 _inputs_T = _inputs_T[:1] + [_inputs_T[1:]] 

4371 (input, init_value) = _inputs_T 

4372 _inputs_flat = list(input) + list(init_value) 

4373 _attrs = ("N", _attr_N, "T", _attr_T, "dimensions_to_reduce", 

4374 dimensions_to_reduce, "reducer", reducer) 

4375 _result = _execute.execute(b"XlaVariadicReduce", _attr_N, 

4376 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

4377 name=name) 

4378 if _execute.must_record_gradient(): 

4379 _execute.record_gradient( 

4380 "XlaVariadicReduce", _inputs_flat, _attrs, _result) 

4381 return _result 

4382 

4383 

4384@_dispatch.add_fallback_dispatch_list 

4385@_dispatch.add_type_based_api_dispatcher 

4386@tf_export('xla_variadic_reduce_v2') 

4387def xla_variadic_reduce_v2(inputs, init_values, dimensions_to_reduce, reducer, name=None): 

4388 r"""Wraps the variadic XLA Reduce operator. 

4389 

4390 Semantics are documented at 

4391 https://www.tensorflow.org/performance/xla/operation_semantics#variadic_reduce. 

4392 

4393 This is an expanded version of XlaVariadicReduce, with support for 

4394 operands of different dtypes, and improved shape inference. 

4395 

4396 Args: 

4397 inputs: A list of `Tensor` objects. the input tensor(s) 

4398 init_values: A list of `Tensor` objects. Must have the same type as `inputs`. 

4399 scalar initial value(s) for the reduction 

4400 dimensions_to_reduce: A list of `ints`. 

4401 dimension numbers over which to reduce 

4402 reducer: A function decorated with @Defun. a reducer function to apply 

4403 name: A name for the operation (optional). 

4404 

4405 Returns: 

4406 A list of `Tensor` objects. Has the same type as `inputs`. 

4407 """ 

4408 _ctx = _context._context or _context.context() 

4409 tld = _ctx._thread_local_data 

4410 if tld.is_eager: 

4411 try: 

4412 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4413 _ctx, "XlaVariadicReduceV2", name, inputs, init_values, 

4414 "dimensions_to_reduce", dimensions_to_reduce, "reducer", reducer) 

4415 return _result 

4416 except _core._NotOkStatusException as e: 

4417 _ops.raise_from_not_ok_status(e, name) 

4418 except _core._FallbackException: 

4419 pass 

4420 try: 

4421 _result = _dispatcher_for_xla_variadic_reduce_v2( 

4422 (inputs, init_values, dimensions_to_reduce, reducer, name,), None) 

4423 if _result is not NotImplemented: 

4424 return _result 

4425 return xla_variadic_reduce_v2_eager_fallback( 

4426 inputs, init_values, dimensions_to_reduce=dimensions_to_reduce, 

4427 reducer=reducer, name=name, ctx=_ctx) 

4428 except _core._SymbolicException: 

4429 pass # Add nodes to the TensorFlow graph. 

4430 except (TypeError, ValueError): 

4431 _result = _dispatch.dispatch( 

4432 xla_variadic_reduce_v2, (), dict(inputs=inputs, 

4433 init_values=init_values, 

4434 dimensions_to_reduce=dimensions_to_reduce, 

4435 reducer=reducer, name=name) 

4436 ) 

4437 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4438 return _result 

4439 raise 

4440 else: 

4441 _result = _dispatcher_for_xla_variadic_reduce_v2( 

4442 (inputs, init_values, dimensions_to_reduce, reducer, name,), None) 

4443 if _result is not NotImplemented: 

4444 return _result 

4445 # Add nodes to the TensorFlow graph. 

4446 if not isinstance(dimensions_to_reduce, (list, tuple)): 

4447 raise TypeError( 

4448 "Expected list for 'dimensions_to_reduce' argument to " 

4449 "'xla_variadic_reduce_v2' Op, not %r." % dimensions_to_reduce) 

4450 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

4451 try: 

4452 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4453 "XlaVariadicReduceV2", inputs=inputs, init_values=init_values, 

4454 dimensions_to_reduce=dimensions_to_reduce, 

4455 reducer=reducer, name=name) 

4456 except (TypeError, ValueError): 

4457 _result = _dispatch.dispatch( 

4458 xla_variadic_reduce_v2, (), dict(inputs=inputs, 

4459 init_values=init_values, 

4460 dimensions_to_reduce=dimensions_to_reduce, 

4461 reducer=reducer, name=name) 

4462 ) 

4463 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4464 return _result 

4465 raise 

4466 _result = _outputs[:] 

4467 if _execute.must_record_gradient(): 

4468 _attrs = ("T", _op.get_attr("T"), "dimensions_to_reduce", 

4469 _op.get_attr("dimensions_to_reduce"), "reducer", 

4470 _op.get_attr("reducer")) 

4471 _inputs_flat = _op.inputs 

4472 _execute.record_gradient( 

4473 "XlaVariadicReduceV2", _inputs_flat, _attrs, _result) 

4474 return _result 

4475 

4476XlaVariadicReduceV2 = tf_export("raw_ops.XlaVariadicReduceV2")(_ops.to_raw_op(xla_variadic_reduce_v2)) 

4477_dispatcher_for_xla_variadic_reduce_v2 = xla_variadic_reduce_v2._tf_type_based_dispatcher.Dispatch 

4478 

4479 

4480def xla_variadic_reduce_v2_eager_fallback(inputs, init_values, dimensions_to_reduce, reducer, name, ctx): 

4481 if not isinstance(dimensions_to_reduce, (list, tuple)): 

4482 raise TypeError( 

4483 "Expected list for 'dimensions_to_reduce' argument to " 

4484 "'xla_variadic_reduce_v2' Op, not %r." % dimensions_to_reduce) 

4485 dimensions_to_reduce = [_execute.make_int(_i, "dimensions_to_reduce") for _i in dimensions_to_reduce] 

4486 _attr_T, (inputs, init_values) = _execute.args_to_mixed_eager_tensors((inputs, init_values), ctx) 

4487 _inputs_flat = list(inputs) + list(init_values) 

4488 _attrs = ("T", _attr_T, "dimensions_to_reduce", dimensions_to_reduce, 

4489 "reducer", reducer) 

4490 _result = _execute.execute(b"XlaVariadicReduceV2", len(inputs), 

4491 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

4492 name=name) 

4493 if _execute.must_record_gradient(): 

4494 _execute.record_gradient( 

4495 "XlaVariadicReduceV2", _inputs_flat, _attrs, _result) 

4496 return _result 

4497 

4498 

4499@_dispatch.add_fallback_dispatch_list 

4500@_dispatch.add_type_based_api_dispatcher 

4501@tf_export('xla_variadic_sort') 

4502def xla_variadic_sort(inputs, dimension, comparator, is_stable, name=None): 

4503 r"""Wraps the XLA Sort operator, documented at 

4504 

4505 https://www.tensorflow.org/performance/xla/operation_semantics#sort 

4506 . 

4507 

4508 Sorts one or more tensors, with support for custom comparator, dimension, and 

4509 is_stable attributes. 

4510 

4511 Args: 

4512 inputs: A list of `Tensor` objects. 

4513 A list of `Tensor` of identical shape but possibly different types. 

4514 dimension: A `Tensor` of type `int32`. 

4515 The dimension along which to sort. Must be a compile-time constant. 

4516 comparator: A function decorated with @Defun. 

4517 A comparator function to apply to 2*N scalars and returning a 

4518 boolean. N is the number of sort inputs. If you want to sort in ascending 

4519 order then the comparator should perform a less-than comparison. 

4520 is_stable: A `bool`. Whether to use stable sort. 

4521 name: A name for the operation (optional). 

4522 

4523 Returns: 

4524 A list of `Tensor` objects. Has the same type as `inputs`. 

4525 A list of `Tensor` of same shape and types as the `input`. 

4526 """ 

4527 _ctx = _context._context or _context.context() 

4528 tld = _ctx._thread_local_data 

4529 if tld.is_eager: 

4530 try: 

4531 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4532 _ctx, "XlaVariadicSort", name, inputs, dimension, "comparator", 

4533 comparator, "is_stable", is_stable) 

4534 return _result 

4535 except _core._NotOkStatusException as e: 

4536 _ops.raise_from_not_ok_status(e, name) 

4537 except _core._FallbackException: 

4538 pass 

4539 try: 

4540 _result = _dispatcher_for_xla_variadic_sort( 

4541 (inputs, dimension, comparator, is_stable, name,), None) 

4542 if _result is not NotImplemented: 

4543 return _result 

4544 return xla_variadic_sort_eager_fallback( 

4545 inputs, dimension, comparator=comparator, is_stable=is_stable, 

4546 name=name, ctx=_ctx) 

4547 except _core._SymbolicException: 

4548 pass # Add nodes to the TensorFlow graph. 

4549 except (TypeError, ValueError): 

4550 _result = _dispatch.dispatch( 

4551 xla_variadic_sort, (), dict(inputs=inputs, dimension=dimension, 

4552 comparator=comparator, 

4553 is_stable=is_stable, name=name) 

4554 ) 

4555 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4556 return _result 

4557 raise 

4558 else: 

4559 _result = _dispatcher_for_xla_variadic_sort( 

4560 (inputs, dimension, comparator, is_stable, name,), None) 

4561 if _result is not NotImplemented: 

4562 return _result 

4563 # Add nodes to the TensorFlow graph. 

4564 is_stable = _execute.make_bool(is_stable, "is_stable") 

4565 try: 

4566 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4567 "XlaVariadicSort", inputs=inputs, dimension=dimension, 

4568 comparator=comparator, is_stable=is_stable, 

4569 name=name) 

4570 except (TypeError, ValueError): 

4571 _result = _dispatch.dispatch( 

4572 xla_variadic_sort, (), dict(inputs=inputs, dimension=dimension, 

4573 comparator=comparator, 

4574 is_stable=is_stable, name=name) 

4575 ) 

4576 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4577 return _result 

4578 raise 

4579 _result = _outputs[:] 

4580 if _execute.must_record_gradient(): 

4581 _attrs = ("T", _op.get_attr("T"), "comparator", 

4582 _op.get_attr("comparator"), "is_stable", 

4583 _op._get_attr_bool("is_stable")) 

4584 _inputs_flat = _op.inputs 

4585 _execute.record_gradient( 

4586 "XlaVariadicSort", _inputs_flat, _attrs, _result) 

4587 return _result 

4588 

4589XlaVariadicSort = tf_export("raw_ops.XlaVariadicSort")(_ops.to_raw_op(xla_variadic_sort)) 

4590_dispatcher_for_xla_variadic_sort = xla_variadic_sort._tf_type_based_dispatcher.Dispatch 

4591 

4592 

4593def xla_variadic_sort_eager_fallback(inputs, dimension, comparator, is_stable, name, ctx): 

4594 is_stable = _execute.make_bool(is_stable, "is_stable") 

4595 _attr_T, inputs = _execute.convert_to_mixed_eager_tensors(inputs, ctx) 

4596 dimension = _ops.convert_to_tensor(dimension, _dtypes.int32) 

4597 _inputs_flat = list(inputs) + [dimension] 

4598 _attrs = ("T", _attr_T, "comparator", comparator, "is_stable", is_stable) 

4599 _result = _execute.execute(b"XlaVariadicSort", len(inputs), 

4600 inputs=_inputs_flat, attrs=_attrs, ctx=ctx, 

4601 name=name) 

4602 if _execute.must_record_gradient(): 

4603 _execute.record_gradient( 

4604 "XlaVariadicSort", _inputs_flat, _attrs, _result) 

4605 return _result 

4606 

4607 

4608@_dispatch.add_fallback_dispatch_list 

4609@_dispatch.add_type_based_api_dispatcher 

4610@tf_export('xla_while') 

4611def xla_while(input, cond, body, name=None): 

4612 r"""output = input; While (Cond(output)) { output = Body(output) } 

4613 

4614 Args: 

4615 input: A list of `Tensor` objects. 

4616 A list of input tensors whose types are T. 

4617 cond: A function decorated with @Defun. 

4618 A function takes 'input' and returns a tensor. If the tensor is 

4619 a scalar of non-boolean, the scalar is converted to a boolean 

4620 according to the following rule: if the scalar is a numerical 

4621 value, non-zero means True and zero means False; if the scalar is 

4622 a string, non-empty means True and empty means False. If the 

4623 tensor is not a scalar, non-emptiness means True and False 

4624 otherwise. 

4625 body: A function decorated with @Defun. 

4626 A function that takes a list of tensors and returns another 

4627 list of tensors. Both lists have the same types as specified by T. 

4628 name: A name for the operation (optional). 

4629 

4630 Returns: 

4631 A list of `Tensor` objects. Has the same type as `input`. 

4632 A list of output tensors whose types are T. 

4633 """ 

4634 _ctx = _context._context or _context.context() 

4635 tld = _ctx._thread_local_data 

4636 if tld.is_eager: 

4637 try: 

4638 _result = pywrap_tfe.TFE_Py_FastPathExecute( 

4639 _ctx, "XlaWhile", name, input, "cond", cond, "body", body) 

4640 return _result 

4641 except _core._NotOkStatusException as e: 

4642 _ops.raise_from_not_ok_status(e, name) 

4643 except _core._FallbackException: 

4644 pass 

4645 try: 

4646 _result = _dispatcher_for_xla_while( 

4647 (input, cond, body, name,), None) 

4648 if _result is not NotImplemented: 

4649 return _result 

4650 return xla_while_eager_fallback( 

4651 input, cond=cond, body=body, name=name, ctx=_ctx) 

4652 except _core._SymbolicException: 

4653 pass # Add nodes to the TensorFlow graph. 

4654 except (TypeError, ValueError): 

4655 _result = _dispatch.dispatch( 

4656 xla_while, (), dict(input=input, cond=cond, body=body, name=name) 

4657 ) 

4658 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4659 return _result 

4660 raise 

4661 else: 

4662 _result = _dispatcher_for_xla_while( 

4663 (input, cond, body, name,), None) 

4664 if _result is not NotImplemented: 

4665 return _result 

4666 # Add nodes to the TensorFlow graph. 

4667 try: 

4668 _, _, _op, _outputs = _op_def_library._apply_op_helper( 

4669 "XlaWhile", input=input, cond=cond, body=body, name=name) 

4670 except (TypeError, ValueError): 

4671 _result = _dispatch.dispatch( 

4672 xla_while, (), dict(input=input, cond=cond, body=body, name=name) 

4673 ) 

4674 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED: 

4675 return _result 

4676 raise 

4677 _result = _outputs[:] 

4678 if not _result: 

4679 return _op 

4680 if _execute.must_record_gradient(): 

4681 _attrs = ("T", _op.get_attr("T"), "cond", _op.get_attr("cond"), "body", 

4682 _op.get_attr("body")) 

4683 _inputs_flat = _op.inputs 

4684 _execute.record_gradient( 

4685 "XlaWhile", _inputs_flat, _attrs, _result) 

4686 return _result 

4687 

4688XlaWhile = tf_export("raw_ops.XlaWhile")(_ops.to_raw_op(xla_while)) 

4689_dispatcher_for_xla_while = xla_while._tf_type_based_dispatcher.Dispatch 

4690 

4691 

4692def xla_while_eager_fallback(input, cond, body, name, ctx): 

4693 _attr_T, input = _execute.convert_to_mixed_eager_tensors(input, ctx) 

4694 _inputs_flat = list(input) 

4695 _attrs = ("T", _attr_T, "cond", cond, "body", body) 

4696 _result = _execute.execute(b"XlaWhile", len(input), inputs=_inputs_flat, 

4697 attrs=_attrs, ctx=ctx, name=name) 

4698 if _execute.must_record_gradient(): 

4699 _execute.record_gradient( 

4700 "XlaWhile", _inputs_flat, _attrs, _result) 

4701 return _result 

4702