Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/compiler/tf2xla/python/xla.py: 62%
232 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental library that exposes XLA operations directly in TensorFlow.
17It is sometimes useful to be able to build HLO programs directly from
18TensorFlow. This file provides Tensorflow operators that mirror the semantics of
19HLO operators as closely as possible.
21Note: Most of the operators defined in this module are used by the jax2tf
22converter (see go/jax2tf for details) and are used in SavedModel produced
23by jax2tf. Hence, we need to maintain backwards compatibility for these
24operators. Please reach out to the JAX team if you want to make changes.
25"""
27from tensorflow.compiler.tf2xla.ops import gen_xla_ops
28from tensorflow.compiler.xla import xla_data_pb2
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import bitwise_ops
35from tensorflow.python.ops import gen_math_ops
36from tensorflow.python.ops import gen_random_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import random_ops
39from tensorflow.python.ops import special_math_ops
40from tensorflow.python.ops import stateless_random_ops
41from tensorflow.python.ops.numpy_ops import np_utils
43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
44# ops include:
45# infeed/outfeed (available via tf.contrib.tpu)
46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
47# conditional
48# gather/scatter
49# collapse
51# This file reuses builtin names (following XLA's names, so we can call things
52# like xla.max), so we capture the builtin versions here.
53# pylint: disable=redefined-builtin
54_max = max
55_min = min
56_slice = slice # pylint: disable=invalid-name
58constant = constant_op.constant
60# Unary operators.
62# For most arithmetic operators there is a TensorFlow operator
63# that exactly corresponds to each XLA operator. Rather than defining
64# XLA-specific variants, we reuse the corresponding TensorFlow operator.
65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
66# wrap every HLO operator, because that would allow us to be confident that the
67# semantics match.
70def _unary_op(fn):
71 """Wrapper that restricts `fn` to have the correct signature."""
73 def unary_op_wrapper(x, name=None):
74 return fn(x, name=name)
76 return unary_op_wrapper
79abs = _unary_op(math_ops.abs)
80# TODO(phawkins): implement clz.
81conj = _unary_op(math_ops.conj)
82cos = _unary_op(math_ops.cos)
83ceil = _unary_op(math_ops.ceil)
84digamma = _unary_op(math_ops.digamma)
85erf = _unary_op(math_ops.erf)
86erfc = _unary_op(math_ops.erfc)
87erfinv = _unary_op(math_ops.erfinv)
88ndtri = _unary_op(math_ops.ndtri)
89exp = _unary_op(math_ops.exp)
90expm1 = _unary_op(math_ops.expm1)
91floor = _unary_op(math_ops.floor)
92imag = _unary_op(math_ops.imag)
93is_finite = _unary_op(math_ops.is_finite)
94lgamma = _unary_op(math_ops.lgamma)
95log = _unary_op(math_ops.log)
96log1p = _unary_op(math_ops.log1p)
97logical_not = _unary_op(math_ops.logical_not)
98neg = _unary_op(math_ops.neg)
99real = _unary_op(math_ops.real)
100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
101# numbers halfway between two integers.
102round = _unary_op(math_ops.round)
103sin = _unary_op(math_ops.sin)
104sign = _unary_op(math_ops.sign)
105tan = _unary_op(math_ops.tan)
106tanh = _unary_op(math_ops.tanh)
108# Bessel
109bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
110bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
112# Binary operators
114# The main difference between TensorFlow and XLA binary ops is the broadcasting
115# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
116# requires an explicit specification of which dimensions to broadcast if the
117# arguments have different ranks.
120def _broadcasting_binary_op(fn):
121 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
123 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
124 """Inner wrapper function."""
125 broadcast_dims = broadcast_dims or []
126 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
127 # Rather than relying on having static shape information in the TensorFlow
128 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
129 # at JIT compilation time.
130 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
131 return fn(x, y, name=name)
133 return broadcasting_binary_op_wrapper
136# Map from TF signed types to TF unsigned types.
137_SIGNED_TO_UNSIGNED_TABLE = {
138 dtypes.int8: dtypes.uint8,
139 dtypes.int16: dtypes.uint16,
140 dtypes.int32: dtypes.uint32,
141 dtypes.int64: dtypes.uint64,
142}
144# Map from TF unsigned types to TF signed types.
145_UNSIGNED_TO_SIGNED_TABLE = {
146 dtypes.uint8: dtypes.int8,
147 dtypes.uint16: dtypes.int16,
148 dtypes.uint32: dtypes.int32,
149 dtypes.uint64: dtypes.int64,
150}
153def _shift_right_logical_helper(x, y, name=None):
154 """Performs an integer right logical shift irrespective of input type."""
155 assert y.dtype == x.dtype
156 dtype = x.dtype
157 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
158 if signed:
159 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
160 x = math_ops.cast(x, unsigned_dtype)
161 y = math_ops.cast(y, unsigned_dtype)
162 output = bitwise_ops.right_shift(x, y, name=name)
163 if signed:
164 output = math_ops.cast(output, dtype)
165 return output
168def _shift_right_arithmetic_helper(x, y, name=None):
169 """Performs an integer right arithmetic shift irrespective of input type."""
170 assert y.dtype == x.dtype
171 dtype = x.dtype
172 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
173 if unsigned:
174 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
175 x = math_ops.cast(x, signed_dtype)
176 y = math_ops.cast(y, signed_dtype)
177 output = bitwise_ops.right_shift(x, y, name=name)
178 if unsigned:
179 output = math_ops.cast(output, dtype)
180 return output
183add = _broadcasting_binary_op(math_ops.add)
184sub = _broadcasting_binary_op(math_ops.sub)
185mul = _broadcasting_binary_op(math_ops.mul)
186div = _broadcasting_binary_op(math_ops.div)
187rem = _broadcasting_binary_op(gen_math_ops.mod)
188max = _broadcasting_binary_op(math_ops.maximum)
189min = _broadcasting_binary_op(math_ops.minimum)
190atan2 = _broadcasting_binary_op(math_ops.atan2)
191complex = _broadcasting_binary_op(math_ops.complex)
192logical_and = _broadcasting_binary_op(math_ops.logical_and)
193logical_or = _broadcasting_binary_op(math_ops.logical_or)
194logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
195eq = _broadcasting_binary_op(math_ops.equal)
196ne = _broadcasting_binary_op(math_ops.not_equal)
197ge = _broadcasting_binary_op(math_ops.greater_equal)
198gt = _broadcasting_binary_op(math_ops.greater)
199le = _broadcasting_binary_op(math_ops.less_equal)
200lt = _broadcasting_binary_op(math_ops.less)
201pow = _broadcasting_binary_op(math_ops.pow)
202shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
203shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
204shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
206igamma = _broadcasting_binary_op(math_ops.igamma)
207igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
208random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
209igammac = _broadcasting_binary_op(math_ops.igammac)
210polygamma = _broadcasting_binary_op(math_ops.polygamma)
211zeta = _broadcasting_binary_op(math_ops.zeta)
214def _binary_op(fn):
215 """Wrapper that restricts `fn` to have the correct signature."""
217 def binary_op_wrapper(x, y, name=None):
218 return fn(x, y, name=name)
220 return binary_op_wrapper
223transpose = _binary_op(array_ops.transpose)
224rev = _binary_op(array_ops.reverse)
226bitcast_convert_type = array_ops.bitcast
229def broadcast(x, dims, name=None):
230 x = ops.convert_to_tensor(x)
231 shape = array_ops.concat([constant_op.constant(dims),
232 array_ops.shape(x)],
233 axis=0)
234 return array_ops.broadcast_to(x, shape, name=name)
237def clamp(a, x, b, name=None):
238 return min(max(a, x, name=name), b, name=name)
241concatenate = array_ops.concat
244def conv(lhs,
245 rhs,
246 window_strides,
247 padding,
248 lhs_dilation,
249 rhs_dilation,
250 dimension_numbers,
251 feature_group_count=1,
252 precision_config=None,
253 preferred_element_type=None,
254 name=None,
255 use_v2=False,
256 batch_group_count=1):
257 """Wraps the XLA ConvGeneralDilated operator.
259 ConvGeneralDilated is the most general form of XLA convolution and is
260 documented at
261 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
263 Args:
264 lhs: the input tensor
265 rhs: the kernel tensor
266 window_strides: the inter-window strides
267 padding: the padding to apply at the start and end of each input dimensions
268 lhs_dilation: dilation to apply between input elements
269 rhs_dilation: dilation to apply between kernel elements
270 dimension_numbers: a `ConvolutionDimensionNumbers` proto.
271 feature_group_count: number of feature groups for grouped convolution.
272 precision_config: a `xla.PrecisionConfig` proto.
273 preferred_element_type: the result `dtype`.
274 name: an optional name for the operator.
275 use_v2: an optional request to use the XlaConvV2 op even if not necessary.
276 batch_group_count: number of batch groups or grouped filters.
278 Returns:
279 A tensor representing the output of the convolution.
280 """
281 precision_config_proto = ""
282 if precision_config:
283 precision_config_proto = precision_config.SerializeToString()
284 needs_v2 = (
285 preferred_element_type or (lhs.dtype != rhs.dtype) or
286 batch_group_count > 1)
287 if preferred_element_type is None:
288 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
289 if needs_v2 or use_v2:
290 return gen_xla_ops.xla_conv_v2(
291 lhs,
292 rhs,
293 window_strides=window_strides,
294 padding=padding,
295 lhs_dilation=lhs_dilation,
296 rhs_dilation=rhs_dilation,
297 feature_group_count=feature_group_count,
298 batch_group_count=batch_group_count,
299 dimension_numbers=dimension_numbers.SerializeToString(),
300 precision_config=precision_config_proto,
301 preferred_element_type=preferred_element_type,
302 name=name)
303 return gen_xla_ops.xla_conv(
304 lhs,
305 rhs,
306 window_strides=window_strides,
307 padding=padding,
308 lhs_dilation=lhs_dilation,
309 rhs_dilation=rhs_dilation,
310 feature_group_count=feature_group_count,
311 dimension_numbers=dimension_numbers.SerializeToString(),
312 precision_config=precision_config_proto,
313 name=name)
316convert_element_type = math_ops.cast
319def dot(lhs, rhs, name=None):
320 return math_ops.tensordot(lhs, rhs, axes=1, name=name)
323DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers
324PrecisionConfig = xla_data_pb2.PrecisionConfig
327def dot_general(lhs,
328 rhs,
329 dimension_numbers,
330 precision_config=None,
331 preferred_element_type=None,
332 name=None,
333 use_v2=False):
334 precision_config_proto = ""
335 if precision_config:
336 precision_config_proto = precision_config.SerializeToString()
337 needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
338 if preferred_element_type is None:
339 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
340 if needs_v2 or use_v2:
341 return gen_xla_ops.xla_dot_v2(
342 lhs,
343 rhs,
344 dimension_numbers=dimension_numbers.SerializeToString(),
345 precision_config=precision_config_proto,
346 preferred_element_type=preferred_element_type,
347 name=name)
348 return gen_xla_ops.xla_dot(
349 lhs,
350 rhs,
351 dimension_numbers=dimension_numbers.SerializeToString(),
352 precision_config=precision_config_proto,
353 name=name)
356def self_adjoint_eig(a, lower, max_iter, epsilon):
357 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
360def svd(a, max_iter, epsilon, precision_config=None):
361 precision_config_proto = ""
362 if precision_config:
363 precision_config_proto = precision_config.SerializeToString()
364 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
367dynamic_slice = gen_xla_ops.xla_dynamic_slice
368dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
369einsum = gen_xla_ops.xla_einsum
371# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
372# the XLA-specific pad operator.
373pad = gen_xla_ops.xla_pad
376def random_normal(mu, sigma, dims, name=None):
377 mu = ops.convert_to_tensor(mu)
378 return random_ops.random_normal(
379 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
382def random_uniform(minval, maxval, dims, name=None):
383 minval = ops.convert_to_tensor(minval)
384 return random_ops.random_uniform(
385 dims, minval, maxval, dtype=minval.dtype, name=name)
388def rng_bit_generator(algorithm, initial_state, shape, dtype):
389 """Stateless PRNG bit generator.
391 Wraps the XLA RngBitGenerator operator, documented at
392 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
394 Args:
395 algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX,
396 THREEFRY, AUTO_SELECT}.
397 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should
398 be a u64[2] and for PHILOX a u64[3].
399 shape: The output shape of the generated data.
400 dtype: The type of the tensor.
402 Returns:
403 a tuple with a new state and generated data of the given shape.
404 """
405 alg_int = stateless_random_ops.convert_alg_to_int(algorithm)
406 return gen_xla_ops.xla_rng_bit_generator(
407 alg_int, initial_state, shape, dtype=dtype)
410recv = gen_xla_ops.xla_recv
411reduce = gen_xla_ops.xla_reduce
412variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
414ops.no_gradient("XlaVariadicReduce")
417def reduce_window(operand,
418 init,
419 reducer,
420 window_dimensions,
421 window_strides=None,
422 base_dilations=None,
423 window_dilations=None,
424 padding=None,
425 name=None):
426 """Wraps the XLA ReduceWindow operator.
428 ReduceWindow is documented at
429 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
431 Args:
432 operand: the input tensor
433 init: a scalar tensor representing the initial value for the reduction
434 reducer: a reduction function that combines a pair of scalars.
435 window_dimensions: shape of the window, as a list of integers
436 window_strides: inter-window strides, as a list of integers. Optional; if
437 omitted, defaults to strides of 1.
438 padding: padding to apply to 'operand'. List of (low, high) pairs of
439 integers that specify the padding to apply before and after each
440 dimension. Optional; if omitted, defaults to no padding.
441 name: the operator name, or None.
443 Returns:
444 A tensor that represents the output of the reduce_window operator.
445 """
446 window_strides = window_strides or [1] * len(window_dimensions)
447 base_dilations = base_dilations or [1] * len(window_dimensions)
448 window_dilations = window_dilations or [1] * len(window_dimensions)
449 padding = padding or [(0, 0)] * len(window_dimensions)
450 return gen_xla_ops.xla_reduce_window(
451 input=operand,
452 init_value=init,
453 window_dimensions=window_dimensions,
454 window_strides=window_strides,
455 base_dilations=base_dilations,
456 window_dilations=window_dilations,
457 padding=padding,
458 computation=reducer,
459 name=name)
462replica_id = gen_xla_ops.xla_replica_id
464# Set a static bound for the given input value as a hint to Xla compiler,
465# returns the same value.
466# Usage:
467# def f(t, p):
468# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
469# return t[:p] # xla knows the bound of the slice is 3.
470set_bound = gen_xla_ops.xla_set_bound
472# Make a static dimension into a xla bounded dynamic dimension. The current
473# static dimension size will become the bound and the second operand becomes the
474# dynamic size of the dimension.
475#
476# This should mostly be used for testing.
477#
478# def f():
479# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
480# # Tells xla the valid size of the array is 3.
481# dim = 0
482# p = xla_set_dynamic_dimension_size(array, dim, 3)
483# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
484set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
486# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic
487# dimension into a static dimension. The bound of the size of dimension
488# `dim_index` becomes the static dimension size.
489remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size
492def reshape(x, new_sizes, dimensions=None, name=None):
493 if dimensions is not None:
494 x = array_ops.transpose(x, dimensions)
495 x = array_ops.reshape(x, new_sizes, name=name)
496 return x
499def select(condition, x, y, name=None):
500 return array_ops.where(condition, x, y, name)
503select_and_scatter = gen_xla_ops.xla_select_and_scatter
504send = gen_xla_ops.xla_send
507def slice(x, start_dims, limit_dims, strides):
508 spec = [
509 _slice(start, limit, stride)
510 for (start, limit, stride) in zip(start_dims, limit_dims, strides)
511 ]
512 return x[tuple(spec)]
515sharding = gen_xla_ops.xla_sharding
518@ops.RegisterGradient("XlaSharding")
519def _sharding_grad(op, grad):
520 """Gradient for XlaSharding op."""
521 sharding_attr = op.get_attr("sharding")
522 grad_sharding = gen_xla_ops.xla_sharding(
523 grad,
524 sharding=sharding_attr,
525 unspecified_dims=op.get_attr("unspecified_dims"))
526 # pylint: disable=protected-access
527 grad_sharding.op._set_attr("_XlaSharding",
528 attr_value_pb2.AttrValue(s=sharding_attr))
529 return [grad_sharding]
532spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
533spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
536@ops.RegisterGradient("XlaSpmdFullToShardShape")
537def _spmd_full_to_shard_shape_grad(op, grad):
538 s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
539 grad,
540 manual_sharding=op.get_attr("manual_sharding"),
541 full_shape=op.inputs[0].shape.as_list(),
542 dim=op.get_attr("dim"),
543 unspecified_dims=op.get_attr("unspecified_dims"))
544 return [s2f]
547@ops.RegisterGradient("XlaSpmdShardToFullShape")
548def _spmd_shard_to_full_shape_grad(op, grad):
549 f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
550 grad,
551 manual_sharding=op.get_attr("manual_sharding"),
552 dim=op.get_attr("dim"),
553 unspecified_dims=op.get_attr("unspecified_dims"))
554 return [f2s]
557sort = gen_xla_ops.xla_sort
558key_value_sort = gen_xla_ops.xla_key_value_sort
559variadic_sort = gen_xla_ops.xla_variadic_sort
560while_loop = gen_xla_ops.xla_while
561dequantize = gen_xla_ops.xla_dequantize
562custom_call = gen_xla_ops.xla_custom_call
565def custom_call_v2(
566 call_target_name,
567 operands,
568 result_specs,
569 backend_config=None,
570 has_side_effect=None,
571 name=None,
572):
573 """Emits an HLO `CustomCall` operation with multiple outputs.
575 See `CustomCall` specification at
576 https://tensorflow.org/xla/operation_semantics#customcall,
577 and `mhlo.custom_call` specification at
578 https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop.
580 Args:
581 call_target_name: Name of the user function. The function signature must
582 conform to version 3 of the API, see
583 `API_VERSION_STATUS_RETURNING_UNIFIED`. All operands and results assumed
584 to be in the default layout.
585 operands: A sequence of tensors with possibly different types.
586 result_specs: A sequence of tensor specs for all results.
587 backend_config: A string that encodes a metadata for the backend. Empty
588 string by default.
589 has_side_effect: Indicates whether the custom call has side effects. `False`
590 by default.
591 name: Optional name of the operation.
593 Returns:
594 A tuple of output tensors.
595 """
596 return gen_xla_ops.xla_custom_call_v2(
597 operands=operands,
598 call_target_name=call_target_name,
599 backend_config="" if backend_config is None else backend_config,
600 has_side_effect=False if has_side_effect is None else has_side_effect,
601 result_dtypes=tuple(spec.dtype for spec in result_specs),
602 result_shapes=tuple(spec.shape for spec in result_specs),
603 name=name,
604 )
607def call_module(args, *, version=4, module, Tout, Sout,
608 dim_args_spec=(), platforms=()):
609 # See documentation for the XlaCallModule op.
610 return gen_xla_ops.xla_call_module(
611 args, version=version, module=module, dim_args_spec=dim_args_spec,
612 Tout=Tout, Sout=Sout, platforms=platforms)
615def gather(operand,
616 start_indices,
617 dimension_numbers,
618 slice_sizes,
619 indices_are_sorted=False,
620 name=None):
621 return gen_xla_ops.xla_gather(
622 operand,
623 start_indices,
624 slice_sizes=slice_sizes,
625 dimension_numbers=dimension_numbers.SerializeToString(),
626 indices_are_sorted=indices_are_sorted,
627 name=name)
630def scatter(operand,
631 scatter_indices,
632 updates,
633 update_computation,
634 dimension_numbers,
635 indices_are_sorted=False,
636 name=None):
637 return gen_xla_ops.xla_scatter(
638 operand,
639 scatter_indices,
640 updates,
641 update_computation=update_computation,
642 dimension_numbers=dimension_numbers.SerializeToString(),
643 indices_are_sorted=indices_are_sorted,
644 name=name)
647def optimization_barrier(*args):
648 return gen_xla_ops.xla_optimization_barrier(args)
651def reduce_precision(operand, exponent_bits, mantissa_bits):
652 return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)