Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/compiler/tf2xla/python/xla.py: 62%

232 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Experimental library that exposes XLA operations directly in TensorFlow. 

16 

17It is sometimes useful to be able to build HLO programs directly from 

18TensorFlow. This file provides Tensorflow operators that mirror the semantics of 

19HLO operators as closely as possible. 

20 

21Note: Most of the operators defined in this module are used by the jax2tf 

22converter (see go/jax2tf for details) and are used in SavedModel produced 

23by jax2tf. Hence, we need to maintain backwards compatibility for these 

24operators. Please reach out to the JAX team if you want to make changes. 

25""" 

26 

27from tensorflow.compiler.tf2xla.ops import gen_xla_ops 

28from tensorflow.compiler.xla import xla_data_pb2 

29from tensorflow.core.framework import attr_value_pb2 

30from tensorflow.python.framework import constant_op 

31from tensorflow.python.framework import dtypes 

32from tensorflow.python.framework import ops 

33from tensorflow.python.ops import array_ops 

34from tensorflow.python.ops import bitwise_ops 

35from tensorflow.python.ops import gen_math_ops 

36from tensorflow.python.ops import gen_random_ops 

37from tensorflow.python.ops import math_ops 

38from tensorflow.python.ops import random_ops 

39from tensorflow.python.ops import special_math_ops 

40from tensorflow.python.ops import stateless_random_ops 

41from tensorflow.python.ops.numpy_ops import np_utils 

42 

43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing 

44# ops include: 

45# infeed/outfeed (available via tf.contrib.tpu) 

46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) 

47# conditional 

48# gather/scatter 

49# collapse 

50 

51# This file reuses builtin names (following XLA's names, so we can call things 

52# like xla.max), so we capture the builtin versions here. 

53# pylint: disable=redefined-builtin 

54_max = max 

55_min = min 

56_slice = slice # pylint: disable=invalid-name 

57 

58constant = constant_op.constant 

59 

60# Unary operators. 

61 

62# For most arithmetic operators there is a TensorFlow operator 

63# that exactly corresponds to each XLA operator. Rather than defining 

64# XLA-specific variants, we reuse the corresponding TensorFlow operator. 

65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 

66# wrap every HLO operator, because that would allow us to be confident that the 

67# semantics match. 

68 

69 

70def _unary_op(fn): 

71 """Wrapper that restricts `fn` to have the correct signature.""" 

72 

73 def unary_op_wrapper(x, name=None): 

74 return fn(x, name=name) 

75 

76 return unary_op_wrapper 

77 

78 

79abs = _unary_op(math_ops.abs) 

80# TODO(phawkins): implement clz. 

81conj = _unary_op(math_ops.conj) 

82cos = _unary_op(math_ops.cos) 

83ceil = _unary_op(math_ops.ceil) 

84digamma = _unary_op(math_ops.digamma) 

85erf = _unary_op(math_ops.erf) 

86erfc = _unary_op(math_ops.erfc) 

87erfinv = _unary_op(math_ops.erfinv) 

88ndtri = _unary_op(math_ops.ndtri) 

89exp = _unary_op(math_ops.exp) 

90expm1 = _unary_op(math_ops.expm1) 

91floor = _unary_op(math_ops.floor) 

92imag = _unary_op(math_ops.imag) 

93is_finite = _unary_op(math_ops.is_finite) 

94lgamma = _unary_op(math_ops.lgamma) 

95log = _unary_op(math_ops.log) 

96log1p = _unary_op(math_ops.log1p) 

97logical_not = _unary_op(math_ops.logical_not) 

98neg = _unary_op(math_ops.neg) 

99real = _unary_op(math_ops.real) 

100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for 

101# numbers halfway between two integers. 

102round = _unary_op(math_ops.round) 

103sin = _unary_op(math_ops.sin) 

104sign = _unary_op(math_ops.sign) 

105tan = _unary_op(math_ops.tan) 

106tanh = _unary_op(math_ops.tanh) 

107 

108# Bessel 

109bessel_i0e = _unary_op(special_math_ops.bessel_i0e) 

110bessel_i1e = _unary_op(special_math_ops.bessel_i1e) 

111 

112# Binary operators 

113 

114# The main difference between TensorFlow and XLA binary ops is the broadcasting 

115# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA 

116# requires an explicit specification of which dimensions to broadcast if the 

117# arguments have different ranks. 

118 

119 

120def _broadcasting_binary_op(fn): 

121 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" 

122 

123 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): 

124 """Inner wrapper function.""" 

125 broadcast_dims = broadcast_dims or [] 

126 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) 

127 # Rather than relying on having static shape information in the TensorFlow 

128 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes 

129 # at JIT compilation time. 

130 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) 

131 return fn(x, y, name=name) 

132 

133 return broadcasting_binary_op_wrapper 

134 

135 

136# Map from TF signed types to TF unsigned types. 

137_SIGNED_TO_UNSIGNED_TABLE = { 

138 dtypes.int8: dtypes.uint8, 

139 dtypes.int16: dtypes.uint16, 

140 dtypes.int32: dtypes.uint32, 

141 dtypes.int64: dtypes.uint64, 

142} 

143 

144# Map from TF unsigned types to TF signed types. 

145_UNSIGNED_TO_SIGNED_TABLE = { 

146 dtypes.uint8: dtypes.int8, 

147 dtypes.uint16: dtypes.int16, 

148 dtypes.uint32: dtypes.int32, 

149 dtypes.uint64: dtypes.int64, 

150} 

151 

152 

153def _shift_right_logical_helper(x, y, name=None): 

154 """Performs an integer right logical shift irrespective of input type.""" 

155 assert y.dtype == x.dtype 

156 dtype = x.dtype 

157 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE 

158 if signed: 

159 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] 

160 x = math_ops.cast(x, unsigned_dtype) 

161 y = math_ops.cast(y, unsigned_dtype) 

162 output = bitwise_ops.right_shift(x, y, name=name) 

163 if signed: 

164 output = math_ops.cast(output, dtype) 

165 return output 

166 

167 

168def _shift_right_arithmetic_helper(x, y, name=None): 

169 """Performs an integer right arithmetic shift irrespective of input type.""" 

170 assert y.dtype == x.dtype 

171 dtype = x.dtype 

172 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE 

173 if unsigned: 

174 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] 

175 x = math_ops.cast(x, signed_dtype) 

176 y = math_ops.cast(y, signed_dtype) 

177 output = bitwise_ops.right_shift(x, y, name=name) 

178 if unsigned: 

179 output = math_ops.cast(output, dtype) 

180 return output 

181 

182 

183add = _broadcasting_binary_op(math_ops.add) 

184sub = _broadcasting_binary_op(math_ops.sub) 

185mul = _broadcasting_binary_op(math_ops.mul) 

186div = _broadcasting_binary_op(math_ops.div) 

187rem = _broadcasting_binary_op(gen_math_ops.mod) 

188max = _broadcasting_binary_op(math_ops.maximum) 

189min = _broadcasting_binary_op(math_ops.minimum) 

190atan2 = _broadcasting_binary_op(math_ops.atan2) 

191complex = _broadcasting_binary_op(math_ops.complex) 

192logical_and = _broadcasting_binary_op(math_ops.logical_and) 

193logical_or = _broadcasting_binary_op(math_ops.logical_or) 

194logical_xor = _broadcasting_binary_op(math_ops.logical_xor) 

195eq = _broadcasting_binary_op(math_ops.equal) 

196ne = _broadcasting_binary_op(math_ops.not_equal) 

197ge = _broadcasting_binary_op(math_ops.greater_equal) 

198gt = _broadcasting_binary_op(math_ops.greater) 

199le = _broadcasting_binary_op(math_ops.less_equal) 

200lt = _broadcasting_binary_op(math_ops.less) 

201pow = _broadcasting_binary_op(math_ops.pow) 

202shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) 

203shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) 

204shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) 

205 

206igamma = _broadcasting_binary_op(math_ops.igamma) 

207igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) 

208random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) 

209igammac = _broadcasting_binary_op(math_ops.igammac) 

210polygamma = _broadcasting_binary_op(math_ops.polygamma) 

211zeta = _broadcasting_binary_op(math_ops.zeta) 

212 

213 

214def _binary_op(fn): 

215 """Wrapper that restricts `fn` to have the correct signature.""" 

216 

217 def binary_op_wrapper(x, y, name=None): 

218 return fn(x, y, name=name) 

219 

220 return binary_op_wrapper 

221 

222 

223transpose = _binary_op(array_ops.transpose) 

224rev = _binary_op(array_ops.reverse) 

225 

226bitcast_convert_type = array_ops.bitcast 

227 

228 

229def broadcast(x, dims, name=None): 

230 x = ops.convert_to_tensor(x) 

231 shape = array_ops.concat([constant_op.constant(dims), 

232 array_ops.shape(x)], 

233 axis=0) 

234 return array_ops.broadcast_to(x, shape, name=name) 

235 

236 

237def clamp(a, x, b, name=None): 

238 return min(max(a, x, name=name), b, name=name) 

239 

240 

241concatenate = array_ops.concat 

242 

243 

244def conv(lhs, 

245 rhs, 

246 window_strides, 

247 padding, 

248 lhs_dilation, 

249 rhs_dilation, 

250 dimension_numbers, 

251 feature_group_count=1, 

252 precision_config=None, 

253 preferred_element_type=None, 

254 name=None, 

255 use_v2=False, 

256 batch_group_count=1): 

257 """Wraps the XLA ConvGeneralDilated operator. 

258 

259 ConvGeneralDilated is the most general form of XLA convolution and is 

260 documented at 

261 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 

262 

263 Args: 

264 lhs: the input tensor 

265 rhs: the kernel tensor 

266 window_strides: the inter-window strides 

267 padding: the padding to apply at the start and end of each input dimensions 

268 lhs_dilation: dilation to apply between input elements 

269 rhs_dilation: dilation to apply between kernel elements 

270 dimension_numbers: a `ConvolutionDimensionNumbers` proto. 

271 feature_group_count: number of feature groups for grouped convolution. 

272 precision_config: a `xla.PrecisionConfig` proto. 

273 preferred_element_type: the result `dtype`. 

274 name: an optional name for the operator. 

275 use_v2: an optional request to use the XlaConvV2 op even if not necessary. 

276 batch_group_count: number of batch groups or grouped filters. 

277 

278 Returns: 

279 A tensor representing the output of the convolution. 

280 """ 

281 precision_config_proto = "" 

282 if precision_config: 

283 precision_config_proto = precision_config.SerializeToString() 

284 needs_v2 = ( 

285 preferred_element_type or (lhs.dtype != rhs.dtype) or 

286 batch_group_count > 1) 

287 if preferred_element_type is None: 

288 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 

289 if needs_v2 or use_v2: 

290 return gen_xla_ops.xla_conv_v2( 

291 lhs, 

292 rhs, 

293 window_strides=window_strides, 

294 padding=padding, 

295 lhs_dilation=lhs_dilation, 

296 rhs_dilation=rhs_dilation, 

297 feature_group_count=feature_group_count, 

298 batch_group_count=batch_group_count, 

299 dimension_numbers=dimension_numbers.SerializeToString(), 

300 precision_config=precision_config_proto, 

301 preferred_element_type=preferred_element_type, 

302 name=name) 

303 return gen_xla_ops.xla_conv( 

304 lhs, 

305 rhs, 

306 window_strides=window_strides, 

307 padding=padding, 

308 lhs_dilation=lhs_dilation, 

309 rhs_dilation=rhs_dilation, 

310 feature_group_count=feature_group_count, 

311 dimension_numbers=dimension_numbers.SerializeToString(), 

312 precision_config=precision_config_proto, 

313 name=name) 

314 

315 

316convert_element_type = math_ops.cast 

317 

318 

319def dot(lhs, rhs, name=None): 

320 return math_ops.tensordot(lhs, rhs, axes=1, name=name) 

321 

322 

323DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers 

324PrecisionConfig = xla_data_pb2.PrecisionConfig 

325 

326 

327def dot_general(lhs, 

328 rhs, 

329 dimension_numbers, 

330 precision_config=None, 

331 preferred_element_type=None, 

332 name=None, 

333 use_v2=False): 

334 precision_config_proto = "" 

335 if precision_config: 

336 precision_config_proto = precision_config.SerializeToString() 

337 needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) 

338 if preferred_element_type is None: 

339 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 

340 if needs_v2 or use_v2: 

341 return gen_xla_ops.xla_dot_v2( 

342 lhs, 

343 rhs, 

344 dimension_numbers=dimension_numbers.SerializeToString(), 

345 precision_config=precision_config_proto, 

346 preferred_element_type=preferred_element_type, 

347 name=name) 

348 return gen_xla_ops.xla_dot( 

349 lhs, 

350 rhs, 

351 dimension_numbers=dimension_numbers.SerializeToString(), 

352 precision_config=precision_config_proto, 

353 name=name) 

354 

355 

356def self_adjoint_eig(a, lower, max_iter, epsilon): 

357 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) 

358 

359 

360def svd(a, max_iter, epsilon, precision_config=None): 

361 precision_config_proto = "" 

362 if precision_config: 

363 precision_config_proto = precision_config.SerializeToString() 

364 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) 

365 

366 

367dynamic_slice = gen_xla_ops.xla_dynamic_slice 

368dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice 

369einsum = gen_xla_ops.xla_einsum 

370 

371# TODO(phawkins): generalize tf.pad to support interior padding, and then remove 

372# the XLA-specific pad operator. 

373pad = gen_xla_ops.xla_pad 

374 

375 

376def random_normal(mu, sigma, dims, name=None): 

377 mu = ops.convert_to_tensor(mu) 

378 return random_ops.random_normal( 

379 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) 

380 

381 

382def random_uniform(minval, maxval, dims, name=None): 

383 minval = ops.convert_to_tensor(minval) 

384 return random_ops.random_uniform( 

385 dims, minval, maxval, dtype=minval.dtype, name=name) 

386 

387 

388def rng_bit_generator(algorithm, initial_state, shape, dtype): 

389 """Stateless PRNG bit generator. 

390 

391 Wraps the XLA RngBitGenerator operator, documented at 

392 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. 

393 

394 Args: 

395 algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX, 

396 THREEFRY, AUTO_SELECT}. 

397 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should 

398 be a u64[2] and for PHILOX a u64[3]. 

399 shape: The output shape of the generated data. 

400 dtype: The type of the tensor. 

401 

402 Returns: 

403 a tuple with a new state and generated data of the given shape. 

404 """ 

405 alg_int = stateless_random_ops.convert_alg_to_int(algorithm) 

406 return gen_xla_ops.xla_rng_bit_generator( 

407 alg_int, initial_state, shape, dtype=dtype) 

408 

409 

410recv = gen_xla_ops.xla_recv 

411reduce = gen_xla_ops.xla_reduce 

412variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 

413 

414ops.no_gradient("XlaVariadicReduce") 

415 

416 

417def reduce_window(operand, 

418 init, 

419 reducer, 

420 window_dimensions, 

421 window_strides=None, 

422 base_dilations=None, 

423 window_dilations=None, 

424 padding=None, 

425 name=None): 

426 """Wraps the XLA ReduceWindow operator. 

427 

428 ReduceWindow is documented at 

429 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 

430 

431 Args: 

432 operand: the input tensor 

433 init: a scalar tensor representing the initial value for the reduction 

434 reducer: a reduction function that combines a pair of scalars. 

435 window_dimensions: shape of the window, as a list of integers 

436 window_strides: inter-window strides, as a list of integers. Optional; if 

437 omitted, defaults to strides of 1. 

438 padding: padding to apply to 'operand'. List of (low, high) pairs of 

439 integers that specify the padding to apply before and after each 

440 dimension. Optional; if omitted, defaults to no padding. 

441 name: the operator name, or None. 

442 

443 Returns: 

444 A tensor that represents the output of the reduce_window operator. 

445 """ 

446 window_strides = window_strides or [1] * len(window_dimensions) 

447 base_dilations = base_dilations or [1] * len(window_dimensions) 

448 window_dilations = window_dilations or [1] * len(window_dimensions) 

449 padding = padding or [(0, 0)] * len(window_dimensions) 

450 return gen_xla_ops.xla_reduce_window( 

451 input=operand, 

452 init_value=init, 

453 window_dimensions=window_dimensions, 

454 window_strides=window_strides, 

455 base_dilations=base_dilations, 

456 window_dilations=window_dilations, 

457 padding=padding, 

458 computation=reducer, 

459 name=name) 

460 

461 

462replica_id = gen_xla_ops.xla_replica_id 

463 

464# Set a static bound for the given input value as a hint to Xla compiler, 

465# returns the same value. 

466# Usage: 

467# def f(t, p): 

468# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. 

469# return t[:p] # xla knows the bound of the slice is 3. 

470set_bound = gen_xla_ops.xla_set_bound 

471 

472# Make a static dimension into a xla bounded dynamic dimension. The current 

473# static dimension size will become the bound and the second operand becomes the 

474# dynamic size of the dimension. 

475# 

476# This should mostly be used for testing. 

477# 

478# def f(): 

479# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) 

480# # Tells xla the valid size of the array is 3. 

481# dim = 0 

482# p = xla_set_dynamic_dimension_size(array, dim, 3) 

483# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. 

484set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size 

485 

486# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic 

487# dimension into a static dimension. The bound of the size of dimension 

488# `dim_index` becomes the static dimension size. 

489remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size 

490 

491 

492def reshape(x, new_sizes, dimensions=None, name=None): 

493 if dimensions is not None: 

494 x = array_ops.transpose(x, dimensions) 

495 x = array_ops.reshape(x, new_sizes, name=name) 

496 return x 

497 

498 

499def select(condition, x, y, name=None): 

500 return array_ops.where(condition, x, y, name) 

501 

502 

503select_and_scatter = gen_xla_ops.xla_select_and_scatter 

504send = gen_xla_ops.xla_send 

505 

506 

507def slice(x, start_dims, limit_dims, strides): 

508 spec = [ 

509 _slice(start, limit, stride) 

510 for (start, limit, stride) in zip(start_dims, limit_dims, strides) 

511 ] 

512 return x[tuple(spec)] 

513 

514 

515sharding = gen_xla_ops.xla_sharding 

516 

517 

518@ops.RegisterGradient("XlaSharding") 

519def _sharding_grad(op, grad): 

520 """Gradient for XlaSharding op.""" 

521 sharding_attr = op.get_attr("sharding") 

522 grad_sharding = gen_xla_ops.xla_sharding( 

523 grad, 

524 sharding=sharding_attr, 

525 unspecified_dims=op.get_attr("unspecified_dims")) 

526 # pylint: disable=protected-access 

527 grad_sharding.op._set_attr("_XlaSharding", 

528 attr_value_pb2.AttrValue(s=sharding_attr)) 

529 return [grad_sharding] 

530 

531 

532spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape 

533spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape 

534 

535 

536@ops.RegisterGradient("XlaSpmdFullToShardShape") 

537def _spmd_full_to_shard_shape_grad(op, grad): 

538 s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( 

539 grad, 

540 manual_sharding=op.get_attr("manual_sharding"), 

541 full_shape=op.inputs[0].shape.as_list(), 

542 dim=op.get_attr("dim"), 

543 unspecified_dims=op.get_attr("unspecified_dims")) 

544 return [s2f] 

545 

546 

547@ops.RegisterGradient("XlaSpmdShardToFullShape") 

548def _spmd_shard_to_full_shape_grad(op, grad): 

549 f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( 

550 grad, 

551 manual_sharding=op.get_attr("manual_sharding"), 

552 dim=op.get_attr("dim"), 

553 unspecified_dims=op.get_attr("unspecified_dims")) 

554 return [f2s] 

555 

556 

557sort = gen_xla_ops.xla_sort 

558key_value_sort = gen_xla_ops.xla_key_value_sort 

559variadic_sort = gen_xla_ops.xla_variadic_sort 

560while_loop = gen_xla_ops.xla_while 

561dequantize = gen_xla_ops.xla_dequantize 

562custom_call = gen_xla_ops.xla_custom_call 

563 

564 

565def custom_call_v2( 

566 call_target_name, 

567 operands, 

568 result_specs, 

569 backend_config=None, 

570 has_side_effect=None, 

571 name=None, 

572): 

573 """Emits an HLO `CustomCall` operation with multiple outputs. 

574 

575 See `CustomCall` specification at 

576 https://tensorflow.org/xla/operation_semantics#customcall, 

577 and `mhlo.custom_call` specification at 

578 https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop. 

579 

580 Args: 

581 call_target_name: Name of the user function. The function signature must 

582 conform to version 3 of the API, see 

583 `API_VERSION_STATUS_RETURNING_UNIFIED`. All operands and results assumed 

584 to be in the default layout. 

585 operands: A sequence of tensors with possibly different types. 

586 result_specs: A sequence of tensor specs for all results. 

587 backend_config: A string that encodes a metadata for the backend. Empty 

588 string by default. 

589 has_side_effect: Indicates whether the custom call has side effects. `False` 

590 by default. 

591 name: Optional name of the operation. 

592 

593 Returns: 

594 A tuple of output tensors. 

595 """ 

596 return gen_xla_ops.xla_custom_call_v2( 

597 operands=operands, 

598 call_target_name=call_target_name, 

599 backend_config="" if backend_config is None else backend_config, 

600 has_side_effect=False if has_side_effect is None else has_side_effect, 

601 result_dtypes=tuple(spec.dtype for spec in result_specs), 

602 result_shapes=tuple(spec.shape for spec in result_specs), 

603 name=name, 

604 ) 

605 

606 

607def call_module(args, *, version=4, module, Tout, Sout, 

608 dim_args_spec=(), platforms=()): 

609 # See documentation for the XlaCallModule op. 

610 return gen_xla_ops.xla_call_module( 

611 args, version=version, module=module, dim_args_spec=dim_args_spec, 

612 Tout=Tout, Sout=Sout, platforms=platforms) 

613 

614 

615def gather(operand, 

616 start_indices, 

617 dimension_numbers, 

618 slice_sizes, 

619 indices_are_sorted=False, 

620 name=None): 

621 return gen_xla_ops.xla_gather( 

622 operand, 

623 start_indices, 

624 slice_sizes=slice_sizes, 

625 dimension_numbers=dimension_numbers.SerializeToString(), 

626 indices_are_sorted=indices_are_sorted, 

627 name=name) 

628 

629 

630def scatter(operand, 

631 scatter_indices, 

632 updates, 

633 update_computation, 

634 dimension_numbers, 

635 indices_are_sorted=False, 

636 name=None): 

637 return gen_xla_ops.xla_scatter( 

638 operand, 

639 scatter_indices, 

640 updates, 

641 update_computation=update_computation, 

642 dimension_numbers=dimension_numbers.SerializeToString(), 

643 indices_are_sorted=indices_are_sorted, 

644 name=name) 

645 

646 

647def optimization_barrier(*args): 

648 return gen_xla_ops.xla_optimization_barrier(args) 

649 

650 

651def reduce_precision(operand, exponent_bits, mantissa_bits): 

652 return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)