Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/gen_ragged_array_ops.py: 16%
243 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1"""Python wrappers around TensorFlow ops.
3This file is MACHINE GENERATED! Do not edit.
4"""
6import collections
8from tensorflow.python import pywrap_tfe as pywrap_tfe
9from tensorflow.python.eager import context as _context
10from tensorflow.python.eager import core as _core
11from tensorflow.python.eager import execute as _execute
12from tensorflow.python.framework import dtypes as _dtypes
13from tensorflow.security.fuzzing.py import annotation_types as _atypes
15from tensorflow.python.framework import op_def_registry as _op_def_registry
16from tensorflow.python.framework import ops as _ops
17from tensorflow.python.framework import op_def_library as _op_def_library
18from tensorflow.python.util.deprecation import deprecated_endpoints
19from tensorflow.python.util import dispatch as _dispatch
20from tensorflow.python.util.tf_export import tf_export
22from typing import TypeVar
23_RaggedCrossOutput = collections.namedtuple(
24 "RaggedCross",
25 ["output_values", "output_row_splits"])
28def ragged_cross(ragged_values, ragged_row_splits, sparse_indices, sparse_values, sparse_shape, dense_inputs, input_order, hashed_output, num_buckets, hash_key, out_values_type, out_row_splits_type, name=None):
29 r"""Generates a feature cross from a list of tensors, and returns it as a
30RaggedTensor. See `tf.ragged.cross` for more details.
32 Args:
33 ragged_values: A list of `Tensor` objects with types from: `int64`, `string`.
34 The values tensor for each RaggedTensor input.
35 ragged_row_splits: A list of `Tensor` objects with types from: `int32`, `int64`.
36 The row_splits tensor for each RaggedTensor input.
37 sparse_indices: A list of `Tensor` objects with type `int64`.
38 The indices tensor for each SparseTensor input.
39 sparse_values: A list of `Tensor` objects with types from: `int64`, `string`.
40 The values tensor for each SparseTensor input.
41 sparse_shape: A list with the same length as `sparse_indices` of `Tensor` objects with type `int64`.
42 The dense_shape tensor for each SparseTensor input.
43 dense_inputs: A list of `Tensor` objects with types from: `int64`, `string`.
44 The tf.Tensor inputs.
45 input_order: A `string`.
46 String specifying the tensor type for each input. The `i`th character in
47 this string specifies the type of the `i`th input, and is one of: 'R' (ragged),
48 'D' (dense), or 'S' (sparse). This attr is used to ensure that the crossed
49 values are combined in the order of the inputs from the call to tf.ragged.cross.
50 hashed_output: A `bool`.
51 num_buckets: An `int` that is `>= 0`.
52 hash_key: An `int`.
53 out_values_type: A `tf.DType` from: `tf.int64, tf.string`.
54 out_row_splits_type: A `tf.DType` from: `tf.int32, tf.int64`.
55 name: A name for the operation (optional).
57 Returns:
58 A tuple of `Tensor` objects (output_values, output_row_splits).
60 output_values: A `Tensor` of type `out_values_type`.
61 output_row_splits: A `Tensor` of type `out_row_splits_type`.
62 """
63 _ctx = _context._context or _context.context()
64 tld = _ctx._thread_local_data
65 if tld.is_eager:
66 try:
67 _result = pywrap_tfe.TFE_Py_FastPathExecute(
68 _ctx, "RaggedCross", name, ragged_values, ragged_row_splits,
69 sparse_indices, sparse_values, sparse_shape, dense_inputs,
70 "input_order", input_order, "hashed_output", hashed_output,
71 "num_buckets", num_buckets, "hash_key", hash_key, "out_values_type",
72 out_values_type, "out_row_splits_type", out_row_splits_type)
73 _result = _RaggedCrossOutput._make(_result)
74 return _result
75 except _core._NotOkStatusException as e:
76 _ops.raise_from_not_ok_status(e, name)
77 except _core._FallbackException:
78 pass
79 try:
80 return ragged_cross_eager_fallback(
81 ragged_values, ragged_row_splits, sparse_indices, sparse_values,
82 sparse_shape, dense_inputs, input_order=input_order,
83 hashed_output=hashed_output, num_buckets=num_buckets,
84 hash_key=hash_key, out_values_type=out_values_type,
85 out_row_splits_type=out_row_splits_type, name=name, ctx=_ctx)
86 except _core._SymbolicException:
87 pass # Add nodes to the TensorFlow graph.
88 # Add nodes to the TensorFlow graph.
89 if not isinstance(sparse_indices, (list, tuple)):
90 raise TypeError(
91 "Expected list for 'sparse_indices' argument to "
92 "'ragged_cross' Op, not %r." % sparse_indices)
93 _attr_Nsparse = len(sparse_indices)
94 if not isinstance(sparse_shape, (list, tuple)):
95 raise TypeError(
96 "Expected list for 'sparse_shape' argument to "
97 "'ragged_cross' Op, not %r." % sparse_shape)
98 if len(sparse_shape) != _attr_Nsparse:
99 raise ValueError(
100 "List argument 'sparse_shape' to 'ragged_cross' Op with length %d "
101 "must match length %d of argument 'sparse_indices'." %
102 (len(sparse_shape), _attr_Nsparse))
103 input_order = _execute.make_str(input_order, "input_order")
104 hashed_output = _execute.make_bool(hashed_output, "hashed_output")
105 num_buckets = _execute.make_int(num_buckets, "num_buckets")
106 hash_key = _execute.make_int(hash_key, "hash_key")
107 out_values_type = _execute.make_type(out_values_type, "out_values_type")
108 out_row_splits_type = _execute.make_type(out_row_splits_type, "out_row_splits_type")
109 _, _, _op, _outputs = _op_def_library._apply_op_helper(
110 "RaggedCross", ragged_values=ragged_values,
111 ragged_row_splits=ragged_row_splits,
112 sparse_indices=sparse_indices,
113 sparse_values=sparse_values, sparse_shape=sparse_shape,
114 dense_inputs=dense_inputs, input_order=input_order,
115 hashed_output=hashed_output, num_buckets=num_buckets,
116 hash_key=hash_key, out_values_type=out_values_type,
117 out_row_splits_type=out_row_splits_type, name=name)
118 _result = _outputs[:]
119 if _execute.must_record_gradient():
120 _attrs = ("Nsparse", _op._get_attr_int("Nsparse"), "input_order",
121 _op.get_attr("input_order"), "hashed_output",
122 _op._get_attr_bool("hashed_output"), "num_buckets",
123 _op._get_attr_int("num_buckets"), "hash_key",
124 _op._get_attr_int("hash_key"), "ragged_values_types",
125 _op.get_attr("ragged_values_types"), "ragged_splits_types",
126 _op.get_attr("ragged_splits_types"), "sparse_values_types",
127 _op.get_attr("sparse_values_types"), "dense_types",
128 _op.get_attr("dense_types"), "out_values_type",
129 _op._get_attr_type("out_values_type"), "out_row_splits_type",
130 _op._get_attr_type("out_row_splits_type"))
131 _inputs_flat = _op.inputs
132 _execute.record_gradient(
133 "RaggedCross", _inputs_flat, _attrs, _result)
134 _result = _RaggedCrossOutput._make(_result)
135 return _result
137RaggedCross = tf_export("raw_ops.RaggedCross")(_ops.to_raw_op(ragged_cross))
140def ragged_cross_eager_fallback(ragged_values, ragged_row_splits, sparse_indices, sparse_values, sparse_shape, dense_inputs, input_order, hashed_output, num_buckets, hash_key, out_values_type, out_row_splits_type, name, ctx):
141 if not isinstance(sparse_indices, (list, tuple)):
142 raise TypeError(
143 "Expected list for 'sparse_indices' argument to "
144 "'ragged_cross' Op, not %r." % sparse_indices)
145 _attr_Nsparse = len(sparse_indices)
146 if not isinstance(sparse_shape, (list, tuple)):
147 raise TypeError(
148 "Expected list for 'sparse_shape' argument to "
149 "'ragged_cross' Op, not %r." % sparse_shape)
150 if len(sparse_shape) != _attr_Nsparse:
151 raise ValueError(
152 "List argument 'sparse_shape' to 'ragged_cross' Op with length %d "
153 "must match length %d of argument 'sparse_indices'." %
154 (len(sparse_shape), _attr_Nsparse))
155 input_order = _execute.make_str(input_order, "input_order")
156 hashed_output = _execute.make_bool(hashed_output, "hashed_output")
157 num_buckets = _execute.make_int(num_buckets, "num_buckets")
158 hash_key = _execute.make_int(hash_key, "hash_key")
159 out_values_type = _execute.make_type(out_values_type, "out_values_type")
160 out_row_splits_type = _execute.make_type(out_row_splits_type, "out_row_splits_type")
161 _attr_ragged_values_types, ragged_values = _execute.convert_to_mixed_eager_tensors(ragged_values, ctx)
162 _attr_ragged_splits_types, ragged_row_splits = _execute.convert_to_mixed_eager_tensors(ragged_row_splits, ctx)
163 _attr_sparse_values_types, sparse_values = _execute.convert_to_mixed_eager_tensors(sparse_values, ctx)
164 _attr_dense_types, dense_inputs = _execute.convert_to_mixed_eager_tensors(dense_inputs, ctx)
165 sparse_indices = _ops.convert_n_to_tensor(sparse_indices, _dtypes.int64)
166 sparse_shape = _ops.convert_n_to_tensor(sparse_shape, _dtypes.int64)
167 _inputs_flat = list(ragged_values) + list(ragged_row_splits) + list(sparse_indices) + list(sparse_values) + list(sparse_shape) + list(dense_inputs)
168 _attrs = ("Nsparse", _attr_Nsparse, "input_order", input_order,
169 "hashed_output", hashed_output, "num_buckets", num_buckets, "hash_key",
170 hash_key, "ragged_values_types", _attr_ragged_values_types,
171 "ragged_splits_types", _attr_ragged_splits_types, "sparse_values_types",
172 _attr_sparse_values_types, "dense_types", _attr_dense_types,
173 "out_values_type", out_values_type, "out_row_splits_type",
174 out_row_splits_type)
175 _result = _execute.execute(b"RaggedCross", 2, inputs=_inputs_flat,
176 attrs=_attrs, ctx=ctx, name=name)
177 if _execute.must_record_gradient():
178 _execute.record_gradient(
179 "RaggedCross", _inputs_flat, _attrs, _result)
180 _result = _RaggedCrossOutput._make(_result)
181 return _result
183_RaggedFillEmptyRowsOutput = collections.namedtuple(
184 "RaggedFillEmptyRows",
185 ["output_value_rowids", "output_values", "empty_row_indicator", "reverse_index_map"])
188@_dispatch.add_fallback_dispatch_list
189@_dispatch.add_type_based_api_dispatcher
190@tf_export('ragged_fill_empty_rows')
191def ragged_fill_empty_rows(value_rowids, values, nrows, default_value, name=None):
192 r"""TODO: add doc.
194 Args:
195 value_rowids: A `Tensor` of type `int64`.
196 values: A `Tensor`.
197 nrows: A `Tensor` of type `int64`.
198 default_value: A `Tensor`. Must have the same type as `values`.
199 name: A name for the operation (optional).
201 Returns:
202 A tuple of `Tensor` objects (output_value_rowids, output_values, empty_row_indicator, reverse_index_map).
204 output_value_rowids: A `Tensor` of type `int64`.
205 output_values: A `Tensor`. Has the same type as `values`.
206 empty_row_indicator: A `Tensor` of type `bool`.
207 reverse_index_map: A `Tensor` of type `int64`.
208 """
209 _ctx = _context._context or _context.context()
210 tld = _ctx._thread_local_data
211 if tld.is_eager:
212 try:
213 _result = pywrap_tfe.TFE_Py_FastPathExecute(
214 _ctx, "RaggedFillEmptyRows", name, value_rowids, values, nrows,
215 default_value)
216 _result = _RaggedFillEmptyRowsOutput._make(_result)
217 return _result
218 except _core._NotOkStatusException as e:
219 _ops.raise_from_not_ok_status(e, name)
220 except _core._FallbackException:
221 pass
222 try:
223 _result = _dispatcher_for_ragged_fill_empty_rows(
224 (value_rowids, values, nrows, default_value, name,), None)
225 if _result is not NotImplemented:
226 return _result
227 return ragged_fill_empty_rows_eager_fallback(
228 value_rowids, values, nrows, default_value, name=name, ctx=_ctx)
229 except _core._SymbolicException:
230 pass # Add nodes to the TensorFlow graph.
231 except (TypeError, ValueError):
232 _result = _dispatch.dispatch(
233 ragged_fill_empty_rows, (), dict(value_rowids=value_rowids,
234 values=values, nrows=nrows,
235 default_value=default_value,
236 name=name)
237 )
238 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
239 return _result
240 raise
241 else:
242 _result = _dispatcher_for_ragged_fill_empty_rows(
243 (value_rowids, values, nrows, default_value, name,), None)
244 if _result is not NotImplemented:
245 return _result
246 # Add nodes to the TensorFlow graph.
247 try:
248 _, _, _op, _outputs = _op_def_library._apply_op_helper(
249 "RaggedFillEmptyRows", value_rowids=value_rowids, values=values,
250 nrows=nrows, default_value=default_value,
251 name=name)
252 except (TypeError, ValueError):
253 _result = _dispatch.dispatch(
254 ragged_fill_empty_rows, (), dict(value_rowids=value_rowids,
255 values=values, nrows=nrows,
256 default_value=default_value,
257 name=name)
258 )
259 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
260 return _result
261 raise
262 _result = _outputs[:]
263 if _execute.must_record_gradient():
264 _attrs = ("T", _op._get_attr_type("T"))
265 _inputs_flat = _op.inputs
266 _execute.record_gradient(
267 "RaggedFillEmptyRows", _inputs_flat, _attrs, _result)
268 _result = _RaggedFillEmptyRowsOutput._make(_result)
269 return _result
271RaggedFillEmptyRows = tf_export("raw_ops.RaggedFillEmptyRows")(_ops.to_raw_op(ragged_fill_empty_rows))
272_dispatcher_for_ragged_fill_empty_rows = ragged_fill_empty_rows._tf_type_based_dispatcher.Dispatch
275def ragged_fill_empty_rows_eager_fallback(value_rowids, values, nrows, default_value, name, ctx):
276 _attr_T, _inputs_T = _execute.args_to_matching_eager([values, default_value], ctx, [])
277 (values, default_value) = _inputs_T
278 value_rowids = _ops.convert_to_tensor(value_rowids, _dtypes.int64)
279 nrows = _ops.convert_to_tensor(nrows, _dtypes.int64)
280 _inputs_flat = [value_rowids, values, nrows, default_value]
281 _attrs = ("T", _attr_T)
282 _result = _execute.execute(b"RaggedFillEmptyRows", 4, inputs=_inputs_flat,
283 attrs=_attrs, ctx=ctx, name=name)
284 if _execute.must_record_gradient():
285 _execute.record_gradient(
286 "RaggedFillEmptyRows", _inputs_flat, _attrs, _result)
287 _result = _RaggedFillEmptyRowsOutput._make(_result)
288 return _result
290_RaggedFillEmptyRowsGradOutput = collections.namedtuple(
291 "RaggedFillEmptyRowsGrad",
292 ["d_values", "d_default_value"])
295@_dispatch.add_fallback_dispatch_list
296@_dispatch.add_type_based_api_dispatcher
297@tf_export('ragged_fill_empty_rows_grad')
298def ragged_fill_empty_rows_grad(reverse_index_map, grad_values, name=None):
299 r"""TODO: add doc.
301 Args:
302 reverse_index_map: A `Tensor` of type `int64`.
303 grad_values: A `Tensor`.
304 name: A name for the operation (optional).
306 Returns:
307 A tuple of `Tensor` objects (d_values, d_default_value).
309 d_values: A `Tensor`. Has the same type as `grad_values`.
310 d_default_value: A `Tensor`. Has the same type as `grad_values`.
311 """
312 _ctx = _context._context or _context.context()
313 tld = _ctx._thread_local_data
314 if tld.is_eager:
315 try:
316 _result = pywrap_tfe.TFE_Py_FastPathExecute(
317 _ctx, "RaggedFillEmptyRowsGrad", name, reverse_index_map, grad_values)
318 _result = _RaggedFillEmptyRowsGradOutput._make(_result)
319 return _result
320 except _core._NotOkStatusException as e:
321 _ops.raise_from_not_ok_status(e, name)
322 except _core._FallbackException:
323 pass
324 try:
325 _result = _dispatcher_for_ragged_fill_empty_rows_grad(
326 (reverse_index_map, grad_values, name,), None)
327 if _result is not NotImplemented:
328 return _result
329 return ragged_fill_empty_rows_grad_eager_fallback(
330 reverse_index_map, grad_values, name=name, ctx=_ctx)
331 except _core._SymbolicException:
332 pass # Add nodes to the TensorFlow graph.
333 except (TypeError, ValueError):
334 _result = _dispatch.dispatch(
335 ragged_fill_empty_rows_grad, (), dict(reverse_index_map=reverse_index_map,
336 grad_values=grad_values,
337 name=name)
338 )
339 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
340 return _result
341 raise
342 else:
343 _result = _dispatcher_for_ragged_fill_empty_rows_grad(
344 (reverse_index_map, grad_values, name,), None)
345 if _result is not NotImplemented:
346 return _result
347 # Add nodes to the TensorFlow graph.
348 try:
349 _, _, _op, _outputs = _op_def_library._apply_op_helper(
350 "RaggedFillEmptyRowsGrad", reverse_index_map=reverse_index_map,
351 grad_values=grad_values, name=name)
352 except (TypeError, ValueError):
353 _result = _dispatch.dispatch(
354 ragged_fill_empty_rows_grad, (), dict(reverse_index_map=reverse_index_map,
355 grad_values=grad_values,
356 name=name)
357 )
358 if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
359 return _result
360 raise
361 _result = _outputs[:]
362 if _execute.must_record_gradient():
363 _attrs = ("T", _op._get_attr_type("T"))
364 _inputs_flat = _op.inputs
365 _execute.record_gradient(
366 "RaggedFillEmptyRowsGrad", _inputs_flat, _attrs, _result)
367 _result = _RaggedFillEmptyRowsGradOutput._make(_result)
368 return _result
370RaggedFillEmptyRowsGrad = tf_export("raw_ops.RaggedFillEmptyRowsGrad")(_ops.to_raw_op(ragged_fill_empty_rows_grad))
371_dispatcher_for_ragged_fill_empty_rows_grad = ragged_fill_empty_rows_grad._tf_type_based_dispatcher.Dispatch
374def ragged_fill_empty_rows_grad_eager_fallback(reverse_index_map, grad_values, name, ctx):
375 _attr_T, (grad_values,) = _execute.args_to_matching_eager([grad_values], ctx, [])
376 reverse_index_map = _ops.convert_to_tensor(reverse_index_map, _dtypes.int64)
377 _inputs_flat = [reverse_index_map, grad_values]
378 _attrs = ("T", _attr_T)
379 _result = _execute.execute(b"RaggedFillEmptyRowsGrad", 2,
380 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
381 name=name)
382 if _execute.must_record_gradient():
383 _execute.record_gradient(
384 "RaggedFillEmptyRowsGrad", _inputs_flat, _attrs, _result)
385 _result = _RaggedFillEmptyRowsGradOutput._make(_result)
386 return _result
388_RaggedGatherOutput = collections.namedtuple(
389 "RaggedGather",
390 ["output_nested_splits", "output_dense_values"])
393def ragged_gather(params_nested_splits, params_dense_values, indices, OUTPUT_RAGGED_RANK, name=None):
394 r"""Gather ragged slices from `params` axis `0` according to `indices`.
396 Outputs a `RaggedTensor` output composed from `output_dense_values` and
397 `output_nested_splits`, such that:
399 ```python
400 output.shape = indices.shape + params.shape[1:]
401 output.ragged_rank = indices.shape.ndims + params.ragged_rank
402 output[i...j, d0...dn] = params[indices[i...j], d0...dn]
403 ```
405 where
407 * `params =
408 ragged.from_nested_row_splits(params_dense_values, params_nested_splits)`
409 provides the values that should be gathered.
410 * `indices` ia a dense tensor with dtype `int32` or `int64`, indicating which
411 values should be gathered.
412 * `output =
413 ragged.from_nested_row_splits(output_dense_values, output_nested_splits)`
414 is the output tensor.
416 (Note: This c++ op is used to implement the higher-level python
417 `tf.ragged.gather` op, which also supports ragged indices.)
419 Args:
420 params_nested_splits: A list of at least 1 `Tensor` objects with the same type in: `int32`, `int64`.
421 The `nested_row_splits` tensors that define the row-partitioning for the
422 `params` RaggedTensor input.
423 params_dense_values: A `Tensor`.
424 The `flat_values` for the `params` RaggedTensor. There was a terminology change
425 at the python level from dense_values to flat_values, so dense_values is the
426 deprecated name.
427 indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
428 Indices in the outermost dimension of `params` of the values that should be
429 gathered.
430 OUTPUT_RAGGED_RANK: An `int` that is `>= 0`.
431 The ragged rank of the output RaggedTensor. `output_nested_splits` will contain
432 this number of `row_splits` tensors. This value should equal
433 `indices.shape.ndims + params.ragged_rank - 1`.
434 name: A name for the operation (optional).
436 Returns:
437 A tuple of `Tensor` objects (output_nested_splits, output_dense_values).
439 output_nested_splits: A list of `OUTPUT_RAGGED_RANK` `Tensor` objects with the same type as `params_nested_splits`.
440 output_dense_values: A `Tensor`. Has the same type as `params_dense_values`.
441 """
442 _ctx = _context._context or _context.context()
443 tld = _ctx._thread_local_data
444 if tld.is_eager:
445 try:
446 _result = pywrap_tfe.TFE_Py_FastPathExecute(
447 _ctx, "RaggedGather", name, params_nested_splits, params_dense_values,
448 indices, "OUTPUT_RAGGED_RANK", OUTPUT_RAGGED_RANK)
449 _result = _RaggedGatherOutput._make(_result)
450 return _result
451 except _core._NotOkStatusException as e:
452 _ops.raise_from_not_ok_status(e, name)
453 except _core._FallbackException:
454 pass
455 try:
456 return ragged_gather_eager_fallback(
457 params_nested_splits, params_dense_values, indices,
458 OUTPUT_RAGGED_RANK=OUTPUT_RAGGED_RANK, name=name, ctx=_ctx)
459 except _core._SymbolicException:
460 pass # Add nodes to the TensorFlow graph.
461 # Add nodes to the TensorFlow graph.
462 if not isinstance(params_nested_splits, (list, tuple)):
463 raise TypeError(
464 "Expected list for 'params_nested_splits' argument to "
465 "'ragged_gather' Op, not %r." % params_nested_splits)
466 _attr_PARAMS_RAGGED_RANK = len(params_nested_splits)
467 OUTPUT_RAGGED_RANK = _execute.make_int(OUTPUT_RAGGED_RANK, "OUTPUT_RAGGED_RANK")
468 _, _, _op, _outputs = _op_def_library._apply_op_helper(
469 "RaggedGather", params_nested_splits=params_nested_splits,
470 params_dense_values=params_dense_values,
471 indices=indices,
472 OUTPUT_RAGGED_RANK=OUTPUT_RAGGED_RANK, name=name)
473 _result = _outputs[:]
474 if _execute.must_record_gradient():
475 _attrs = ("Tvalues", _op._get_attr_type("Tvalues"), "Tindices",
476 _op._get_attr_type("Tindices"), "Tsplits",
477 _op._get_attr_type("Tsplits"), "PARAMS_RAGGED_RANK",
478 _op._get_attr_int("PARAMS_RAGGED_RANK"), "OUTPUT_RAGGED_RANK",
479 _op._get_attr_int("OUTPUT_RAGGED_RANK"))
480 _inputs_flat = _op.inputs
481 _execute.record_gradient(
482 "RaggedGather", _inputs_flat, _attrs, _result)
483 _result = [_result[:OUTPUT_RAGGED_RANK]] + _result[OUTPUT_RAGGED_RANK:]
484 _result = _RaggedGatherOutput._make(_result)
485 return _result
487RaggedGather = tf_export("raw_ops.RaggedGather")(_ops.to_raw_op(ragged_gather))
490def ragged_gather_eager_fallback(params_nested_splits, params_dense_values, indices, OUTPUT_RAGGED_RANK, name, ctx):
491 if not isinstance(params_nested_splits, (list, tuple)):
492 raise TypeError(
493 "Expected list for 'params_nested_splits' argument to "
494 "'ragged_gather' Op, not %r." % params_nested_splits)
495 _attr_PARAMS_RAGGED_RANK = len(params_nested_splits)
496 OUTPUT_RAGGED_RANK = _execute.make_int(OUTPUT_RAGGED_RANK, "OUTPUT_RAGGED_RANK")
497 _attr_Tvalues, (params_dense_values,) = _execute.args_to_matching_eager([params_dense_values], ctx, [])
498 _attr_Tindices, (indices,) = _execute.args_to_matching_eager([indices], ctx, [_dtypes.int32, _dtypes.int64, ])
499 _attr_Tsplits, params_nested_splits = _execute.args_to_matching_eager(list(params_nested_splits), ctx, [_dtypes.int32, _dtypes.int64, ], _dtypes.int64)
500 _inputs_flat = list(params_nested_splits) + [params_dense_values, indices]
501 _attrs = ("Tvalues", _attr_Tvalues, "Tindices", _attr_Tindices, "Tsplits",
502 _attr_Tsplits, "PARAMS_RAGGED_RANK", _attr_PARAMS_RAGGED_RANK,
503 "OUTPUT_RAGGED_RANK", OUTPUT_RAGGED_RANK)
504 _result = _execute.execute(b"RaggedGather", OUTPUT_RAGGED_RANK + 1,
505 inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
506 name=name)
507 if _execute.must_record_gradient():
508 _execute.record_gradient(
509 "RaggedGather", _inputs_flat, _attrs, _result)
510 _result = [_result[:OUTPUT_RAGGED_RANK]] + _result[OUTPUT_RAGGED_RANK:]
511 _result = _RaggedGatherOutput._make(_result)
512 return _result