Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/data/experimental/ops/lookup_ops.py: 28%

65 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14#============================================================================== 

15"""Lookup operations.""" 

16 

17from tensorflow.python.data.experimental.ops.cardinality import assert_cardinality 

18from tensorflow.python.framework import dtypes 

19from tensorflow.python.framework import ops 

20from tensorflow.python.framework import tensor_spec 

21from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops 

22from tensorflow.python.ops import lookup_ops 

23from tensorflow.python.ops import math_ops 

24from tensorflow.python.util.tf_export import tf_export 

25 

26 

27def _check_table_initializer_element_spec(element_spec): 

28 """Raises an error if the given table initializer element spec is invalid.""" 

29 base_error = ("Datasets used to initialize lookup tables must " 

30 "produce elements in the form (key, value), where " 

31 "the keys and values are scalar tensors. ") 

32 specific_error = None 

33 if len(element_spec) != 2: 

34 raise ValueError(base_error + "However, the given dataset produces " 

35 f"{len(element_spec)} components instead of two " 

36 "(key, value) components. Full dataset element spec: " 

37 f"{element_spec}.") 

38 if not isinstance(element_spec[0], tensor_spec.TensorSpec): 

39 raise ValueError(base_error + "However, the given dataset produces " 

40 f"non-Tensor keys of type {type(element_spec[0])}.") 

41 if not isinstance(element_spec[1], tensor_spec.TensorSpec): 

42 raise ValueError(base_error + "However, the given dataset produces " 

43 f"non-Tensor values of type {type(element_spec[1])}.") 

44 if element_spec[0].shape.rank not in (None, 0): 

45 raise ValueError( 

46 base_error + "However, the given dataset produces " 

47 f"non-scalar key Tensors of rank {element_spec[0].shape.rank}.") 

48 if element_spec[1].shape.rank not in (None, 0): 

49 raise ValueError( 

50 base_error + "However, the given dataset produces " 

51 f"non-scalar value Tensors of rank {element_spec[1].shape.rank}.") 

52 

53 

54@tf_export("data.experimental.DatasetInitializer") 

55class DatasetInitializer(lookup_ops.TableInitializerBase): 

56 """Creates a table initializer from a `tf.data.Dataset`. 

57 

58 Sample usage: 

59 

60 >>> keys = tf.data.Dataset.range(100) 

61 >>> values = tf.data.Dataset.range(100).map( 

62 ... lambda x: tf.strings.as_string(x * 2)) 

63 >>> ds = tf.data.Dataset.zip((keys, values)) 

64 >>> init = tf.data.experimental.DatasetInitializer(ds) 

65 >>> table = tf.lookup.StaticHashTable(init, "") 

66 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() 

67 array([b'0', b'2', b'4'], dtype=object) 

68 

69 Attributes: 

70 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 

71 first scalar is treated as a key and the second as value. 

72 Raises: ValueError if `dataset` doesn't conform to specifications. 

73 """ 

74 

75 def __init__(self, dataset): 

76 """Creates a table initializer from a `tf.data.Dataset`. 

77 

78 Args: 

79 dataset: A `tf.data.Dataset` object that produces tuples of scalars. The 

80 first scalar is treated as a key and the second as value. 

81 Raises: ValueError if `dataset` doesn't conform to specifications. 

82 Returns: A `DatasetInitializer` object 

83 """ 

84 # Assert that the dataset element spec is a tuple of TensorSpecs where 

85 # each tensor is a scalar. 

86 self.dataset = dataset 

87 elem_spec = self.dataset.element_spec 

88 _check_table_initializer_element_spec(elem_spec) 

89 

90 key_type = elem_spec[0].dtype 

91 value_type = elem_spec[1].dtype 

92 super(DatasetInitializer, self).__init__(key_type, value_type) 

93 

94 def initialize(self, table): 

95 lookup_ops.check_table_dtypes(table, self._key_dtype, self._value_dtype) 

96 init_op = ged_ops.initialize_table_from_dataset( 

97 table.resource_handle, self.dataset._variant_tensor) # pylint: disable=protected-access 

98 ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op) 

99 return init_op 

100 

101 

102@tf_export("data.experimental.table_from_dataset") 

103def table_from_dataset(dataset=None, 

104 num_oov_buckets=0, 

105 vocab_size=None, 

106 default_value=None, 

107 hasher_spec=lookup_ops.FastHashSpec, 

108 key_dtype=dtypes.string, 

109 name=None): 

110 """Returns a lookup table based on the given dataset. 

111 

112 This operation constructs a lookup table based on the given dataset of pairs 

113 of (key, value). 

114 

115 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 

116 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 

117 `default_value`. 

118 The bucket ID range is 

119 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 

120 

121 Sample Usages: 

122 

123 >>> keys = tf.data.Dataset.range(100) 

124 >>> values = tf.data.Dataset.range(100).map( 

125 ... lambda x: tf.strings.as_string(x * 2)) 

126 >>> ds = tf.data.Dataset.zip((keys, values)) 

127 >>> table = tf.data.experimental.table_from_dataset( 

128 ... ds, default_value='n/a', key_dtype=tf.int64) 

129 >>> table.lookup(tf.constant([0, 1, 2], dtype=tf.int64)).numpy() 

130 array([b'0', b'2', b'4'], dtype=object) 

131 

132 Args: 

133 dataset: A dataset containing (key, value) pairs. 

134 num_oov_buckets: The number of out-of-vocabulary buckets. 

135 vocab_size: Number of the elements in the vocabulary, if known. 

136 default_value: The value to use for out-of-vocabulary feature values. 

137 Defaults to -1. 

138 hasher_spec: A `HasherSpec` to specify the hash function to use for 

139 assignation of out-of-vocabulary buckets. 

140 key_dtype: The `key` data type. 

141 name: A name for this op (optional). 

142 

143 Returns: 

144 The lookup table based on the given dataset. 

145 

146 Raises: 

147 ValueError: If 

148 * `dataset` does not contain pairs 

149 * The 2nd item in the `dataset` pairs has a dtype which is incompatible 

150 with `default_value` 

151 * `num_oov_buckets` is negative 

152 * `vocab_size` is not greater than zero 

153 * The `key_dtype` is not integer or string 

154 """ 

155 elem_spec = dataset.element_spec 

156 _check_table_initializer_element_spec(elem_spec) 

157 if default_value is None: 

158 default_value = -1 

159 if not (elem_spec[1].dtype.is_integer or elem_spec[1].dtype.is_floating): 

160 raise ValueError("`default_value` must be specified when creating a " 

161 "table from a dataset that produces values of type " 

162 f"{elem_spec[1].dtype}.") 

163 if num_oov_buckets < 0: 

164 raise ValueError("`num_oov_buckets` must be greater than or equal to 0, " 

165 f"got {num_oov_buckets}.") 

166 if (not isinstance(vocab_size, ops.Tensor) and vocab_size is not None and 

167 vocab_size < 1): 

168 raise ValueError(f"`vocab_size` must be greater than 0, got {vocab_size}.") 

169 if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype): 

170 raise TypeError("`key_dtype` must be either an integer or string type, " 

171 f"but got {key_dtype}") 

172 if vocab_size is not None: 

173 if isinstance(vocab_size, ops.Tensor): 

174 vocab_size = math_ops.cast(vocab_size, dtypes.int64) 

175 dataset = dataset.take(vocab_size) 

176 dataset = dataset.apply(assert_cardinality(vocab_size)) 

177 with ops.name_scope(name, "string_to_index"): 

178 initializer = DatasetInitializer(dataset) 

179 with ops.name_scope(None, "hash_table"): 

180 table = lookup_ops.StaticHashTableV1(initializer, default_value) 

181 if num_oov_buckets: 

182 table = lookup_ops.IdTableWithHashBuckets( 

183 table, 

184 num_oov_buckets=num_oov_buckets, 

185 hasher_spec=hasher_spec, 

186 key_dtype=key_dtype) 

187 return table 

188 

189 

190@tf_export("data.experimental.index_table_from_dataset") 

191def index_table_from_dataset(dataset=None, 

192 num_oov_buckets=0, 

193 vocab_size=None, 

194 default_value=-1, 

195 hasher_spec=lookup_ops.FastHashSpec, 

196 key_dtype=dtypes.string, 

197 name=None): 

198 """Returns an index lookup table based on the given dataset. 

199 

200 This operation constructs a lookup table based on the given dataset of keys. 

201 

202 Any lookup of an out-of-vocabulary token will return a bucket ID based on its 

203 hash if `num_oov_buckets` is greater than zero. Otherwise it is assigned the 

204 `default_value`. 

205 The bucket ID range is 

206 `[vocabulary size, vocabulary size + num_oov_buckets - 1]`. 

207 

208 Sample Usages: 

209 

210 >>> ds = tf.data.Dataset.range(100).map(lambda x: tf.strings.as_string(x * 2)) 

211 >>> table = tf.data.experimental.index_table_from_dataset( 

212 ... ds, key_dtype=dtypes.int64) 

213 >>> table.lookup(tf.constant(['0', '2', '4'], dtype=tf.string)).numpy() 

214 array([0, 1, 2]) 

215 

216 Args: 

217 dataset: A dataset of keys. 

218 num_oov_buckets: The number of out-of-vocabulary buckets. 

219 vocab_size: Number of the elements in the vocabulary, if known. 

220 default_value: The value to use for out-of-vocabulary feature values. 

221 Defaults to -1. 

222 hasher_spec: A `HasherSpec` to specify the hash function to use for 

223 assignation of out-of-vocabulary buckets. 

224 key_dtype: The `key` data type. 

225 name: A name for this op (optional). 

226 

227 Returns: 

228 The lookup table based on the given dataset. 

229 

230 Raises: 

231 ValueError: If 

232 * `num_oov_buckets` is negative 

233 * `vocab_size` is not greater than zero 

234 * The `key_dtype` is not integer or string 

235 """ 

236 return table_from_dataset(dataset.enumerate().map(lambda v, k: (k, v)), 

237 num_oov_buckets, vocab_size, default_value, 

238 hasher_spec, key_dtype, name)