Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/preprocessing/text_vectorization.py: 29%
155 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras text vectorization preprocessing layer."""
18import numpy as np
19import tensorflow.compat.v2 as tf
21from keras.src import backend
22from keras.src.engine import base_preprocessing_layer
23from keras.src.layers.preprocessing import preprocessing_utils as utils
24from keras.src.layers.preprocessing import string_lookup
25from keras.src.saving.legacy.saved_model import layer_serialization
26from keras.src.saving.serialization_lib import deserialize_keras_object
27from keras.src.utils import layer_utils
28from keras.src.utils import tf_utils
30# isort: off
31from tensorflow.python.util.tf_export import keras_export
33LOWER_AND_STRIP_PUNCTUATION = "lower_and_strip_punctuation"
34STRIP_PUNCTUATION = "strip_punctuation"
35LOWER = "lower"
37WHITESPACE = "whitespace"
38CHARACTER = "character"
40TF_IDF = utils.TF_IDF
41INT = utils.INT
42MULTI_HOT = utils.MULTI_HOT
43COUNT = utils.COUNT
45# This is an explicit regex of all the tokens that will be stripped if
46# LOWER_AND_STRIP_PUNCTUATION is set. If an application requires other
47# stripping, a Callable should be passed into the 'standardize' arg.
48DEFAULT_STRIP_REGEX = r'[!"#$%&()\*\+,-\./:;<=>?@\[\\\]^_`{|}~\']'
51@keras_export(
52 "keras.layers.TextVectorization",
53 "keras.layers.experimental.preprocessing.TextVectorization",
54 v1=[],
55)
56class TextVectorization(base_preprocessing_layer.PreprocessingLayer):
57 """A preprocessing layer which maps text features to integer sequences.
59 This layer has basic options for managing text in a Keras model. It
60 transforms a batch of strings (one example = one string) into either a list
61 of token indices (one example = 1D tensor of integer token indices) or a
62 dense representation (one example = 1D tensor of float values representing
63 data about the example's tokens). This layer is meant to handle natural
64 language inputs. To handle simple string inputs (categorical strings or
65 pre-tokenized strings) see `tf.keras.layers.StringLookup`.
67 The vocabulary for the layer must be either supplied on construction or
68 learned via `adapt()`. When this layer is adapted, it will analyze the
69 dataset, determine the frequency of individual string values, and create a
70 vocabulary from them. This vocabulary can have unlimited size or be capped,
71 depending on the configuration options for this layer; if there are more
72 unique values in the input than the maximum vocabulary size, the most
73 frequent terms will be used to create the vocabulary.
75 The processing of each example contains the following steps:
77 1. Standardize each example (usually lowercasing + punctuation stripping)
78 2. Split each example into substrings (usually words)
79 3. Recombine substrings into tokens (usually ngrams)
80 4. Index tokens (associate a unique int value with each token)
81 5. Transform each example using this index, either into a vector of ints or
82 a dense float vector.
84 Some notes on passing callables to customize splitting and normalization for
85 this layer:
87 1. Any callable can be passed to this Layer, but if you want to serialize
88 this object you should only pass functions that are registered Keras
89 serializables (see `tf.keras.saving.register_keras_serializable` for more
90 details).
91 2. When using a custom callable for `standardize`, the data received
92 by the callable will be exactly as passed to this layer. The callable
93 should return a tensor of the same shape as the input.
94 3. When using a custom callable for `split`, the data received by the
95 callable will have the 1st dimension squeezed out - instead of
96 `[["string to split"], ["another string to split"]]`, the Callable will
97 see `["string to split", "another string to split"]`. The callable should
98 return a Tensor with the first dimension containing the split tokens -
99 in this example, we should see something like `[["string", "to",
100 "split"], ["another", "string", "to", "split"]]`. This makes the callable
101 site natively compatible with `tf.strings.split()`.
103 For an overview and full list of preprocessing layers, see the preprocessing
104 [guide](https://www.tensorflow.org/guide/keras/preprocessing_layers).
106 Args:
107 max_tokens: Maximum size of the vocabulary for this layer. This should
108 only be specified when adapting a vocabulary or when setting
109 `pad_to_max_tokens=True`. Note that this vocabulary
110 contains 1 OOV token, so the effective number of tokens is
111 `(max_tokens - 1 - (1 if output_mode == "int" else 0))`.
112 standardize: Optional specification for standardization to apply to the
113 input text. Values can be:
114 - `None`: No standardization.
115 - `"lower_and_strip_punctuation"`: Text will be lowercased and all
116 punctuation removed.
117 - `"lower"`: Text will be lowercased.
118 - `"strip_punctuation"`: All punctuation will be removed.
119 - Callable: Inputs will passed to the callable function, which should
120 be standardized and returned.
121 split: Optional specification for splitting the input text. Values can be:
122 - `None`: No splitting.
123 - `"whitespace"`: Split on whitespace.
124 - `"character"`: Split on each unicode character.
125 - Callable: Standardized inputs will passed to the callable function,
126 which should be split and returned.
127 ngrams: Optional specification for ngrams to create from the
128 possibly-split input text. Values can be None, an integer or tuple of
129 integers; passing an integer will create ngrams up to that integer, and
130 passing a tuple of integers will create ngrams for the specified values
131 in the tuple. Passing None means that no ngrams will be created.
132 output_mode: Optional specification for the output of the layer. Values
133 can be `"int"`, `"multi_hot"`, `"count"` or `"tf_idf"`, configuring the
134 layer as follows:
135 - `"int"`: Outputs integer indices, one integer index per split string
136 token. When `output_mode == "int"`, 0 is reserved for masked
137 locations; this reduces the vocab size to
138 `max_tokens - 2` instead of `max_tokens - 1`.
139 - `"multi_hot"`: Outputs a single int array per batch, of either
140 vocab_size or max_tokens size, containing 1s in all elements where
141 the token mapped to that index exists at least once in the batch
142 item.
143 - `"count"`: Like `"multi_hot"`, but the int array contains a count of
144 the number of times the token at that index appeared in the
145 batch item.
146 - `"tf_idf"`: Like `"multi_hot"`, but the TF-IDF algorithm is applied
147 to find the value in each token slot.
148 For `"int"` output, any shape of input and output is supported. For all
149 other output modes, currently only rank 1 inputs (and rank 2 outputs
150 after splitting) are supported.
151 output_sequence_length: Only valid in INT mode. If set, the output will
152 have its time dimension padded or truncated to exactly
153 `output_sequence_length` values, resulting in a tensor of shape
154 `(batch_size, output_sequence_length)` regardless of how many tokens
155 resulted from the splitting step. Defaults to `None`.
156 pad_to_max_tokens: Only valid in `"multi_hot"`, `"count"`, and `"tf_idf"`
157 modes. If True, the output will have its feature axis padded to
158 `max_tokens` even if the number of unique tokens in the vocabulary is
159 less than max_tokens, resulting in a tensor of shape `(batch_size,
160 max_tokens)` regardless of vocabulary size. Defaults to `False`.
161 vocabulary: Optional. Either an array of strings or a string path to a
162 text file. If passing an array, can pass a tuple, list, 1D numpy array,
163 or 1D tensor containing the string vocabulary terms. If passing a file
164 path, the file should contain one line per term in the vocabulary. If
165 this argument is set, there is no need to `adapt()` the layer.
166 idf_weights: Only valid when `output_mode` is `"tf_idf"`. A tuple, list,
167 1D numpy array, or 1D tensor of the same length as the vocabulary,
168 containing the floating point inverse document frequency weights, which
169 will be multiplied by per sample term counts for the final `tf_idf`
170 weight. If the `vocabulary` argument is set, and `output_mode` is
171 `"tf_idf"`, this argument must be supplied.
172 ragged: Boolean. Only applicable to `"int"` output mode. If True, returns
173 a `RaggedTensor` instead of a dense `Tensor`, where each sequence may
174 have a different length after string splitting. Defaults to `False`.
175 sparse: Boolean. Only applicable to `"multi_hot"`, `"count"`, and
176 `"tf_idf"` output modes. If True, returns a `SparseTensor` instead of a
177 dense `Tensor`. Defaults to `False`.
178 encoding: Optional. The text encoding to use to interpret the input
179 strings. Defaults to `"utf-8"`.
181 Example:
183 This example instantiates a `TextVectorization` layer that lowercases text,
184 splits on whitespace, strips punctuation, and outputs integer vocab indices.
186 >>> text_dataset = tf.data.Dataset.from_tensor_slices(["foo", "bar", "baz"])
187 >>> max_features = 5000 # Maximum vocab size.
188 >>> max_len = 4 # Sequence length to pad the outputs to.
189 >>>
190 >>> # Create the layer.
191 >>> vectorize_layer = tf.keras.layers.TextVectorization(
192 ... max_tokens=max_features,
193 ... output_mode='int',
194 ... output_sequence_length=max_len)
195 >>>
196 >>> # Now that the vocab layer has been created, call `adapt` on the
197 >>> # text-only dataset to create the vocabulary. You don't have to batch,
198 >>> # but for large datasets this means we're not keeping spare copies of
199 >>> # the dataset.
200 >>> vectorize_layer.adapt(text_dataset.batch(64))
201 >>>
202 >>> # Create the model that uses the vectorize text layer
203 >>> model = tf.keras.models.Sequential()
204 >>>
205 >>> # Start by creating an explicit input layer. It needs to have a shape of
206 >>> # (1,) (because we need to guarantee that there is exactly one string
207 >>> # input per batch), and the dtype needs to be 'string'.
208 >>> model.add(tf.keras.Input(shape=(1,), dtype=tf.string))
209 >>>
210 >>> # The first layer in our model is the vectorization layer. After this
211 >>> # layer, we have a tensor of shape (batch_size, max_len) containing
212 >>> # vocab indices.
213 >>> model.add(vectorize_layer)
214 >>>
215 >>> # Now, the model can map strings to integers, and you can add an
216 >>> # embedding layer to map these integers to learned embeddings.
217 >>> input_data = [["foo qux bar"], ["qux baz"]]
218 >>> model.predict(input_data)
219 array([[2, 1, 4, 0],
220 [1, 3, 0, 0]])
222 Example:
224 This example instantiates a `TextVectorization` layer by passing a list
225 of vocabulary terms to the layer's `__init__()` method.
227 >>> vocab_data = ["earth", "wind", "and", "fire"]
228 >>> max_len = 4 # Sequence length to pad the outputs to.
229 >>>
230 >>> # Create the layer, passing the vocab directly. You can also pass the
231 >>> # vocabulary arg a path to a file containing one vocabulary word per
232 >>> # line.
233 >>> vectorize_layer = tf.keras.layers.TextVectorization(
234 ... max_tokens=max_features,
235 ... output_mode='int',
236 ... output_sequence_length=max_len,
237 ... vocabulary=vocab_data)
238 >>>
239 >>> # Because we've passed the vocabulary directly, we don't need to adapt
240 >>> # the layer - the vocabulary is already set. The vocabulary contains the
241 >>> # padding token ('') and OOV token ('[UNK]') as well as the passed
242 >>> # tokens.
243 >>> vectorize_layer.get_vocabulary()
244 ['', '[UNK]', 'earth', 'wind', 'and', 'fire']
246 """
248 def __init__(
249 self,
250 max_tokens=None,
251 standardize="lower_and_strip_punctuation",
252 split="whitespace",
253 ngrams=None,
254 output_mode="int",
255 output_sequence_length=None,
256 pad_to_max_tokens=False,
257 vocabulary=None,
258 idf_weights=None,
259 sparse=False,
260 ragged=False,
261 encoding="utf-8",
262 **kwargs,
263 ):
265 # This layer only applies to string processing, and so should only have
266 # a dtype of 'string'.
267 if "dtype" in kwargs and kwargs["dtype"] != tf.string:
268 raise ValueError(
269 "`TextVectorization` may only have a dtype of string. "
270 f"Received dtype: {kwargs['dtype']}."
271 )
272 elif "dtype" not in kwargs:
273 kwargs["dtype"] = tf.string
275 # 'standardize' must be one of
276 # (None, LOWER_AND_STRIP_PUNCTUATION, LOWER, STRIP_PUNCTUATION,
277 # callable)
278 layer_utils.validate_string_arg(
279 standardize,
280 allowable_strings=(
281 LOWER_AND_STRIP_PUNCTUATION,
282 LOWER,
283 STRIP_PUNCTUATION,
284 ),
285 layer_name="TextVectorization",
286 arg_name="standardize",
287 allow_none=True,
288 allow_callables=True,
289 )
291 # 'split' must be one of (None, WHITESPACE, CHARACTER, callable)
292 layer_utils.validate_string_arg(
293 split,
294 allowable_strings=(WHITESPACE, CHARACTER),
295 layer_name="TextVectorization",
296 arg_name="split",
297 allow_none=True,
298 allow_callables=True,
299 )
301 # Support deprecated names for output_modes.
302 if output_mode == "binary":
303 output_mode = MULTI_HOT
304 if output_mode == "tf-idf":
305 output_mode = TF_IDF
306 # 'output_mode' must be one of (None, INT, COUNT, MULTI_HOT, TF_IDF)
307 layer_utils.validate_string_arg(
308 output_mode,
309 allowable_strings=(INT, COUNT, MULTI_HOT, TF_IDF),
310 layer_name="TextVectorization",
311 arg_name="output_mode",
312 allow_none=True,
313 )
315 # 'ngrams' must be one of (None, int, tuple(int))
316 if not (
317 ngrams is None
318 or isinstance(ngrams, int)
319 or isinstance(ngrams, tuple)
320 and all(isinstance(item, int) for item in ngrams)
321 ):
322 raise ValueError(
323 "`ngrams` must be None, an integer, or a tuple of "
324 f"integers. Received: ngrams={ngrams}"
325 )
327 # 'output_sequence_length' must be one of (None, int) and is only
328 # set if output_mode is INT.
329 if output_mode == INT and not (
330 isinstance(output_sequence_length, int)
331 or (output_sequence_length is None)
332 ):
333 raise ValueError(
334 "`output_sequence_length` must be either None or an "
335 "integer when `output_mode` is 'int'. Received: "
336 f"output_sequence_length={output_sequence_length}"
337 )
339 if output_mode != INT and output_sequence_length is not None:
340 raise ValueError(
341 "`output_sequence_length` must not be set if `output_mode` is "
342 "not 'int'. "
343 f"Received output_sequence_length={output_sequence_length}."
344 )
346 if ragged and output_mode != INT:
347 raise ValueError(
348 "`ragged` must not be true if `output_mode` is "
349 f"`'int'`. Received: ragged={ragged} and "
350 f"output_mode={output_mode}"
351 )
353 if ragged and output_sequence_length is not None:
354 raise ValueError(
355 "`output_sequence_length` must not be set if ragged "
356 f"is True. Received: ragged={ragged} and "
357 f"output_sequence_length={output_sequence_length}"
358 )
360 self._max_tokens = max_tokens
361 self._standardize = standardize
362 self._split = split
363 self._ngrams_arg = ngrams
364 if isinstance(ngrams, int):
365 self._ngrams = tuple(range(1, ngrams + 1))
366 else:
367 self._ngrams = ngrams
368 self._ragged = ragged
370 self._output_mode = output_mode
371 self._output_sequence_length = output_sequence_length
372 self._encoding = encoding
374 # VocabularySavedModelSaver will clear the config vocabulary to restore
375 # the lookup table ops directly. We persist this hidden option to
376 # persist the fact that we have have a non-adaptable layer with a
377 # manually set vocab.
378 self._has_input_vocabulary = kwargs.pop(
379 "has_input_vocabulary", (vocabulary is not None)
380 )
382 vocabulary_size = kwargs.pop("vocabulary_size", None)
384 super().__init__(**kwargs)
385 base_preprocessing_layer.keras_kpl_gauge.get_cell(
386 "TextVectorization"
387 ).set(True)
389 self._lookup_layer = string_lookup.StringLookup(
390 max_tokens=max_tokens,
391 vocabulary=vocabulary,
392 idf_weights=idf_weights,
393 pad_to_max_tokens=pad_to_max_tokens,
394 mask_token="",
395 output_mode=output_mode if output_mode is not None else INT,
396 sparse=sparse,
397 has_input_vocabulary=self._has_input_vocabulary,
398 encoding=encoding,
399 vocabulary_size=vocabulary_size,
400 )
402 def compute_output_shape(self, input_shape):
403 if self._output_mode == INT:
404 return tf.TensorShape(
405 [input_shape[0], self._output_sequence_length]
406 )
408 if self._split is None:
409 if len(input_shape) <= 1:
410 input_shape = tuple(input_shape) + (1,)
411 else:
412 input_shape = tuple(input_shape) + (None,)
413 return self._lookup_layer.compute_output_shape(input_shape)
415 def compute_output_signature(self, input_spec):
416 output_shape = self.compute_output_shape(input_spec.shape.as_list())
417 output_dtype = (
418 tf.int64 if self._output_mode == INT else backend.floatx()
419 )
420 return tf.TensorSpec(shape=output_shape, dtype=output_dtype)
422 # We override this method solely to generate a docstring.
423 def adapt(self, data, batch_size=None, steps=None):
424 """Computes a vocabulary of string terms from tokens in a dataset.
426 Calling `adapt()` on a `TextVectorization` layer is an alternative to
427 passing in a precomputed vocabulary on construction via the `vocabulary`
428 argument. A `TextVectorization` layer should always be either adapted
429 over a dataset or supplied with a vocabulary.
431 During `adapt()`, the layer will build a vocabulary of all string tokens
432 seen in the dataset, sorted by occurrence count, with ties broken by
433 sort order of the tokens (high to low). At the end of `adapt()`, if
434 `max_tokens` is set, the vocabulary wil be truncated to `max_tokens`
435 size. For example, adapting a layer with `max_tokens=1000` will compute
436 the 1000 most frequent tokens occurring in the input dataset. If
437 `output_mode='tf-idf'`, `adapt()` will also learn the document
438 frequencies of each token in the input dataset.
440 In order to make `TextVectorization` efficient in any distribution
441 context, the vocabulary is kept static with respect to any compiled
442 `tf.Graph`s that call the layer. As a consequence, if the layer is
443 adapted a second time, any models using the layer should be re-compiled.
444 For more information see
445 `tf.keras.layers.experimental.preprocessing.PreprocessingLayer.adapt`.
447 `adapt()` is meant only as a single machine utility to compute layer
448 state. To analyze a dataset that cannot fit on a single machine, see
449 [Tensorflow Transform](
450 https://www.tensorflow.org/tfx/transform/get_started) for a
451 multi-machine, map-reduce solution.
453 Arguments:
454 data: The data to train on. It can be passed either as a
455 `tf.data.Dataset`, or as a numpy array.
456 batch_size: Integer or `None`.
457 Number of samples per state update.
458 If unspecified, `batch_size` will default to 32.
459 Do not specify the `batch_size` if your data is in the
460 form of datasets, generators, or `keras.utils.Sequence` instances
461 (since they generate batches).
462 steps: Integer or `None`.
463 Total number of steps (batches of samples)
464 When training with input tensors such as
465 TensorFlow data tensors, the default `None` is equal to
466 the number of samples in your dataset divided by
467 the batch size, or 1 if that cannot be determined. If x is a
468 `tf.data` dataset, and 'steps' is None, the epoch will run until
469 the input dataset is exhausted. When passing an infinitely
470 repeating dataset, you must specify the `steps` argument. This
471 argument is not supported with array inputs.
472 """
473 super().adapt(data, batch_size=batch_size, steps=steps)
475 def update_state(self, data):
476 self._lookup_layer.update_state(self._preprocess(data))
478 def finalize_state(self):
479 self._lookup_layer.finalize_state()
481 def reset_state(self):
482 self._lookup_layer.reset_state()
484 def get_vocabulary(self, include_special_tokens=True):
485 """Returns the current vocabulary of the layer.
487 Args:
488 include_special_tokens: If True, the returned vocabulary will include
489 the padding and OOV tokens, and a term's index in the vocabulary
490 will equal the term's index when calling the layer. If False, the
491 returned vocabulary will not include any padding or OOV tokens.
492 """
493 return self._lookup_layer.get_vocabulary(include_special_tokens)
495 def vocabulary_size(self):
496 """Gets the current size of the layer's vocabulary.
498 Returns:
499 The integer size of the vocabulary, including optional mask and
500 OOV indices.
501 """
502 return self._lookup_layer.vocabulary_size()
504 def get_config(self):
505 config = {
506 "max_tokens": self._lookup_layer.max_tokens,
507 "standardize": self._standardize,
508 "split": self._split,
509 "ngrams": self._ngrams_arg,
510 "output_mode": self._output_mode,
511 "output_sequence_length": self._output_sequence_length,
512 "pad_to_max_tokens": self._lookup_layer.pad_to_max_tokens,
513 "sparse": self._lookup_layer.sparse,
514 "ragged": self._ragged,
515 "vocabulary": utils.listify_tensors(
516 self._lookup_layer.input_vocabulary
517 ),
518 "idf_weights": utils.listify_tensors(
519 self._lookup_layer.input_idf_weights
520 ),
521 "encoding": self._encoding,
522 "vocabulary_size": self.vocabulary_size(),
523 }
524 base_config = super().get_config()
525 return dict(list(base_config.items()) + list(config.items()))
527 @classmethod
528 def from_config(cls, config):
529 if config["standardize"] not in (
530 LOWER_AND_STRIP_PUNCTUATION,
531 LOWER,
532 STRIP_PUNCTUATION,
533 ):
534 config["standardize"] = deserialize_keras_object(
535 config["standardize"]
536 )
537 if config["split"] not in (WHITESPACE, CHARACTER):
538 config["split"] = deserialize_keras_object(config["split"])
539 return cls(**config)
541 def set_vocabulary(self, vocabulary, idf_weights=None):
542 """Sets vocabulary (and optionally document frequency) for this layer.
544 This method sets the vocabulary and idf weights for this layer directly,
545 instead of analyzing a dataset through 'adapt'. It should be used
546 whenever the vocab (and optionally document frequency) information is
547 already known. If vocabulary data is already present in the layer, this
548 method will replace it.
550 Args:
551 vocabulary: Either an array or a string path to a text file. If
552 passing an array, can pass a tuple, list, 1D numpy array, or 1D
553 tensor containing the vocbulary terms. If passing a file path, the
554 file should contain one line per term in the vocabulary.
555 idf_weights: A tuple, list, 1D numpy array, or 1D tensor of inverse
556 document frequency weights with equal length to vocabulary. Must be
557 set if `output_mode` is `"tf_idf"`. Should not be set otherwise.
559 Raises:
560 ValueError: If there are too many inputs, the inputs do not match, or
561 input data is missing.
562 RuntimeError: If the vocabulary cannot be set when this function is
563 called. This happens when `"multi_hot"`, `"count"`, and "tf_idf"
564 modes, if `pad_to_max_tokens` is False and the layer itself has
565 already been called.
566 """
567 self._lookup_layer.set_vocabulary(vocabulary, idf_weights=idf_weights)
569 def _preprocess(self, inputs):
570 inputs = utils.ensure_tensor(inputs, dtype=tf.string)
571 if self._standardize in (LOWER, LOWER_AND_STRIP_PUNCTUATION):
572 inputs = tf.strings.lower(inputs)
573 if self._standardize in (
574 STRIP_PUNCTUATION,
575 LOWER_AND_STRIP_PUNCTUATION,
576 ):
577 inputs = tf.strings.regex_replace(inputs, DEFAULT_STRIP_REGEX, "")
578 if callable(self._standardize):
579 inputs = self._standardize(inputs)
581 if self._split is not None:
582 # If we are splitting, we validate that the 1st axis is of dimension
583 # 1 and so can be squeezed out. We do this here instead of after
584 # splitting for performance reasons - it's more expensive to squeeze
585 # a ragged tensor.
586 if inputs.shape.rank > 1:
587 if inputs.shape[-1] != 1:
588 raise ValueError(
589 "When using `TextVectorization` to tokenize strings, "
590 "the input rank must be 1 or the last shape dimension "
591 f"must be 1. Received: inputs.shape={inputs.shape} "
592 f"with rank={inputs.shape.rank}"
593 )
594 else:
595 inputs = tf.squeeze(inputs, axis=-1)
596 if self._split == WHITESPACE:
597 # This treats multiple whitespaces as one whitespace, and strips
598 # leading and trailing whitespace.
599 inputs = tf.strings.split(inputs)
600 elif self._split == CHARACTER:
601 inputs = tf.strings.unicode_split(inputs, "UTF-8")
602 elif callable(self._split):
603 inputs = self._split(inputs)
604 else:
605 raise ValueError(
606 "%s is not a supported splitting."
607 "TextVectorization supports the following options "
608 "for `split`: None, 'whitespace', or a Callable."
609 % self._split
610 )
612 # Note that 'inputs' here can be either ragged or dense depending on the
613 # configuration choices for this Layer. The strings.ngrams op, however,
614 # does support both ragged and dense inputs.
615 if self._ngrams is not None:
616 inputs = tf.strings.ngrams(
617 inputs, ngram_width=self._ngrams, separator=" "
618 )
620 return inputs
622 def call(self, inputs):
623 if isinstance(inputs, (list, tuple, np.ndarray)):
624 inputs = tf.convert_to_tensor(inputs)
626 inputs = self._preprocess(inputs)
628 # If we're not doing any output processing, return right away.
629 if self._output_mode is None:
630 return inputs
632 lookup_data = self._lookup_layer(inputs)
634 # For any non-int output, we can return directly from the underlying
635 # layer.
636 if self._output_mode != INT:
637 return lookup_data
639 if self._ragged:
640 return lookup_data
642 # If we have a ragged tensor, we can pad during the conversion to dense.
643 if tf_utils.is_ragged(lookup_data):
644 shape = lookup_data.shape.as_list()
645 # If output sequence length is None, to_tensor will pad the last
646 # dimension to the bounding shape of the ragged dimension.
647 shape[-1] = self._output_sequence_length
648 return lookup_data.to_tensor(default_value=0, shape=shape)
650 # If we have a dense tensor, we need to pad/trim directly.
651 if self._output_sequence_length is not None:
652 # Maybe trim the output.
653 lookup_data = lookup_data[..., : self._output_sequence_length]
655 # Maybe pad the output. We need to be careful to use dynamic shape
656 # here as required_space_to_batch_paddings requires a fully known
657 # shape.
658 shape = tf.shape(lookup_data)
659 padded_shape = tf.concat(
660 (shape[:-1], [self._output_sequence_length]), 0
661 )
662 padding, _ = tf.required_space_to_batch_paddings(
663 shape, padded_shape
664 )
665 return tf.pad(lookup_data, padding)
667 return lookup_data
669 @property
670 def _trackable_saved_model_saver(self):
671 return layer_serialization.VocabularySavedModelSaver(self)
673 def save_own_variables(self, store):
674 self._lookup_layer.save_own_variables(store)
676 def load_own_variables(self, store):
677 self._lookup_layer.load_own_variables(store)
679 def save_assets(self, dir_path):
680 self._lookup_layer.save_assets(dir_path)
682 def load_assets(self, dir_path):
683 self._lookup_layer.load_assets(dir_path)