Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow_addons/text/skip_gram_ops.py: 21%

67 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Skip-gram sampling ops from https://arxiv.org/abs/1301.3781.""" 

16 

17import csv 

18import tensorflow as tf 

19 

20from tensorflow_addons.utils.resource_loader import LazySO 

21 

22from tensorflow_addons.utils.types import AcceptableDTypes, FloatTensorLike, TensorLike 

23from typing import Optional 

24 

25_skip_gram_so = LazySO("custom_ops/text/_skip_gram_ops.so") 

26 

27tf.no_gradient("Addons>SkipGramGenerateCandidates") 

28 

29 

30def skip_gram_sample( 

31 input_tensor: TensorLike, 

32 min_skips: FloatTensorLike = 1, 

33 max_skips: FloatTensorLike = 5, 

34 start: FloatTensorLike = 0, 

35 limit: FloatTensorLike = -1, 

36 emit_self_as_target: bool = False, 

37 vocab_freq_table: tf.lookup.KeyValueTensorInitializer = None, 

38 vocab_min_count: Optional[FloatTensorLike] = None, 

39 vocab_subsampling: Optional[FloatTensorLike] = None, 

40 corpus_size: Optional[FloatTensorLike] = None, 

41 seed: Optional[FloatTensorLike] = None, 

42 name: Optional[str] = None, 

43) -> tf.Tensor: 

44 """Generates skip-gram token and label paired Tensors from the input 

45 tensor. 

46 

47 Generates skip-gram `("token", "label")` pairs using each element in the 

48 rank-1 `input_tensor` as a token. The window size used for each token will 

49 be randomly selected from the range specified by `[min_skips, max_skips]`, 

50 inclusive. See https://arxiv.org/abs/1301.3781 for more details about 

51 skip-gram. 

52 

53 For example, given `input_tensor = ["the", "quick", "brown", "fox", 

54 "jumps"]`, `min_skips = 1`, `max_skips = 2`, `emit_self_as_target = False`, 

55 the output `(tokens, labels)` pairs for the token "quick" will be randomly 

56 selected from either `(tokens=["quick", "quick"], labels=["the", "brown"])` 

57 for 1 skip, or `(tokens=["quick", "quick", "quick"], 

58 labels=["the", "brown", "fox"])` for 2 skips. 

59 

60 If `emit_self_as_target = True`, each token will also be emitted as a label 

61 for itself. From the previous example, the output will be either 

62 `(tokens=["quick", "quick", "quick"], labels=["the", "quick", "brown"])` 

63 for 1 skip, or `(tokens=["quick", "quick", "quick", "quick"], 

64 labels=["the", "quick", "brown", "fox"])` for 2 skips. 

65 

66 The same process is repeated for each element of `input_tensor` and 

67 concatenated together into the two output rank-1 `Tensors` (one for all the 

68 tokens, another for all the labels). 

69 

70 If `vocab_freq_table` is specified, tokens in `input_tensor` that are not 

71 present in the vocabulary are discarded. Tokens whose frequency counts are 

72 below `vocab_min_count` are also discarded. Tokens whose frequency 

73 proportions in the corpus exceed `vocab_subsampling` may be randomly 

74 down-sampled. See Eq. 5 in http://arxiv.org/abs/1310.4546 for more details 

75 about subsampling. 

76 

77 Args: 

78 input_tensor: A rank-1 `Tensor` from which to generate skip-gram 

79 candidates. 

80 min_skips: `int` or scalar `Tensor` specifying the minimum window size to 

81 randomly use for each token. Must be >= 0 and <= `max_skips`. If 

82 `min_skips` and `max_skips` are both 0, the only label outputted will 

83 be the token itself when `emit_self_as_target = True` - 

84 or no output otherwise. 

85 max_skips: `int` or scalar `Tensor` specifying the maximum window size to 

86 randomly use for each token. Must be >= 0. 

87 start: `int` or scalar `Tensor` specifying the position in 

88 `input_tensor` from which to start generating skip-gram candidates. 

89 limit: `int` or scalar `Tensor` specifying the maximum number of 

90 elements in `input_tensor` to use in generating skip-gram candidates. 

91 -1 means to use the rest of the `Tensor` after `start`. 

92 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit 

93 each token as a label for itself. 

94 vocab_freq_table: (Optional) A lookup table (subclass of 

95 `lookup.InitializableLookupTableBase`) that maps tokens to their raw 

96 frequency counts. If specified, any token in `input_tensor` that is not 

97 found in `vocab_freq_table` will be filtered out before generating 

98 skip-gram candidates. While this will typically map to integer raw 

99 frequency counts, it could also map to float frequency proportions. 

100 `vocab_min_count` and `corpus_size` should be in the same units 

101 as this. 

102 vocab_min_count: (Optional) `int`, `float`, or scalar `Tensor` specifying 

103 minimum frequency threshold (from `vocab_freq_table`) for a token to be 

104 kept in `input_tensor`. If this is specified, `vocab_freq_table` must 

105 also be specified - and they should both be in the same units. 

106 vocab_subsampling: (Optional) `float` specifying frequency proportion 

107 threshold for tokens from `input_tensor`. Tokens that occur more 

108 frequently (based on the ratio of the token's `vocab_freq_table` value 

109 to the `corpus_size`) will be randomly down-sampled. Reasonable 

110 starting values may be around 1e-3 or 1e-5. If this is specified, both 

111 `vocab_freq_table` and `corpus_size` must also be specified. See Eq. 5 

112 in http://arxiv.org/abs/1310.4546 for more details. 

113 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the 

114 total number of tokens in the corpus (e.g., sum of all the frequency 

115 counts of `vocab_freq_table`). Used with `vocab_subsampling` for 

116 down-sampling frequently occurring tokens. If this is specified, 

117 `vocab_freq_table` and `vocab_subsampling` must also be specified. 

118 seed: (Optional) `int` used to create a random seed for window size and 

119 subsampling. See `set_random_seed` docs for behavior. 

120 name: (Optional) A `string` name or a name scope for the operations. 

121 

122 Returns: 

123 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of 

124 rank-1 and has the same type as `input_tensor`. 

125 

126 Raises: 

127 ValueError: If `vocab_freq_table` is not provided, but `vocab_min_count`, 

128 `vocab_subsampling`, or `corpus_size` is specified. 

129 If `vocab_subsampling` and `corpus_size` are not both present or 

130 both absent. 

131 """ 

132 

133 if vocab_freq_table is None and ( 

134 vocab_min_count is not None 

135 or vocab_subsampling is not None 

136 or corpus_size is not None 

137 ): 

138 raise ValueError( 

139 "vocab_freq_table is not provided, but vocab_min_count={}, " 

140 "vocab_subsampling={}, or corpus_size={} is not None." 

141 "These settings are useless without a vocab_freq_table.".format( 

142 vocab_min_count, vocab_subsampling, corpus_size 

143 ) 

144 ) 

145 

146 if (vocab_subsampling is None) != (corpus_size is None): 

147 raise ValueError( 

148 "vocab_subsampling is {} while corpus_size is {} - both must be " 

149 "provided in order for subsampling to work.".format( 

150 vocab_subsampling, corpus_size 

151 ) 

152 ) 

153 

154 with tf.name_scope(name or "skip_gram_sample"): 

155 

156 input_tensor = _filter_input( 

157 input_tensor=input_tensor, 

158 vocab_freq_table=vocab_freq_table, 

159 vocab_min_count=vocab_min_count, 

160 vocab_subsampling=vocab_subsampling, 

161 corpus_size=corpus_size, 

162 seed=seed, 

163 ) 

164 

165 seed1, seed2 = tf.compat.v1.get_seed(seed) 

166 tokens, labels = _skip_gram_so.ops.addons_skip_gram_generate_candidates( 

167 input_tensor=input_tensor, 

168 min_skips=min_skips, 

169 max_skips=max_skips, 

170 start=start, 

171 limit=limit, 

172 emit_self_as_target=emit_self_as_target, 

173 # Note that seed here should be seed1! This is due to 

174 # GuardedPhiloxRandom's hard-coded attributes of "seed" and "seed2". 

175 seed=seed1, 

176 seed2=seed2, 

177 ) 

178 

179 # TODO(weiho): If the need arises, add support for sparse input_tensor that 

180 # figures out sentence boundaries, then calls 

181 # skip_gram_generate_candidates() on each sentence. 

182 

183 return tokens, labels 

184 

185 

186def skip_gram_sample_with_text_vocab( 

187 input_tensor: TensorLike, 

188 vocab_freq_file: str, 

189 vocab_token_index: FloatTensorLike = 0, 

190 vocab_token_dtype: Optional[AcceptableDTypes] = tf.dtypes.string, 

191 vocab_freq_index: FloatTensorLike = 1, 

192 vocab_freq_dtype: Optional[AcceptableDTypes] = tf.dtypes.float64, 

193 vocab_delimiter: str = ",", 

194 vocab_min_count: Optional[FloatTensorLike] = None, 

195 vocab_subsampling: Optional[FloatTensorLike] = None, 

196 corpus_size: Optional[FloatTensorLike] = None, 

197 min_skips: FloatTensorLike = 1, 

198 max_skips: FloatTensorLike = 5, 

199 start: FloatTensorLike = 0, 

200 limit: FloatTensorLike = -1, 

201 emit_self_as_target: bool = False, 

202 seed: Optional[FloatTensorLike] = None, 

203 name: Optional[str] = None, 

204) -> tf.Tensor: 

205 """Skip-gram sampling with a text vocabulary file. 

206 

207 Wrapper around `skip_gram_sample()` for use with a text vocabulary file. 

208 The vocabulary file is expected to be a plain-text file, with lines of 

209 `vocab_delimiter`-separated columns. The `vocab_token_index` column should 

210 contain the vocabulary term, while the `vocab_freq_index` column should 

211 contain the number of times that term occurs in the corpus. For example, 

212 with a text vocabulary file of: 

213 

214 ``` 

215 bonjour,fr,42 

216 hello,en,777 

217 hola,es,99 

218 ``` 

219 

220 You should set `vocab_delimiter=","`, `vocab_token_index=0`, and 

221 `vocab_freq_index=2`. 

222 

223 See `skip_gram_sample()` documentation for more details about the skip-gram 

224 sampling process. 

225 

226 Args: 

227 input_tensor: 

228 A rank-1 `Tensor` from which to generate skip-gram candidates. 

229 vocab_freq_file: 

230 `string` specifying full file path to the text vocab file. 

231 vocab_token_index: `int` specifying which column in the text vocab file 

232 contains the tokens. 

233 vocab_token_dtype: 

234 `DType` specifying the format of the tokens in the text vocab file. 

235 vocab_freq_index: `int` specifying which column in the text vocab file 

236 contains the frequency counts of the tokens. 

237 vocab_freq_dtype: `DType` specifying the format of the frequency counts 

238 in the text vocab file. 

239 vocab_delimiter: `string` specifying the delimiter used in the text vocab 

240 file. 

241 vocab_min_count: `int`, `float`, or scalar `Tensor` specifying 

242 minimum frequency threshold (from `vocab_freq_file`) for a token to be 

243 kept in `input_tensor`. This should correspond with `vocab_freq_dtype`. 

244 vocab_subsampling: (Optional) `float` specifying frequency proportion 

245 threshold for tokens from `input_tensor`. Tokens that occur more 

246 frequently will be randomly down-sampled. Reasonable starting values 

247 may be around 1e-3 or 1e-5. See Eq. 5 in http://arxiv.org/abs/1310.4546 

248 for more details. 

249 corpus_size: (Optional) `int`, `float`, or scalar `Tensor` specifying the 

250 total number of tokens in the corpus (e.g., sum of all the frequency 

251 counts of `vocab_freq_file`). Used with `vocab_subsampling` for 

252 down-sampling frequently occurring tokens. If this is specified, 

253 `vocab_freq_file` and `vocab_subsampling` must also be specified. 

254 If `corpus_size` is needed but not supplied, then it will be calculated 

255 from `vocab_freq_file`. You might want to supply your own value if you 

256 have already eliminated infrequent tokens from your vocabulary files 

257 (where frequency < vocab_min_count) to save memory in the internal 

258 token lookup table. Otherwise, the unused tokens' variables will waste 

259 memory. The user-supplied `corpus_size` value must be greater than or 

260 equal to the sum of all the frequency counts of `vocab_freq_file`. 

261 min_skips: `int` or scalar `Tensor` specifying the minimum window size to 

262 randomly use for each token. Must be >= 0 and <= `max_skips`. If 

263 `min_skips` and `max_skips` are both 0, the only label outputted will 

264 be the token itself. 

265 max_skips: `int` or scalar `Tensor` specifying the maximum window size to 

266 randomly use for each token. Must be >= 0. 

267 start: `int` or scalar `Tensor` specifying the position in `input_tensor` 

268 from which to start generating skip-gram candidates. 

269 limit: `int` or scalar `Tensor` specifying the maximum number of elements 

270 in `input_tensor` to use in generating skip-gram candidates. -1 means 

271 to use the rest of the `Tensor` after `start`. 

272 emit_self_as_target: `bool` or scalar `Tensor` specifying whether to emit 

273 each token as a label for itself. 

274 seed: (Optional) `int` used to create a random seed for window size and 

275 subsampling. See 

276 [`set_random_seed`](../../g3doc/python/constant_op.md#set_random_seed) 

277 for behavior. 

278 name: (Optional) A `string` name or a name scope for the operations. 

279 

280 Returns: 

281 A `tuple` containing (token, label) `Tensors`. Each output `Tensor` is of 

282 rank-1 and has the same type as `input_tensor`. 

283 

284 Raises: 

285 ValueError: If `vocab_token_index` or `vocab_freq_index` is less than 0 

286 or exceeds the number of columns in `vocab_freq_file`. 

287 If `vocab_token_index` and `vocab_freq_index` are both set to the same 

288 column. If any token in `vocab_freq_file` has a negative frequency. 

289 """ 

290 

291 if vocab_token_index < 0 or vocab_freq_index < 0: 

292 raise ValueError( 

293 "vocab_token_index={} and vocab_freq_index={} must both be >= 0.".format( 

294 vocab_token_index, vocab_freq_index 

295 ) 

296 ) 

297 if vocab_token_index == vocab_freq_index: 

298 raise ValueError( 

299 "vocab_token_index and vocab_freq_index should be different, " 

300 "but are both {}.".format(vocab_token_index) 

301 ) 

302 

303 # Iterates through the vocab file and calculates the number of vocab terms as 

304 # well as the total corpus size (by summing the frequency counts of all the 

305 # vocab terms). 

306 calculated_corpus_size = 0.0 

307 vocab_size = 0 

308 with tf.io.gfile.GFile(vocab_freq_file, mode="r") as f: 

309 reader = csv.reader(f, delimiter=vocab_delimiter) 

310 for row in reader: 

311 if vocab_token_index >= len(row) or vocab_freq_index >= len(row): 

312 raise ValueError( 

313 "Row in vocab file only has {} columns, " 

314 "so vocab_token_index={} or " 

315 "vocab_freq_index={} is out of bounds. Row content: {}".format( 

316 len(row), vocab_token_index, vocab_freq_index, row 

317 ) 

318 ) 

319 vocab_size += 1 

320 freq = vocab_freq_dtype.as_numpy_dtype(row[vocab_freq_index]) 

321 if freq < 0: 

322 raise ValueError( 

323 "Row in vocab file has negative frequency of {}. " 

324 "Row content: {}".format(freq, row) 

325 ) 

326 # Note: tokens whose frequencies are below vocab_min_count will still 

327 # contribute to the total corpus size used for vocab subsampling. 

328 calculated_corpus_size += freq 

329 

330 if not corpus_size: 

331 corpus_size = calculated_corpus_size 

332 elif calculated_corpus_size - corpus_size > 1e-6: 

333 raise ValueError( 

334 "`corpus_size`={} must be greater than or equal to the " 

335 "sum of all the frequency counts ({}) of `vocab_freq_file` ({}).".format( 

336 corpus_size, calculated_corpus_size, vocab_freq_file 

337 ) 

338 ) 

339 

340 vocab_freq_table = tf.lookup.StaticHashTable( 

341 tf.lookup.TextFileInitializer( 

342 filename=vocab_freq_file, 

343 key_dtype=vocab_token_dtype, 

344 key_index=vocab_token_index, 

345 value_dtype=vocab_freq_dtype, 

346 value_index=vocab_freq_index, 

347 vocab_size=vocab_size, 

348 delimiter=vocab_delimiter, 

349 ), 

350 # For vocab terms not in vocab file, use a default value of -1. 

351 default_value=-1, 

352 ) 

353 

354 return skip_gram_sample( 

355 input_tensor, 

356 min_skips=min_skips, 

357 max_skips=max_skips, 

358 start=start, 

359 limit=limit, 

360 emit_self_as_target=emit_self_as_target, 

361 vocab_freq_table=vocab_freq_table, 

362 vocab_min_count=vocab_min_count, 

363 vocab_subsampling=vocab_subsampling, 

364 # corpus_size is not used unless vocab_subsampling is specified. 

365 corpus_size=None if vocab_subsampling is None else corpus_size, 

366 seed=seed, 

367 name=name, 

368 ) 

369 

370 

371def _filter_input( 

372 input_tensor, 

373 vocab_freq_table, 

374 vocab_min_count, 

375 vocab_subsampling, 

376 corpus_size, 

377 seed, 

378): 

379 input_tensor = tf.convert_to_tensor(input_tensor) 

380 """Filters input tensor based on vocab freq, threshold, and subsampling.""" 

381 if vocab_freq_table is None: 

382 return input_tensor 

383 

384 if not isinstance(vocab_freq_table, tf.lookup.StaticHashTable): 

385 raise ValueError( 

386 "vocab_freq_table must be a subclass of " 

387 "InitializableLookupTableBase (such as HashTable) instead of type " 

388 "{}.".format(type(vocab_freq_table)) 

389 ) 

390 

391 with tf.name_scope("filter_vocab"): 

392 freq = vocab_freq_table.lookup(input_tensor) 

393 # Filters out elements in input_tensor that are not found in 

394 # vocab_freq_table (table returns a default value of -1 specified above when 

395 # an element is not found). 

396 mask = tf.math.not_equal(freq, vocab_freq_table.default_value) 

397 

398 # Filters out elements whose vocab frequencies are less than the threshold. 

399 if vocab_min_count is not None: 

400 cast_threshold = tf.cast(vocab_min_count, freq.dtype) 

401 mask = tf.math.logical_and( 

402 mask, tf.math.greater_equal(freq, cast_threshold) 

403 ) 

404 

405 input_tensor = tf.boolean_mask(input_tensor, mask) 

406 freq = tf.boolean_mask(freq, mask) 

407 

408 if not vocab_subsampling: 

409 return input_tensor 

410 

411 if vocab_subsampling < 0 or vocab_subsampling > 1: 

412 raise ValueError( 

413 "Invalid vocab_subsampling={} - it should be within range [0, 1].".format( 

414 vocab_subsampling 

415 ) 

416 ) 

417 

418 # Subsamples the input tokens based on vocabulary frequency and 

419 # vocab_subsampling threshold (ie randomly discard commonly appearing 

420 # tokens). 

421 with tf.name_scope("subsample_vocab"): 

422 corpus_size = tf.cast(corpus_size, tf.dtypes.float64) 

423 freq = tf.cast(freq, tf.dtypes.float64) 

424 vocab_subsampling = tf.cast(vocab_subsampling, tf.dtypes.float64) 

425 

426 # From tensorflow_models/tutorials/embedding/word2vec_kernels.cc, which is 

427 # suppose to correlate with Eq. 5 in http://arxiv.org/abs/1310.4546. 

428 keep_prob = (tf.math.sqrt(freq / (vocab_subsampling * corpus_size)) + 1.0) * ( 

429 vocab_subsampling * corpus_size / freq 

430 ) 

431 random_prob = tf.random.uniform( 

432 tf.shape(freq), minval=0, maxval=1, dtype=tf.dtypes.float64, seed=seed 

433 ) 

434 

435 mask = tf.math.less_equal(random_prob, keep_prob) 

436 return tf.boolean_mask(input_tensor, mask)