Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/warm_starting_util.py: 15%

150 statements  

« prev     ^ index     » next       coverage.py v7.4.0, created at 2024-01-03 07:57 +0000

1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14# ============================================================================== 

15"""Utilities to warm-start TF.Learn Estimators.""" 

16 

17import collections 

18 

19from tensorflow.python.framework import errors 

20from tensorflow.python.framework import ops 

21from tensorflow.python.ops import state_ops 

22from tensorflow.python.ops import variable_scope 

23from tensorflow.python.ops import variables as variables_lib 

24from tensorflow.python.platform import tf_logging as logging 

25from tensorflow.python.training import checkpoint_ops 

26from tensorflow.python.training import checkpoint_utils 

27from tensorflow.python.training import saver as saver_lib 

28from tensorflow.python.training.saving import saveable_object_util 

29from tensorflow.python.util.tf_export import tf_export 

30 

31 

32@tf_export(v1=["train.VocabInfo"]) 

33class VocabInfo( 

34 collections.namedtuple("VocabInfo", [ 

35 "new_vocab", 

36 "new_vocab_size", 

37 "num_oov_buckets", 

38 "old_vocab", 

39 "old_vocab_size", 

40 "backup_initializer", 

41 "axis", 

42 ])): 

43 """Vocabulary information for warm-starting. 

44 

45 See `tf.estimator.WarmStartSettings` for examples of using 

46 VocabInfo to warm-start. 

47 

48 Args: 

49 new_vocab: [Required] A path to the new vocabulary file (used with the model 

50 to be trained). 

51 new_vocab_size: [Required] An integer indicating how many entries of the new 

52 vocabulary will used in training. 

53 num_oov_buckets: [Required] An integer indicating how many OOV buckets are 

54 associated with the vocabulary. 

55 old_vocab: [Required] A path to the old vocabulary file (used with the 

56 checkpoint to be warm-started from). 

57 old_vocab_size: [Optional] An integer indicating how many entries of the old 

58 vocabulary were used in the creation of the checkpoint. If not provided, 

59 the entire old vocabulary will be used. 

60 backup_initializer: [Optional] A variable initializer used for variables 

61 corresponding to new vocabulary entries and OOV. If not provided, these 

62 entries will be zero-initialized. 

63 axis: [Optional] Denotes what axis the vocabulary corresponds to. The 

64 default, 0, corresponds to the most common use case (embeddings or 

65 linear weights for binary classification / regression). An axis of 1 

66 could be used for warm-starting output layers with class vocabularies. 

67 

68 Returns: 

69 A `VocabInfo` which represents the vocabulary information for warm-starting. 

70 

71 Raises: 

72 ValueError: `axis` is neither 0 or 1. 

73 

74 Example Usage: 

75```python 

76 embeddings_vocab_info = tf.VocabInfo( 

77 new_vocab='embeddings_vocab', 

78 new_vocab_size=100, 

79 num_oov_buckets=1, 

80 old_vocab='pretrained_embeddings_vocab', 

81 old_vocab_size=10000, 

82 backup_initializer=tf.compat.v1.truncated_normal_initializer( 

83 mean=0.0, stddev=(1 / math.sqrt(embedding_dim))), 

84 axis=0) 

85 

86 softmax_output_layer_kernel_vocab_info = tf.VocabInfo( 

87 new_vocab='class_vocab', 

88 new_vocab_size=5, 

89 num_oov_buckets=0, # No OOV for classes. 

90 old_vocab='old_class_vocab', 

91 old_vocab_size=8, 

92 backup_initializer=tf.compat.v1.glorot_uniform_initializer(), 

93 axis=1) 

94 

95 softmax_output_layer_bias_vocab_info = tf.VocabInfo( 

96 new_vocab='class_vocab', 

97 new_vocab_size=5, 

98 num_oov_buckets=0, # No OOV for classes. 

99 old_vocab='old_class_vocab', 

100 old_vocab_size=8, 

101 backup_initializer=tf.compat.v1.zeros_initializer(), 

102 axis=0) 

103 

104 #Currently, only axis=0 and axis=1 are supported. 

105 ``` 

106 """ 

107 

108 def __new__(cls, 

109 new_vocab, 

110 new_vocab_size, 

111 num_oov_buckets, 

112 old_vocab, 

113 old_vocab_size=-1, 

114 backup_initializer=None, 

115 axis=0): 

116 if axis != 0 and axis != 1: 

117 raise ValueError("The only supported values for the axis argument are 0 " 

118 "and 1. Provided axis: {}".format(axis)) 

119 

120 return super(VocabInfo, cls).__new__( 

121 cls, 

122 new_vocab, 

123 new_vocab_size, 

124 num_oov_buckets, 

125 old_vocab, 

126 old_vocab_size, 

127 backup_initializer, 

128 axis, 

129 ) 

130 

131 

132def _infer_var_name(var): 

133 """Returns name of the `var`. 

134 

135 Args: 

136 var: A list. The list can contain either of the following: 

137 (i) A single `Variable` 

138 (ii) A single `ResourceVariable` 

139 (iii) Multiple `Variable` objects which must be slices of the same larger 

140 variable. 

141 (iv) A single `PartitionedVariable` 

142 

143 Returns: 

144 Name of the `var` 

145 """ 

146 name_to_var_dict = saveable_object_util.op_list_to_dict(var) 

147 if len(name_to_var_dict) > 1: 

148 raise TypeError("`var` = %s passed as arg violates the constraints. " 

149 "name_to_var_dict = %s" % (var, name_to_var_dict)) 

150 return list(name_to_var_dict.keys())[0] 

151 

152 

153def _get_var_info(var, prev_tensor_name=None): 

154 """Helper method for standarizing Variable and naming. 

155 

156 Args: 

157 var: Current graph's variable that needs to be warm-started (initialized). 

158 Can be either of the following: (i) `Variable` (ii) `ResourceVariable` 

159 (iii) list of `Variable`: The list must contain slices of the same larger 

160 variable. (iv) `PartitionedVariable` 

161 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 

162 None, we lookup tensor with same name as given `var`. 

163 

164 Returns: 

165 A tuple of the Tensor name and var. 

166 """ 

167 if checkpoint_utils._is_variable(var): # pylint: disable=protected-access 

168 current_var_name = _infer_var_name([var]) 

169 elif (isinstance(var, list) and 

170 all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access 

171 current_var_name = _infer_var_name(var) 

172 elif isinstance(var, variables_lib.PartitionedVariable): 

173 current_var_name = _infer_var_name([var]) 

174 var = var._get_variable_list() # pylint: disable=protected-access 

175 else: 

176 raise TypeError( 

177 "var MUST be one of the following: a Variable, list of Variable or " 

178 "PartitionedVariable, but is {}".format(type(var))) 

179 if not prev_tensor_name: 

180 # Assume tensor name remains the same. 

181 prev_tensor_name = current_var_name 

182 

183 return prev_tensor_name, var 

184 

185 

186# pylint: disable=protected-access 

187# Accesses protected members of tf.Variable to reset the variable's internal 

188# state. 

189def _warm_start_var_with_vocab(var, 

190 current_vocab_path, 

191 current_vocab_size, 

192 prev_ckpt, 

193 prev_vocab_path, 

194 previous_vocab_size=-1, 

195 current_oov_buckets=0, 

196 prev_tensor_name=None, 

197 initializer=None, 

198 axis=0): 

199 """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`. 

200 

201 Use this method when the `var` is backed by vocabulary. This method stitches 

202 the given `var` such that values corresponding to individual features in the 

203 vocabulary remain consistent irrespective of changing order of the features 

204 between old and new vocabularies. 

205 

206 Args: 

207 var: Current graph's variable that needs to be warm-started (initialized). 

208 Can be either of the following: 

209 (i) `Variable` 

210 (ii) `ResourceVariable` 

211 (iii) list of `Variable`: The list must contain slices of the same larger 

212 variable. 

213 (iv) `PartitionedVariable` 

214 current_vocab_path: Path to the vocab file used for the given `var`. 

215 current_vocab_size: An `int` specifying the number of entries in the current 

216 vocab. 

217 prev_ckpt: A string specifying the directory with checkpoint file(s) or path 

218 to checkpoint. The given checkpoint must have tensor with name 

219 `prev_tensor_name` (if not None) or tensor with name same as given `var`. 

220 prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`. 

221 previous_vocab_size: If provided, will constrain previous vocab to the first 

222 `previous_vocab_size` entries. -1 means use the entire previous vocab. 

223 current_oov_buckets: An `int` specifying the number of out-of-vocabulary 

224 buckets used for given `var`. 

225 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If 

226 None, we lookup tensor with same name as given `var`. 

227 initializer: Variable initializer to be used for missing entries. If None, 

228 missing entries will be zero-initialized. 

229 axis: Axis of the variable that the provided vocabulary corresponds to. 

230 

231 Raises: 

232 ValueError: If required args are not provided. 

233 """ 

234 if not (current_vocab_path and current_vocab_size and prev_ckpt and 

235 prev_vocab_path): 

236 raise ValueError("Invalid args: Must provide all of [current_vocab_path, " 

237 "current_vocab_size, prev_ckpt, prev_vocab_path}.") 

238 if checkpoint_utils._is_variable(var): 

239 var = [var] 

240 elif (isinstance(var, list) and 

241 all(checkpoint_utils._is_variable(v) for v in var)): 

242 var = var 

243 elif isinstance(var, variables_lib.PartitionedVariable): 

244 var = var._get_variable_list() 

245 else: 

246 raise TypeError( 

247 "var MUST be one of the following: a Variable, list of Variable or " 

248 "PartitionedVariable, but is {}".format(type(var))) 

249 

250 if not prev_tensor_name: 

251 # Assume tensor name remains the same. 

252 prev_tensor_name = _infer_var_name(var) 

253 

254 total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var) 

255 for v in var: 

256 v_shape = v.get_shape().as_list() 

257 slice_info = v._get_save_slice_info() 

258 partition_info = None 

259 if slice_info: 

260 partition_info = variable_scope._PartitionInfo( 

261 full_shape=slice_info.full_shape, var_offset=slice_info.var_offset) 

262 

263 if axis == 0: 

264 new_row_vocab_size = current_vocab_size 

265 new_col_vocab_size = v_shape[1] 

266 old_row_vocab_size = previous_vocab_size 

267 old_row_vocab_file = prev_vocab_path 

268 new_row_vocab_file = current_vocab_path 

269 old_col_vocab_file = None 

270 new_col_vocab_file = None 

271 num_row_oov_buckets = current_oov_buckets 

272 num_col_oov_buckets = 0 

273 elif axis == 1: 

274 # Note that we must compute this value across all partitions, whereas 

275 # in the axis = 0 case, we can simply use v_shape[1] because we don't 

276 # allow partitioning across axis = 1. 

277 new_row_vocab_size = total_v_first_axis 

278 new_col_vocab_size = current_vocab_size 

279 old_row_vocab_size = -1 

280 old_row_vocab_file = None 

281 new_row_vocab_file = None 

282 old_col_vocab_file = prev_vocab_path 

283 new_col_vocab_file = current_vocab_path 

284 num_row_oov_buckets = 0 

285 num_col_oov_buckets = current_oov_buckets 

286 else: 

287 raise ValueError("The only supported values for the axis argument are 0 " 

288 "and 1. Provided axis: {}".format(axis)) 

289 

290 init = checkpoint_ops._load_and_remap_matrix_initializer( 

291 ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt), 

292 old_tensor_name=prev_tensor_name, 

293 new_row_vocab_size=new_row_vocab_size, 

294 new_col_vocab_size=new_col_vocab_size, 

295 old_row_vocab_size=old_row_vocab_size, 

296 old_row_vocab_file=old_row_vocab_file, 

297 new_row_vocab_file=new_row_vocab_file, 

298 old_col_vocab_file=old_col_vocab_file, 

299 new_col_vocab_file=new_col_vocab_file, 

300 num_row_oov_buckets=num_row_oov_buckets, 

301 num_col_oov_buckets=num_col_oov_buckets, 

302 initializer=initializer) 

303 new_init_val = ops.convert_to_tensor( 

304 init(shape=v_shape, partition_info=partition_info)) 

305 v._initializer_op = state_ops.assign(v, new_init_val) 

306 

307 

308# pylint: enable=protected-access 

309 

310 

311def _get_grouped_variables(vars_to_warm_start): 

312 """Collects and groups (possibly partitioned) variables into a dictionary. 

313 

314 The variables can be provided explicitly through vars_to_warm_start, or they 

315 are retrieved from collections (see below). 

316 

317 Args: 

318 vars_to_warm_start: One of the following: 

319 

320 - A regular expression (string) that captures which variables to 

321 warm-start (see tf.compat.v1.get_collection). This expression will 

322 only consider variables in the TRAINABLE_VARIABLES collection. 

323 - A list of strings, each representing a full variable name to warm-start. 

324 These will consider variables in GLOBAL_VARIABLES collection. 

325 - A list of Variables to warm-start. 

326 - `None`, in which case all variables in TRAINABLE_VARIABLES will be used. 

327 Returns: 

328 A dictionary mapping variable names (strings) to lists of Variables. 

329 Raises: 

330 ValueError: If vars_to_warm_start is not a string, `None`, a list of 

331 `Variables`, or a list of strings. 

332 """ 

333 # TODO(b/143899805): Remove unicode checks when deprecating Python2. 

334 if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None: 

335 # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match 

336 # everything (in TRAINABLE_VARIABLES) here. 

337 logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.") 

338 list_of_vars = ops.get_collection( 

339 ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start) 

340 elif isinstance(vars_to_warm_start, list): 

341 if all(isinstance(v, str) for v in vars_to_warm_start): 

342 list_of_vars = [] 

343 for v in vars_to_warm_start: 

344 list_of_vars += ops.get_collection( 

345 ops.GraphKeys.GLOBAL_VARIABLES, scope=v) 

346 elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access 

347 list_of_vars = vars_to_warm_start 

348 else: 

349 raise ValueError("If `vars_to_warm_start` is a list, it must be all " 

350 "`Variable` or all `str`. Given types are {}".format( 

351 [type(v) for v in vars_to_warm_start])) 

352 else: 

353 raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given " 

354 "type is {}".format(type(vars_to_warm_start))) 

355 # We have to deal with partitioned variables, since get_collection flattens 

356 # out the list. 

357 grouped_variables = {} 

358 for v in list_of_vars: 

359 t = [v] if not isinstance(v, list) else v 

360 var_name = _infer_var_name(t) 

361 grouped_variables.setdefault(var_name, []).append(v) 

362 

363 return grouped_variables 

364 

365 

366def _get_object_checkpoint_renames(path, variable_names): 

367 """Returns a dictionary mapping variable names to checkpoint keys. 

368 

369 The warm-starting utility expects variable names to match with the variable 

370 names in the checkpoint. For object-based checkpoints, the variable names 

371 and names in the checkpoint are different. Thus, for object-based checkpoints, 

372 this function is used to obtain the map from variable names to checkpoint 

373 keys. 

374 

375 Args: 

376 path: path to checkpoint directory or file. 

377 variable_names: list of variable names to load from the checkpoint. 

378 

379 Returns: 

380 If the checkpoint is object-based, this function returns a map from variable 

381 names to their corresponding checkpoint keys. 

382 If the checkpoint is name-based, this returns an empty dict. 

383 

384 Raises: 

385 ValueError: If the object-based checkpoint is missing variables. 

386 """ 

387 fname = checkpoint_utils._get_checkpoint_filename(path) # pylint: disable=protected-access 

388 try: 

389 names_to_keys = saver_lib.object_graph_key_mapping(fname) 

390 except errors.NotFoundError: 

391 # If an error is raised from `object_graph_key_mapping`, then the 

392 # checkpoint is name-based. There are no renames, so return an empty dict. 

393 return {} 

394 

395 missing_names = set(variable_names) - set(names_to_keys.keys()) 

396 if missing_names: 

397 raise ValueError( 

398 "Attempting to warm-start from an object-based checkpoint, but found " 

399 "that the checkpoint did not contain values for all variables. The " 

400 "following variables were missing: {}" 

401 .format(missing_names)) 

402 return {name: names_to_keys[name] for name in variable_names} 

403 

404 

405@tf_export(v1=["train.warm_start"]) 

406def warm_start(ckpt_to_initialize_from, 

407 vars_to_warm_start=".*", 

408 var_name_to_vocab_info=None, 

409 var_name_to_prev_var_name=None): 

410 """Warm-starts a model using the given settings. 

411 

412 If you are using a tf.estimator.Estimator, this will automatically be called 

413 during training. 

414 

415 Args: 

416 ckpt_to_initialize_from: [Required] A string specifying the directory with 

417 checkpoint file(s) or path to checkpoint from which to warm-start the 

418 model parameters. 

419 vars_to_warm_start: [Optional] One of the following: 

420 

421 - A regular expression (string) that captures which variables to 

422 warm-start (see tf.compat.v1.get_collection). This expression will only 

423 consider variables in the TRAINABLE_VARIABLES collection -- if you need 

424 to warm-start non_TRAINABLE vars (such as optimizer accumulators or 

425 batch norm statistics), please use the below option. 

426 - A list of strings, each a regex scope provided to 

427 tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see 

428 tf.compat.v1.get_collection). For backwards compatibility reasons, 

429 this is separate from the single-string argument type. 

430 - A list of Variables to warm-start. If you do not have access to the 

431 `Variable` objects at the call site, please use the above option. 

432 - `None`, in which case only TRAINABLE variables specified in 

433 `var_name_to_vocab_info` will be warm-started. 

434 

435 Defaults to `'.*'`, which warm-starts all variables in the 

436 TRAINABLE_VARIABLES collection. Note that this excludes variables such 

437 as accumulators and moving statistics from batch norm. 

438 var_name_to_vocab_info: [Optional] Dict of variable names (strings) to 

439 `tf.estimator.VocabInfo`. The variable names should be "full" variables, 

440 not the names of the partitions. If not explicitly provided, the variable 

441 is assumed to have no (changes to) vocabulary. 

442 var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to 

443 name of the previously-trained variable in `ckpt_to_initialize_from`. If 

444 not explicitly provided, the name of the variable is assumed to be same 

445 between previous checkpoint and current model. Note that this has no 

446 effect on the set of variables that is warm-started, and only controls 

447 name mapping (use `vars_to_warm_start` for controlling what variables to 

448 warm-start). 

449 

450 Raises: 

451 ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo 

452 configuration for variable names that are not used. This is to ensure 

453 a stronger check for variable configuration than relying on users to 

454 examine the logs. 

455 """ 

456 logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from)) 

457 grouped_variables = _get_grouped_variables(vars_to_warm_start) 

458 

459 if var_name_to_vocab_info is None: 

460 var_name_to_vocab_info = {} 

461 

462 if not var_name_to_prev_var_name: 

463 # Detect whether the checkpoint is object-based, in which case the 

464 # var_name_to_prev_var_name dictionary should map variable names to 

465 # checkpoint keys. If the user has specified var_name_to_prev_var_name, we 

466 # do not override it. 

467 var_name_to_prev_var_name = _get_object_checkpoint_renames( 

468 ckpt_to_initialize_from, grouped_variables.keys()) 

469 

470 warmstarted_count = 0 

471 

472 # Keep track of which var_names in var_name_to_prev_var_name and 

473 # var_name_to_vocab_info have been used. Err on the safer side by throwing an 

474 # exception if any are unused by the end of the loop. It is easy to misname 

475 # a variable during this configuration, in which case without this check, we 

476 # would fail to warm-start silently. 

477 prev_var_name_used = set() 

478 vocab_info_used = set() 

479 

480 # Group the vocabless vars into one call to init_from_checkpoint. 

481 vocabless_vars = {} 

482 for var_name, variable in grouped_variables.items(): 

483 prev_var_name = var_name_to_prev_var_name.get(var_name) 

484 if prev_var_name: 

485 prev_var_name_used.add(var_name) 

486 vocab_info = var_name_to_vocab_info.get(var_name) 

487 if vocab_info: 

488 vocab_info_used.add(var_name) 

489 warmstarted_count += 1 

490 logging.debug( 

491 "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}" 

492 " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}" 

493 " initializer: {}".format( 

494 var_name, vocab_info.new_vocab, vocab_info.new_vocab_size, 

495 vocab_info.old_vocab, (vocab_info.old_vocab_size if 

496 vocab_info.old_vocab_size > 0 else "All"), 

497 vocab_info.num_oov_buckets, prev_var_name or "Unchanged", 

498 vocab_info.backup_initializer or "zero-initialized")) 

499 _warm_start_var_with_vocab( 

500 variable, 

501 current_vocab_path=vocab_info.new_vocab, 

502 current_vocab_size=vocab_info.new_vocab_size, 

503 prev_ckpt=ckpt_to_initialize_from, 

504 prev_vocab_path=vocab_info.old_vocab, 

505 previous_vocab_size=vocab_info.old_vocab_size, 

506 current_oov_buckets=vocab_info.num_oov_buckets, 

507 prev_tensor_name=prev_var_name, 

508 initializer=vocab_info.backup_initializer, 

509 axis=vocab_info.axis) 

510 else: 

511 # For the special value of vars_to_warm_start = None, 

512 # we only warm-start variables with explicitly specified vocabularies. 

513 if vars_to_warm_start: 

514 warmstarted_count += 1 

515 logging.debug("Warm-starting variable: {}; prev_var_name: {}".format( 

516 var_name, prev_var_name or "Unchanged")) 

517 # Because we use a default empty list in grouped_variables, single 

518 # unpartitioned variables will be lists here, which we rectify in order 

519 # for init_from_checkpoint logic to work correctly. 

520 if len(variable) == 1: 

521 variable = variable[0] 

522 prev_tensor_name, var = _get_var_info(variable, prev_var_name) 

523 if prev_tensor_name in vocabless_vars: 

524 # The API for checkpoint_utils.init_from_checkpoint accepts a mapping 

525 # from checkpoint tensor names to model variable names, so it does not 

526 # support warm-starting two variables from the same tensor. Our work- 

527 # around is to run init_from_checkpoint multiple times, each time we 

528 # encounter a new variable that should be initialized by a previously- 

529 # used tensor. 

530 logging.debug("Requested prev_var_name {} initialize both {} and {}; " 

531 "calling init_from_checkpoint.".format( 

532 prev_tensor_name, 

533 vocabless_vars[prev_tensor_name], 

534 var)) 

535 checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, 

536 vocabless_vars) 

537 vocabless_vars.clear() 

538 vocabless_vars[prev_tensor_name] = var 

539 

540 if vocabless_vars: 

541 checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from, 

542 vocabless_vars) 

543 prev_var_name_not_used = set( 

544 var_name_to_prev_var_name.keys()) - prev_var_name_used 

545 vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used 

546 

547 logging.info("Warm-started %d variables.", warmstarted_count) 

548 

549 if prev_var_name_not_used: 

550 raise ValueError( 

551 "You provided the following variables in " 

552 "var_name_to_prev_var_name that were not used: " 

553 "{0}. Perhaps you misspelled them? Here is the list of viable " 

554 "variable names: {1}".format(prev_var_name_not_used, 

555 grouped_variables.keys())) 

556 if vocab_info_not_used: 

557 raise ValueError( 

558 "You provided the following variables in " 

559 "var_name_to_vocab_info that were not used: {0}. " 

560 " Perhaps you misspelled them? Here is the list of viable variable " 

561 "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))