Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/skip_gram_ops.py: 21%
67 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781."""
17import csv
18import tensorflow as tf
20from tensorflow_addons.utils.resource_loader import LazySO
22from tensorflow_addons.utils.types import AcceptableDTypes, FloatTensorLike, TensorLike
23from typing import Optional
25_skip_gram_so = LazySO("custom_ops/text/_skip_gram_ops.so")
27tf.no_gradient("Addons>SkipGramGenerateCandidates")
30def skip_gram_sample(
31 input_tensor: TensorLike,
32 min_skips: FloatTensorLike = 1,
33 max_skips: FloatTensorLike = 5,
34 start: FloatTensorLike = 0,
35 limit: FloatTensorLike = -1,
36 emit_self_as_target: bool = False,
37 vocab_freq_table: tf.lookup.KeyValueTensorInitializer = None,
38 vocab_min_count: Optional[FloatTensorLike] = None,
39 vocab_subsampling: Optional[FloatTensorLike] = None,
40 corpus_size: Optional[FloatTensorLike] = None,
41 seed: Optional[FloatTensorLike] = None,
42 name: Optional[str] = None,
43) -> tf.Tensor:
44 """Generates skip-gram token and label paired Tensors from the input
45 tensor.
47 Generates skip-gram `("token", "label")` pairs using each element in the
48 rank-1 `input_tensor` as a token. The window size used for each token will
49 be randomly selected from the range specified by `[min_skips, max_skips]`,
50 inclusive. See https://arxiv.org/abs/1301.3781 for more details about
51 skip-gram.
53 For example, given `input_tensor = ["the", "quick", "brown", "fox",
54 "jumps"]`, `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`,
55 the output `(tokens, labels)` pairs for the token "quick" will be randomly
56 selected from either `(tokens=["quick", "quick"], labels=["the", "brown"])`
57 for 1 skip, or `(tokens=["quick", "quick", "quick"],
58 labels=["the", "brown", "fox"])` for 2 skips.
60 If `emit_self_as_target = True`, each token will also be emitted as a label
61 for itself. From the previous example, the output will be either
62 `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])`
63 for 1 skip, or `(tokens=["quick", "quick", "quick", "quick"],
64 labels=["the", "quick", "brown", "fox"])` for 2 skips.
66 The same process is repeated for each element of `input_tensor` and
67 concatenated together into the two output rank-1 `Tensors` (one for all the
68 tokens, another for all the labels).
70 If `vocab_freq_table` is specified, tokens in `input_tensor` that are not
71 present in the vocabulary are discarded. Tokens whose frequency counts are
72 below `vocab_min_count` are also discarded. Tokens whose frequency
73 proportions in the corpus exceed `vocab_subsampling` may be randomly
74 down-sampled. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details
75 about subsampling.
77 Args:
78 input_tensor: A rank-1 `Tensor` from which to generate skip-gram
79 candidates.
80 min_skips: `int` or scalar `Tensor` specifying the minimum window size to
81 randomly use for each token. Must be >= 0 and <= `max_skips`. If
82 `min_skips` and `max_skips` are both 0, the only label outputted will
83 be the token itself when `emit_self_as_target = True` -
84 or no output otherwise.
85 max_skips: `int` or scalar `Tensor` specifying the maximum window size to
86 randomly use for each token. Must be >= 0.
87 start: `int` or scalar `Tensor` specifying the position in
88 `input_tensor` from which to start generating skip-gram candidates.
89 limit: `int` or scalar `Tensor` specifying the maximum number of
90 elements in `input_tensor` to use in generating skip-gram candidates.
91 -1 means to use the rest of the `Tensor` after `start`.
92 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
93 each token as a label for itself.
94 vocab_freq_table: (Optional) A lookup table (subclass of
95 `lookup.InitializableLookupTableBase`) that maps tokens to their raw
96 frequency counts. If specified, any token in `input_tensor` that is not
97 found in `vocab_freq_table` will be filtered out before generating
98 skip-gram candidates. While this will typically map to integer raw
99 frequency counts, it could also map to float frequency proportions.
100 `vocab_min_count` and `corpus_size` should be in the same units
101 as this.
102 vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying
103 minimum frequency threshold (from `vocab_freq_table`) for a token to be
104 kept in `input_tensor`. If this is specified, `vocab_freq_table` must
105 also be specified - and they should both be in the same units.
106 vocab_subsampling: (Optional) `float` specifying frequency proportion
107 threshold for tokens from `input_tensor`. Tokens that occur more
108 frequently (based on the ratio of the token's `vocab_freq_table` value
109 to the `corpus_size`) will be randomly down-sampled. Reasonable
110 starting values may be around 1e-3 or 1e-5. If this is specified, both
111 `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5
112 in http://arxiv.org/abs/1310.4546 for more details.
113 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
114 total number of tokens in the corpus (e.g., sum of all the frequency
115 counts of `vocab_freq_table`). Used with `vocab_subsampling` for
116 down-sampling frequently occurring tokens. If this is specified,
117 `vocab_freq_table` and `vocab_subsampling` must also be specified.
118 seed: (Optional) `int` used to create a random seed for window size and
119 subsampling. See `set_random_seed` docs for behavior.
120 name: (Optional) A `string` name or a name scope for the operations.
122 Returns:
123 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
124 rank-1 and has the same type as `input_tensor`.
126 Raises:
127 ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`,
128 `vocab_subsampling`, or `corpus_size` is specified.
129 If `vocab_subsampling` and `corpus_size` are not both present or
130 both absent.
131 """
133 if vocab_freq_table is None and (
134 vocab_min_count is not None
135 or vocab_subsampling is not None
136 or corpus_size is not None
137 ):
138 raise ValueError(
139 "vocab_freq_table is not provided, but vocab_min_count={}, "
140 "vocab_subsampling={}, or corpus_size={} is not None."
141 "These settings are useless without a vocab_freq_table.".format(
142 vocab_min_count, vocab_subsampling, corpus_size
143 )
144 )
146 if (vocab_subsampling is None) != (corpus_size is None):
147 raise ValueError(
148 "vocab_subsampling is {} while corpus_size is {} - both must be "
149 "provided in order for subsampling to work.".format(
150 vocab_subsampling, corpus_size
151 )
152 )
154 with tf.name_scope(name or "skip_gram_sample"):
156 input_tensor = _filter_input(
157 input_tensor=input_tensor,
158 vocab_freq_table=vocab_freq_table,
159 vocab_min_count=vocab_min_count,
160 vocab_subsampling=vocab_subsampling,
161 corpus_size=corpus_size,
162 seed=seed,
163 )
165 seed1, seed2 = tf.compat.v1.get_seed(seed)
166 tokens, labels = _skip_gram_so.ops.addons_skip_gram_generate_candidates(
167 input_tensor=input_tensor,
168 min_skips=min_skips,
169 max_skips=max_skips,
170 start=start,
171 limit=limit,
172 emit_self_as_target=emit_self_as_target,
173 # Note that seed here should be seed1! This is due to
174 # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2".
175 seed=seed1,
176 seed2=seed2,
177 )
179 # TODO(weiho): If the need arises, add support for sparse input_tensor that
180 # figures out sentence boundaries, then calls
181 # skip_gram_generate_candidates() on each sentence.
183 return tokens, labels
186def skip_gram_sample_with_text_vocab(
187 input_tensor: TensorLike,
188 vocab_freq_file: str,
189 vocab_token_index: FloatTensorLike = 0,
190 vocab_token_dtype: Optional[AcceptableDTypes] = tf.dtypes.string,
191 vocab_freq_index: FloatTensorLike = 1,
192 vocab_freq_dtype: Optional[AcceptableDTypes] = tf.dtypes.float64,
193 vocab_delimiter: str = ",",
194 vocab_min_count: Optional[FloatTensorLike] = None,
195 vocab_subsampling: Optional[FloatTensorLike] = None,
196 corpus_size: Optional[FloatTensorLike] = None,
197 min_skips: FloatTensorLike = 1,
198 max_skips: FloatTensorLike = 5,
199 start: FloatTensorLike = 0,
200 limit: FloatTensorLike = -1,
201 emit_self_as_target: bool = False,
202 seed: Optional[FloatTensorLike] = None,
203 name: Optional[str] = None,
204) -> tf.Tensor:
205 """Skip-gram sampling with a text vocabulary file.
207 Wrapper around `skip_gram_sample()` for use with a text vocabulary file.
208 The vocabulary file is expected to be a plain-text file, with lines of
209 `vocab_delimiter`-separated columns. The `vocab_token_index` column should
210 contain the vocabulary term, while the `vocab_freq_index` column should
211 contain the number of times that term occurs in the corpus. For example,
212 with a text vocabulary file of:
214 ```
215 bonjour,fr,42
216 hello,en,777
217 hola,es,99
218 ```
220 You should set `vocab_delimiter=","`, `vocab_token_index=0`, and
221 `vocab_freq_index=2`.
223 See `skip_gram_sample()` documentation for more details about the skip-gram
224 sampling process.
226 Args:
227 input_tensor:
228 A rank-1 `Tensor` from which to generate skip-gram candidates.
229 vocab_freq_file:
230 `string` specifying full file path to the text vocab file.
231 vocab_token_index: `int` specifying which column in the text vocab file
232 contains the tokens.
233 vocab_token_dtype:
234 `DType` specifying the format of the tokens in the text vocab file.
235 vocab_freq_index: `int` specifying which column in the text vocab file
236 contains the frequency counts of the tokens.
237 vocab_freq_dtype: `DType` specifying the format of the frequency counts
238 in the text vocab file.
239 vocab_delimiter: `string` specifying the delimiter used in the text vocab
240 file.
241 vocab_min_count: `int`, `float`, or scalar `Tensor` specifying
242 minimum frequency threshold (from `vocab_freq_file`) for a token to be
243 kept in `input_tensor`. This should correspond with `vocab_freq_dtype`.
244 vocab_subsampling: (Optional) `float` specifying frequency proportion
245 threshold for tokens from `input_tensor`. Tokens that occur more
246 frequently will be randomly down-sampled. Reasonable starting values
247 may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546
248 for more details.
249 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the
250 total number of tokens in the corpus (e.g., sum of all the frequency
251 counts of `vocab_freq_file`). Used with `vocab_subsampling` for
252 down-sampling frequently occurring tokens. If this is specified,
253 `vocab_freq_file` and `vocab_subsampling` must also be specified.
254 If `corpus_size` is needed but not supplied, then it will be calculated
255 from `vocab_freq_file`. You might want to supply your own value if you
256 have already eliminated infrequent tokens from your vocabulary files
257 (where frequency < vocab_min_count) to save memory in the internal
258 token lookup table. Otherwise, the unused tokens' variables will waste
259 memory. The user-supplied `corpus_size` value must be greater than or
260 equal to the sum of all the frequency counts of `vocab_freq_file`.
261 min_skips: `int` or scalar `Tensor` specifying the minimum window size to
262 randomly use for each token. Must be >= 0 and <= `max_skips`. If
263 `min_skips` and `max_skips` are both 0, the only label outputted will
264 be the token itself.
265 max_skips: `int` or scalar `Tensor` specifying the maximum window size to
266 randomly use for each token. Must be >= 0.
267 start: `int` or scalar `Tensor` specifying the position in `input_tensor`
268 from which to start generating skip-gram candidates.
269 limit: `int` or scalar `Tensor` specifying the maximum number of elements
270 in `input_tensor` to use in generating skip-gram candidates. -1 means
271 to use the rest of the `Tensor` after `start`.
272 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit
273 each token as a label for itself.
274 seed: (Optional) `int` used to create a random seed for window size and
275 subsampling. See
276 [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed)
277 for behavior.
278 name: (Optional) A `string` name or a name scope for the operations.
280 Returns:
281 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of
282 rank-1 and has the same type as `input_tensor`.
284 Raises:
285 ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0
286 or exceeds the number of columns in `vocab_freq_file`.
287 If `vocab_token_index` and `vocab_freq_index` are both set to the same
288 column. If any token in `vocab_freq_file` has a negative frequency.
289 """
291 if vocab_token_index < 0 or vocab_freq_index < 0:
292 raise ValueError(
293 "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".format(
294 vocab_token_index, vocab_freq_index
295 )
296 )
297 if vocab_token_index == vocab_freq_index:
298 raise ValueError(
299 "vocab_token_index and vocab_freq_index should be different, "
300 "but are both {}.".format(vocab_token_index)
301 )
303 # Iterates through the vocab file and calculates the number of vocab terms as
304 # well as the total corpus size (by summing the frequency counts of all the
305 # vocab terms).
306 calculated_corpus_size = 0.0
307 vocab_size = 0
308 with tf.io.gfile.GFile(vocab_freq_file, mode="r") as f:
309 reader = csv.reader(f, delimiter=vocab_delimiter)
310 for row in reader:
311 if vocab_token_index >= len(row) or vocab_freq_index >= len(row):
312 raise ValueError(
313 "Row in vocab file only has {} columns, "
314 "so vocab_token_index={} or "
315 "vocab_freq_index={} is out of bounds. Row content: {}".format(
316 len(row), vocab_token_index, vocab_freq_index, row
317 )
318 )
319 vocab_size += 1
320 freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index])
321 if freq < 0:
322 raise ValueError(
323 "Row in vocab file has negative frequency of {}. "
324 "Row content: {}".format(freq, row)
325 )
326 # Note: tokens whose frequencies are below vocab_min_count will still
327 # contribute to the total corpus size used for vocab subsampling.
328 calculated_corpus_size += freq
330 if not corpus_size:
331 corpus_size = calculated_corpus_size
332 elif calculated_corpus_size - corpus_size > 1e-6:
333 raise ValueError(
334 "`corpus_size`={} must be greater than or equal to the "
335 "sum of all the frequency counts ({}) of `vocab_freq_file` ({}).".format(
336 corpus_size, calculated_corpus_size, vocab_freq_file
337 )
338 )
340 vocab_freq_table = tf.lookup.StaticHashTable(
341 tf.lookup.TextFileInitializer(
342 filename=vocab_freq_file,
343 key_dtype=vocab_token_dtype,
344 key_index=vocab_token_index,
345 value_dtype=vocab_freq_dtype,
346 value_index=vocab_freq_index,
347 vocab_size=vocab_size,
348 delimiter=vocab_delimiter,
349 ),
350 # For vocab terms not in vocab file, use a default value of -1.
351 default_value=-1,
352 )
354 return skip_gram_sample(
355 input_tensor,
356 min_skips=min_skips,
357 max_skips=max_skips,
358 start=start,
359 limit=limit,
360 emit_self_as_target=emit_self_as_target,
361 vocab_freq_table=vocab_freq_table,
362 vocab_min_count=vocab_min_count,
363 vocab_subsampling=vocab_subsampling,
364 # corpus_size is not used unless vocab_subsampling is specified.
365 corpus_size=None if vocab_subsampling is None else corpus_size,
366 seed=seed,
367 name=name,
368 )
371def _filter_input(
372 input_tensor,
373 vocab_freq_table,
374 vocab_min_count,
375 vocab_subsampling,
376 corpus_size,
377 seed,
378):
379 input_tensor = tf.convert_to_tensor(input_tensor)
380 """Filters input tensor based on vocab freq, threshold, and subsampling."""
381 if vocab_freq_table is None:
382 return input_tensor
384 if not isinstance(vocab_freq_table, tf.lookup.StaticHashTable):
385 raise ValueError(
386 "vocab_freq_table must be a subclass of "
387 "InitializableLookupTableBase (such as HashTable) instead of type "
388 "{}.".format(type(vocab_freq_table))
389 )
391 with tf.name_scope("filter_vocab"):
392 freq = vocab_freq_table.lookup(input_tensor)
393 # Filters out elements in input_tensor that are not found in
394 # vocab_freq_table (table returns a default value of -1 specified above when
395 # an element is not found).
396 mask = tf.math.not_equal(freq, vocab_freq_table.default_value)
398 # Filters out elements whose vocab frequencies are less than the threshold.
399 if vocab_min_count is not None:
400 cast_threshold = tf.cast(vocab_min_count, freq.dtype)
401 mask = tf.math.logical_and(
402 mask, tf.math.greater_equal(freq, cast_threshold)
403 )
405 input_tensor = tf.boolean_mask(input_tensor, mask)
406 freq = tf.boolean_mask(freq, mask)
408 if not vocab_subsampling:
409 return input_tensor
411 if vocab_subsampling < 0 or vocab_subsampling > 1:
412 raise ValueError(
413 "Invalid vocab_subsampling={} - it should be within range [0, 1].".format(
414 vocab_subsampling
415 )
416 )
418 # Subsamples the input tokens based on vocabulary frequency and
419 # vocab_subsampling threshold (ie randomly discard commonly appearing
420 # tokens).
421 with tf.name_scope("subsample_vocab"):
422 corpus_size = tf.cast(corpus_size, tf.dtypes.float64)
423 freq = tf.cast(freq, tf.dtypes.float64)
424 vocab_subsampling = tf.cast(vocab_subsampling, tf.dtypes.float64)
426 # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is
427 # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546.
428 keep_prob = (tf.math.sqrt(freq / (vocab_subsampling * corpus_size)) + 1.0) * (
429 vocab_subsampling * corpus_size / freq
430 )
431 random_prob = tf.random.uniform(
432 tf.shape(freq), minval=0, maxval=1, dtype=tf.dtypes.float64, seed=seed
433 )
435 mask = tf.math.less_equal(random_prob, keep_prob)
436 return tf.boolean_mask(input_tensor, mask)