Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/ops/lookup_ops.py: 36%
696 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
16# pylint: disable=g-bad-name
17import collections
18import functools
19import uuid
21from tensorflow.python.checkpoint import saveable_compat
22from tensorflow.python.eager import context
23from tensorflow.python.framework import constant_op
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import sparse_tensor
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.framework import tensor_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import control_flow_ops
31from tensorflow.python.ops import gen_lookup_ops
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import string_ops
34# go/tf-wildcard-import
35# pylint: disable=wildcard-import
36from tensorflow.python.ops.gen_lookup_ops import *
37from tensorflow.python.ops.ragged import ragged_tensor
38from tensorflow.python.saved_model import registration
39from tensorflow.python.trackable import asset
40# pylint: enable=wildcard-import
41from tensorflow.python.trackable import base as trackable_base
42from tensorflow.python.trackable import resource
43from tensorflow.python.training.saver import BaseSaverBuilder
44from tensorflow.python.util import compat as compat_util
45from tensorflow.python.util.deprecation import deprecated
46from tensorflow.python.util.tf_export import tf_export
49@tf_export(v1=["initialize_all_tables"])
50@deprecated(None, "Use `tf.tables_initializer` instead.")
51def initialize_all_tables(name="init_all_tables"):
52 """Returns an Op that initializes all tables of the default graph.
54 Args:
55 name: Optional name for the initialization op.
57 Returns:
58 An Op that initializes all tables. Note that if there are
59 not tables the returned Op is a NoOp.
60 """
61 return tables_initializer(name)
64@tf_export(v1=["initializers.tables_initializer", "tables_initializer"])
65def tables_initializer(name="init_all_tables"):
66 """Returns an Op that initializes all tables of the default graph.
68 Args:
69 name: Optional name for the initialization op.
71 Returns:
72 An Op that initializes all tables. Note that if there are
73 not tables the returned Op is a NoOp.
75 @compatibility(TF2)
76 `tf.compat.v1.tables_initializer` is no longer needed with eager execution and
77 `tf.function`. In TF2, when creating an initializable table like a
78 `tf.lookup.StaticHashTable`, the table will automatically be initialized on
79 creation.
81 #### Before & After Usage Example
83 Before:
85 >>> with tf.compat.v1.Session():
86 ... init = tf.compat.v1.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
87 ... table = tf.compat.v1.lookup.StaticHashTable(init, default_value=-1)
88 ... tf.compat.v1.tables_initializer().run()
89 ... result = table.lookup(tf.constant(['a', 'c'])).eval()
90 >>> result
91 array([ 1, -1], dtype=int32)
93 After:
95 >>> init = tf.lookup.KeyValueTensorInitializer(['a', 'b'], [1, 2])
96 >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
97 >>> table.lookup(tf.constant(['a', 'c'])).numpy()
98 array([ 1, -1], dtype=int32)
100 @end_compatibility
101 """
102 initializers = ops.get_collection(ops.GraphKeys.TABLE_INITIALIZERS)
103 if initializers:
104 return control_flow_ops.group(*initializers, name=name)
105 return control_flow_ops.no_op(name=name)
108def check_table_dtypes(table, key_dtype, value_dtype):
109 """Check that the given key_dtype and value_dtype matches the table dtypes.
111 Args:
112 table: The table to check types against to.
113 key_dtype: The key data type to check.
114 value_dtype: The value data type to check.
116 Raises:
117 TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
118 types.
119 """
120 if key_dtype.base_dtype != table.key_dtype:
121 raise TypeError(f"Invalid key dtype for table, expected {table.key_dtype} "
122 f"but got {key_dtype}.")
123 if value_dtype.base_dtype != table.value_dtype:
124 raise TypeError("Invalid value dtype for table, expected "
125 f"{table.value_dtype} but got {value_dtype}.")
128class LookupInterface(resource.TrackableResource):
129 """Represent a lookup table that persists across different steps."""
131 def __init__(self, key_dtype, value_dtype):
132 """Construct a lookup table interface.
134 Args:
135 key_dtype: The table key type.
136 value_dtype: The table value type.
137 """
138 self._key_dtype = dtypes.as_dtype(key_dtype)
139 self._value_dtype = dtypes.as_dtype(value_dtype)
140 super(LookupInterface, self).__init__()
142 def _create_resource(self):
143 raise NotImplementedError
145 @property
146 def key_dtype(self):
147 """The table key dtype."""
148 return self._key_dtype
150 @property
151 def value_dtype(self):
152 """The table value dtype."""
153 return self._value_dtype
155 @property
156 def name(self):
157 """The name of the table."""
158 return NotImplementedError
160 def size(self, name=None):
161 """Compute the number of elements in this table."""
162 raise NotImplementedError
164 def lookup(self, keys, name=None):
165 """Looks up `keys` in a table, outputs the corresponding values."""
166 raise NotImplementedError
168 def __getitem__(self, keys):
169 """Looks up `keys` in a table, outputs the corresponding values."""
170 return self.lookup(keys)
173class InitializableLookupTableBase(LookupInterface):
174 """Initializable lookup table interface.
176 An initializable lookup tables persist across different steps.
177 """
179 def __init__(self, default_value, initializer):
180 """Construct a table object from a table reference.
182 If requires a table initializer object (subclass of `TableInitializerBase`).
183 It provides the table key and value types, as well as the op to initialize
184 the table. The caller is responsible to execute the initialization op.
186 Args:
187 default_value: The value to use if a key is missing in the table.
188 initializer: The table initializer to use.
189 """
190 super(InitializableLookupTableBase, self).__init__(initializer.key_dtype,
191 initializer.value_dtype)
192 self._default_value = ops.convert_to_tensor(
193 default_value, dtype=self._value_dtype)
194 self._default_value.get_shape().merge_with(tensor_shape.TensorShape([]))
195 if isinstance(initializer, trackable_base.Trackable):
196 self._initializer = self._track_trackable(initializer, "_initializer")
197 with ops.init_scope():
198 self._resource_handle = self._create_resource()
199 if (not context.executing_eagerly() and
200 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
201 with ops.init_scope():
202 self._init_op = self._initialize()
203 else:
204 self._init_op = self._initialize()
206 def _initialize(self):
207 return self._initializer.initialize(self)
209 @property
210 def default_value(self):
211 """The default value of the table."""
212 return self._default_value
214 def size(self, name=None):
215 """Compute the number of elements in this table.
217 Args:
218 name: A name for the operation (optional).
220 Returns:
221 A scalar tensor containing the number of elements in this table.
222 """
223 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
224 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
226 def lookup(self, keys, name=None):
227 """Looks up `keys` in a table, outputs the corresponding values.
229 The `default_value` is used for keys not present in the table.
231 Args:
232 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
233 name: A name for the operation (optional).
235 Returns:
236 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
237 otherwise a dense `Tensor`.
239 Raises:
240 TypeError: when `keys` or `default_value` doesn't match the table data
241 types.
242 """
243 key_tensor = keys
244 if isinstance(keys,
245 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
246 key_tensor = keys.values
248 if keys.dtype.base_dtype != self._key_dtype:
249 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
250 f"received: {keys.dtype}")
252 with ops.name_scope(
253 name, "%s_Lookup" % self.name,
254 (self.resource_handle, key_tensor, self._default_value)):
255 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle,
256 key_tensor,
257 self._default_value)
259 values.set_shape(key_tensor.get_shape())
260 if isinstance(keys, sparse_tensor.SparseTensor):
261 return sparse_tensor.SparseTensor(keys.indices, values, keys.dense_shape)
262 elif isinstance(keys, ragged_tensor.RaggedTensor):
263 return keys.with_values(values)
264 else:
265 return values
268class InitializableLookupTableBaseV1(InitializableLookupTableBase):
270 @property
271 def initializer(self):
272 return self._init_op
275@registration.register_tf_serializable(
276 predicate=lambda obj: isinstance(obj, StaticHashTable))
277@tf_export("lookup.StaticHashTable", v1=[])
278class StaticHashTable(InitializableLookupTableBase):
279 """A generic hash table that is immutable once initialized.
281 Example usage:
283 >>> keys_tensor = tf.constant(['a', 'b', 'c'])
284 >>> vals_tensor = tf.constant([7, 8, 9])
285 >>> input_tensor = tf.constant(['a', 'f'])
286 >>> table = tf.lookup.StaticHashTable(
287 ... tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
288 ... default_value=-1)
289 >>> table.lookup(input_tensor).numpy()
290 array([ 7, -1], dtype=int32)
292 Or for more pythonic code:
294 >>> table[input_tensor].numpy()
295 array([ 7, -1], dtype=int32)
297 The result of a lookup operation has the same shape as the argument:
299 >>> input_tensor = tf.constant([['a', 'b'], ['c', 'd']])
300 >>> table[input_tensor].numpy()
301 array([[ 7, 8],
302 [ 9, -1]], dtype=int32)
305 """
307 def __init__(self,
308 initializer,
309 default_value,
310 name=None,
311 experimental_is_anonymous=False):
312 """Creates a non-initialized `HashTable` object.
314 Creates a table, the type of its keys and values are specified by the
315 initializer.
316 Before using the table you will have to initialize it. After initialization
317 the table will be immutable.
319 Args:
320 initializer: The table initializer to use. See `HashTable` kernel for
321 supported key and value types.
322 default_value: The value to use if a key is missing in the table.
323 name: A name for the operation (optional).
324 experimental_is_anonymous: Whether to use anonymous mode for the
325 table (default is False). In anonymous mode, the table
326 resource can only be accessed via a resource handle. It can't
327 be looked up by a name. When all resource handles pointing to
328 that resource are gone, the resource will be deleted
329 automatically.
331 Returns:
332 A `HashTable` object.
333 """
334 self._initializer = initializer
335 self._default_value = default_value
336 self._is_anonymous = experimental_is_anonymous
337 if not self._is_anonymous:
338 self._shared_name = self._initializer._shared_name # pylint: disable=protected-access
339 if not self._shared_name:
340 # Force using a shared name so that StaticHashTable resources can be
341 # shared across different kernels. If no "shared_name" is set and
342 # "use_node_name_sharing" is False, then each kernel gets its own local
343 # resource.
344 self._shared_name = "hash_table_%s" % (str(uuid.uuid4()),)
345 self._name = name or "hash_table"
346 self._table_name = None
347 super(StaticHashTable, self).__init__(default_value, initializer)
348 self._value_shape = self._default_value.get_shape()
350 def _create_resource(self):
351 if self._is_anonymous:
352 table_ref = gen_lookup_ops.anonymous_hash_table(
353 key_dtype=self._initializer.key_dtype,
354 value_dtype=self._initializer.value_dtype,
355 name=self._name)
356 else:
357 table_ref = gen_lookup_ops.hash_table_v2(
358 shared_name=self._shared_name,
359 key_dtype=self._initializer.key_dtype,
360 value_dtype=self._initializer.value_dtype,
361 name=self._name)
362 if context.executing_eagerly():
363 self._table_name = None
364 else:
365 self._table_name = table_ref.op.name.split("/")[-1]
366 return table_ref
368 @property
369 def name(self):
370 return self._table_name
372 def export(self, name=None):
373 """Returns tensors of all keys and values in the table.
375 Args:
376 name: A name for the operation (optional).
378 Returns:
379 A pair of tensors with the first tensor containing all keys and the
380 second tensors containing all values in the table.
381 """
382 with ops.name_scope(name, "%s_Export" % self.name, [self.resource_handle]):
383 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
384 self.resource_handle, self._key_dtype, self._value_dtype)
386 exported_values.set_shape(exported_keys.get_shape().concatenate(
387 self._value_shape))
388 return exported_keys, exported_values
390 def _serialize_to_proto(self, **unused_kwargs):
391 return None
393 def _add_trackable_child(self, name, value):
394 setattr(self, name, value)
395 if isinstance(value, trackable_base.Trackable):
396 self._track_trackable(value, name) # pylint:disable=protected-access
398 @classmethod
399 def _deserialize_from_proto(cls, **kwargs):
401 class _RestoredStaticHashTable(resource.RestoredResource): # pylint: disable=protected-access
403 @classmethod
404 def _resource_type(cls):
405 return "RestoredStaticHashTable"
407 return _RestoredStaticHashTable._deserialize_from_proto(**kwargs) # pylint: disable=protected-access
410@tf_export(v1=["lookup.StaticHashTable"])
411class StaticHashTableV1(StaticHashTable):
412 """A generic hash table that is immutable once initialized.
414 When running in graph mode, you must evaluate the tensor returned by
415 `tf.tables_initializer()` before evaluating the tensor returned by
416 this class's `lookup()` method. Example usage in graph mode:
418 ```python
419 keys_tensor = tf.constant([1, 2])
420 vals_tensor = tf.constant([3, 4])
421 input_tensor = tf.constant([1, 5])
422 table = tf.lookup.StaticHashTable(
423 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
424 out = table.lookup(input_tensor)
425 with tf.Session() as sess:
426 sess.run(tf.tables_initializer())
427 print(sess.run(out))
428 ```
430 Note that in graph mode if you set `experimental_is_anonymous` to
431 `True`, you should only call `Session.run` once, otherwise each
432 `Session.run` will create (and destroy) a new table unrelated to
433 each other, leading to errors such as "Table not initialized".
434 You can do so like this:
436 ```python
437 keys_tensor = tf.constant([1, 2])
438 vals_tensor = tf.constant([3, 4])
439 input_tensor = tf.constant([1, 5])
440 table = tf.lookup.StaticHashTable(
441 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1,
442 experimental_is_anonymous=True)
443 with tf.control_dependencies([tf.tables_initializer()]):
444 out = table.lookup(input_tensor)
445 with tf.Session() as sess:
446 print(sess.run(out))
447 ```
449 In eager mode, no special code is needed to initialize the table.
450 Example usage in eager mode:
452 ```python
453 tf.enable_eager_execution()
454 keys_tensor = tf.constant([1, 2])
455 vals_tensor = tf.constant([3, 4])
456 input_tensor = tf.constant([1, 5])
457 table = tf.lookup.StaticHashTable(
458 tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor), -1)
459 print(table.lookup(input_tensor))
460 ```
461 """
463 @property
464 def initializer(self):
465 return self._init_op
468# For backwards compatibility. This will be removed in TF 2.0.
469class HashTable(StaticHashTableV1):
471 @property
472 def init(self):
473 return self.initializer
476class TableInitializerBase(trackable_base.Trackable):
477 """Base class for lookup table initializers."""
479 def __init__(self, key_dtype, value_dtype):
480 """Construct a table initializer object.
482 Args:
483 key_dtype: Type of the table keys.
484 value_dtype: Type of the table values.
485 """
486 self._key_dtype = dtypes.as_dtype(key_dtype)
487 self._value_dtype = dtypes.as_dtype(value_dtype)
489 @property
490 def key_dtype(self):
491 """The expected table key dtype."""
492 return self._key_dtype
494 @property
495 def value_dtype(self):
496 """The expected table value dtype."""
497 return self._value_dtype
499 def initialize(self, table):
500 """Returns the table initialization op."""
501 raise NotImplementedError
503 @property
504 def _shared_name(self):
505 """Returns a shared name to be used by the table."""
506 shared_name = ""
507 if context.executing_eagerly():
508 # Ensure a unique name when eager execution is enabled to avoid spurious
509 # sharing issues.
510 # TODO(rohanj): Use context.anonymous_name() instead.
511 shared_name += str(ops.uid())
512 return shared_name
515@tf_export("lookup.KeyValueTensorInitializer")
516class KeyValueTensorInitializer(TableInitializerBase):
517 """Table initializers given `keys` and `values` tensors.
519 >>> keys_tensor = tf.constant(['a', 'b', 'c'])
520 >>> vals_tensor = tf.constant([7, 8, 9])
521 >>> input_tensor = tf.constant(['a', 'f'])
522 >>> init = tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor)
523 >>> table = tf.lookup.StaticHashTable(
524 ... init,
525 ... default_value=-1)
526 >>> table.lookup(input_tensor).numpy()
527 array([ 7, -1], dtype=int32)
529 """
531 def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
532 """Constructs a table initializer object based on keys and values tensors.
534 Args:
535 keys: The tensor for the keys.
536 values: The tensor for the values.
537 key_dtype: The `keys` data type. Used when `keys` is a python array.
538 value_dtype: The `values` data type. Used when `values` is a python array.
539 name: A name for the operation (optional).
540 """
541 if (not context.executing_eagerly() and
542 ops.get_default_graph()._get_control_flow_context() is not None): # pylint: disable=protected-access
543 with ops.init_scope():
544 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
545 self._values = ops.convert_to_tensor(
546 values, dtype=value_dtype, name="values")
547 else:
548 self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
549 self._values = ops.convert_to_tensor(
550 values, dtype=value_dtype, name="values")
551 self._name = name if name is not None else "key_value_init"
552 if context.executing_eagerly():
553 # Ensure a unique name when eager execution is enabled to avoid spurious
554 # sharing issues.
555 # TODO(rohanj): Use context.anonymous_name() instead.
556 self._name += str(ops.uid())
558 super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
559 self._values.dtype)
561 def initialize(self, table):
562 """Initializes the given `table` with `keys` and `values` tensors.
564 Args:
565 table: The table to initialize.
567 Returns:
568 The operation that initializes the table.
570 Raises:
571 TypeError: when the keys and values data types do not match the table
572 key and value data types.
573 """
574 check_table_dtypes(table, self._keys.dtype, self._values.dtype)
575 with ops.name_scope(
576 self._name, values=(table.resource_handle, self._keys, self._values)):
577 init_op = gen_lookup_ops.lookup_table_import_v2(table.resource_handle,
578 self._keys, self._values)
579 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
580 return init_op
583@tf_export("lookup.TextFileIndex")
584class TextFileIndex:
585 """The key and value content to get from each line.
587 This class defines the key and value used for `tf.lookup.TextFileInitializer`.
589 The key and value content to get from each line is specified either
590 by the following, or a value `>=0`.
591 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
592 expects data type int64.
593 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
594 type string.
596 A value `>=0` means use the index (starting at zero) of the split line based
597 on `delimiter`.
598 """
599 WHOLE_LINE = -2
600 LINE_NUMBER = -1
603@tf_export("lookup.TextFileInitializer")
604class TextFileInitializer(TableInitializerBase):
605 r"""Table initializers from a text file.
607 This initializer assigns one entry in the table for each line in the file.
609 The key and value type of the table to initialize is given by `key_dtype` and
610 `value_dtype`.
612 The key and value content to get from each line is specified by
613 the `key_index` and `value_index`.
615 * `TextFileIndex.LINE_NUMBER` means use the line number starting from zero,
616 expects data type int64.
617 * `TextFileIndex.WHOLE_LINE` means use the whole line content, expects data
618 type string.
619 * A value `>=0` means use the index (starting at zero) of the split line based
620 on `delimiter`.
622 For example if we have a file with the following content:
624 >>> import tempfile
625 >>> f = tempfile.NamedTemporaryFile(delete=False)
626 >>> content='\n'.join(["emerson 10", "lake 20", "palmer 30",])
627 >>> f.file.write(content.encode('utf-8'))
628 >>> f.file.close()
630 The following snippet initializes a table with the first column as keys and
631 second column as values:
633 * `emerson -> 10`
634 * `lake -> 20`
635 * `palmer -> 30`
637 >>> init= tf.lookup.TextFileInitializer(
638 ... filename=f.name,
639 ... key_dtype=tf.string, key_index=0,
640 ... value_dtype=tf.int64, value_index=1,
641 ... delimiter=" ")
642 >>> table = tf.lookup.StaticHashTable(init, default_value=-1)
643 >>> table.lookup(tf.constant(['palmer','lake','tarkus'])).numpy()
645 Similarly to initialize the whole line as keys and the line number as values.
647 * `emerson 10 -> 0`
648 * `lake 20 -> 1`
649 * `palmer 30 -> 2`
651 >>> init = tf.lookup.TextFileInitializer(
652 ... filename=f.name,
653 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
654 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
655 >>> table = tf.lookup.StaticHashTable(init, -1)
656 >>> table.lookup(tf.constant('palmer 30')).numpy()
657 2
658 """
660 def __init__(self,
661 filename,
662 key_dtype,
663 key_index,
664 value_dtype,
665 value_index,
666 vocab_size=None,
667 delimiter="\t",
668 name=None,
669 value_index_offset=0):
670 """Constructs a table initializer object to populate from a text file.
672 It generates one key-value pair per line. The type of table key and
673 value are specified by `key_dtype` and `value_dtype`, respectively.
674 Similarly the content of the key and value are specified by the key_index
675 and value_index.
677 - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
678 expects data type int64.
679 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
680 type string or int64.
681 - A value >=0 means use the index (starting at zero) of the split line based
682 on `delimiter`.
684 Args:
685 filename: The filename of the text file to be used for initialization. The
686 path must be accessible from wherever the graph is initialized (eg.
687 trainer or eval workers). The filename may be a scalar `Tensor`.
688 key_dtype: The `key` data type.
689 key_index: the index that represents information of a line to get the
690 table 'key' values from.
691 value_dtype: The `value` data type.
692 value_index: the index that represents information of a line to get the
693 table 'value' values from.'
694 vocab_size: The number of elements in the file, if known.
695 delimiter: The delimiter to separate fields in a line.
696 name: A name for the operation (optional).
697 value_index_offset: A number to add to all indices extracted from the file
698 This is useful for cases where a user would like to reserve one or more
699 low index values for control characters. For instance, if you would
700 like to ensure that no vocabulary item is mapped to index 0 (so you can
701 reserve 0 for a masking value), you can set value_index_offset to 1;
702 this will mean that the first vocabulary element is mapped to 1
703 instead of 0.
705 Raises:
706 ValueError: when the filename is empty, or when the table key and value
707 data types do not match the expected data types.
708 """
709 if not isinstance(filename, ops.Tensor) and not filename:
710 raise ValueError("`filename` argument required for tf.lookup.TextFileInitializer")
712 self._filename_arg = filename
713 key_dtype = dtypes.as_dtype(key_dtype)
714 value_dtype = dtypes.as_dtype(value_dtype)
716 if key_index < -2:
717 raise ValueError(f"`key_index` should be >= -2, received: {key_index}.")
719 if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
720 raise ValueError("`key_dtype` must be int64 if `key_index` is "
721 f"{TextFileIndex.LINE_NUMBER}, received: {key_dtype}")
722 if ((key_index == TextFileIndex.WHOLE_LINE) and
723 (not key_dtype.is_integer) and (key_dtype != dtypes.string)):
724 raise ValueError(
725 "`key_dtype` should be either integer or string for `key_index` "
726 f"{TextFileIndex.WHOLE_LINE}, received: {key_dtype}")
727 if value_index < -2:
728 raise ValueError("`value_index` should be >= -2, received: "
729 f"{value_index}")
731 if value_index == TextFileIndex.LINE_NUMBER and value_dtype != dtypes.int64:
732 raise ValueError("`value_dtype` must be int64 for `value_index` "
733 f"{TextFileIndex.LINE_NUMBER}, received: {value_dtype}")
734 if ((value_index == TextFileIndex.WHOLE_LINE) and
735 (not value_dtype.is_integer) and (value_dtype != dtypes.string)):
736 raise ValueError(
737 "`value_dtype` should be either integer or string for `value_index` "
738 f"{TextFileIndex.WHOLE_LINE}, received: {value_dtype}")
740 if (vocab_size is not None) and (vocab_size <= 0):
741 raise ValueError(f"`vocab_size` should be > 0, received: {vocab_size}")
743 self._key_index = key_index
744 self._value_index = value_index
745 self._vocab_size = vocab_size
746 self._delimiter = delimiter
747 self._name = name
748 self._filename = self._track_trackable(
749 asset.Asset(filename), "_filename")
750 self._offset = value_index_offset
752 super(TextFileInitializer, self).__init__(key_dtype, value_dtype)
754 def initialize(self, table):
755 """Initializes the table from a text file.
757 Args:
758 table: The table to be initialized.
760 Returns:
761 The operation that initializes the table.
763 Raises:
764 TypeError: when the keys and values data types do not match the table
765 key and value data types.
766 """
767 check_table_dtypes(table, self.key_dtype, self.value_dtype)
768 with ops.name_scope(self._name, "text_file_init", (table.resource_handle,)):
769 filename = ops.convert_to_tensor(
770 self._filename, dtypes.string, name="asset_filepath")
771 init_op = gen_lookup_ops.initialize_table_from_text_file_v2(
772 table.resource_handle, filename, self._key_index, self._value_index,
773 -1 if self._vocab_size is None else self._vocab_size, self._delimiter,
774 self._offset)
775 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
776 # If the filename tensor is anything other than a string constant (e.g.,
777 # if it is a placeholder) then it does not make sense to track it as an
778 # asset.
779 if not context.executing_eagerly() and constant_op.is_constant(filename):
780 ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
781 return init_op
783 @property
784 def _shared_name(self):
785 if self._vocab_size:
786 # Keep the shared_name:
787 # <table_type>_<filename>_<vocab_size>_<key_index>_<value_index>_<offset>
788 if self._offset:
789 shared_name = "hash_table_%s_%d_%s_%s_%s" % (
790 self._filename_arg, self._vocab_size, self._key_index,
791 self._value_index, self._offset)
792 else:
793 shared_name = "hash_table_%s_%d_%s_%s" % (
794 self._filename_arg, self._vocab_size, self._key_index,
795 self._value_index)
796 else:
797 # Keep the shared_name
798 # <table_type>_<filename>_<key_index>_<value_index>_<offset>
799 if self._offset:
800 shared_name = "hash_table_%s_%s_%s_%s" % (
801 self._filename_arg, self._key_index, self._value_index,
802 self._offset)
803 else:
804 shared_name = "hash_table_%s_%s_%s" % (
805 self._filename_arg, self._key_index, self._value_index)
807 return shared_name
810class TextFileStringTableInitializer(TextFileInitializer):
811 """Table initializer for `int64` IDs to string tables from a text file."""
813 def __init__(self,
814 filename,
815 key_column_index=TextFileIndex.LINE_NUMBER,
816 value_column_index=TextFileIndex.WHOLE_LINE,
817 vocab_size=None,
818 delimiter="\t",
819 name="text_file_string_table_init"):
820 """Constructs an initializer for an id-to-string table from a text file.
822 It populates a table that its key and value types are int64 and string,
823 respectively. It generates one key-value pair per line.
824 The content of the key and value are specified by `key_column_index`
825 and `value_column_index`.
827 - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
828 expects data type int64.
829 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
830 type string or int64.
831 - A value >=0 means use the index (starting at zero) of the split line based
832 on `delimiter`.
834 Args:
835 filename: The filename of the text file to be used for initialization. The
836 path must be accessible from wherever the graph is initialized (eg.
837 trainer or eval workers). The filename may be a scalar `Tensor`.
838 key_column_index: The column index from the text file to get the keys
839 from. The default is to use the line number, starting from zero.
840 value_column_index: The column index from the text file to get the values
841 from. The default is to use the whole line content.
842 vocab_size: The number of elements in the file, if known.
843 delimiter: The delimiter to separate fields in a line.
844 name: Optional name for the op.
846 Raises:
847 TypeError: when the filename is empty, or when the table key and value
848 data types do not match the expected data types.
849 """
850 super(TextFileStringTableInitializer, self).__init__(
851 filename,
852 dtypes.int64,
853 key_column_index,
854 dtypes.string,
855 value_column_index,
856 vocab_size=vocab_size,
857 delimiter=delimiter,
858 name=name)
861class TextFileIdTableInitializer(TextFileInitializer):
862 """Table initializer for string to `int64` IDs tables from a text file."""
864 def __init__(self,
865 filename,
866 key_column_index=TextFileIndex.WHOLE_LINE,
867 value_column_index=TextFileIndex.LINE_NUMBER,
868 vocab_size=None,
869 delimiter="\t",
870 name="text_file_id_table_init",
871 key_dtype=dtypes.string):
872 """Constructs an initializer for an string-to-id table from a text file.
874 It populates a table that its key and value types are string and int64,
875 respectively. It generates one key-value pair per line.
876 The content of the key and value are specified by the key_index
877 and value_index.
879 - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
880 expects data type int64.
881 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
882 type string.
883 - A value >=0 means use the index (starting at zero) of the split line based
884 on `delimiter`.
886 Args:
887 filename: The filename of the text file to be used for initialization. The
888 path must be accessible from wherever the graph is initialized (eg.
889 trainer or eval workers). The filename may be a scalar `Tensor`.
890 key_column_index: The column index from the text file to get the `key`
891 values from. The default is to use the whole line content.
892 value_column_index: The column index from the text file to get the `value`
893 values from. The default is to use the line number, starting from zero.
894 vocab_size: The number of elements in the file, if known.
895 delimiter: The delimiter to separate fields in a line.
896 name: Optional name for the op.
897 key_dtype: The `key` data type.
899 Raises:
900 TypeError: when the filename is empty, or when the table key and value
901 data types do not match the expected data types.
902 """
903 super(TextFileIdTableInitializer, self).__init__(
904 filename,
905 key_dtype,
906 key_column_index,
907 dtypes.int64,
908 value_column_index,
909 vocab_size=vocab_size,
910 delimiter=delimiter,
911 name=name)
914class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
915 """A structure for the spec of the hashing function to use for hash buckets.
917 `hasher` is the name of the hashing function to use (eg. "fasthash",
918 "stronghash").
919 `key` is optional and specify the key to use for the hash function if
920 supported, currently only used by a strong hash.
922 Fields:
923 hasher: The hasher name to use.
924 key: The key to be used by the hashing function, if required.
925 """
926 __slots__ = ()
929FastHashSpec = HasherSpec("fasthash", None) # pylint: disable=invalid-name
932class StrongHashSpec(HasherSpec):
933 """A structure to specify a key of the strong keyed hash spec.
935 The strong hash requires a `key`, which is a list of 2 unsigned integer
936 numbers. These should be non-zero; random numbers generated from random.org
937 would be a fine choice.
939 Fields:
940 key: The key to be used by the keyed hashing function.
941 """
942 __slots__ = ()
944 def __new__(cls, key):
945 if len(key) != 2:
946 raise ValueError(f"`key` must have size 2, received {len(key)}")
948 if not isinstance(key[0], compat_util.integral_types) or not isinstance(
949 key[1], compat_util.integral_types):
950 raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
952 return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
955def _as_string(tensor):
956 if dtypes.string == tensor.dtype.base_dtype:
957 return tensor
958 return string_ops.as_string(tensor)
961class IdTableWithHashBuckets(LookupInterface):
962 r"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
964 For example, if an instance of `IdTableWithHashBuckets` is initialized with a
965 string-to-id table that maps:
967 * `emerson -> 0`
968 * `lake -> 1`
969 * `palmer -> 2`
971 The `IdTableWithHashBuckets` object will performs the following mapping:
973 * `emerson -> 0`
974 * `lake -> 1`
975 * `palmer -> 2`
976 * `<other term> -> bucket_id`, where bucket_id will be between `3` and
977 `3 + num_oov_buckets - 1`, calculated by:
978 `hash(<term>) % num_oov_buckets + vocab_size`
980 If input_tensor is `["emerson", "lake", "palmer", "king", "crimson"]`,
981 the lookup result is `[0, 1, 2, 4, 7]`.
983 If `table` is None, only out-of-vocabulary buckets are used.
985 Example usage:
987 ```python
988 num_oov_buckets = 3
989 input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
990 table = tf.IdTableWithHashBuckets(
991 tf.StaticHashTable(
992 tf.lookup.TextFileInitializer(
993 filename,
994 key_dtype=tf.string,
995 key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
996 value_dtype=tf.int64,
997 value_index=tf.lookup.TextFileIndex.LINE_NUMBER,
998 delimiter="\t"),
999 default_value),
1000 num_oov_buckets)
1001 out = table.lookup(input_tensor).
1002 table.init.run()
1003 print(out.eval())
1004 ```
1006 The hash function used for generating out-of-vocabulary buckets ID is handled
1007 by `hasher_spec`.
1008 """
1010 def __init__(self,
1011 table,
1012 num_oov_buckets,
1013 hasher_spec=FastHashSpec,
1014 name=None,
1015 key_dtype=None):
1016 """Construct a `IdTableWithHashBuckets` object.
1018 Args:
1019 table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
1020 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
1021 hasher_spec: A `HasherSpec` to specify the hash function to use for
1022 assignation of out-of-vocabulary buckets (optional).
1023 name: A name for the operation (optional).
1024 key_dtype: Data type of keys passed to `lookup`. Defaults to
1025 `table.key_dtype` if `table` is specified, otherwise `tf.string`. Must
1026 be string or integer, and must be castable to `table.key_dtype`.
1028 Raises:
1029 ValueError: when `table` in None and `num_oov_buckets` is not positive.
1030 TypeError: when `hasher_spec` is invalid.
1031 """
1032 # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1033 # characters to use as table name.
1034 if name:
1035 name = name.rstrip("/")
1036 if table:
1037 if key_dtype is None:
1038 key_dtype = table.key_dtype
1039 supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1040 if table.key_dtype not in supported_table_key_dtypes:
1041 raise TypeError("Invalid `key_dtype`, expected one of "
1042 f"{supported_table_key_dtypes}, received {key_dtype}.")
1043 if table.key_dtype.is_integer != key_dtype.is_integer:
1044 raise TypeError("Invalid `key dtype`, expected %s but got %s." %
1045 ("integer" if key_dtype.is_integer else "non-integer",
1046 table.key_dtype))
1047 if table.value_dtype != dtypes.int64:
1048 raise TypeError("Invalid `value_dtype`: expected int64 but got %s." %
1049 (table.value_dtype))
1050 self._table = table
1051 name = name or self._table.name
1052 else:
1053 if num_oov_buckets <= 0:
1054 raise ValueError("`oov_buckets` must be > 0 if no `table` is supplied.")
1055 key_dtype = dtypes.string if key_dtype is None else key_dtype
1056 self._table = None
1057 name = name or "hash_bucket"
1058 if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
1059 raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1060 f"{key_dtype}.")
1061 self._num_oov_buckets = num_oov_buckets
1063 if not isinstance(hasher_spec, HasherSpec):
1064 raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1065 f"{type(hasher_spec)}.")
1066 self._hasher_spec = hasher_spec
1067 if name:
1068 self._table_name = name.split("/")[-1]
1069 else:
1070 self._table_name = None
1071 super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64)
1073 def _create_resource(self):
1074 if self._table is not None:
1075 return self._table._create_resource() # pylint: disable=protected-access
1076 return None
1078 def _initialize(self):
1079 if self._table is not None:
1080 return self._table._initialize() # pylint: disable=protected-access
1081 with ops.name_scope(None, "init"):
1082 return control_flow_ops.no_op()
1084 @property
1085 def initializer(self):
1086 if self._table is not None:
1087 return self._table._init_op # pylint: disable=protected-access
1088 with ops.name_scope(None, "init"):
1089 return control_flow_ops.no_op()
1091 @property
1092 @deprecated("2018-12-15", "Use `initializer` instead.")
1093 def init(self):
1094 return self.initializer
1096 @property
1097 def resource_handle(self):
1098 if self._table is not None:
1099 return self._table.resource_handle
1100 return None
1102 @property
1103 def name(self):
1104 return self._table_name
1106 def size(self, name=None):
1107 """Compute the number of elements in this table."""
1108 with ops.name_scope(name, "%s_Size" % self.name):
1109 if self._table:
1110 tsize = self._table.size()
1111 else:
1112 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1113 return tsize + self._num_oov_buckets
1115 def _get_string_to_hash_bucket_fn(self, hasher_spec):
1116 """Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
1117 if not isinstance(hasher_spec, HasherSpec):
1118 raise TypeError("`hasher_spec` must be of type HasherSpec, got "
1119 f"{type(hasher_spec)}.")
1120 if hasher_spec.hasher == "fasthash":
1121 return string_ops.string_to_hash_bucket_fast
1122 if hasher_spec.hasher == "legacy":
1123 return string_ops.string_to_hash_bucket
1124 if hasher_spec.hasher == "stronghash":
1125 return functools.partial(
1126 string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
1127 raise ValueError(
1128 f"Found unknown hasher {hasher_spec.hasher} in `hasher_spec`")
1130 def lookup(self, keys, name=None):
1131 """Looks up `keys` in the table, outputs the corresponding values.
1133 It assigns out-of-vocabulary keys to buckets based in their hashes.
1135 Args:
1136 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1137 name: Optional name for the op.
1139 Returns:
1140 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1141 otherwise a dense `Tensor`.
1143 Raises:
1144 TypeError: when `keys` doesn't match the table key data type.
1145 """
1146 if keys.dtype.base_dtype != self._key_dtype:
1147 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1148 f"received: {keys.dtype}")
1149 values = keys
1150 if isinstance(keys,
1151 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1152 values = keys.values
1153 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1154 values = math_ops.cast(values, dtypes.int64)
1156 if self._num_oov_buckets == 0:
1157 ids = self._table.lookup(values, name=name)
1158 else:
1159 # TODO(yleon): Consider moving this functionality to its own kernel.
1160 with ops.name_scope(name, "%s_Lookup" % self.name):
1161 str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
1162 self._hasher_spec)
1163 buckets = str_to_hash_bucket(
1164 _as_string(values),
1165 num_buckets=self._num_oov_buckets,
1166 name="hash_bucket")
1167 if self._table:
1168 ids = self._table.lookup(values)
1169 buckets = math_ops.add(buckets, self._table.size())
1170 is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1171 ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1172 else:
1173 ids = buckets
1174 if isinstance(keys, sparse_tensor.SparseTensor):
1175 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1176 elif isinstance(keys, ragged_tensor.RaggedTensor):
1177 return keys.with_values(ids)
1178 return ids
1181@tf_export("lookup.StaticVocabularyTable", v1=[])
1182class StaticVocabularyTable(LookupInterface):
1183 r"""String to Id table that assigns out-of-vocabulary keys to hash buckets.
1185 For example, if an instance of `StaticVocabularyTable` is initialized with a
1186 string-to-id initializer that maps:
1188 >>> init = tf.lookup.KeyValueTensorInitializer(
1189 ... keys=tf.constant(['emerson', 'lake', 'palmer']),
1190 ... values=tf.constant([0, 1, 2], dtype=tf.int64))
1191 >>> table = tf.lookup.StaticVocabularyTable(
1192 ... init,
1193 ... num_oov_buckets=5)
1195 The `Vocabulary` object will performs the following mapping:
1197 * `emerson -> 0`
1198 * `lake -> 1`
1199 * `palmer -> 2`
1200 * `<other term> -> bucket_id`, where `bucket_id` will be between `3` and
1201 `3 + num_oov_buckets - 1 = 7`, calculated by:
1202 `hash(<term>) % num_oov_buckets + vocab_size`
1204 If input_tensor is:
1206 >>> input_tensor = tf.constant(["emerson", "lake", "palmer",
1207 ... "king", "crimson"])
1208 >>> table[input_tensor].numpy()
1209 array([0, 1, 2, 6, 7])
1211 If `initializer` is None, only out-of-vocabulary buckets are used.
1213 Example usage:
1215 >>> num_oov_buckets = 3
1216 >>> vocab = ["emerson", "lake", "palmer", "crimnson"]
1217 >>> import tempfile
1218 >>> f = tempfile.NamedTemporaryFile(delete=False)
1219 >>> f.write('\n'.join(vocab).encode('utf-8'))
1220 >>> f.close()
1222 >>> init = tf.lookup.TextFileInitializer(
1223 ... f.name,
1224 ... key_dtype=tf.string, key_index=tf.lookup.TextFileIndex.WHOLE_LINE,
1225 ... value_dtype=tf.int64, value_index=tf.lookup.TextFileIndex.LINE_NUMBER)
1226 >>> table = tf.lookup.StaticVocabularyTable(init, num_oov_buckets)
1227 >>> table.lookup(tf.constant(["palmer", "crimnson" , "king",
1228 ... "tarkus", "black", "moon"])).numpy()
1229 array([2, 3, 5, 6, 6, 4])
1231 The hash function used for generating out-of-vocabulary buckets ID is
1232 Fingerprint64.
1234 Note that the out-of-vocabulary bucket IDs always range from the table `size`
1235 up to `size + num_oov_buckets - 1` regardless of the table values, which could
1236 cause unexpected collisions:
1238 >>> init = tf.lookup.KeyValueTensorInitializer(
1239 ... keys=tf.constant(["emerson", "lake", "palmer"]),
1240 ... values=tf.constant([1, 2, 3], dtype=tf.int64))
1241 >>> table = tf.lookup.StaticVocabularyTable(
1242 ... init,
1243 ... num_oov_buckets=1)
1244 >>> input_tensor = tf.constant(["emerson", "lake", "palmer", "king"])
1245 >>> table[input_tensor].numpy()
1246 array([1, 2, 3, 3])
1247 """
1249 def __init__(self,
1250 initializer,
1251 num_oov_buckets,
1252 lookup_key_dtype=None,
1253 name=None,
1254 experimental_is_anonymous=False):
1255 """Construct a `StaticVocabularyTable` object.
1257 Args:
1258 initializer: A `TableInitializerBase` object that contains the data used
1259 to initialize the table. If None, then we only use out-of-vocab buckets.
1260 num_oov_buckets: Number of buckets to use for out-of-vocabulary keys. Must
1261 be greater than zero. If out-of-vocab buckets are not required, use
1262 `StaticHashTable` instead.
1263 lookup_key_dtype: Data type of keys passed to `lookup`. Defaults to
1264 `initializer.key_dtype` if `initializer` is specified, otherwise
1265 `tf.string`. Must be string or integer, and must be castable to
1266 `initializer.key_dtype`.
1267 name: A name for the operation (optional).
1268 experimental_is_anonymous: Whether to use anonymous mode for the
1269 table (default is False). In anonymous mode, the table
1270 resource can only be accessed via a resource handle. It can't
1271 be looked up by a name. When all resource handles pointing to
1272 that resource are gone, the resource will be deleted
1273 automatically.
1275 Raises:
1276 ValueError: when `num_oov_buckets` is not positive.
1277 TypeError: when lookup_key_dtype or initializer.key_dtype are not
1278 integer or string. Also when initializer.value_dtype != int64.
1279 """
1280 if num_oov_buckets <= 0:
1281 raise ValueError("`num_oov_buckets` must be > 0; use StaticHashTable.")
1282 # If a name ends with a '/' it is a "name scope", remove all trailing '/'
1283 # characters to use as table name.
1284 if name:
1285 name = name.rstrip("/")
1286 if initializer:
1287 if lookup_key_dtype is None:
1288 lookup_key_dtype = initializer.key_dtype
1289 supported_table_key_dtypes = (dtypes.int64, dtypes.string)
1290 if initializer.key_dtype not in supported_table_key_dtypes:
1291 raise TypeError("Invalid `key_dtype`, expected one of %s, but got %s." %
1292 (supported_table_key_dtypes, initializer.key_dtype))
1293 if initializer.key_dtype.is_integer != lookup_key_dtype.is_integer:
1294 raise TypeError(
1295 "Invalid `key_dtype`, expected %s but got %s." %
1296 ("integer" if lookup_key_dtype.is_integer else "non-integer",
1297 initializer.key_dtype))
1298 if initializer.value_dtype != dtypes.int64:
1299 raise TypeError("Invalid `value_dtype`, expected %s but got %s." %
1300 (dtypes.int64, initializer.value_dtype))
1301 if isinstance(initializer, trackable_base.Trackable):
1302 self._initializer = self._track_trackable(initializer, "_initializer")
1303 self._table = HashTable(
1304 initializer,
1305 default_value=-1,
1306 experimental_is_anonymous=experimental_is_anonymous)
1307 name = name or self._table.name
1308 else:
1309 lookup_key_dtype = dtypes.string
1310 self._table = None
1311 name = name or "hash_bucket"
1312 if (not lookup_key_dtype.is_integer) and (dtypes.string !=
1313 lookup_key_dtype):
1314 raise TypeError("Invalid `key_dtype`, expected integer or string, got "
1315 f"{lookup_key_dtype}")
1316 self._num_oov_buckets = num_oov_buckets
1318 self._table_name = None
1319 if name is not None:
1320 self._table_name = name.split("/")[-1]
1321 super(StaticVocabularyTable, self).__init__(lookup_key_dtype, dtypes.int64)
1323 def _create_resource(self):
1324 if self._table is not None:
1325 return self._table._create_resource() # pylint: disable=protected-access
1326 return None
1328 def _initialize(self):
1329 if self._table is not None:
1330 return self._table._initialize() # pylint: disable=protected-access
1331 with ops.name_scope(None, "init"):
1332 return control_flow_ops.no_op()
1334 @property
1335 def resource_handle(self):
1336 if self._table is not None:
1337 return self._table.resource_handle
1338 return None
1340 @property
1341 def name(self):
1342 return self._table_name
1344 def size(self, name=None):
1345 """Compute the number of elements in this table."""
1346 with ops.name_scope(name, "%s_Size" % self.name):
1347 if self._table:
1348 tsize = self._table.size()
1349 else:
1350 tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
1351 return tsize + self._num_oov_buckets
1353 def lookup(self, keys, name=None):
1354 """Looks up `keys` in the table, outputs the corresponding values.
1356 It assigns out-of-vocabulary keys to buckets based in their hashes.
1358 Args:
1359 keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
1360 name: Optional name for the op.
1362 Returns:
1363 A `SparseTensor` if keys are sparse, a `RaggedTensor` if keys are ragged,
1364 otherwise a dense `Tensor`.
1366 Raises:
1367 TypeError: when `keys` doesn't match the table key data type.
1368 """
1369 if keys.dtype.base_dtype != self._key_dtype:
1370 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1371 f"received: {keys.dtype}")
1372 values = keys
1373 if isinstance(keys,
1374 (sparse_tensor.SparseTensor, ragged_tensor.RaggedTensor)):
1375 values = keys.values
1376 if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
1377 values = math_ops.cast(values, dtypes.int64)
1379 # TODO(yleon): Consider moving this functionality to its own kernel.
1380 with ops.name_scope(name, "%s_Lookup" % self.name):
1381 buckets = string_ops.string_to_hash_bucket_fast(
1382 _as_string(values),
1383 num_buckets=self._num_oov_buckets,
1384 name="hash_bucket")
1385 if self._table:
1386 ids = self._table.lookup(values)
1387 buckets = math_ops.add(buckets, self._table.size())
1388 is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
1389 ids = array_ops.where_v2(is_id_non_default, ids, buckets)
1390 else:
1391 ids = buckets
1392 if isinstance(keys, sparse_tensor.SparseTensor):
1393 return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
1394 elif isinstance(keys, ragged_tensor.RaggedTensor):
1395 return keys.with_values(ids)
1396 return ids
1399@tf_export(v1=["lookup.StaticVocabularyTable"])
1400class StaticVocabularyTableV1(StaticVocabularyTable):
1402 @property
1403 def initializer(self):
1404 if self._table is not None:
1405 return self._table._init_op # pylint: disable=protected-access
1406 with ops.name_scope(None, "init"):
1407 return control_flow_ops.no_op()
1410def index_table_from_file(vocabulary_file=None,
1411 num_oov_buckets=0,
1412 vocab_size=None,
1413 default_value=-1,
1414 hasher_spec=FastHashSpec,
1415 key_dtype=dtypes.string,
1416 name=None,
1417 key_column_index=TextFileIndex.WHOLE_LINE,
1418 value_column_index=TextFileIndex.LINE_NUMBER,
1419 delimiter="\t"):
1420 """Returns a lookup table that converts a string tensor into int64 IDs.
1422 This operation constructs a lookup table to convert tensor of strings into
1423 int64 IDs. The mapping can be initialized from a vocabulary file specified in
1424 `vocabulary_file`, where the whole line is the key and the zero-based line
1425 number is the ID.
1427 Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1428 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1429 `default_value`.
1430 The bucket ID range is
1431 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
1433 The underlying table must be initialized by calling
1434 `session.run(tf.compat.v1.tables_initializer())` or
1435 `session.run(table.init())` once.
1437 To specify multi-column vocabulary files, use key_column_index and
1438 value_column_index and delimiter.
1440 - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1441 expects data type int64.
1442 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1443 type string.
1444 - A value >=0 means use the index (starting at zero) of the split line based
1445 on `delimiter`.
1447 Sample Usages:
1449 If we have a vocabulary file "test.txt" with the following content:
1451 ```
1452 emerson
1453 lake
1454 palmer
1455 ```
1457 ```python
1458 features = tf.constant(["emerson", "lake", "and", "palmer"])
1459 table = tf.lookup.index_table_from_file(
1460 vocabulary_file="test.txt", num_oov_buckets=1)
1461 ids = table.lookup(features)
1462 ...
1463 tf.compat.v1.tables_initializer().run()
1465 ids.eval() ==> [0, 1, 3, 2] # where 3 is the out-of-vocabulary bucket
1466 ```
1468 Args:
1469 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1470 num_oov_buckets: The number of out-of-vocabulary buckets.
1471 vocab_size: Number of the elements in the vocabulary, if known.
1472 default_value: The value to use for out-of-vocabulary feature values.
1473 Defaults to -1.
1474 hasher_spec: A `HasherSpec` to specify the hash function to use for
1475 assignation of out-of-vocabulary buckets.
1476 key_dtype: The `key` data type.
1477 name: A name for this op (optional).
1478 key_column_index: The column index from the text file to get the `key`
1479 values from. The default is to use the whole line content.
1480 value_column_index: The column index from the text file to get the `value`
1481 values from. The default is to use the line number, starting from zero.
1482 delimiter: The delimiter to separate fields in a line.
1484 Returns:
1485 The lookup table to map a `key_dtype` `Tensor` to index `int64` `Tensor`.
1487 Raises:
1488 ValueError: If `vocabulary_file` is not set.
1489 ValueError: If `num_oov_buckets` is negative or `vocab_size` is not greater
1490 than zero.
1491 """
1492 if vocabulary_file is None or (isinstance(vocabulary_file, str) and
1493 not vocabulary_file):
1494 raise ValueError(
1495 "`vocabulary_file` must be specified and must not be empty.")
1496 if num_oov_buckets < 0:
1497 raise ValueError(
1498 "num_oov_buckets must be greater or equal than 0, got %d." %
1499 num_oov_buckets)
1500 if vocab_size is not None and vocab_size < 1:
1501 vocab_file_value = vocabulary_file
1502 if isinstance(vocabulary_file, ops.Tensor):
1503 vocab_file_value = tensor_util.constant_value(vocabulary_file) or "?"
1504 raise ValueError("`vocab_size` must be greater than 0, got %d for "
1505 "vocabulary_file: %s." % (vocab_size, vocab_file_value))
1506 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
1507 raise TypeError("Dtype for `keys` should be either integer or string.")
1509 with ops.name_scope(name, "string_to_index"):
1510 table = None
1511 with ops.name_scope(None, "hash_table"):
1512 init = TextFileIdTableInitializer(
1513 vocabulary_file,
1514 vocab_size=vocab_size,
1515 key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
1516 name="table_init",
1517 key_column_index=key_column_index,
1518 value_column_index=value_column_index,
1519 delimiter=delimiter)
1521 table = StaticHashTableV1(init, default_value)
1522 if num_oov_buckets:
1523 table = IdTableWithHashBuckets(
1524 table,
1525 num_oov_buckets=num_oov_buckets,
1526 hasher_spec=hasher_spec,
1527 key_dtype=key_dtype)
1529 return table
1532def index_table_from_tensor(vocabulary_list,
1533 num_oov_buckets=0,
1534 default_value=-1,
1535 hasher_spec=FastHashSpec,
1536 dtype=dtypes.string,
1537 name=None):
1538 """Returns a lookup table that converts a string tensor into int64 IDs.
1540 This operation constructs a lookup table to convert tensor of strings into
1541 int64 IDs. The mapping can be initialized from a string `vocabulary_list` 1-D
1542 tensor where each element is a key and corresponding index within the tensor
1543 is the value.
1545 Any lookup of an out-of-vocabulary token will return a bucket ID based on its
1546 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
1547 `default_value`. The bucket ID range is
1548 `[vocabulary list size, vocabulary list size + num_oov_buckets - 1]`.
1550 The underlying table must be initialized by calling
1551 `session.run(tf.compat.v1.tables_initializer())` or
1552 `session.run(table.init())` once.
1554 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1555 the table initializer op, it will throw a `FailedPreconditionError`.
1557 Sample Usages:
1559 ```python
1560 vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1561 table = tf.lookup.index_table_from_tensor(
1562 vocabulary_list=vocabulary_list, num_oov_buckets=1, default_value=-1)
1563 features = tf.constant(["emerson", "lake", "and", "palmer"])
1564 ids = table.lookup(features)
1565 ...
1566 tf.compat.v1.tables_initializer().run()
1568 ids.eval() ==> [0, 1, 4, 2]
1569 ```
1571 Args:
1572 vocabulary_list: A 1-D `Tensor` that specifies the mapping of keys to
1573 indices. The type of this object must be castable to `dtype`.
1574 num_oov_buckets: The number of out-of-vocabulary buckets.
1575 default_value: The value to use for out-of-vocabulary feature values.
1576 Defaults to -1.
1577 hasher_spec: A `HasherSpec` to specify the hash function to use for
1578 assignment of out-of-vocabulary buckets.
1579 dtype: The type of values passed to `lookup`. Only string and integers are
1580 supported.
1581 name: A name for this op (optional).
1583 Returns:
1584 The lookup table to map an input `Tensor` to index `int64` `Tensor`.
1586 Raises:
1587 ValueError: If `vocabulary_list` is invalid.
1588 ValueError: If `num_oov_buckets` is negative.
1589 """
1590 if vocabulary_list is None:
1591 raise ValueError("`vocabulary_list` must be specified.")
1593 if num_oov_buckets < 0:
1594 raise ValueError(
1595 "`num_oov_buckets` must be greater or equal than 0, got %d." %
1596 num_oov_buckets)
1598 if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
1599 raise TypeError("`dtype` must either be integer or string.")
1601 with ops.name_scope(name, "string_to_index"):
1602 keys = ops.convert_to_tensor(vocabulary_list)
1603 if keys.dtype.is_integer != dtype.is_integer:
1604 raise ValueError(
1605 "Invalid `dtype`: Expected %s, got %s." %
1606 ("integer" if dtype.is_integer else "non-integer", keys.dtype))
1607 if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
1608 raise ValueError("Invalid `dtype`: Expected %s, got %s." %
1609 (dtype, keys.dtype))
1610 num_elements = array_ops.size(keys)
1611 values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1613 with ops.name_scope(None, "hash_table"):
1614 table_keys = math_ops.cast(
1615 keys, dtypes.int64) if keys.dtype.is_integer else keys
1616 init = KeyValueTensorInitializer(
1617 table_keys,
1618 values,
1619 table_keys.dtype.base_dtype,
1620 dtypes.int64,
1621 name="table_init")
1622 table = StaticHashTableV1(init, default_value)
1623 if num_oov_buckets:
1624 table = IdTableWithHashBuckets(
1625 table,
1626 num_oov_buckets=num_oov_buckets,
1627 hasher_spec=hasher_spec,
1628 key_dtype=dtype)
1629 return table
1632def index_to_string_table_from_file(vocabulary_file,
1633 vocab_size=None,
1634 default_value="UNK",
1635 name=None,
1636 key_column_index=TextFileIndex.LINE_NUMBER,
1637 value_column_index=TextFileIndex.WHOLE_LINE,
1638 delimiter="\t"):
1639 """Returns a lookup table that maps a `Tensor` of indices into strings.
1641 This operation constructs a lookup table to map int64 indices into string
1642 values. The table is initialized from a vocabulary file specified in
1643 `vocabulary_file`, where the whole line is the value and the
1644 zero-based line number is the index.
1646 Any input which does not have a corresponding index in the vocabulary file
1647 (an out-of-vocabulary entry) is assigned the `default_value`
1649 The underlying table must be initialized by calling
1650 `session.run(tf.compat.v1.tables_initializer())` or
1651 `session.run(table.init())` once.
1653 To specify multi-column vocabulary files, use key_column_index and
1654 value_column_index and delimiter.
1656 - TextFileIndex.LINE_NUMBER means use the line number starting from zero,
1657 expects data type int64.
1658 - TextFileIndex.WHOLE_LINE means use the whole line content, expects data
1659 type string.
1660 - A value >=0 means use the index (starting at zero) of the split line based
1661 on `delimiter`.
1663 Sample Usages:
1665 If we have a vocabulary file "test.txt" with the following content:
1667 ```
1668 emerson
1669 lake
1670 palmer
1671 ```
1673 ```python
1674 indices = tf.constant([1, 5], tf.int64)
1675 table = tf.lookup.index_to_string_table_from_file(
1676 vocabulary_file="test.txt", default_value="UNKNOWN")
1677 values = table.lookup(indices)
1678 ...
1679 tf.compat.v1.tables_initializer().run()
1681 values.eval() ==> ["lake", "UNKNOWN"]
1682 ```
1684 Args:
1685 vocabulary_file: The vocabulary filename, may be a constant scalar `Tensor`.
1686 vocab_size: Number of the elements in the vocabulary, if known.
1687 default_value: The value to use for out-of-vocabulary indices.
1688 name: A name for this op (optional).
1689 key_column_index: The column index from the text file to get the `key`
1690 values from. The default is to use the line number, starting from zero.
1691 value_column_index: The column index from the text file to get the `value`
1692 values from. The default is to use the whole line content.
1693 delimiter: The delimiter to separate fields in a line.
1695 Returns:
1696 The lookup table to map a string values associated to a given index `int64`
1697 `Tensors`.
1699 Raises:
1700 ValueError: when `vocabulary_file` is empty.
1701 ValueError: when `vocab_size` is invalid.
1702 """
1703 if vocabulary_file is None or (isinstance(vocabulary_file, str) and
1704 not vocabulary_file):
1705 raise ValueError(
1706 "`vocabulary_file` must be specified and must not be empty.")
1708 if vocab_size is not None and vocab_size < 1:
1709 raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
1711 with ops.name_scope(name, "index_to_string"):
1712 init = TextFileStringTableInitializer(
1713 vocabulary_file,
1714 vocab_size=vocab_size,
1715 name="table_init",
1716 key_column_index=key_column_index,
1717 value_column_index=value_column_index,
1718 delimiter=delimiter)
1720 # TODO(yleon): Use a more efficient structure.
1721 return StaticHashTableV1(init, default_value)
1724def index_to_string_table_from_tensor(vocabulary_list,
1725 default_value="UNK",
1726 name=None):
1727 """Returns a lookup table that maps a `Tensor` of indices into strings.
1729 This operation constructs a lookup table to map int64 indices into string
1730 values. The mapping is initialized from a string `vocabulary_list` 1-D
1731 `Tensor` where each element is a value and the corresponding index within the
1732 tensor is the key.
1734 Any input which does not have a corresponding index in 'vocabulary_list'
1735 (an out-of-vocabulary entry) is assigned the `default_value`
1737 The underlying table must be initialized by calling
1738 `session.run(tf.compat.v1.tables_initializer())` or
1739 `session.run(table.init())` once.
1741 Elements in `vocabulary_list` cannot have duplicates, otherwise when executing
1742 the table initializer op, it will throw a `FailedPreconditionError`.
1744 Sample Usages:
1746 ```python
1747 vocabulary_list = tf.constant(["emerson", "lake", "palmer"])
1748 indices = tf.constant([1, 5], tf.int64)
1749 table = tf.lookup.index_to_string_table_from_tensor(
1750 vocabulary_list, default_value="UNKNOWN")
1751 values = table.lookup(indices)
1752 ...
1753 tf.compat.v1.tables_initializer().run()
1755 values.eval() ==> ["lake", "UNKNOWN"]
1756 ```
1758 Args:
1759 vocabulary_list: A 1-D string `Tensor` that specifies the strings to map
1760 from indices.
1761 default_value: The value to use for out-of-vocabulary indices.
1762 name: A name for this op (optional).
1764 Returns:
1765 The lookup table to map a string values associated to a given index `int64`
1766 `Tensors`.
1768 Raises:
1769 ValueError: when `vocabulary_list` is not set.
1770 """
1772 if vocabulary_list is None:
1773 raise ValueError("`vocabulary_list` argument must be specified.")
1775 with ops.name_scope(name, "index_to_string"):
1776 vocabulary_list = ops.convert_to_tensor(vocabulary_list, dtypes.string)
1777 num_elements = array_ops.size(vocabulary_list)
1778 keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
1780 init = KeyValueTensorInitializer(
1781 keys, vocabulary_list, dtypes.int64, dtypes.string, name="table_init")
1782 # TODO(yleon): Use a more efficient structure.
1783 return StaticHashTableV1(init, default_value)
1786@tf_export("lookup.experimental.MutableHashTable")
1787@saveable_compat.legacy_saveable_name("table")
1788class MutableHashTable(LookupInterface):
1789 """A generic mutable hash table implementation.
1791 Data can be inserted by calling the `insert` method and removed by calling the
1792 `remove` method. It does not support initialization via the init method.
1794 `MutableHashTable` requires additional memory during checkpointing and restore
1795 operations to create temporary key and value tensors.
1797 Example usage:
1799 >>> table = tf.lookup.experimental.MutableHashTable(key_dtype=tf.string,
1800 ... value_dtype=tf.int64,
1801 ... default_value=-1)
1802 >>> keys_tensor = tf.constant(['a', 'b', 'c'])
1803 >>> vals_tensor = tf.constant([7, 8, 9], dtype=tf.int64)
1804 >>> input_tensor = tf.constant(['a', 'f'])
1805 >>> table.insert(keys_tensor, vals_tensor)
1806 >>> table.lookup(input_tensor).numpy()
1807 array([ 7, -1])
1808 >>> table.remove(tf.constant(['c']))
1809 >>> table.lookup(keys_tensor).numpy()
1810 array([ 7, 8, -1])
1811 >>> sorted(table.export()[0].numpy())
1812 [b'a', b'b']
1813 >>> sorted(table.export()[1].numpy())
1814 [7, 8]
1815 """
1817 def __init__(self,
1818 key_dtype,
1819 value_dtype,
1820 default_value,
1821 name="MutableHashTable",
1822 checkpoint=True,
1823 experimental_is_anonymous=False):
1824 """Creates an empty `MutableHashTable` object.
1826 Creates a table, the type of its keys and values are specified by key_dtype
1827 and value_dtype, respectively.
1829 Args:
1830 key_dtype: the type of the key tensors.
1831 value_dtype: the type of the value tensors.
1832 default_value: The value to use if a key is missing in the table.
1833 name: A name for the operation (optional).
1834 checkpoint: if True, the contents of the table are saved to and restored
1835 from checkpoints. If `shared_name` is empty for a checkpointed table, it
1836 is shared using the table node name.
1837 experimental_is_anonymous: Whether to use anonymous mode for the
1838 table (default is False). In anonymous mode, the table
1839 resource can only be accessed via a resource handle. It can't
1840 be looked up by a name. When all resource handles pointing to
1841 that resource are gone, the resource will be deleted
1842 automatically.
1844 Returns:
1845 A `MutableHashTable` object.
1847 Raises:
1848 ValueError: If checkpoint is True and no name was specified.
1849 """
1850 self._default_value = ops.convert_to_tensor(
1851 default_value, dtype=value_dtype)
1852 self._value_shape = self._default_value.get_shape()
1853 self._checkpoint = checkpoint
1854 self._key_dtype = key_dtype
1855 self._value_dtype = value_dtype
1856 self._name = name
1857 self._is_anonymous = experimental_is_anonymous
1858 if not self._is_anonymous:
1859 self._shared_name = None
1860 if context.executing_eagerly():
1861 # TODO(allenl): This will leak memory due to kernel caching by
1862 # the shared_name attribute value (but is better than the
1863 # alternative of sharing everything by default when executing
1864 # eagerly; hopefully creating tables in a loop is uncommon).
1865 self._shared_name = "table_%d" % (ops.uid(),)
1866 super(MutableHashTable, self).__init__(key_dtype, value_dtype)
1867 self._resource_handle = self._create_resource()
1868 if checkpoint:
1869 saveable = MutableHashTable._Saveable(self, name)
1870 if not context.executing_eagerly():
1871 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
1873 def _create_resource(self):
1874 if self._is_anonymous:
1875 if self._default_value.get_shape().ndims == 0:
1876 table_ref = gen_lookup_ops.anonymous_mutable_hash_table(
1877 key_dtype=self._key_dtype,
1878 value_dtype=self._value_dtype,
1879 name=self._name)
1880 else:
1881 table_ref = gen_lookup_ops.anonymous_mutable_hash_table_of_tensors(
1882 key_dtype=self._key_dtype,
1883 value_dtype=self._value_dtype,
1884 value_shape=self._default_value.get_shape(),
1885 name=self._name)
1886 else:
1887 # The table must be shared if checkpointing is requested for multi-worker
1888 # training to work correctly. Use the node name if no shared_name has been
1889 # explicitly specified.
1890 use_node_name_sharing = self._checkpoint and self._shared_name is None
1891 if self._default_value.get_shape().ndims == 0:
1892 table_ref = gen_lookup_ops.mutable_hash_table_v2(
1893 shared_name=self._shared_name,
1894 use_node_name_sharing=use_node_name_sharing,
1895 key_dtype=self._key_dtype,
1896 value_dtype=self._value_dtype,
1897 name=self._name)
1898 else:
1899 table_ref = gen_lookup_ops.mutable_hash_table_of_tensors_v2(
1900 shared_name=self._shared_name,
1901 use_node_name_sharing=use_node_name_sharing,
1902 key_dtype=self._key_dtype,
1903 value_dtype=self._value_dtype,
1904 value_shape=self._default_value.get_shape(),
1905 name=self._name)
1907 if context.executing_eagerly():
1908 self._table_name = None
1909 else:
1910 self._table_name = table_ref.op.name.split("/")[-1]
1911 return table_ref
1913 @property
1914 def name(self):
1915 return self._table_name
1917 def size(self, name=None):
1918 """Compute the number of elements in this table.
1920 Args:
1921 name: A name for the operation (optional).
1923 Returns:
1924 A scalar tensor containing the number of elements in this table.
1925 """
1926 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
1927 with ops.colocate_with(self.resource_handle):
1928 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
1930 def remove(self, keys, name=None):
1931 """Removes `keys` and its associated values from the table.
1933 If a key is not present in the table, it is silently ignored.
1935 Args:
1936 keys: Keys to remove. Can be a tensor of any shape. Must match the table's
1937 key type.
1938 name: A name for the operation (optional).
1940 Returns:
1941 The created Operation.
1943 Raises:
1944 TypeError: when `keys` do not match the table data types.
1945 """
1946 if keys.dtype != self._key_dtype:
1947 raise TypeError(f"Dtype of argument `keys` must be {self._key_dtype}, "
1948 f"received: {keys.dtype}")
1950 with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
1951 (self.resource_handle, keys, self._default_value)):
1952 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
1954 return op
1956 def lookup(self, keys, dynamic_default_values=None, name=None):
1957 """Looks up `keys` in a table, outputs the corresponding values.
1959 The `default_value` is used for keys not present in the table.
1961 Args:
1962 keys: Keys to look up. Can be a tensor of any shape. Must match the
1963 table's key_dtype.
1964 dynamic_default_values: The values to use if a key is missing in the
1965 table. If None (by default), the `table.default_value` will be used.
1966 Shape of `dynamic_default_values` must be same with
1967 `table.default_value` or the lookup result tensor.
1968 In the latter case, each key will have a different default value.
1970 For example:
1972 ```python
1973 keys = [0, 1, 3]
1974 dynamic_default_values = [[1, 3, 4], [2, 3, 9], [8, 3, 0]]
1976 # The key '0' will use [1, 3, 4] as default value.
1977 # The key '1' will use [2, 3, 9] as default value.
1978 # The key '3' will use [8, 3, 0] as default value.
1979 ```
1981 name: A name for the operation (optional).
1983 Returns:
1984 A tensor containing the values in the same shape as `keys` using the
1985 table's value type.
1987 Raises:
1988 TypeError: when `keys` do not match the table data types.
1989 """
1990 with ops.name_scope(name, "%s_lookup_table_find" % self.name,
1991 (self.resource_handle, keys, self._default_value)):
1992 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
1993 with ops.colocate_with(self.resource_handle):
1994 values = gen_lookup_ops.lookup_table_find_v2(
1995 self.resource_handle, keys, dynamic_default_values
1996 if dynamic_default_values is not None else self._default_value)
1997 return values
1999 def insert(self, keys, values, name=None):
2000 """Associates `keys` with `values`.
2002 Args:
2003 keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2004 key type.
2005 values: Values to be associated with keys. Must be a tensor of the same
2006 shape as `keys` and match the table's value type.
2007 name: A name for the operation (optional).
2009 Returns:
2010 The created Operation.
2012 Raises:
2013 TypeError: when `keys` or `values` doesn't match the table data
2014 types.
2015 """
2016 with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2017 [self.resource_handle, keys, values]):
2018 keys = ops.convert_to_tensor(keys, self._key_dtype, name="keys")
2019 values = ops.convert_to_tensor(values, self._value_dtype, name="values")
2020 with ops.colocate_with(self.resource_handle):
2021 # pylint: disable=protected-access
2022 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2023 values)
2024 return op
2026 def export(self, name=None):
2027 """Returns tensors of all keys and values in the table.
2029 Args:
2030 name: A name for the operation (optional).
2032 Returns:
2033 A pair of tensors with the first tensor containing all keys and the
2034 second tensors containing all values in the table.
2035 """
2036 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2037 [self.resource_handle]):
2038 with ops.colocate_with(self.resource_handle):
2039 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2040 self.resource_handle, self._key_dtype, self._value_dtype)
2041 return exported_keys, exported_values
2043 def _serialize_to_tensors(self):
2044 """Implements checkpointing protocols for `Trackable`."""
2045 tensors = self.export()
2046 return {"-keys": tensors[0], "-values": tensors[1]}
2048 def _restore_from_tensors(self, restored_tensors):
2049 """Implements checkpointing protocols for `Trackable`."""
2050 with ops.name_scope("%s_table_restore" % self._name):
2051 with ops.colocate_with(self.resource_handle):
2052 return gen_lookup_ops.lookup_table_import_v2(
2053 self.resource_handle,
2054 restored_tensors["-keys"],
2055 restored_tensors["-values"])
2057 # This class is needed for `MutableHashTable(checkpoint=True)`.
2058 class _Saveable(BaseSaverBuilder.SaveableObject):
2059 """SaveableObject implementation for DenseHashTable."""
2061 def __init__(self, table, name, table_name=None):
2062 tensors = table.export()
2063 specs = [
2064 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2065 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2066 ]
2067 self.table_name = table_name or name
2068 # pylint: disable=protected-access
2069 super(MutableHashTable._Saveable, self).__init__(table, specs, name)
2071 def restore(self, restored_tensors, restored_shapes):
2072 del restored_shapes # unused
2073 # pylint: disable=protected-access
2074 with ops.name_scope("%s_table_restore" % self.table_name):
2075 with ops.colocate_with(self.op.resource_handle):
2076 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2077 restored_tensors[0],
2078 restored_tensors[1])
2081@tf_export("lookup.experimental.DenseHashTable")
2082@saveable_compat.legacy_saveable_name("table")
2083class DenseHashTable(LookupInterface):
2084 """A mutable hash table with faster lookups and higher memory usage.
2086 Data can be inserted by calling the `insert` method and removed by calling the
2087 `remove` method. It does not support initialization via the init method.
2089 Compared to `MutableHashTable`, `DenseHashTable` offers generally faster
2090 `insert`, `remove` and `lookup` operations, in exchange for a higher overall
2091 memory footprint.
2093 It uses "open addressing" with quadratic reprobing to resolve collisions. This
2094 requires specifying two keys in the key space, `empty_key` and `deleted_key`,
2095 that can never inserted into the table.
2097 Unlike `MutableHashTable`, `DenseHashTable` does not require additional memory
2098 for temporary tensors created during checkpointing and restore operations.
2100 Example usage:
2102 >>> table = tf.lookup.experimental.DenseHashTable(
2103 ... key_dtype=tf.string,
2104 ... value_dtype=tf.int64,
2105 ... default_value=-1,
2106 ... empty_key='',
2107 ... deleted_key='$')
2108 >>> keys = tf.constant(['a', 'b', 'c'])
2109 >>> values = tf.constant([0, 1, 2], dtype=tf.int64)
2110 >>> table.insert(keys, values)
2111 >>> table.remove(tf.constant(['c']))
2112 >>> table.lookup(tf.constant(['a', 'b', 'c','d'])).numpy()
2113 array([ 0, 1, -1, -1])
2114 """
2116 # TODO(andreasst): consider extracting common code with MutableHashTable into
2117 # a common superclass.
2118 def __init__(self,
2119 key_dtype,
2120 value_dtype,
2121 default_value,
2122 empty_key,
2123 deleted_key,
2124 initial_num_buckets=None,
2125 name="MutableDenseHashTable",
2126 checkpoint=True,
2127 experimental_is_anonymous=False):
2128 """Creates an empty `DenseHashTable` object.
2130 Creates a table, the type of its keys and values are specified by key_dtype
2131 and value_dtype, respectively.
2133 Args:
2134 key_dtype: the type of the key tensors.
2135 value_dtype: the type of the value tensors.
2136 default_value: The value to use if a key is missing in the table.
2137 empty_key: the key to use to represent empty buckets internally. Must not
2138 be used in insert, remove or lookup operations.
2139 deleted_key: the key to use to represent deleted buckets internally. Must
2140 not be used in insert, remove or lookup operations and be different from
2141 the empty_key.
2142 initial_num_buckets: the initial number of buckets (optional,
2143 default to 2^17=131072). Note that the default value is
2144 relatively large (~1MB), so if you are going to create many
2145 tables (likely the case when `experimental_is_anonymous` is
2146 `True`), you should set `initial_num_buckets` to a smaller
2147 value to reduce memory usage.
2148 name: A name for the operation (optional).
2149 checkpoint: if True, the contents of the table are saved to and restored
2150 from checkpoints. If `shared_name` is empty for a checkpointed table, it
2151 is shared using the table node name.
2152 experimental_is_anonymous: Whether to use anonymous mode for the
2153 table (default is False). In anonymous mode, the table
2154 resource can only be accessed via a resource handle. It can't
2155 be looked up by a name. When all resource handles pointing to
2156 that resource are gone, the resource will be deleted
2157 automatically.
2159 Returns:
2160 A `DenseHashTable` object.
2162 Raises:
2163 ValueError: If checkpoint is True and no name was specified.
2164 """
2165 self._default_value = ops.convert_to_tensor(
2166 default_value, dtype=value_dtype, name="default_value")
2167 self._key_dtype = key_dtype
2168 self._value_dtype = value_dtype
2169 # TODO(b/201578996): Pick a good default for initial_num_buckets
2170 # other than 2^17.
2171 self._initial_num_buckets = initial_num_buckets
2172 self._value_shape = self._default_value.get_shape()
2173 self._checkpoint = checkpoint
2174 self._name = name
2175 self._empty_key = empty_key
2176 self._deleted_key = deleted_key
2177 self._is_anonymous = experimental_is_anonymous
2178 if not self._is_anonymous:
2179 self._shared_name = None
2180 if context.executing_eagerly():
2181 # TODO(allenl): This will leak memory due to kernel caching by
2182 # the shared_name attribute value (but is better than the
2183 # alternative of sharing everything by default when executing
2184 # eagerly; hopefully creating tables in a loop is uncommon).
2185 self._shared_name = "table_%d" % (ops.uid(),)
2186 super(DenseHashTable, self).__init__(key_dtype, value_dtype)
2187 self._resource_handle = self._create_resource()
2188 if checkpoint:
2189 saveable = DenseHashTable._Saveable(self, name)
2190 if not context.executing_eagerly():
2191 ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
2193 def _create_resource(self):
2194 empty_key = ops.convert_to_tensor(
2195 self._empty_key, dtype=self._key_dtype, name="empty_key")
2196 deleted_key = ops.convert_to_tensor(
2197 self._deleted_key, dtype=self._key_dtype, name="deleted_key")
2198 if self._is_anonymous:
2199 table_ref = gen_lookup_ops.anonymous_mutable_dense_hash_table(
2200 empty_key=empty_key,
2201 deleted_key=deleted_key,
2202 value_dtype=self._value_dtype,
2203 value_shape=self._value_shape,
2204 initial_num_buckets=self._initial_num_buckets,
2205 name=self._name)
2206 else:
2207 # The table must be shared if checkpointing is requested for multi-worker
2208 # training to work correctly. Use the node name if no shared_name has been
2209 # explicitly specified.
2210 use_node_name_sharing = self._checkpoint and self._shared_name is None
2211 table_ref = gen_lookup_ops.mutable_dense_hash_table_v2(
2212 empty_key=empty_key,
2213 deleted_key=deleted_key,
2214 shared_name=self._shared_name,
2215 use_node_name_sharing=use_node_name_sharing,
2216 value_dtype=self._value_dtype,
2217 value_shape=self._value_shape,
2218 initial_num_buckets=self._initial_num_buckets,
2219 name=self._name)
2220 if context.executing_eagerly():
2221 self._table_name = None
2222 else:
2223 self._table_name = table_ref.op.name.split("/")[-1]
2224 return table_ref
2226 @property
2227 def name(self):
2228 return self._table_name
2230 def size(self, name=None):
2231 """Compute the number of elements in this table.
2233 Args:
2234 name: A name for the operation (optional).
2236 Returns:
2237 A scalar tensor containing the number of elements in this table.
2238 """
2239 with ops.name_scope(name, "%s_Size" % self.name, [self.resource_handle]):
2240 with ops.colocate_with(self.resource_handle):
2241 return gen_lookup_ops.lookup_table_size_v2(self.resource_handle)
2243 def lookup(self, keys, name=None):
2244 """Looks up `keys` in a table, outputs the corresponding values.
2246 The `default_value` is used for keys not present in the table.
2248 Args:
2249 keys: Keys to look up. Can be a tensor of any shape. Must match the
2250 table's key_dtype.
2251 name: A name for the operation (optional).
2253 Returns:
2254 A tensor containing the values in the same shape as `keys` using the
2255 table's value type.
2257 Raises:
2258 TypeError: when `keys` do not match the table data types.
2259 """
2260 with ops.name_scope(name, "%s_lookup_table_find" % self.name,
2261 [self.resource_handle, keys]):
2262 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2263 with ops.colocate_with(self.resource_handle):
2264 values = gen_lookup_ops.lookup_table_find_v2(self.resource_handle, keys,
2265 self._default_value)
2267 return values
2269 def insert_or_assign(self, keys, values, name=None):
2270 """Associates `keys` with `values`.
2272 Args:
2273 keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2274 key type.
2275 values: Values to be associated with keys. Must be a tensor of the same
2276 shape as `keys` and match the table's value type.
2277 name: A name for the operation (optional).
2279 Returns:
2280 The created Operation.
2282 Raises:
2283 TypeError: when `keys` or `values` doesn't match the table data
2284 types.
2285 """
2286 with ops.name_scope(name, "%s_lookup_table_insert" % self.name,
2287 [self.resource_handle, keys, values]):
2288 keys = ops.convert_to_tensor(keys, dtype=self._key_dtype, name="keys")
2289 values = ops.convert_to_tensor(
2290 values, dtype=self._value_dtype, name="values")
2291 with ops.colocate_with(self.resource_handle):
2292 op = gen_lookup_ops.lookup_table_insert_v2(self.resource_handle, keys,
2293 values)
2294 return op
2296 def insert(self, keys, values, name=None):
2297 """Associates `keys` with `values`.
2299 Args:
2300 keys: Keys to insert. Can be a tensor of any shape. Must match the table's
2301 key type.
2302 values: Values to be associated with keys. Must be a tensor of the same
2303 shape as `keys` and match the table's value type.
2304 name: A name for the operation (optional).
2306 Returns:
2307 The created Operation.
2309 Raises:
2310 TypeError: when `keys` or `values` doesn't match the table data
2311 types.
2312 """
2313 return self.insert_or_assign(keys, values, name)
2315 def erase(self, keys, name=None):
2316 """Removes `keys` and its associated values from the table.
2318 If a key is not present in the table, it is silently ignored.
2320 Args:
2321 keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2322 key type.
2323 name: A name for the operation (optional).
2325 Returns:
2326 The created Operation.
2328 Raises:
2329 TypeError: when `keys` do not match the table data types.
2330 """
2331 if keys.dtype != self._key_dtype:
2332 raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
2333 (self._key_dtype, keys.dtype))
2335 with ops.name_scope(name, "%s_lookup_table_remove" % self.name,
2336 (self.resource_handle, keys, self._default_value)):
2337 # pylint: disable=protected-access
2338 op = gen_lookup_ops.lookup_table_remove_v2(self.resource_handle, keys)
2340 return op
2342 def remove(self, keys, name=None):
2343 """Removes `keys` and its associated values from the table.
2345 If a key is not present in the table, it is silently ignored.
2347 Args:
2348 keys: Keys to remove. Can be a tensor of any shape. Must match the table's
2349 key type.
2350 name: A name for the operation (optional).
2352 Returns:
2353 The created Operation.
2355 Raises:
2356 TypeError: when `keys` do not match the table data types.
2357 """
2358 return self.erase(keys, name)
2360 def export(self, name=None):
2361 """Returns tensors of all keys and values in the table.
2363 Args:
2364 name: A name for the operation (optional).
2366 Returns:
2367 A pair of tensors with the first tensor containing all keys and the
2368 second tensors containing all values in the table.
2369 """
2370 with ops.name_scope(name, "%s_lookup_table_export_values" % self.name,
2371 [self.resource_handle]):
2372 with ops.colocate_with(self.resource_handle):
2373 exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
2374 self.resource_handle, self._key_dtype, self._value_dtype)
2376 return exported_keys, exported_values
2378 def _serialize_to_tensors(self):
2379 """Implements checkpointing interface in `Trackable`."""
2380 tensors = self.export()
2381 return {"-keys": tensors[0], "-values": tensors[1]}
2383 def _restore_from_tensors(self, restored_tensors):
2384 """Implements checkpointing interface in `Trackable`."""
2385 with ops.name_scope("%s_table_restore" % self._name):
2386 with ops.colocate_with(self.resource_handle):
2387 return gen_lookup_ops.lookup_table_import_v2(
2388 self.resource_handle,
2389 restored_tensors["-keys"],
2390 restored_tensors["-values"])
2392 # This class is needed for `DenseHashTable(checkpoint=True)`.
2393 class _Saveable(BaseSaverBuilder.SaveableObject):
2394 """SaveableObject implementation for DenseHashTable."""
2396 def __init__(self, table, name, table_name=None):
2397 tensors = table.export()
2398 specs = [
2399 BaseSaverBuilder.SaveSpec(tensors[0], "", name + "-keys"),
2400 BaseSaverBuilder.SaveSpec(tensors[1], "", name + "-values")
2401 ]
2402 self.table_name = table_name or name
2403 # pylint: disable=protected-access
2404 super(DenseHashTable._Saveable, self).__init__(table, specs, name)
2406 def restore(self, restored_tensors, restored_shapes):
2407 del restored_shapes # unused
2408 # pylint: disable=protected-access
2409 with ops.name_scope("%s_table_restore" % self.table_name):
2410 with ops.colocate_with(self.op.resource_handle):
2411 return gen_lookup_ops.lookup_table_import_v2(self.op.resource_handle,
2412 restored_tensors[0],
2413 restored_tensors[1])
2416ops.NotDifferentiable("LookupTableFind")
2417ops.NotDifferentiable("LookupTableFindV2")
2418ops.NotDifferentiable("LookupTableInsert")
2419ops.NotDifferentiable("LookupTableInsertV2")
2420ops.NotDifferentiable("LookupTableSize")
2421ops.NotDifferentiable("LookupTableSizeV2")
2422ops.NotDifferentiable("HashTable")
2423ops.NotDifferentiable("HashTableV2")
2424ops.NotDifferentiable("InitializeTable")
2425ops.NotDifferentiable("InitializeTableV2")
2426ops.NotDifferentiable("InitializeTableFromTextFile")
2427ops.NotDifferentiable("InitializeTableFromTextFileV2")
2428ops.NotDifferentiable("MutableDenseHashTable")
2429ops.NotDifferentiable("MutableDenseHashTableV2")
2430ops.NotDifferentiable("MutableHashTable")
2431ops.NotDifferentiable("MutableHashTableV2")
2432ops.NotDifferentiable("MutableHashTableOfTensors")
2433ops.NotDifferentiable("MutableHashTableOfTensorsV2")