Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/lookup_ops.py: 28%
65 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14#==============================================================================
15"""Lookup operations."""
17from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality
18from tensorflow.python.framework import dtypes
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
22from tensorflow.python.ops import lookup_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.util.tf_export import tf_export
27def _check_table_initializer_element_spec(element_spec):
28 """Raises an error if the given table initializer element spec is invalid."""
29 base_error = ("Datasets used to initialize lookup tables must "
30 "produce elements in the form (key, value), where "
31 "the keys and values are scalar tensors. ")
32 specific_error = None
33 if len(element_spec) != 2:
34 raise ValueError(base_error + "However, the given dataset produces "
35 f"{len(element_spec)} components instead of two "
36 "(key, value) components. Full dataset element spec: "
37 f"{element_spec}.")
38 if not isinstance(element_spec[0], tensor_spec.TensorSpec):
39 raise ValueError(base_error + "However, the given dataset produces "
40 f"non-Tensor keys of type {type(element_spec[0])}.")
41 if not isinstance(element_spec[1], tensor_spec.TensorSpec):
42 raise ValueError(base_error + "However, the given dataset produces "
43 f"non-Tensor values of type {type(element_spec[1])}.")
44 if element_spec[0].shape.rank not in (None, 0):
45 raise ValueError(
46 base_error + "However, the given dataset produces "
47 f"non-scalar key Tensors of rank {element_spec[0].shape.rank}.")
48 if element_spec[1].shape.rank not in (None, 0):
49 raise ValueError(
50 base_error + "However, the given dataset produces "
51 f"non-scalar value Tensors of rank {element_spec[1].shape.rank}.")
54@tf_export("data.experimental.DatasetInitializer")
55class DatasetInitializer(lookup_ops.TableInitializerBase):
56 """Creates a table initializer from a `tf.data.Dataset`.
58 Sample usage:
60 >>> keys = tf.data.Dataset.range(100)
61 >>> values = tf.data.Dataset.range(100).map(
62 ... lambda x: tf.strings.as_string(x * 2))
63 >>> ds = tf.data.Dataset.zip((keys, values))
64 >>> init = tf.data.experimental.DatasetInitializer(ds)
65 >>> table = tf.lookup.StaticHashTable(init, "")
66 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
67 array([b'0', b'2', b'4'], dtype=object)
69 Attributes:
70 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
71 first scalar is treated as a key and the second as value.
72 Raises: ValueError if `dataset` doesn't conform to specifications.
73 """
75 def __init__(self, dataset):
76 """Creates a table initializer from a `tf.data.Dataset`.
78 Args:
79 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The
80 first scalar is treated as a key and the second as value.
81 Raises: ValueError if `dataset` doesn't conform to specifications.
82 Returns: A `DatasetInitializer` object
83 """
84 # Assert that the dataset element spec is a tuple of TensorSpecs where
85 # each tensor is a scalar.
86 self.dataset = dataset
87 elem_spec = self.dataset.element_spec
88 _check_table_initializer_element_spec(elem_spec)
90 key_type = elem_spec[0].dtype
91 value_type = elem_spec[1].dtype
92 super(DatasetInitializer, self).__init__(key_type, value_type)
94 def initialize(self, table):
95 lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype)
96 init_op = ged_ops.initialize_table_from_dataset(
97 table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access
98 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
99 return init_op
102@tf_export("data.experimental.table_from_dataset")
103def table_from_dataset(dataset=None,
104 num_oov_buckets=0,
105 vocab_size=None,
106 default_value=None,
107 hasher_spec=lookup_ops.FastHashSpec,
108 key_dtype=dtypes.string,
109 name=None):
110 """Returns a lookup table based on the given dataset.
112 This operation constructs a lookup table based on the given dataset of pairs
113 of (key, value).
115 Any lookup of an out-of-vocabulary token will return a bucket ID based on its
116 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
117 `default_value`.
118 The bucket ID range is
119 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
121 Sample Usages:
123 >>> keys = tf.data.Dataset.range(100)
124 >>> values = tf.data.Dataset.range(100).map(
125 ... lambda x: tf.strings.as_string(x * 2))
126 >>> ds = tf.data.Dataset.zip((keys, values))
127 >>> table = tf.data.experimental.table_from_dataset(
128 ... ds, default_value='n/a', key_dtype=tf.int64)
129 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy()
130 array([b'0', b'2', b'4'], dtype=object)
132 Args:
133 dataset: A dataset containing (key, value) pairs.
134 num_oov_buckets: The number of out-of-vocabulary buckets.
135 vocab_size: Number of the elements in the vocabulary, if known.
136 default_value: The value to use for out-of-vocabulary feature values.
137 Defaults to -1.
138 hasher_spec: A `HasherSpec` to specify the hash function to use for
139 assignation of out-of-vocabulary buckets.
140 key_dtype: The `key` data type.
141 name: A name for this op (optional).
143 Returns:
144 The lookup table based on the given dataset.
146 Raises:
147 ValueError: If
148 * `dataset` does not contain pairs
149 * The 2nd item in the `dataset` pairs has a dtype which is incompatible
150 with `default_value`
151 * `num_oov_buckets` is negative
152 * `vocab_size` is not greater than zero
153 * The `key_dtype` is not integer or string
154 """
155 elem_spec = dataset.element_spec
156 _check_table_initializer_element_spec(elem_spec)
157 if default_value is None:
158 default_value = -1
159 if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating):
160 raise ValueError("`default_value` must be specified when creating a "
161 "table from a dataset that produces values of type "
162 f"{elem_spec[1].dtype}.")
163 if num_oov_buckets < 0:
164 raise ValueError("`num_oov_buckets` must be greater than or equal to 0, "
165 f"got {num_oov_buckets}.")
166 if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and
167 vocab_size < 1):
168 raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.")
169 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
170 raise TypeError("`key_dtype` must be either an integer or string type, "
171 f"but got {key_dtype}")
172 if vocab_size is not None:
173 if isinstance(vocab_size, ops.Tensor):
174 vocab_size = math_ops.cast(vocab_size, dtypes.int64)
175 dataset = dataset.take(vocab_size)
176 dataset = dataset.apply(assert_cardinality(vocab_size))
177 with ops.name_scope(name, "string_to_index"):
178 initializer = DatasetInitializer(dataset)
179 with ops.name_scope(None, "hash_table"):
180 table = lookup_ops.StaticHashTableV1(initializer, default_value)
181 if num_oov_buckets:
182 table = lookup_ops.IdTableWithHashBuckets(
183 table,
184 num_oov_buckets=num_oov_buckets,
185 hasher_spec=hasher_spec,
186 key_dtype=key_dtype)
187 return table
190@tf_export("data.experimental.index_table_from_dataset")
191def index_table_from_dataset(dataset=None,
192 num_oov_buckets=0,
193 vocab_size=None,
194 default_value=-1,
195 hasher_spec=lookup_ops.FastHashSpec,
196 key_dtype=dtypes.string,
197 name=None):
198 """Returns an index lookup table based on the given dataset.
200 This operation constructs a lookup table based on the given dataset of keys.
202 Any lookup of an out-of-vocabulary token will return a bucket ID based on its
203 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the
204 `default_value`.
205 The bucket ID range is
206 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`.
208 Sample Usages:
210 >>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2))
211 >>> table = tf.data.experimental.index_table_from_dataset(
212 ... ds, key_dtype=dtypes.int64)
213 >>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy()
214 array([0, 1, 2])
216 Args:
217 dataset: A dataset of keys.
218 num_oov_buckets: The number of out-of-vocabulary buckets.
219 vocab_size: Number of the elements in the vocabulary, if known.
220 default_value: The value to use for out-of-vocabulary feature values.
221 Defaults to -1.
222 hasher_spec: A `HasherSpec` to specify the hash function to use for
223 assignation of out-of-vocabulary buckets.
224 key_dtype: The `key` data type.
225 name: A name for this op (optional).
227 Returns:
228 The lookup table based on the given dataset.
230 Raises:
231 ValueError: If
232 * `num_oov_buckets` is negative
233 * `vocab_size` is not greater than zero
234 * The `key_dtype` is not integer or string
235 """
236 return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)),
237 num_oov_buckets, vocab_size, default_value,
238 hasher_spec, key_dtype, name)