Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/ragged/ragged_string_ops.py: 25%
255 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Ragged operations for working with string Tensors."""
17import typing
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import tensor_spec
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import array_ops_stack
26from tensorflow.python.ops import cond
27from tensorflow.python.ops import gen_string_ops
28from tensorflow.python.ops import string_ops
29from tensorflow.python.ops.ragged import ragged_array_ops
30from tensorflow.python.ops.ragged import ragged_functional_ops
31from tensorflow.python.ops.ragged import ragged_math_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.util import compat as util_compat
34from tensorflow.python.util import deprecation
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.lazy_loader import LazyLoader
37from tensorflow.python.util.tf_export import tf_export
40map_fn_lib = LazyLoader("map_fn_lib", globals(),
41 "tensorflow.python.ops.map_fn")
44@tf_export("strings.bytes_split")
45@dispatch.add_dispatch_support
46def string_bytes_split(input, name=None): # pylint: disable=redefined-builtin
47 """Split string elements of `input` into bytes.
49 Examples:
51 >>> tf.strings.bytes_split('hello').numpy()
52 array([b'h', b'e', b'l', b'l', b'o'], dtype=object)
53 >>> tf.strings.bytes_split(['hello', '123'])
54 <tf.RaggedTensor [[b'h', b'e', b'l', b'l', b'o'], [b'1', b'2', b'3']]>
56 Note that this op splits strings into bytes, not unicode characters. To
57 split strings into unicode characters, use `tf.strings.unicode_split`.
59 See also: `tf.io.decode_raw`, `tf.strings.split`, `tf.strings.unicode_split`.
61 Args:
62 input: A string `Tensor` or `RaggedTensor`: the strings to split. Must
63 have a statically known rank (`N`).
64 name: A name for the operation (optional).
66 Returns:
67 A `RaggedTensor` of rank `N+1`: the bytes that make up the source strings.
68 """
69 with ops.name_scope(name, "StringsByteSplit", [input]):
70 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input,
71 name="input")
72 if isinstance(input, ragged_tensor.RaggedTensor):
73 return input.with_flat_values(string_bytes_split(input.flat_values))
75 rank = input.shape.ndims
76 if rank is None:
77 raise ValueError("input must have a statically-known rank.")
79 if rank == 0:
80 return string_bytes_split(array_ops_stack.stack([input]))[0]
81 elif rank == 1:
82 indices, values, shape = gen_string_ops.string_split(
83 input, delimiter="", skip_empty=False)
84 return ragged_tensor.RaggedTensor.from_value_rowids(
85 values=values, value_rowids=indices[:, 0], nrows=shape[0],
86 validate=False)
87 else:
88 return string_bytes_split(ragged_tensor.RaggedTensor.from_tensor(input))
91# pylint: disable=redefined-builtin
92@tf_export("strings.unicode_encode")
93@dispatch.add_dispatch_support
94def unicode_encode(input,
95 output_encoding,
96 errors="replace",
97 replacement_char=65533,
98 name=None):
99 r"""Encodes each sequence of Unicode code points in `input` into a string.
101 `result[i1...iN]` is the string formed by concatenating the Unicode
102 codepoints `input[1...iN, :]`, encoded using `output_encoding`.
104 Args:
105 input: An `N+1` dimensional potentially ragged integer tensor with shape
106 `[D1...DN, num_chars]`.
107 output_encoding: Unicode encoding that should be used to encode each
108 codepoint sequence. Can be `"UTF-8"`, `"UTF-16-BE"`, or `"UTF-32-BE"`.
109 errors: Specifies the response when an invalid codepoint is encountered
110 (optional). One of:
111 * `'replace'`: Replace invalid codepoint with the
112 `replacement_char`. (default)
113 * `'ignore'`: Skip invalid codepoints.
114 * `'strict'`: Raise an exception for any invalid codepoint.
115 replacement_char: The replacement character codepoint to be used in place of
116 any invalid input when `errors='replace'`. Any valid unicode codepoint may
117 be used. The default value is the default unicode replacement character
118 which is 0xFFFD (U+65533).
119 name: A name for the operation (optional).
121 Returns:
122 A `N` dimensional `string` tensor with shape `[D1...DN]`.
124 #### Example:
126 >>> input = tf.ragged.constant(
127 ... [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]])
128 >>> print(unicode_encode(input, 'UTF-8'))
129 tf.Tensor([b'G\xc3\xb6\xc3\xb6dnight' b'\xf0\x9f\x98\x8a'],
130 shape=(2,), dtype=string)
131 """
132 with ops.name_scope(name, "UnicodeEncode", [input]):
133 input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
134 if input_tensor.shape.ndims is None:
135 raise ValueError("Rank of input_tensor must be statically known.")
136 if ragged_tensor.is_ragged(input_tensor):
137 if input_tensor.flat_values.shape.ndims > 1:
138 # If the flat_values of our ragged tensor is multi-dimensional, we can
139 # process it separately and our output will have the same nested splits
140 # as our input.
141 return input_tensor.with_flat_values(
142 unicode_encode(input_tensor.flat_values, output_encoding, errors,
143 replacement_char))
144 elif input_tensor.ragged_rank > 1:
145 # Recursively process the values of the ragged tensor.
146 return input_tensor.with_values(
147 unicode_encode(input_tensor.values, output_encoding, errors,
148 replacement_char))
149 else:
150 # Our ragged tensor is of the correct shape (rank 1 flat_values tensor
151 # with ragged_rank of 1) so we can process it as normal.
152 return gen_string_ops.unicode_encode(
153 input_values=input_tensor.values,
154 input_splits=input_tensor.row_splits,
155 output_encoding=output_encoding,
156 errors=errors,
157 replacement_char=replacement_char)
158 else:
159 if input_tensor.shape.ndims == 2:
160 # The input tensor is of the correct 2-D shape, it's just not ragged.
161 return unicode_encode(
162 ragged_tensor.RaggedTensor.from_tensor(input_tensor),
163 output_encoding, errors, replacement_char)
164 elif input_tensor.shape.ndims > 2:
165 # We need to initially flatten the input tensor to 2-D, and then can
166 # reshape the output of our processed flattened tensor.
167 flat_input_tensor = array_ops.reshape(
168 input_tensor,
169 array_ops_stack.stack([-1, array_ops.shape(input_tensor)[-1]]))
170 flat_output_tensor = unicode_encode(flat_input_tensor, output_encoding,
171 errors, replacement_char)
172 return array_ops.reshape(flat_output_tensor, input_tensor.shape[:-1])
173 elif input_tensor.shape.ndims == 0:
174 raise ValueError("input_tensor's rank must be at least 1.")
175 else:
176 # Our input tensor is rank 1, so we create a ragged tensor with an added
177 # dimension to create the correct input shape & type, and then remove
178 # the additional dimension from the output and return the string scalar.
179 ragged_input_tensor = ragged_tensor.RaggedTensor.from_row_splits(
180 input_tensor,
181 array_ops_stack.stack(
182 [0, array_ops.shape(input_tensor, out_type=dtypes.int32)[0]]),
183 validate=False)
184 output_tensor = unicode_encode(ragged_input_tensor, output_encoding,
185 errors, replacement_char)
186 return array_ops.reshape(output_tensor, [])
189# pylint: disable=redefined-builtin
190@tf_export("strings.unicode_decode")
191@dispatch.add_dispatch_support
192def unicode_decode(input,
193 input_encoding,
194 errors="replace",
195 replacement_char=0xFFFD,
196 replace_control_characters=False,
197 name=None):
198 r"""Decodes each string in `input` into a sequence of Unicode code points.
200 `result[i1...iN, j]` is the Unicode codepoint for the `j`th character in
201 `input[i1...iN]`, when decoded using `input_encoding`.
203 Args:
204 input: An `N` dimensional potentially ragged `string` tensor with shape
205 `[D1...DN]`. `N` must be statically known.
206 input_encoding: String name for the unicode encoding that should be used to
207 decode each string.
208 errors: Specifies the response when an input string can't be converted
209 using the indicated encoding. One of:
210 * `'strict'`: Raise an exception for any illegal substrings.
211 * `'replace'`: Replace illegal substrings with `replacement_char`.
212 * `'ignore'`: Skip illegal substrings.
213 replacement_char: The replacement codepoint to be used in place of invalid
214 substrings in `input` when `errors='replace'`; and in place of C0 control
215 characters in `input` when `replace_control_characters=True`.
216 replace_control_characters: Whether to replace the C0 control characters
217 `(U+0000 - U+001F)` with the `replacement_char`.
218 name: A name for the operation (optional).
220 Returns:
221 A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
222 The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
223 `tf.RaggedTensor` otherwise.
225 #### Example:
227 >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
228 >>> tf.strings.unicode_decode(input, 'UTF-8').to_list()
229 [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
230 """
231 with ops.name_scope(name, "UnicodeDecode", [input]):
232 return _unicode_decode(input, input_encoding, errors, replacement_char,
233 replace_control_characters, with_offsets=False)
236@tf_export("strings.unicode_decode_with_offsets")
237@dispatch.add_dispatch_support
238def unicode_decode_with_offsets(input,
239 input_encoding,
240 errors="replace",
241 replacement_char=0xFFFD,
242 replace_control_characters=False,
243 name=None):
244 r"""Decodes each string into a sequence of code points with start offsets.
246 This op is similar to `tf.strings.decode(...)`, but it also returns the
247 start offset for each character in its respective string. This information
248 can be used to align the characters with the original byte sequence.
250 Returns a tuple `(codepoints, start_offsets)` where:
252 * `codepoints[i1...iN, j]` is the Unicode codepoint for the `j`th character
253 in `input[i1...iN]`, when decoded using `input_encoding`.
254 * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
255 character in `input[i1...iN]`, when decoded using `input_encoding`.
257 Args:
258 input: An `N` dimensional potentially ragged `string` tensor with shape
259 `[D1...DN]`. `N` must be statically known.
260 input_encoding: String name for the unicode encoding that should be used to
261 decode each string.
262 errors: Specifies the response when an input string can't be converted
263 using the indicated encoding. One of:
264 * `'strict'`: Raise an exception for any illegal substrings.
265 * `'replace'`: Replace illegal substrings with `replacement_char`.
266 * `'ignore'`: Skip illegal substrings.
267 replacement_char: The replacement codepoint to be used in place of invalid
268 substrings in `input` when `errors='replace'`; and in place of C0 control
269 characters in `input` when `replace_control_characters=True`.
270 replace_control_characters: Whether to replace the C0 control characters
271 `(U+0000 - U+001F)` with the `replacement_char`.
272 name: A name for the operation (optional).
274 Returns:
275 A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
277 * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
278 * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
280 The returned tensors are `tf.Tensor`s if `input` is a scalar, or
281 `tf.RaggedTensor`s otherwise.
283 #### Example:
285 >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
286 >>> result = tf.strings.unicode_decode_with_offsets(input, 'UTF-8')
287 >>> result[0].to_list() # codepoints
288 [[71, 246, 246, 100, 110, 105, 103, 104, 116], [128522]]
289 >>> result[1].to_list() # offsets
290 [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
292 """
293 with ops.name_scope(name, "UnicodeDecodeWithOffsets", [input]):
294 return _unicode_decode(input, input_encoding, errors, replacement_char,
295 replace_control_characters, with_offsets=True)
298@tf_export("strings.unicode_split")
299@dispatch.add_dispatch_support
300def unicode_split(input,
301 input_encoding,
302 errors="replace",
303 replacement_char=0xFFFD,
304 name=None):
305 r"""Splits each string in `input` into a sequence of Unicode code points.
307 `result[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
308 `j`th character, when decoded using `input_encoding`.
310 Args:
311 input: An `N` dimensional potentially ragged `string` tensor with shape
312 `[D1...DN]`. `N` must be statically known.
313 input_encoding: String name for the unicode encoding that should be used to
314 decode each string.
315 errors: Specifies the response when an input string can't be converted
316 using the indicated encoding. One of:
317 * `'strict'`: Raise an exception for any illegal substrings.
318 * `'replace'`: Replace illegal substrings with `replacement_char`.
319 * `'ignore'`: Skip illegal substrings.
320 replacement_char: The replacement codepoint to be used in place of invalid
321 substrings in `input` when `errors='replace'`.
322 name: A name for the operation (optional).
324 Returns:
325 A `N+1` dimensional `int32` tensor with shape `[D1...DN, (num_chars)]`.
326 The returned tensor is a `tf.Tensor` if `input` is a scalar, or a
327 `tf.RaggedTensor` otherwise.
329 #### Example:
331 >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
332 >>> tf.strings.unicode_split(input, 'UTF-8').to_list()
333 [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
334 [b'\xf0\x9f\x98\x8a']]
335 """
336 with ops.name_scope(name, "UnicodeSplit", [input]):
337 codepoints = _unicode_decode(input, input_encoding, errors,
338 replacement_char, False, with_offsets=False)
339 return unicode_encode(
340 ragged_array_ops.expand_dims(codepoints, -1),
341 output_encoding=input_encoding,
342 errors=errors,
343 replacement_char=replacement_char)
346@tf_export("strings.unicode_split_with_offsets")
347@dispatch.add_dispatch_support
348def unicode_split_with_offsets(input,
349 input_encoding,
350 errors="replace",
351 replacement_char=0xFFFD,
352 name=None):
353 r"""Splits each string into a sequence of code points with start offsets.
355 This op is similar to `tf.strings.decode(...)`, but it also returns the
356 start offset for each character in its respective string. This information
357 can be used to align the characters with the original byte sequence.
359 Returns a tuple `(chars, start_offsets)` where:
361 * `chars[i1...iN, j]` is the substring of `input[i1...iN]` that encodes its
362 `j`th character, when decoded using `input_encoding`.
363 * `start_offsets[i1...iN, j]` is the start byte offset for the `j`th
364 character in `input[i1...iN]`, when decoded using `input_encoding`.
366 Args:
367 input: An `N` dimensional potentially ragged `string` tensor with shape
368 `[D1...DN]`. `N` must be statically known.
369 input_encoding: String name for the unicode encoding that should be used to
370 decode each string.
371 errors: Specifies the response when an input string can't be converted
372 using the indicated encoding. One of:
373 * `'strict'`: Raise an exception for any illegal substrings.
374 * `'replace'`: Replace illegal substrings with `replacement_char`.
375 * `'ignore'`: Skip illegal substrings.
376 replacement_char: The replacement codepoint to be used in place of invalid
377 substrings in `input` when `errors='replace'`.
378 name: A name for the operation (optional).
380 Returns:
381 A tuple of `N+1` dimensional tensors `(codepoints, start_offsets)`.
383 * `codepoints` is an `int32` tensor with shape `[D1...DN, (num_chars)]`.
384 * `offsets` is an `int64` tensor with shape `[D1...DN, (num_chars)]`.
386 The returned tensors are `tf.Tensor`s if `input` is a scalar, or
387 `tf.RaggedTensor`s otherwise.
389 #### Example:
391 >>> input = [s.encode('utf8') for s in (u'G\xf6\xf6dnight', u'\U0001f60a')]
392 >>> result = tf.strings.unicode_split_with_offsets(input, 'UTF-8')
393 >>> result[0].to_list() # character substrings
394 [[b'G', b'\xc3\xb6', b'\xc3\xb6', b'd', b'n', b'i', b'g', b'h', b't'],
395 [b'\xf0\x9f\x98\x8a']]
396 >>> result[1].to_list() # offsets
397 [[0, 1, 3, 5, 6, 7, 8, 9, 10], [0]]
399 """
400 with ops.name_scope(name, "UnicodeSplitWithOffsets", [input]):
401 codepoints, offsets = _unicode_decode(input, input_encoding, errors,
402 replacement_char, False,
403 with_offsets=True)
404 chars = unicode_encode(
405 ragged_array_ops.expand_dims(codepoints, -1),
406 output_encoding=input_encoding,
407 errors=errors,
408 replacement_char=replacement_char)
409 return chars, offsets
412def _unicode_decode(input, input_encoding, errors, replacement_char,
413 replace_control_characters, with_offsets):
414 """Decodes each string into a sequence of codepoints."""
415 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(input, name="input")
416 input_ndims = input.shape.ndims
417 if input_ndims is None:
418 raise ValueError("Rank of `input` must be statically known.")
420 if input_ndims > 1:
421 # Convert to a ragged tensor with ragged_rank = input_ndims - 1.
422 if not ragged_tensor.is_ragged(input):
423 input = ragged_tensor.RaggedTensor.from_tensor(
424 input, ragged_rank=input_ndims - 1)
425 elif input.ragged_rank < input_ndims - 1:
426 input = input.with_flat_values(
427 ragged_tensor.RaggedTensor.from_tensor(
428 input.flat_values,
429 ragged_rank=input_ndims - input.ragged_rank - 1))
431 # Reshape the input to a flat vector, and apply the gen_string_ops op.
432 if ragged_tensor.is_ragged(input):
433 flat_input = array_ops.reshape(input.flat_values, [-1])
434 else:
435 flat_input = array_ops.reshape(input, [-1])
437 if with_offsets:
438 decode_op = gen_string_ops.unicode_decode_with_offsets
439 else:
440 decode_op = gen_string_ops.unicode_decode
441 flat_result = decode_op(
442 input=flat_input,
443 input_encoding=input_encoding,
444 errors=errors,
445 replacement_char=replacement_char,
446 replace_control_characters=replace_control_characters)
448 if input_ndims == 0:
449 codepoints = flat_result.char_values
450 if with_offsets:
451 offsets = flat_result.char_to_byte_starts
452 else:
453 codepoints = ragged_tensor.RaggedTensor.from_row_splits(
454 flat_result.char_values, flat_result.row_splits, validate=False)
455 if input_ndims > 1:
456 codepoints = input.with_flat_values(codepoints)
457 if with_offsets:
458 offsets = ragged_tensor.RaggedTensor.from_row_splits(
459 flat_result.char_to_byte_starts, flat_result.row_splits,
460 validate=False)
461 if input_ndims > 1:
462 offsets = input.with_flat_values(offsets)
464 if with_offsets:
465 return codepoints, offsets
466 else:
467 return codepoints
470@tf_export("strings.split", v1=[])
471@dispatch.add_dispatch_support
472def string_split_v2(input, sep=None, maxsplit=-1, name=None): # pylint: disable=redefined-builtin
473 """Split elements of `input` based on `sep` into a `RaggedTensor`.
475 Let N be the size of `input` (typically N will be the batch size). Split each
476 element of `input` based on `sep` and return a `RaggedTensor` containing the
477 split tokens. Empty tokens are ignored.
479 Example:
481 >>> tf.strings.split('hello world').numpy()
482 array([b'hello', b'world'], dtype=object)
483 >>> tf.strings.split(['hello world', 'a b c'])
484 <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
486 If `sep` is given, consecutive delimiters are not grouped together and are
487 deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
488 `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
489 string, consecutive whitespace are regarded as a single separator, and the
490 result will contain no empty strings at the start or end if the string has
491 leading or trailing whitespace.
493 Note that the above mentioned behavior matches python's str.split.
495 Args:
496 input: A string `Tensor` of rank `N`, the strings to split. If
497 `rank(input)` is not known statically, then it is assumed to be `1`.
498 sep: `0-D` string `Tensor`, the delimiter string.
499 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
500 name: A name for the operation (optional).
502 Raises:
503 ValueError: If sep is not a string.
505 Returns:
506 A `RaggedTensor` of rank `N+1`, the strings split according to the
507 delimiter.
508 """
509 with ops.name_scope(name, "StringSplit", [input]):
510 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
511 input, dtype=dtypes.string, name="input")
512 if isinstance(input, ragged_tensor.RaggedTensor):
513 return input.with_flat_values(
514 string_split_v2(input.flat_values, sep, maxsplit))
516 rank = input.shape.ndims
517 if rank == 0:
518 return string_split_v2(array_ops_stack.stack([input]), sep, maxsplit)[0]
519 elif rank == 1 or rank is None:
520 sparse_result = string_ops.string_split_v2(
521 input, sep=sep, maxsplit=maxsplit)
522 return ragged_tensor.RaggedTensor.from_value_rowids(
523 values=sparse_result.values,
524 value_rowids=sparse_result.indices[:, 0],
525 nrows=sparse_result.dense_shape[0],
526 validate=False)
527 else:
528 return string_split_v2(
529 ragged_tensor.RaggedTensor.from_tensor(input), sep, maxsplit)
532@tf_export(v1=["string_split"])
533@dispatch.add_dispatch_support
534@deprecation.deprecated_args(None,
535 "delimiter is deprecated, please use sep instead.",
536 "delimiter")
537def string_split(source, sep=None, skip_empty=True, delimiter=None,
538 result_type="SparseTensor", name=None): # pylint: disable=invalid-name
539 """Split elements of `source` based on `delimiter`.
541 Let N be the size of `source` (typically N will be the batch size). Split each
542 element of `source` based on `delimiter` and return a `SparseTensor`
543 or `RaggedTensor` containing the split tokens. Empty tokens are ignored.
545 If `sep` is an empty string, each element of the `source` is split
546 into individual strings, each containing one byte. (This includes splitting
547 multibyte sequences of UTF-8.) If delimiter contains multiple bytes, it is
548 treated as a set of delimiters with each considered a potential split point.
550 Examples:
552 >>> print(tf.compat.v1.string_split(['hello world', 'a b c']))
553 SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
554 values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
555 dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
557 >>> print(tf.compat.v1.string_split(['hello world', 'a b c'],
558 ... result_type="RaggedTensor"))
559 <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
561 Args:
562 source: `1-D` string `Tensor`, the strings to split.
563 sep: `0-D` string `Tensor`, the delimiter character, the string should
564 be length 0 or 1. Default is ' '.
565 skip_empty: A `bool`. If `True`, skip the empty strings from the result.
566 delimiter: deprecated alias for `sep`.
567 result_type: The tensor type for the result: one of `"RaggedTensor"` or
568 `"SparseTensor"`.
569 name: A name for the operation (optional).
571 Raises:
572 ValueError: If delimiter is not a string.
574 Returns:
575 A `SparseTensor` or `RaggedTensor` of rank `2`, the strings split according
576 to the delimiter. The first column of the indices corresponds to the row
577 in `source` and the second column corresponds to the index of the split
578 component in this row.
579 """
580 with ops.name_scope(name, "StringSplit", [source]):
581 sparse_result = string_ops.string_split(
582 source, sep=sep, skip_empty=skip_empty, delimiter=delimiter)
583 if result_type == "SparseTensor":
584 return sparse_result
585 elif result_type == "RaggedTensor":
586 return ragged_tensor.RaggedTensor.from_value_rowids(
587 values=sparse_result.values,
588 value_rowids=sparse_result.indices[:, 0],
589 nrows=sparse_result.dense_shape[0],
590 validate=False)
591 else:
592 raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
595# In TensorFlow 1.x, "tf.strings.split" uses the new signature (with maxsplit),
596# but we need to add the result_type argument.
597@tf_export(v1=["strings.split"])
598@dispatch.add_dispatch_support
599def strings_split_v1(input=None, sep=None, maxsplit=-1, # pylint: disable=redefined-builtin
600 result_type="SparseTensor", source=None, name=None):
601 """Split elements of `input` based on `sep`.
603 Let N be the size of `input` (typically N will be the batch size). Split each
604 element of `input` based on `sep` and return a `SparseTensor` or
605 `RaggedTensor` containing the split tokens. Empty tokens are ignored.
607 Examples:
609 >>> print(tf.compat.v1.strings.split(['hello world', 'a b c']))
610 SparseTensor(indices=tf.Tensor( [[0 0] [0 1] [1 0] [1 1] [1 2]], ...),
611 values=tf.Tensor([b'hello' b'world' b'a' b'b' b'c'], ...),
612 dense_shape=tf.Tensor([2 3], shape=(2,), dtype=int64))
614 >>> print(tf.compat.v1.strings.split(['hello world', 'a b c'],
615 ... result_type="RaggedTensor"))
616 <tf.RaggedTensor [[b'hello', b'world'], [b'a', b'b', b'c']]>
618 If `sep` is given, consecutive delimiters are not grouped together and are
619 deemed to delimit empty strings. For example, `input` of `"1<>2<><>3"` and
620 `sep` of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
621 string, consecutive whitespace are regarded as a single separator, and the
622 result will contain no empty strings at the start or end if the string has
623 leading or trailing whitespace.
625 Note that the above mentioned behavior matches python's str.split.
627 Args:
628 input: A string `Tensor` of rank `N`, the strings to split. If
629 `rank(input)` is not known statically, then it is assumed to be `1`.
630 sep: `0-D` string `Tensor`, the delimiter character.
631 maxsplit: An `int`. If `maxsplit > 0`, limit of the split of the result.
632 result_type: The tensor type for the result: one of `"RaggedTensor"` or
633 `"SparseTensor"`.
634 source: alias for "input" argument.
635 name: A name for the operation (optional).
637 Raises:
638 ValueError: If sep is not a string.
640 Returns:
641 A `SparseTensor` or `RaggedTensor` of rank `N+1`, the strings split
642 according to the delimiter.
643 """
644 input = deprecation.deprecated_argument_lookup(
645 "input", input, "source", source)
646 with ops.name_scope(name, "StringSplit", [input]):
647 input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
648 input, dtype=dtypes.string, name="input")
650 if input.shape.rank == 0:
651 input = array_ops.expand_dims(input, 0)
653 if result_type == "SparseTensor":
654 if input.shape.rank == 1:
655 return string_ops.string_split_v2(input, sep=sep, maxsplit=maxsplit)
656 else:
657 return string_split_v2(input, sep=sep, maxsplit=maxsplit).to_sparse()
658 elif result_type == "RaggedTensor":
659 return string_split_v2(input, sep=sep, maxsplit=maxsplit)
660 else:
661 raise ValueError("result_type must be 'RaggedTensor' or 'SparseTensor'.")
664@dispatch.dispatch_for_api(string_ops.reduce_join_v2)
665def reduce_join(inputs: ragged_tensor.Ragged,
666 axis=None,
667 keepdims=None,
668 separator="",
669 name=None):
670 """For docs, see: _RAGGED_REDUCE_DOCSTRING."""
671 return ragged_math_ops.ragged_reduce_aggregate(
672 string_ops.reduce_join, string_ops.unsorted_segment_join, inputs, axis,
673 keepdims, separator, name or "RaggedSegmentJoin")
676@tf_export("strings.ngrams")
677@dispatch.add_dispatch_support
678def ngrams(data,
679 ngram_width,
680 separator=" ",
681 pad_values=None,
682 padding_width=None,
683 preserve_short_sequences=False,
684 name=None):
685 """Create a tensor of n-grams based on `data`.
687 Creates a tensor of n-grams based on `data`. The n-grams are created by
688 joining windows of `width` adjacent strings from the inner axis of `data`
689 using `separator`.
691 The input data can be padded on both the start and end of the sequence, if
692 desired, using the `pad_values` argument. If set, `pad_values` should contain
693 either a tuple of strings or a single string; the 0th element of the tuple
694 will be used to pad the left side of the sequence and the 1st element of the
695 tuple will be used to pad the right side of the sequence. The `padding_width`
696 arg controls how many padding values are added to each side; it defaults to
697 `ngram_width-1`.
699 If this op is configured to not have padding, or if it is configured to add
700 padding with `padding_width` set to less than ngram_width-1, it is possible
701 that a sequence, or a sequence plus padding, is smaller than the ngram
702 width. In that case, no ngrams will be generated for that sequence. This can
703 be prevented by setting `preserve_short_sequences`, which will cause the op
704 to always generate at least one ngram per non-empty sequence.
706 Examples:
708 >>> tf.strings.ngrams(["A", "B", "C", "D"], 2).numpy()
709 array([b'A B', b'B C', b'C D'], dtype=object)
710 >>> tf.strings.ngrams(["TF", "and", "keras"], 1).numpy()
711 array([b'TF', b'and', b'keras'], dtype=object)
713 Args:
714 data: A Tensor or RaggedTensor containing the source data for the ngrams.
715 ngram_width: The width(s) of the ngrams to create. If this is a list or
716 tuple, the op will return ngrams of all specified arities in list order.
717 Values must be non-Tensor integers greater than 0.
718 separator: The separator string used between ngram elements. Must be a
719 string constant, not a Tensor.
720 pad_values: A tuple of (left_pad_value, right_pad_value), a single string,
721 or None. If None, no padding will be added; if a single string, then that
722 string will be used for both left and right padding. Values must be Python
723 strings.
724 padding_width: If set, `padding_width` pad values will be added to both
725 sides of each sequence. Defaults to `ngram_width`-1. Must be greater than
726 0. (Note that 1-grams are never padded, regardless of this value.)
727 preserve_short_sequences: If true, then ensure that at least one ngram is
728 generated for each input sequence. In particular, if an input sequence is
729 shorter than `min(ngram_width) + 2*pad_width`, then generate a single
730 ngram containing the entire sequence. If false, then no ngrams are
731 generated for these short input sequences.
732 name: The op name.
734 Returns:
735 A RaggedTensor of ngrams. If `data.shape=[D1...DN, S]`, then
736 `output.shape=[D1...DN, NUM_NGRAMS]`, where
737 `NUM_NGRAMS=S-ngram_width+1+2*padding_width`.
739 Raises:
740 TypeError: if `pad_values` is set to an invalid type.
741 ValueError: if `pad_values`, `padding_width`, or `ngram_width` is set to an
742 invalid value.
743 """
745 with ops.name_scope(name, "StringNGrams", [data]):
746 if pad_values is None:
747 left_pad = ""
748 right_pad = ""
749 elif isinstance(pad_values, (list, tuple)):
750 if (not isinstance(pad_values[0], util_compat.bytes_or_text_types) or
751 not isinstance(pad_values[1], util_compat.bytes_or_text_types)):
752 raise TypeError(
753 "pad_values must be a string, tuple of strings, or None.")
754 left_pad = pad_values[0]
755 right_pad = pad_values[1]
756 else:
757 if not isinstance(pad_values, util_compat.bytes_or_text_types):
758 raise TypeError(
759 "pad_values must be a string, tuple of strings, or None.")
760 left_pad = pad_values
761 right_pad = pad_values
763 if padding_width is not None and padding_width < 1:
764 raise ValueError("padding_width must be greater than 0.")
766 if padding_width is not None and pad_values is None:
767 raise ValueError("pad_values must be provided if padding_width is set.")
769 data = ragged_tensor.convert_to_tensor_or_ragged_tensor(
770 data, name="data", dtype=dtypes.string)
772 # preserve the shape of the data if it is a tensor
773 to_tensor = False
774 if isinstance(data, ops.Tensor):
775 dense_shape = array_ops.concat([array_ops.shape(data)[:-1], [-1]], axis=0)
776 to_tensor = True
778 if not isinstance(data, ragged_tensor.RaggedTensor):
779 if data.shape.ndims is None:
780 raise ValueError("Rank of data must be known.")
781 elif data.shape.ndims == 0:
782 raise ValueError("Data must have rank>0")
783 elif data.shape.ndims == 1:
784 rt = ragged_tensor.RaggedTensor.from_row_starts(
785 data, [0], validate=False)
786 return ngrams(rt, ngram_width, separator, pad_values, padding_width,
787 preserve_short_sequences, name)[0]
788 else:
789 data = ragged_tensor.RaggedTensor.from_tensor(
790 data, ragged_rank=data.shape.ndims - 1)
792 if data.ragged_rank > 1:
793 output = data.with_values(
794 ngrams(data.values, ngram_width, separator, pad_values, padding_width,
795 preserve_short_sequences, name))
796 return array_ops.reshape(output.flat_values,
797 dense_shape) if to_tensor else output
799 if pad_values is None:
800 padding_width = 0
802 if pad_values is not None and padding_width is None:
803 padding_width = -1
805 if not isinstance(ngram_width, (list, tuple)):
806 ngram_widths = [ngram_width]
807 else:
808 ngram_widths = ngram_width
809 for width in ngram_widths:
810 if width < 1:
811 raise ValueError("All ngram_widths must be greater than 0. Got %s" %
812 ngram_width)
814 output, output_splits = gen_string_ops.string_n_grams(
815 data=data.flat_values,
816 data_splits=data.row_splits,
817 separator=separator,
818 ngram_widths=ngram_widths,
819 left_pad=left_pad,
820 right_pad=right_pad,
821 pad_width=padding_width,
822 preserve_short_sequences=preserve_short_sequences)
824 # if the input is Dense tensor, the output should also be a dense tensor
825 output = ragged_tensor.RaggedTensor.from_row_splits(
826 values=output, row_splits=output_splits, validate=False)
827 return array_ops.reshape(output.flat_values,
828 dense_shape) if to_tensor else output
831@dispatch.dispatch_for_api(string_ops.string_format)
832def string_format(
833 template: str,
834 inputs: typing.Union[ragged_tensor.Ragged,
835 typing.List[ragged_tensor.RaggedOrDense]],
836 placeholder="{}",
837 summarize=3,
838 name=None):
839 """Version of tf.strings.format that handles RaggedTensors."""
840 if tensor_util.is_tf_type(inputs) or ragged_tensor.is_ragged(inputs):
841 inputs = [inputs]
843 split_template = template.split(placeholder)
844 if len(inputs) != len(split_template) - 1:
845 raise ValueError("num placeholders in template and num inputs must match"
846 ": {} vs {}".format(len(split_template) - 1, len(inputs)))
848 with ops.name_scope(name, "StringFormat", [inputs]):
849 output_pieces = [constant_op.constant(split_template[0])]
850 for i, input in enumerate(inputs):
851 if ragged_tensor.is_ragged(input):
852 output_pieces.append(ragged_tensor_to_string(input, summarize))
853 else:
854 output_pieces.append(string_ops.string_format(
855 "{}", [input], summarize=summarize))
856 output_pieces.append(constant_op.constant(split_template[i + 1]))
857 if len(output_pieces) == 1:
858 return output_pieces[0]
859 else:
860 return string_ops.reduce_join(output_pieces)
863def ragged_tensor_to_string(rt, summarize=None):
864 """Returns a scalar string tensor with the contents of a RaggedTensor.
866 Requires that `rt.shape.rank` is not `None`.
868 Note: this converts the entire `RaggedTensor` into a single string scalar.
869 If you want to convert individual elements, use `tf.strings.as_string(rt)`.
871 >>> rt1 = tf.ragged.constant([[1, 2, 3], [4, 5]])
872 >>> ragged_tensor_to_string(rt1).numpy()
873 b'[[1, 2, 3], [4, 5]]'
875 >>> rt2 = tf.ragged.constant([[['a'], ['b', 'c']], [['d', 'e', 'f'], []]])
876 >>> ragged_tensor_to_string(rt2).numpy()
877 b"[[['a'], ['b', 'c']], [['d', 'e', 'f'], []]]"
879 >>> rt3 = tf.ragged.constant([[1], [2, 3, 4, 5, 6], [], [], [7], [8, 9]])
880 >>> ragged_tensor_to_string(rt3, summarize=2).numpy()
881 b'[[1], [2, 3, ..., 5, 6], ..., [7], [8, 9]]'
883 Args:
884 rt: The RaggedTensor that should be converted to a string.
885 summarize: If specified, then only the first and last `summarize` elements
886 within each dimension are included in the string. If `-1` or `None`, then
887 all elements are included.
888 """
889 if (summarize is not None and summarize != -1 and
890 not (isinstance(summarize, int) and summarize > 0)):
891 raise ValueError("Expected summarize to be -1 or a positive int, got %r" %
892 summarize)
893 with ops.name_scope(None, "AsString", [rt]):
894 rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
895 if rt.shape.rank is None:
896 raise ValueError("RaggedTensor to_string requires that rt.shape.rank "
897 "is not None.")
898 # Convert all elements of `rt` to strings.
899 if rt.dtype == dtypes.string:
900 escaped = string_ops.regex_replace(rt.flat_values, r"(['\\])", r"\\\1")
901 str_t = rt.with_flat_values("'" + escaped + "'")
902 else:
903 str_t = rt.with_flat_values(string_ops.as_string(rt.flat_values))
905 return _ragged_tensor_to_string(str_t, summarize)
908def _ragged_tensor_to_string(string_tensor, summarize):
909 """Returns a scalar string tensor with the contents of `string_tensor`.
911 Args:
912 string_tensor: A potentially ragged tensor with dtype=string.
913 summarize: Include only the first and last `summarize` elements of each
914 dimension. If `-1` or `None`, then include all elements.
916 Returns:
917 A scalar string Tensor.
918 """
919 if string_tensor.shape.rank == 1:
920 pieces = string_tensor
921 else:
922 pieces = map_fn_lib.map_fn(
923 lambda s: _ragged_tensor_to_string(s, summarize),
924 string_tensor,
925 fn_output_signature=tensor_spec.TensorSpec(None, dtypes.string))
926 if summarize not in (-1, None):
927 pieces = cond.cond(
928 _nrows(string_tensor) <= 2 * summarize,
929 lambda: pieces,
930 lambda: array_ops.concat( # pylint: disable=g-long-lambda
931 [pieces[:summarize], ["..."], pieces[-summarize:]],
932 axis=0))
933 return "[" + string_ops.reduce_join(pieces, separator=", ") + "]"
936def _nrows(tensor, out_type=dtypes.int32):
937 if isinstance(tensor, ragged_tensor.RaggedTensor):
938 return tensor.nrows(out_type=out_type)
939 else:
940 return array_ops.shape(tensor, out_type=out_type)[0]
943@dispatch.dispatch_for_api(string_ops.string_join)
944def string_join(inputs: typing.List[ragged_tensor.RaggedOrDense],
945 separator="",
946 name=None):
947 """RaggedTensor implementation for tf.strings.join."""
948 if len(inputs) < 0:
949 raise ValueError("tf.strings.join: expected at least one input.")
950 with ops.name_scope(name, "RaggedStringJoin", inputs):
951 return ragged_functional_ops.map_flat_values(string_ops.string_join, inputs,
952 separator)