Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/keras/src/layers/preprocessing/index_lookup.py: 16%
404 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Keras index lookup preprocessing layer."""
18import collections
20import numpy as np
21import tensorflow.compat.v2 as tf
23from keras.src import backend
24from keras.src.engine import base_layer_utils
25from keras.src.engine import base_preprocessing_layer
26from keras.src.layers.preprocessing import preprocessing_utils as utils
27from keras.src.saving.legacy.saved_model import layer_serialization
28from keras.src.utils import layer_utils
29from keras.src.utils import tf_utils
31# isort: off
32from tensorflow.python.platform import tf_logging as logging
34INT = utils.INT
35MULTI_HOT = utils.MULTI_HOT
36ONE_HOT = utils.ONE_HOT
37COUNT = utils.COUNT
38TF_IDF = utils.TF_IDF
40_VOCAB_NAME = "vocab"
41_IDF_WEIGHTS_NAME = "idf_weights"
44class NullInitializer(tf.lookup.KeyValueTensorInitializer):
45 """A placeholder initializer for restoring this layer from a SavedModel."""
47 def __init__(self, key_dtype, value_dtype):
48 """Construct a table initializer object.
50 Args:
51 key_dtype: Type of the table keys.
52 value_dtype: Type of the table values.
53 """
54 self._key_dtype = key_dtype
55 self._value_dtype = value_dtype
57 @property
58 def key_dtype(self):
59 """The expected table key dtype."""
60 return self._key_dtype
62 @property
63 def value_dtype(self):
64 """The expected table value dtype."""
65 return self._value_dtype
67 def initialize(self, table):
68 """Returns the table initialization op."""
69 pass
72class VocabWeightHandler(base_layer_utils.TrackableWeightHandler):
73 """Adds the vocabulary as a layer weight during serialization."""
75 def __init__(self, lookup_layer):
76 # Note that this class doesn't call super().__init__() in order to
77 # have customized behavior. The fileds like '_dtype' and
78 # '_distribute_strategy' are required by the parent class, as well as
79 # tf.distribute. See `strategy.extended.variable_created_in_scope`
80 self._layer = lookup_layer
81 self._dtype = lookup_layer.vocabulary_dtype
82 self._distribute_strategy = tf.distribute.get_strategy()
84 @property
85 def num_tensors(self):
86 return 1
88 def set_weights(self, weights):
89 tokens = tf.convert_to_tensor(weights[0], self._dtype)
90 self._layer.lookup_table = self._layer._lookup_table_from_tokens(tokens)
92 def get_tensors(self):
93 # Just save the non-config part of the vocab (no special tokens).
94 tokens = self._layer.get_vocabulary(include_special_tokens=False)
95 tokens = tf.convert_to_tensor(tokens, self._dtype)
96 return [tokens]
99class IndexLookup(base_preprocessing_layer.PreprocessingLayer):
100 """Maps values from a vocabulary to integer indices.
102 This layer translates a set of arbitrary hashables into an integer output
103 via a table-based lookup, with optional out-of-vocabulary handling. This is
104 the basis layer for both IntegerLookup and StringLookup; it holds the common
105 logic but is not intended to be exported as part of the Keras API.
107 Args:
108 max_tokens: The maximum size of the vocabulary for this layer. If None,
109 there is no cap on the size of the vocabulary. Note that this size
110 includes the OOV and mask tokens.
111 num_oov_indices: The number of out-of-vocabulary tokens to use. If this
112 value is more than 1, OOV inputs are hashed to determine their OOV
113 value. If this value is 0, OOV inputs will cause an error when calling
114 the layer.
115 mask_token: A token that represents masked inputs. When `output_mode` is
116 `"int"`, the token is included in vocabulary and mapped to index 0. In
117 other output modes, the token will not appear in the vocabulary and
118 instances of the mask token in the input will be dropped. If set to
119 None, no mask term will be added.
120 oov_token: Only used when `invert` is True. The token to return for OOV
121 indices.
122 vocabulary: Optional. Either an array or a string path to a text file. If
123 passing an array, can pass a tuple, list, 1D numpy array, or 1D tensor
124 containing the vocbulary terms. If passing a file path, the file should
125 contain one line per term in the vocabulary. If this argument is set,
126 there is no need to `adapt` the layer.
127 vocabulary_dtype: The dtype of the vocabulary terms. For example,
128 `"int64"` or `"string"`.
129 idf_weights: Only valid when `output_mode` is `"tf_idf"`. A tuple, list,
130 1D numpy array, or 1D tensor or the same length as the vocabulary,
131 containing the floating point inverse document frequency weights, which
132 will be multiplied by per sample term counts for the final `tf_idf`
133 weight. If the `vocabulary` argument is set, and `output_mode` is
134 `"tf_idf"`, this argument must be supplied.
135 invert: Only valid when `output_mode` is `"int"`. If True, this layer will
136 map indices to vocabulary items instead of mapping vocabulary items to
137 indices. Defaults to `False`.
138 output_mode: Specification for the output of the layer. Values can be
139 `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or `"tf_idf"`
140 configuring the layer as follows:
141 - `"int"`: Return the raw integer indices of the input tokens.
142 - `"one_hot"`: Encodes each individual element in the input into an
143 array the same size as the vocabulary, containing a 1 at the element
144 index. If the last dimension is size 1, will encode on that
145 dimension. If the last dimension is not size 1, will append a new
146 dimension for the encoded output.
147 - `"multi_hot"`: Encodes each sample in the input into a single array
148 the same size as the vocabulary, containing a 1 for each vocabulary
149 term present in the sample. Treats the last dimension as the sample
150 dimension, if input shape is (..., sample_length), output shape will
151 be (..., num_tokens).
152 - `"count"`: As `"multi_hot"`, but the int array contains a count of
153 the number of times the token at that index appeared in the sample.
154 - `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is applied to
155 find the value in each token slot.
156 Defaults to `"int"`.
157 pad_to_max_tokens: Only valid when `output_mode` is `"multi_hot"`,
158 `"count"`, or `"tf_idf"`. If True, the output will have its feature axis
159 padded to `max_tokens` even if the number of unique tokens in the
160 vocabulary is less than max_tokens, resulting in a tensor of shape
161 [batch_size, max_tokens] regardless of vocabulary size. Defaults to
162 False.
163 sparse: Boolean. Only applicable to `"one_hot"`, `"multi_hot"`, `"count"`
164 and `"tf-idf"` output modes. If True, returns a `SparseTensor` instead
165 of a dense `Tensor`. Defaults to `False`.
166 """
168 def __init__(
169 self,
170 max_tokens,
171 num_oov_indices,
172 mask_token,
173 oov_token,
174 vocabulary_dtype,
175 vocabulary=None,
176 idf_weights=None,
177 invert=False,
178 output_mode="int",
179 sparse=False,
180 pad_to_max_tokens=False,
181 **kwargs,
182 ):
183 # If max_tokens is set, the value must be greater than 1 - otherwise we
184 # are creating a 0-element vocab, which doesn't make sense.
185 if max_tokens is not None and max_tokens <= 1:
186 raise ValueError(
187 "If set, `max_tokens` must be greater than 1. "
188 f"Received: max_tokens={max_tokens}"
189 )
191 if pad_to_max_tokens and max_tokens is None:
192 raise ValueError(
193 "If pad_to_max_tokens is True, must set `max_tokens`. "
194 f"Received: max_tokens={max_tokens}"
195 )
197 if num_oov_indices < 0:
198 raise ValueError(
199 "`num_oov_indices` must be greater than or equal to 0. "
200 f"Received: num_oov_indices={num_oov_indices}"
201 )
203 # Support deprecated names for output_modes.
204 if output_mode == "binary":
205 output_mode = MULTI_HOT
206 if output_mode == "tf-idf":
207 output_mode = TF_IDF
208 # 'output_mode' must be one of (INT, ONE_HOT, MULTI_HOT, COUNT, TF_IDF)
209 layer_utils.validate_string_arg(
210 output_mode,
211 allowable_strings=(INT, ONE_HOT, MULTI_HOT, COUNT, TF_IDF),
212 layer_name=self.__class__.__name__,
213 arg_name="output_mode",
214 )
216 if invert and output_mode != INT:
217 raise ValueError(
218 "`output_mode` must be `'int'` when `invert` is true. "
219 f"Received: output_mode={output_mode}"
220 )
222 if sparse and output_mode == INT:
223 raise ValueError(
224 "`sparse` may only be true if `output_mode` is "
225 "`'one_hot'`, `'multi_hot'`, `'count'` or `'tf_idf'`. "
226 f"Received: sparse={sparse} and "
227 f"output_mode={output_mode}"
228 )
230 if idf_weights is not None and output_mode != TF_IDF:
231 raise ValueError(
232 "`idf_weights` should only be set if `output_mode` is "
233 f"`'tf_idf'`. Received: idf_weights={idf_weights} and "
234 f"output_mode={output_mode}"
235 )
237 self.invert = invert
238 self.max_tokens = max_tokens
239 self.num_oov_indices = num_oov_indices
240 self.mask_token = mask_token
241 self.oov_token = oov_token
242 self.output_mode = output_mode
243 self.sparse = sparse
244 self.pad_to_max_tokens = pad_to_max_tokens
245 self.vocabulary_dtype = vocabulary_dtype
246 self._frozen_vocab_size = kwargs.pop("vocabulary_size", None)
248 self.input_vocabulary = vocabulary
249 self.input_idf_weights = idf_weights
250 # VocabularySavedModelSaver will clear the config vocabulary to restore
251 # the lookup table ops directly. We persist this hidden option to
252 # persist the fact that we have have a non-adaptable layer with a
253 # manually set vocab.
254 self._has_input_vocabulary = kwargs.pop(
255 "has_input_vocabulary", (vocabulary is not None)
256 )
258 # Drop deprecated config options.
259 kwargs.pop("has_static_table", None)
261 # By default, output int64 when output_mode='int' and floats otherwise.
262 if "dtype" not in kwargs:
263 kwargs["dtype"] = (
264 tf.int64 if output_mode == INT else backend.floatx()
265 )
267 super().__init__(**kwargs)
269 # Check dtype only after base layer parses it; dtype parsing is complex.
270 if (
271 output_mode == INT
272 and not tf.as_dtype(self.compute_dtype).is_integer
273 ):
274 input_dtype = kwargs["dtype"]
275 raise ValueError(
276 "When `output_mode='int'`, `dtype` should be an integer "
277 f"type. Received: dtype={input_dtype}"
278 )
280 if invert:
281 self._key_dtype = self.dtype if output_mode == INT else tf.int64
282 self._value_dtype = tf.as_dtype(self.vocabulary_dtype)
283 mask_key = 0
284 mask_value = mask_token
285 self._default_value = self.oov_token
286 else:
287 self._key_dtype = tf.as_dtype(self.vocabulary_dtype)
288 self._value_dtype = self.dtype if output_mode == INT else tf.int64
289 mask_key = mask_token
290 # Masks should map to 0 for int output and be dropped otherwise. Max
291 # ints will be dropped from the bincount op.
292 mask_value = 0 if self.output_mode == INT else self._value_dtype.max
293 if self.num_oov_indices == 0:
294 # If there are no OOV indices, we map OOV tokens to -1 and error
295 # out during call if we find a negative index.
296 self._default_value = -1
297 elif self.num_oov_indices == 1:
298 # If there is only one OOV index, we can set that index as the
299 # default value of the index_lookup table.
300 self._default_value = self._oov_start_index()
301 else:
302 # If we have multiple OOV values, we need to do a further
303 # hashing step; to make this easier, we set the OOV value to -1.
304 # (This lets us do a vectorized add and cast to boolean to
305 # determine locations where we need to do extra hashing.)
306 self._default_value = -1
307 if self.mask_token is not None:
308 self._mask_key = tf.convert_to_tensor(mask_key, self._key_dtype)
309 self._mask_value = tf.convert_to_tensor(
310 mask_value, self._value_dtype
311 )
313 if self.output_mode == TF_IDF:
314 self.idf_weights = tf.Variable(
315 [0] * self._token_start_index(),
316 shape=(None,),
317 dtype=self.compute_dtype,
318 trainable=False,
319 )
320 self.idf_weights_const = self.idf_weights.value()
322 if vocabulary is not None:
323 self.set_vocabulary(vocabulary, idf_weights)
324 else:
325 # When restoring from a keras SavedModel, the loading code will
326 # expect to find and restore a lookup_table attribute on the layer.
327 # This table needs to be uninitialized as a StaticHashTable cannot
328 # be initialized twice.
329 self.lookup_table = self._uninitialized_lookup_table()
331 # Only set up adapt state if we did not receive a vocab on construction.
332 if not self._has_input_vocabulary:
333 # Add custom weight handler to return the layer's vocab as a weight.
334 self._add_trackable(VocabWeightHandler(self), False)
335 # Set adapt state.
336 self.token_counts = tf.lookup.experimental.MutableHashTable(
337 key_dtype=vocabulary_dtype,
338 value_dtype=tf.int64,
339 default_value=0,
340 )
341 if self.output_mode == TF_IDF:
342 self.token_document_counts = (
343 tf.lookup.experimental.MutableHashTable(
344 key_dtype=vocabulary_dtype,
345 value_dtype=tf.int64,
346 default_value=0,
347 )
348 )
349 self.num_documents = tf.Variable(
350 0, dtype=tf.int64, trainable=False
351 )
353 def compute_output_shape(self, input_shape):
354 if self.output_mode == INT:
355 return input_shape
356 depth = (
357 self.max_tokens
358 if self.pad_to_max_tokens
359 else self._frozen_vocab_size
360 )
361 return tf.TensorShape([input_shape[0], depth])
363 def compute_output_signature(self, input_spec):
364 output_shape = self.compute_output_shape(input_spec.shape.as_list())
365 output_dtype = (
366 self.vocabulary_dtype if self.invert else self.compute_dtype
367 )
368 return tf.TensorSpec(shape=output_shape, dtype=output_dtype)
370 def get_vocabulary(self, include_special_tokens=True):
371 """Returns the current vocabulary of the layer.
373 Args:
374 include_special_tokens: If True, the returned vocabulary will include
375 mask and OOV tokens, and a term's index in the vocabulary will equal
376 the term's index when calling the layer. If False, the returned
377 vocabulary will not include any mask or OOV tokens.
378 """
379 # The lookup table data will not be sorted, so we will create a inverted
380 # lookup here, and use that to lookup a range of indices [0,
381 # vocab_size).
382 if self.lookup_table.size() == 0:
383 vocab, indices = [], []
384 else:
385 keys, values = self.lookup_table.export()
386 vocab, indices = (values, keys) if self.invert else (keys, values)
387 vocab, indices = (
388 self._tensor_vocab_to_numpy(vocab),
389 indices.numpy(),
390 )
391 lookup = collections.defaultdict(
392 lambda: self.oov_token, zip(indices, vocab)
393 )
394 vocab = [lookup[x] for x in range(self.vocabulary_size())]
395 if self.mask_token is not None and self.output_mode == INT:
396 vocab[0] = self.mask_token
397 if not include_special_tokens:
398 vocab = vocab[self._token_start_index() :]
399 return vocab
401 def vocabulary_size(self):
402 """Gets the current size of the layer's vocabulary.
404 Returns:
405 The integer size of the vocabulary, including optional mask and oov
406 indices.
407 """
408 if tf.executing_eagerly():
409 return (
410 int(self.lookup_table.size().numpy())
411 + self._token_start_index()
412 )
413 else:
414 return self.lookup_table.size() + self._token_start_index()
416 def vocab_size(self):
417 logging.warning("vocab_size is deprecated, please use vocabulary_size.")
418 return self.vocabulary_size()
420 def get_config(self):
421 config = {
422 "invert": self.invert,
423 "max_tokens": self.max_tokens,
424 "num_oov_indices": self.num_oov_indices,
425 "oov_token": self.oov_token,
426 "mask_token": self.mask_token,
427 "output_mode": self.output_mode,
428 "sparse": self.sparse,
429 "pad_to_max_tokens": self.pad_to_max_tokens,
430 "vocabulary_dtype": self.vocabulary_dtype,
431 "idf_weights": utils.listify_tensors(self.input_idf_weights),
432 "vocabulary": utils.listify_tensors(self.input_vocabulary),
433 "vocabulary_size": self._frozen_vocab_size,
434 }
435 base_config = super().get_config()
436 return dict(list(base_config.items()) + list(config.items()))
438 def _record_vocabulary_size(self):
439 self._ensure_vocab_size_unchanged()
440 with tf.init_scope():
441 self._frozen_vocab_size = self.vocabulary_size()
443 def set_vocabulary(self, vocabulary, idf_weights=None):
444 """Sets vocabulary (and optionally document frequency) for this layer.
446 This method sets the vocabulary and idf weights for this layer directly,
447 instead of analyzing a dataset through `adapt`. It should be used
448 whenever the vocab (and optionally document frequency) information is
449 already known. If vocabulary data is already present in the layer, this
450 method will replace it.
452 Args:
453 vocabulary: Either an array or a string path to a text file. If
454 passing an array, can pass a tuple, list, 1D numpy array, or 1D
455 tensor containing the vocbulary terms. If passing a file path, the
456 file should contain one line per term in the vocabulary.
457 idf_weights: A tuple, list, 1D numpy array, or 1D tensor of inverse
458 document frequency weights with equal length to vocabulary. Must be
459 set if `output_mode` is `"tf_idf"`. Should not be set otherwise.
461 Raises:
462 ValueError: If there are too many inputs, the inputs do not match, or
463 input data is missing.
464 RuntimeError: If the vocabulary cannot be set when this function is
465 called. This happens when `"multi_hot"`, `"count"`, and `"tf_idf"`
466 modes, if `pad_to_max_tokens` is False and the layer itself has
467 already been called.
468 RuntimeError: If a tensor vocabulary is passed outside of eager
469 execution.
470 """
471 if self.output_mode == TF_IDF:
472 if idf_weights is None:
473 raise ValueError(
474 "`idf_weights` must be set if output_mode is TF_IDF"
475 )
476 elif idf_weights is not None:
477 raise ValueError(
478 "`idf_weights` should only be set if output_mode is "
479 f"`'tf_idf'`. Received: output_mode={self.output_mode} "
480 f"and idf_weights={idf_weights}"
481 )
483 if isinstance(vocabulary, str):
484 if not tf.io.gfile.exists(vocabulary):
485 raise ValueError(
486 f"Vocabulary file {vocabulary} does not exist."
487 )
488 if self.output_mode == TF_IDF:
489 raise ValueError(
490 "output_mode `'tf_idf'` does not support loading a "
491 "vocabulary from file."
492 )
493 self.lookup_table = self._lookup_table_from_file(vocabulary)
494 self._record_vocabulary_size()
495 return
497 if not tf.executing_eagerly() and (
498 tf.is_tensor(vocabulary) or tf.is_tensor(idf_weights)
499 ):
500 raise RuntimeError(
501 "Cannot set a tensor vocabulary on {} layer {} when not "
502 "executing eagerly. Create this layer or call `set_vocabulary` "
503 "outside of any `tf.function`s and with eager execution "
504 "enabled.".format(self.__class__.__name__, self.name)
505 )
507 # TODO(mattdangerw): for better performance we should rewrite this
508 # entire function to operate on tensors and convert vocabulary to a
509 # tensor here.
510 if tf.is_tensor(vocabulary):
511 vocabulary = self._tensor_vocab_to_numpy(vocabulary)
512 elif isinstance(vocabulary, (list, tuple)):
513 vocabulary = np.array(vocabulary)
514 if tf.is_tensor(idf_weights):
515 idf_weights = idf_weights.numpy()
516 elif isinstance(idf_weights, (list, tuple)):
517 idf_weights = np.array(idf_weights)
519 if vocabulary.size == 0:
520 raise ValueError(
521 f"Cannot set an empty vocabulary, you passed {vocabulary}."
522 )
524 oov_start = self._oov_start_index()
525 token_start = self._token_start_index()
526 special_tokens = [self.mask_token] * oov_start + [
527 self.oov_token
528 ] * self.num_oov_indices
529 found_special_tokens = np.array_equal(
530 special_tokens, vocabulary[:token_start]
531 )
532 if found_special_tokens:
533 tokens = vocabulary[token_start:]
534 else:
535 tokens = vocabulary
537 repeated_tokens = self._find_repeated_tokens(tokens)
538 if repeated_tokens:
539 raise ValueError(
540 "The passed vocabulary has at least one repeated "
541 "term. Please uniquify your dataset. The repeated terms "
542 "are {}".format(repeated_tokens)
543 )
545 if self.mask_token is not None and self.mask_token in tokens:
546 mask_index = np.argwhere(vocabulary == self.mask_token)[-1]
547 raise ValueError(
548 "Found reserved mask token at unexpected location in "
549 "`vocabulary`. Note that passed `vocabulary` does not need to "
550 "include the OOV and mask tokens. Either remove all mask and "
551 "OOV tokens, or include them only at the start of the "
552 f"vocabulary in precisely this order: {special_tokens}. "
553 f"Received: mask_token={self.mask_token} at "
554 f"vocabulary index {mask_index}"
555 )
556 # Only error out for oov_token when invert=True. When invert=False,
557 # oov_token is unused during lookup.
558 if (
559 self.oov_token is not None
560 and self.invert
561 and self.oov_token in tokens
562 ):
563 oov_index = np.argwhere(vocabulary == self.oov_token)[-1]
564 raise ValueError(
565 "Found reserved OOV token at unexpected location in "
566 "`vocabulary`. Note that passed `vocabulary` does not need to "
567 "include the OOV and mask tokens. Either remove all mask and "
568 "OOV tokens, or include them only at the start of the "
569 f"vocabulary in precisely this order: {special_tokens}. "
570 f"Received: oov_token={self.oov_token} at "
571 f"vocabulary index {oov_index}"
572 )
574 new_vocab_size = token_start + len(tokens)
575 if self.max_tokens is not None and (new_vocab_size > self.max_tokens):
576 raise ValueError(
577 "Attempted to set a vocabulary larger than the maximum vocab "
578 "size. Passed vocab size is {}, max vocab size is {}.".format(
579 new_vocab_size, self.max_tokens
580 )
581 )
582 self.lookup_table = self._lookup_table_from_tokens(tokens)
583 self._record_vocabulary_size()
585 if self.output_mode == TF_IDF and idf_weights is not False:
586 if len(vocabulary) != len(idf_weights):
587 raise ValueError(
588 "`idf_weights` must be the same length as vocabulary. "
589 "len(idf_weights) is {}, len(vocabulary) is {}".format(
590 len(vocabulary), len(idf_weights)
591 )
592 )
593 idf_weights = self._convert_to_ndarray(idf_weights)
594 if idf_weights.ndim != 1:
595 raise ValueError(
596 "TF-IDF data must be a 1-index array, "
597 "but received {}".format(type(idf_weights))
598 )
600 # If the passed vocabulary has no special tokens, we need to pad the
601 # front of idf_weights. We don't have real document frequencies for
602 # these tokens so we will use an average of all idf_weights passed
603 # in as a reasonable default.
604 if found_special_tokens:
605 front_padding = 0
606 front_padding_value = 0
607 else:
608 front_padding = token_start
609 front_padding_value = np.average(idf_weights)
610 # If pad_to_max_tokens is true, and max_tokens is greater than our
611 # total vocab size, we need to pad the back of idf_weights with
612 # zeros as well.
613 back_padding_value = 0
614 if self.pad_to_max_tokens and self.max_tokens is not None:
615 back_padding = (
616 self.max_tokens - front_padding - len(idf_weights)
617 )
618 else:
619 back_padding = 0
620 weights = np.pad(
621 idf_weights,
622 (front_padding, back_padding),
623 "constant",
624 constant_values=(front_padding_value, back_padding_value),
625 )
626 weights = tf.convert_to_tensor(weights, dtype=self.compute_dtype)
627 self.idf_weights.assign(weights)
628 self.idf_weights_const = self.idf_weights.value()
630 def update_state(self, data):
631 if self._has_input_vocabulary:
632 raise ValueError(
633 "Cannot adapt {} layer after setting a static vocabulary via "
634 "init argument "
635 "or `set_vocabulary`.".format(self.__class__.__name__)
636 )
638 data = utils.ensure_tensor(data, dtype=self.vocabulary_dtype)
639 if data.shape.rank == 0:
640 data = tf.expand_dims(data, 0)
641 if data.shape.rank == 1:
642 # Expand dims on axis 0 for tf-idf. A 1-d tensor is a single
643 # document.
644 data = tf.expand_dims(data, 0)
646 tokens, counts = self._num_tokens(data)
647 self.token_counts.insert(
648 tokens, counts + self.token_counts.lookup(tokens)
649 )
651 if self.output_mode == TF_IDF:
652 # Dedupe each row of our dataset.
653 deduped_doc_data = tf.map_fn(lambda x: tf.unique(x)[0], data)
654 # Flatten and count tokens.
655 tokens, doc_counts = self._num_tokens(deduped_doc_data)
656 self.token_document_counts.insert(
657 tokens, doc_counts + self.token_document_counts.lookup(tokens)
658 )
659 if tf_utils.is_ragged(data):
660 self.num_documents.assign_add(data.nrows())
661 else:
662 self.num_documents.assign_add(
663 tf.shape(data, out_type=tf.int64)[0]
664 )
666 def finalize_state(self):
667 if self._has_input_vocabulary or tf.equal(self.token_counts.size(), 0):
668 # Finalize idf_weights to a const for call even if we don't need to
669 # compute a new vocabulary.
670 if self.output_mode == TF_IDF:
671 self.idf_weights_const = self.idf_weights.value()
672 self._record_vocabulary_size()
673 return
675 # Remove special tokens from our counts.
676 if self.mask_token is not None:
677 self.token_counts.remove(
678 tf.convert_to_tensor([self.mask_token], self.vocabulary_dtype)
679 )
680 if self.oov_token is not None:
681 self.token_counts.remove(
682 tf.convert_to_tensor([self.oov_token], self.vocabulary_dtype)
683 )
685 tokens, counts = self.token_counts.export()
686 # To keep vocabs deterministic, we sort our tokens by count and break
687 # ties by sorting the tokens themselves. Tensorflow has no ops for
688 # sorting strings, so we need to use numpy for the sort.
689 sorted_indices = np.lexsort((tokens.numpy(), counts.numpy()))[::-1]
690 token_start = self._token_start_index()
691 if self.max_tokens:
692 max_learned_tokens = self.max_tokens - token_start
693 sorted_indices = sorted_indices[:max_learned_tokens]
694 tokens = tf.gather(tokens, sorted_indices)
695 self.lookup_table = self._lookup_table_from_tokens(tokens)
697 if self.output_mode == TF_IDF:
698 token_document_counts = self.token_document_counts.lookup(tokens)
699 idf_weights = self._inverse_document_frequency(
700 token_document_counts, self.num_documents
701 )
702 idf_weights = tf.cast(idf_weights, self.compute_dtype)
703 # Pad the front of idf_weights with the average idf weight for OOV
704 # tokens. We cannot compute the real idf weight of OOV in a single
705 # pass.
706 idf_weights = tf.pad(
707 idf_weights,
708 [[self._token_start_index(), 0]],
709 constant_values=tf.reduce_mean(idf_weights),
710 )
711 if self.pad_to_max_tokens and self.max_tokens is not None:
712 # Pad the back of idf_weights with zeros.
713 idf_weights = tf.pad(
714 idf_weights,
715 [[0, self.max_tokens - tf.size(idf_weights)]],
716 constant_values=0,
717 )
718 self.idf_weights.assign(idf_weights)
719 self.idf_weights_const = self.idf_weights.value()
721 # We call this here to save memory, now that we've built our vocabulary,
722 # we don't want to keep every token we've seen in separate lookup
723 # tables.
724 self.reset_state()
725 self._record_vocabulary_size()
727 def reset_state(self):
728 if self._has_input_vocabulary:
729 return
731 self.token_counts.remove(self.token_counts.export()[0])
732 if self.output_mode == TF_IDF:
733 self.token_document_counts.remove(
734 self.token_document_counts.export()[0]
735 )
736 self.num_documents.assign(0)
738 def call(self, inputs):
739 self._ensure_known_vocab_size()
741 inputs = utils.ensure_tensor(inputs, dtype=self._key_dtype)
742 original_shape = inputs.shape
743 # Some ops will not handle scalar input, so uprank to rank 1.
744 if inputs.shape.rank == 0:
745 inputs = self._expand_dims(inputs, -1)
747 if tf_utils.is_sparse(inputs):
748 lookups = tf.SparseTensor(
749 inputs.indices,
750 self._lookup_dense(inputs.values),
751 inputs.dense_shape,
752 )
753 elif tf_utils.is_ragged(inputs):
754 lookups = tf.ragged.map_flat_values(self._lookup_dense, inputs)
755 else:
756 lookups = self._lookup_dense(inputs)
758 if self.output_mode == INT:
759 # If we received a scalar input, downrank back to a scalar.
760 if original_shape.rank == 0:
761 lookups = tf.squeeze(lookups, -1)
762 return lookups
764 depth = (
765 self.max_tokens
766 if self.pad_to_max_tokens
767 else self._frozen_vocab_size
768 )
769 idf_weights = (
770 self.idf_weights_const if self.output_mode == TF_IDF else None
771 )
772 return utils.encode_categorical_inputs(
773 lookups,
774 output_mode=self.output_mode,
775 depth=depth,
776 dtype=self.compute_dtype,
777 sparse=self.sparse,
778 idf_weights=idf_weights,
779 )
781 def _lookup_dense(self, inputs):
782 """Lookup table values for a dense Tensor, handling masking and OOV."""
783 # When executing eagerly and tracing keras.Input objects,
784 # do not call lookup.
785 # This is critical for restoring SavedModel, which will first trace
786 # layer.call and then attempt to restore the table. We need the table to
787 # be uninitialized for the restore to work, but calling the table
788 # uninitialized would error.
789 if tf.executing_eagerly() and backend.is_keras_tensor(inputs):
790 lookups = tf.zeros_like(inputs, dtype=self._value_dtype)
791 else:
792 lookups = self.lookup_table.lookup(inputs)
794 if self.mask_token is not None:
795 mask_locations = tf.equal(inputs, self._mask_key)
796 lookups = tf.where(mask_locations, self._mask_value, lookups)
798 if self.invert:
799 return lookups
801 lookup_checks = []
803 if self.num_oov_indices == 0:
804 # If we have zero oov indices, we need to check for oov inputs.
805 oov_indices = tf.where(tf.equal(lookups, -1))
806 oov_inputs = tf.gather_nd(inputs, oov_indices)
807 msg = tf.strings.format(
808 "When `num_oov_indices=0` all inputs should be in vocabulary, "
809 "found OOV values {}, consider setting `num_oov_indices=1`.",
810 (oov_inputs,),
811 )
812 assertion = tf.Assert(tf.equal(tf.size(oov_indices), 0), [msg])
813 lookup_checks.append(assertion)
814 elif self.num_oov_indices > 1:
815 # If we have multiple oov indices, we need a further hashing step.
816 if self._key_dtype.is_integer:
817 oov_indices = tf.math.floormod(inputs, self.num_oov_indices)
818 else:
819 oov_indices = tf.strings.to_hash_bucket_fast(
820 inputs, num_buckets=self.num_oov_indices
821 )
822 oov_indices = oov_indices + self._oov_start_index()
823 oov_locations = tf.equal(lookups, self._default_value)
824 lookups = tf.where(oov_locations, oov_indices, lookups)
826 with tf.control_dependencies(lookup_checks):
827 return tf.identity(lookups)
829 def save_own_variables(self, store):
830 if self.output_mode == TF_IDF:
831 store["idf_weights"] = self.idf_weights_const.numpy()
833 def load_own_variables(self, store):
834 if self.output_mode == TF_IDF:
835 self.idf_weights.assign(store["idf_weights"])
836 self.idf_weights_const = self.idf_weights.value()
838 def save_assets(self, dir_path):
839 if self.input_vocabulary:
840 # Vocab saved in config.
841 # TODO: consider unifying both paths.
842 return
843 vocabulary = self.get_vocabulary(include_special_tokens=True)
844 vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt")
845 with open(vocabulary_filepath, "w") as f:
846 f.write("\n".join([str(w) for w in vocabulary]))
848 def load_assets(self, dir_path):
849 if self.input_vocabulary:
850 # Vocab saved in config.
851 # TODO: consider unifying both paths.
852 return
853 vocabulary_filepath = tf.io.gfile.join(dir_path, "vocabulary.txt")
854 # TODO: fix bug with include_special_tokens and set reload from file.
855 with open(vocabulary_filepath, "r") as f:
856 lines = f.read().split("\n")
857 if tf.as_dtype(self.vocabulary_dtype) == tf.string:
858 values = [str(line) for line in lines]
859 else:
860 values = [int(line) for line in lines]
861 if self.output_mode == TF_IDF:
862 self.set_vocabulary(values, idf_weights=False)
863 else:
864 self.set_vocabulary(values)
866 def _uninitialized_lookup_table(self):
867 with tf.init_scope():
868 initializer = NullInitializer(self._key_dtype, self._value_dtype)
869 return tf.lookup.StaticHashTable(initializer, self._default_value)
871 def _lookup_table_from_tokens(self, tokens):
872 with tf.init_scope():
873 token_start = self._token_start_index()
874 token_end = token_start + tf.size(tokens)
875 indices_dtype = (
876 self._key_dtype if self.invert else self._value_dtype
877 )
878 indices = tf.range(token_start, token_end, dtype=indices_dtype)
879 keys, values = (
880 (indices, tokens) if self.invert else (tokens, indices)
881 )
882 initializer = tf.lookup.KeyValueTensorInitializer(
883 keys, values, self._key_dtype, self._value_dtype
884 )
885 return tf.lookup.StaticHashTable(initializer, self._default_value)
887 def _lookup_table_from_file(self, filename):
888 if self.invert:
889 key_index = tf.lookup.TextFileIndex.LINE_NUMBER
890 value_index = tf.lookup.TextFileIndex.WHOLE_LINE
891 else:
892 key_index = tf.lookup.TextFileIndex.WHOLE_LINE
893 value_index = tf.lookup.TextFileIndex.LINE_NUMBER
894 with tf.init_scope():
895 initializer = tf.lookup.TextFileInitializer(
896 filename=filename,
897 key_dtype=self._key_dtype,
898 key_index=key_index,
899 value_dtype=self._value_dtype,
900 value_index=value_index,
901 value_index_offset=self._token_start_index(),
902 )
903 return tf.lookup.StaticHashTable(initializer, self._default_value)
905 def _convert_to_ndarray(self, x):
906 return np.array(x) if isinstance(x, (list, tuple)) else x
908 def _expand_dims(self, inputs, axis):
909 if tf_utils.is_sparse(inputs):
910 return tf.sparse.expand_dims(inputs, axis)
911 else:
912 return tf.expand_dims(inputs, axis)
914 def _oov_start_index(self):
915 return (
916 1 if self.mask_token is not None and self.output_mode == INT else 0
917 )
919 def _token_start_index(self):
920 return self._oov_start_index() + self.num_oov_indices
922 def _ensure_known_vocab_size(self):
923 if self.output_mode == INT or self.pad_to_max_tokens:
924 return
925 if self._frozen_vocab_size is None:
926 raise RuntimeError(
927 f"When using `output_mode={self.output_mode}` "
928 "and `pad_to_max_tokens=False`, "
929 "you must set the layer's vocabulary before calling it. Either "
930 "pass a `vocabulary` argument to the layer, or call `adapt` "
931 "with some sample data.".format(self.output_mode)
932 )
934 def _ensure_vocab_size_unchanged(self):
935 if self.output_mode == INT or self.pad_to_max_tokens:
936 return
938 with tf.init_scope():
939 new_vocab_size = self.vocabulary_size()
941 if (
942 self._frozen_vocab_size is not None
943 and new_vocab_size != self._frozen_vocab_size
944 ):
945 raise RuntimeError(
946 f"When using `output_mode={self.output_mode}` "
947 "and `pad_to_max_tokens=False`, "
948 "the vocabulary size cannot be changed after the layer is "
949 f"called. Old vocab size is {self._frozen_vocab_size}, "
950 f"new vocab size is {new_vocab_size}"
951 )
953 def _find_repeated_tokens(self, vocabulary):
954 """Return all repeated tokens in a vocabulary."""
955 vocabulary_set = set(vocabulary)
956 if len(vocabulary) != len(vocabulary_set):
957 return [
958 item
959 for item, count in collections.Counter(vocabulary).items()
960 if count > 1
961 ]
962 else:
963 return []
965 def _num_tokens(self, data):
966 """Count the number of tokens in a ragged, sparse or dense tensor."""
967 if tf_utils.is_sparse(data):
968 flat_values = data.values
969 elif tf_utils.is_ragged(data):
970 flat_values = data.flat_values
971 else:
972 flat_values = tf.reshape(data, [-1])
973 tokens, _, counts = tf.unique_with_counts(flat_values, out_idx=tf.int64)
974 return tokens, counts
976 def _inverse_document_frequency(self, token_document_counts, num_documents):
977 """Computes the inverse-document-frequency (IDF) component of "tf_idf".
979 Uses the default weighting scheme described in
980 https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
982 Args:
983 token_document_counts: An array of the # of documents each token
984 appears in.
985 num_documents: An int representing the total number of documents
987 Returns:
988 An array of "inverse document frequency" weights.
989 """
990 return tf.math.log(1 + num_documents / (1 + token_document_counts))
992 @property
993 def _trackable_saved_model_saver(self):
994 return layer_serialization.VocabularySavedModelSaver(self)
996 # Override points for IntegerLookup and StringLookup.
997 def _tensor_vocab_to_numpy(self, vocabulary):
998 """Converts a tensor vocabulary to a numpy vocabulary."""
999 return vocabulary.numpy()