Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/warm_starting_util.py: 15%
150 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to warm-start TF.Learn Estimators."""
17import collections
19from tensorflow.python.framework import errors
20from tensorflow.python.framework import ops
21from tensorflow.python.ops import state_ops
22from tensorflow.python.ops import variable_scope
23from tensorflow.python.ops import variables as variables_lib
24from tensorflow.python.platform import tf_logging as logging
25from tensorflow.python.training import checkpoint_ops
26from tensorflow.python.training import checkpoint_utils
27from tensorflow.python.training import saver as saver_lib
28from tensorflow.python.training.saving import saveable_object_util
29from tensorflow.python.util.tf_export import tf_export
32@tf_export(v1=["train.VocabInfo"])
33class VocabInfo(
34 collections.namedtuple("VocabInfo", [
35 "new_vocab",
36 "new_vocab_size",
37 "num_oov_buckets",
38 "old_vocab",
39 "old_vocab_size",
40 "backup_initializer",
41 "axis",
42 ])):
43 """Vocabulary information for warm-starting.
45 See `tf.estimator.WarmStartSettings` for examples of using
46 VocabInfo to warm-start.
48 Args:
49 new_vocab: [Required] A path to the new vocabulary file (used with the model
50 to be trained).
51 new_vocab_size: [Required] An integer indicating how many entries of the new
52 vocabulary will used in training.
53 num_oov_buckets: [Required] An integer indicating how many OOV buckets are
54 associated with the vocabulary.
55 old_vocab: [Required] A path to the old vocabulary file (used with the
56 checkpoint to be warm-started from).
57 old_vocab_size: [Optional] An integer indicating how many entries of the old
58 vocabulary were used in the creation of the checkpoint. If not provided,
59 the entire old vocabulary will be used.
60 backup_initializer: [Optional] A variable initializer used for variables
61 corresponding to new vocabulary entries and OOV. If not provided, these
62 entries will be zero-initialized.
63 axis: [Optional] Denotes what axis the vocabulary corresponds to. The
64 default, 0, corresponds to the most common use case (embeddings or
65 linear weights for binary classification / regression). An axis of 1
66 could be used for warm-starting output layers with class vocabularies.
68 Returns:
69 A `VocabInfo` which represents the vocabulary information for warm-starting.
71 Raises:
72 ValueError: `axis` is neither 0 or 1.
74 Example Usage:
75```python
76 embeddings_vocab_info = tf.VocabInfo(
77 new_vocab='embeddings_vocab',
78 new_vocab_size=100,
79 num_oov_buckets=1,
80 old_vocab='pretrained_embeddings_vocab',
81 old_vocab_size=10000,
82 backup_initializer=tf.compat.v1.truncated_normal_initializer(
83 mean=0.0, stddev=(1 / math.sqrt(embedding_dim))),
84 axis=0)
86 softmax_output_layer_kernel_vocab_info = tf.VocabInfo(
87 new_vocab='class_vocab',
88 new_vocab_size=5,
89 num_oov_buckets=0, # No OOV for classes.
90 old_vocab='old_class_vocab',
91 old_vocab_size=8,
92 backup_initializer=tf.compat.v1.glorot_uniform_initializer(),
93 axis=1)
95 softmax_output_layer_bias_vocab_info = tf.VocabInfo(
96 new_vocab='class_vocab',
97 new_vocab_size=5,
98 num_oov_buckets=0, # No OOV for classes.
99 old_vocab='old_class_vocab',
100 old_vocab_size=8,
101 backup_initializer=tf.compat.v1.zeros_initializer(),
102 axis=0)
104 #Currently, only axis=0 and axis=1 are supported.
105 ```
106 """
108 def __new__(cls,
109 new_vocab,
110 new_vocab_size,
111 num_oov_buckets,
112 old_vocab,
113 old_vocab_size=-1,
114 backup_initializer=None,
115 axis=0):
116 if axis != 0 and axis != 1:
117 raise ValueError("The only supported values for the axis argument are 0 "
118 "and 1. Provided axis: {}".format(axis))
120 return super(VocabInfo, cls).__new__(
121 cls,
122 new_vocab,
123 new_vocab_size,
124 num_oov_buckets,
125 old_vocab,
126 old_vocab_size,
127 backup_initializer,
128 axis,
129 )
132def _infer_var_name(var):
133 """Returns name of the `var`.
135 Args:
136 var: A list. The list can contain either of the following:
137 (i) A single `Variable`
138 (ii) A single `ResourceVariable`
139 (iii) Multiple `Variable` objects which must be slices of the same larger
140 variable.
141 (iv) A single `PartitionedVariable`
143 Returns:
144 Name of the `var`
145 """
146 name_to_var_dict = saveable_object_util.op_list_to_dict(var)
147 if len(name_to_var_dict) > 1:
148 raise TypeError("`var` = %s passed as arg violates the constraints. "
149 "name_to_var_dict = %s" % (var, name_to_var_dict))
150 return list(name_to_var_dict.keys())[0]
153def _get_var_info(var, prev_tensor_name=None):
154 """Helper method for standarizing Variable and naming.
156 Args:
157 var: Current graph's variable that needs to be warm-started (initialized).
158 Can be either of the following: (i) `Variable` (ii) `ResourceVariable`
159 (iii) list of `Variable`: The list must contain slices of the same larger
160 variable. (iv) `PartitionedVariable`
161 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
162 None, we lookup tensor with same name as given `var`.
164 Returns:
165 A tuple of the Tensor name and var.
166 """
167 if checkpoint_utils._is_variable(var): # pylint: disable=protected-access
168 current_var_name = _infer_var_name([var])
169 elif (isinstance(var, list) and
170 all(checkpoint_utils._is_variable(v) for v in var)): # pylint: disable=protected-access
171 current_var_name = _infer_var_name(var)
172 elif isinstance(var, variables_lib.PartitionedVariable):
173 current_var_name = _infer_var_name([var])
174 var = var._get_variable_list() # pylint: disable=protected-access
175 else:
176 raise TypeError(
177 "var MUST be one of the following: a Variable, list of Variable or "
178 "PartitionedVariable, but is {}".format(type(var)))
179 if not prev_tensor_name:
180 # Assume tensor name remains the same.
181 prev_tensor_name = current_var_name
183 return prev_tensor_name, var
186# pylint: disable=protected-access
187# Accesses protected members of tf.Variable to reset the variable's internal
188# state.
189def _warm_start_var_with_vocab(var,
190 current_vocab_path,
191 current_vocab_size,
192 prev_ckpt,
193 prev_vocab_path,
194 previous_vocab_size=-1,
195 current_oov_buckets=0,
196 prev_tensor_name=None,
197 initializer=None,
198 axis=0):
199 """Warm-starts given variable from `prev_tensor_name` tensor in `prev_ckpt`.
201 Use this method when the `var` is backed by vocabulary. This method stitches
202 the given `var` such that values corresponding to individual features in the
203 vocabulary remain consistent irrespective of changing order of the features
204 between old and new vocabularies.
206 Args:
207 var: Current graph's variable that needs to be warm-started (initialized).
208 Can be either of the following:
209 (i) `Variable`
210 (ii) `ResourceVariable`
211 (iii) list of `Variable`: The list must contain slices of the same larger
212 variable.
213 (iv) `PartitionedVariable`
214 current_vocab_path: Path to the vocab file used for the given `var`.
215 current_vocab_size: An `int` specifying the number of entries in the current
216 vocab.
217 prev_ckpt: A string specifying the directory with checkpoint file(s) or path
218 to checkpoint. The given checkpoint must have tensor with name
219 `prev_tensor_name` (if not None) or tensor with name same as given `var`.
220 prev_vocab_path: Path to the vocab file used for the tensor in `prev_ckpt`.
221 previous_vocab_size: If provided, will constrain previous vocab to the first
222 `previous_vocab_size` entries. -1 means use the entire previous vocab.
223 current_oov_buckets: An `int` specifying the number of out-of-vocabulary
224 buckets used for given `var`.
225 prev_tensor_name: Name of the tensor to lookup in provided `prev_ckpt`. If
226 None, we lookup tensor with same name as given `var`.
227 initializer: Variable initializer to be used for missing entries. If None,
228 missing entries will be zero-initialized.
229 axis: Axis of the variable that the provided vocabulary corresponds to.
231 Raises:
232 ValueError: If required args are not provided.
233 """
234 if not (current_vocab_path and current_vocab_size and prev_ckpt and
235 prev_vocab_path):
236 raise ValueError("Invalid args: Must provide all of [current_vocab_path, "
237 "current_vocab_size, prev_ckpt, prev_vocab_path}.")
238 if checkpoint_utils._is_variable(var):
239 var = [var]
240 elif (isinstance(var, list) and
241 all(checkpoint_utils._is_variable(v) for v in var)):
242 var = var
243 elif isinstance(var, variables_lib.PartitionedVariable):
244 var = var._get_variable_list()
245 else:
246 raise TypeError(
247 "var MUST be one of the following: a Variable, list of Variable or "
248 "PartitionedVariable, but is {}".format(type(var)))
250 if not prev_tensor_name:
251 # Assume tensor name remains the same.
252 prev_tensor_name = _infer_var_name(var)
254 total_v_first_axis = sum(v.get_shape().as_list()[0] for v in var)
255 for v in var:
256 v_shape = v.get_shape().as_list()
257 slice_info = v._get_save_slice_info()
258 partition_info = None
259 if slice_info:
260 partition_info = variable_scope._PartitionInfo(
261 full_shape=slice_info.full_shape, var_offset=slice_info.var_offset)
263 if axis == 0:
264 new_row_vocab_size = current_vocab_size
265 new_col_vocab_size = v_shape[1]
266 old_row_vocab_size = previous_vocab_size
267 old_row_vocab_file = prev_vocab_path
268 new_row_vocab_file = current_vocab_path
269 old_col_vocab_file = None
270 new_col_vocab_file = None
271 num_row_oov_buckets = current_oov_buckets
272 num_col_oov_buckets = 0
273 elif axis == 1:
274 # Note that we must compute this value across all partitions, whereas
275 # in the axis = 0 case, we can simply use v_shape[1] because we don't
276 # allow partitioning across axis = 1.
277 new_row_vocab_size = total_v_first_axis
278 new_col_vocab_size = current_vocab_size
279 old_row_vocab_size = -1
280 old_row_vocab_file = None
281 new_row_vocab_file = None
282 old_col_vocab_file = prev_vocab_path
283 new_col_vocab_file = current_vocab_path
284 num_row_oov_buckets = 0
285 num_col_oov_buckets = current_oov_buckets
286 else:
287 raise ValueError("The only supported values for the axis argument are 0 "
288 "and 1. Provided axis: {}".format(axis))
290 init = checkpoint_ops._load_and_remap_matrix_initializer(
291 ckpt_path=checkpoint_utils._get_checkpoint_filename(prev_ckpt),
292 old_tensor_name=prev_tensor_name,
293 new_row_vocab_size=new_row_vocab_size,
294 new_col_vocab_size=new_col_vocab_size,
295 old_row_vocab_size=old_row_vocab_size,
296 old_row_vocab_file=old_row_vocab_file,
297 new_row_vocab_file=new_row_vocab_file,
298 old_col_vocab_file=old_col_vocab_file,
299 new_col_vocab_file=new_col_vocab_file,
300 num_row_oov_buckets=num_row_oov_buckets,
301 num_col_oov_buckets=num_col_oov_buckets,
302 initializer=initializer)
303 new_init_val = ops.convert_to_tensor(
304 init(shape=v_shape, partition_info=partition_info))
305 v._initializer_op = state_ops.assign(v, new_init_val)
308# pylint: enable=protected-access
311def _get_grouped_variables(vars_to_warm_start):
312 """Collects and groups (possibly partitioned) variables into a dictionary.
314 The variables can be provided explicitly through vars_to_warm_start, or they
315 are retrieved from collections (see below).
317 Args:
318 vars_to_warm_start: One of the following:
320 - A regular expression (string) that captures which variables to
321 warm-start (see tf.compat.v1.get_collection). This expression will
322 only consider variables in the TRAINABLE_VARIABLES collection.
323 - A list of strings, each representing a full variable name to warm-start.
324 These will consider variables in GLOBAL_VARIABLES collection.
325 - A list of Variables to warm-start.
326 - `None`, in which case all variables in TRAINABLE_VARIABLES will be used.
327 Returns:
328 A dictionary mapping variable names (strings) to lists of Variables.
329 Raises:
330 ValueError: If vars_to_warm_start is not a string, `None`, a list of
331 `Variables`, or a list of strings.
332 """
333 # TODO(b/143899805): Remove unicode checks when deprecating Python2.
334 if isinstance(vars_to_warm_start, str) or vars_to_warm_start is None:
335 # Both vars_to_warm_start = '.*' and vars_to_warm_start = None will match
336 # everything (in TRAINABLE_VARIABLES) here.
337 logging.info("Warm-starting variables only in TRAINABLE_VARIABLES.")
338 list_of_vars = ops.get_collection(
339 ops.GraphKeys.TRAINABLE_VARIABLES, scope=vars_to_warm_start)
340 elif isinstance(vars_to_warm_start, list):
341 if all(isinstance(v, str) for v in vars_to_warm_start):
342 list_of_vars = []
343 for v in vars_to_warm_start:
344 list_of_vars += ops.get_collection(
345 ops.GraphKeys.GLOBAL_VARIABLES, scope=v)
346 elif all(checkpoint_utils._is_variable(v) for v in vars_to_warm_start): # pylint: disable=protected-access
347 list_of_vars = vars_to_warm_start
348 else:
349 raise ValueError("If `vars_to_warm_start` is a list, it must be all "
350 "`Variable` or all `str`. Given types are {}".format(
351 [type(v) for v in vars_to_warm_start]))
352 else:
353 raise ValueError("`vars_to_warm_start must be a `list` or `str`. Given "
354 "type is {}".format(type(vars_to_warm_start)))
355 # We have to deal with partitioned variables, since get_collection flattens
356 # out the list.
357 grouped_variables = {}
358 for v in list_of_vars:
359 t = [v] if not isinstance(v, list) else v
360 var_name = _infer_var_name(t)
361 grouped_variables.setdefault(var_name, []).append(v)
363 return grouped_variables
366def _get_object_checkpoint_renames(path, variable_names):
367 """Returns a dictionary mapping variable names to checkpoint keys.
369 The warm-starting utility expects variable names to match with the variable
370 names in the checkpoint. For object-based checkpoints, the variable names
371 and names in the checkpoint are different. Thus, for object-based checkpoints,
372 this function is used to obtain the map from variable names to checkpoint
373 keys.
375 Args:
376 path: path to checkpoint directory or file.
377 variable_names: list of variable names to load from the checkpoint.
379 Returns:
380 If the checkpoint is object-based, this function returns a map from variable
381 names to their corresponding checkpoint keys.
382 If the checkpoint is name-based, this returns an empty dict.
384 Raises:
385 ValueError: If the object-based checkpoint is missing variables.
386 """
387 fname = checkpoint_utils._get_checkpoint_filename(path) # pylint: disable=protected-access
388 try:
389 names_to_keys = saver_lib.object_graph_key_mapping(fname)
390 except errors.NotFoundError:
391 # If an error is raised from `object_graph_key_mapping`, then the
392 # checkpoint is name-based. There are no renames, so return an empty dict.
393 return {}
395 missing_names = set(variable_names) - set(names_to_keys.keys())
396 if missing_names:
397 raise ValueError(
398 "Attempting to warm-start from an object-based checkpoint, but found "
399 "that the checkpoint did not contain values for all variables. The "
400 "following variables were missing: {}"
401 .format(missing_names))
402 return {name: names_to_keys[name] for name in variable_names}
405@tf_export(v1=["train.warm_start"])
406def warm_start(ckpt_to_initialize_from,
407 vars_to_warm_start=".*",
408 var_name_to_vocab_info=None,
409 var_name_to_prev_var_name=None):
410 """Warm-starts a model using the given settings.
412 If you are using a tf.estimator.Estimator, this will automatically be called
413 during training.
415 Args:
416 ckpt_to_initialize_from: [Required] A string specifying the directory with
417 checkpoint file(s) or path to checkpoint from which to warm-start the
418 model parameters.
419 vars_to_warm_start: [Optional] One of the following:
421 - A regular expression (string) that captures which variables to
422 warm-start (see tf.compat.v1.get_collection). This expression will only
423 consider variables in the TRAINABLE_VARIABLES collection -- if you need
424 to warm-start non_TRAINABLE vars (such as optimizer accumulators or
425 batch norm statistics), please use the below option.
426 - A list of strings, each a regex scope provided to
427 tf.compat.v1.get_collection with GLOBAL_VARIABLES (please see
428 tf.compat.v1.get_collection). For backwards compatibility reasons,
429 this is separate from the single-string argument type.
430 - A list of Variables to warm-start. If you do not have access to the
431 `Variable` objects at the call site, please use the above option.
432 - `None`, in which case only TRAINABLE variables specified in
433 `var_name_to_vocab_info` will be warm-started.
435 Defaults to `'.*'`, which warm-starts all variables in the
436 TRAINABLE_VARIABLES collection. Note that this excludes variables such
437 as accumulators and moving statistics from batch norm.
438 var_name_to_vocab_info: [Optional] Dict of variable names (strings) to
439 `tf.estimator.VocabInfo`. The variable names should be "full" variables,
440 not the names of the partitions. If not explicitly provided, the variable
441 is assumed to have no (changes to) vocabulary.
442 var_name_to_prev_var_name: [Optional] Dict of variable names (strings) to
443 name of the previously-trained variable in `ckpt_to_initialize_from`. If
444 not explicitly provided, the name of the variable is assumed to be same
445 between previous checkpoint and current model. Note that this has no
446 effect on the set of variables that is warm-started, and only controls
447 name mapping (use `vars_to_warm_start` for controlling what variables to
448 warm-start).
450 Raises:
451 ValueError: If the WarmStartSettings contains prev_var_name or VocabInfo
452 configuration for variable names that are not used. This is to ensure
453 a stronger check for variable configuration than relying on users to
454 examine the logs.
455 """
456 logging.info("Warm-starting from: {}".format(ckpt_to_initialize_from))
457 grouped_variables = _get_grouped_variables(vars_to_warm_start)
459 if var_name_to_vocab_info is None:
460 var_name_to_vocab_info = {}
462 if not var_name_to_prev_var_name:
463 # Detect whether the checkpoint is object-based, in which case the
464 # var_name_to_prev_var_name dictionary should map variable names to
465 # checkpoint keys. If the user has specified var_name_to_prev_var_name, we
466 # do not override it.
467 var_name_to_prev_var_name = _get_object_checkpoint_renames(
468 ckpt_to_initialize_from, grouped_variables.keys())
470 warmstarted_count = 0
472 # Keep track of which var_names in var_name_to_prev_var_name and
473 # var_name_to_vocab_info have been used. Err on the safer side by throwing an
474 # exception if any are unused by the end of the loop. It is easy to misname
475 # a variable during this configuration, in which case without this check, we
476 # would fail to warm-start silently.
477 prev_var_name_used = set()
478 vocab_info_used = set()
480 # Group the vocabless vars into one call to init_from_checkpoint.
481 vocabless_vars = {}
482 for var_name, variable in grouped_variables.items():
483 prev_var_name = var_name_to_prev_var_name.get(var_name)
484 if prev_var_name:
485 prev_var_name_used.add(var_name)
486 vocab_info = var_name_to_vocab_info.get(var_name)
487 if vocab_info:
488 vocab_info_used.add(var_name)
489 warmstarted_count += 1
490 logging.debug(
491 "Warm-starting variable: {}; current_vocab: {} current_vocab_size: {}"
492 " prev_vocab: {} prev_vocab_size: {} current_oov: {} prev_tensor: {}"
493 " initializer: {}".format(
494 var_name, vocab_info.new_vocab, vocab_info.new_vocab_size,
495 vocab_info.old_vocab, (vocab_info.old_vocab_size if
496 vocab_info.old_vocab_size > 0 else "All"),
497 vocab_info.num_oov_buckets, prev_var_name or "Unchanged",
498 vocab_info.backup_initializer or "zero-initialized"))
499 _warm_start_var_with_vocab(
500 variable,
501 current_vocab_path=vocab_info.new_vocab,
502 current_vocab_size=vocab_info.new_vocab_size,
503 prev_ckpt=ckpt_to_initialize_from,
504 prev_vocab_path=vocab_info.old_vocab,
505 previous_vocab_size=vocab_info.old_vocab_size,
506 current_oov_buckets=vocab_info.num_oov_buckets,
507 prev_tensor_name=prev_var_name,
508 initializer=vocab_info.backup_initializer,
509 axis=vocab_info.axis)
510 else:
511 # For the special value of vars_to_warm_start = None,
512 # we only warm-start variables with explicitly specified vocabularies.
513 if vars_to_warm_start:
514 warmstarted_count += 1
515 logging.debug("Warm-starting variable: {}; prev_var_name: {}".format(
516 var_name, prev_var_name or "Unchanged"))
517 # Because we use a default empty list in grouped_variables, single
518 # unpartitioned variables will be lists here, which we rectify in order
519 # for init_from_checkpoint logic to work correctly.
520 if len(variable) == 1:
521 variable = variable[0]
522 prev_tensor_name, var = _get_var_info(variable, prev_var_name)
523 if prev_tensor_name in vocabless_vars:
524 # The API for checkpoint_utils.init_from_checkpoint accepts a mapping
525 # from checkpoint tensor names to model variable names, so it does not
526 # support warm-starting two variables from the same tensor. Our work-
527 # around is to run init_from_checkpoint multiple times, each time we
528 # encounter a new variable that should be initialized by a previously-
529 # used tensor.
530 logging.debug("Requested prev_var_name {} initialize both {} and {}; "
531 "calling init_from_checkpoint.".format(
532 prev_tensor_name,
533 vocabless_vars[prev_tensor_name],
534 var))
535 checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from,
536 vocabless_vars)
537 vocabless_vars.clear()
538 vocabless_vars[prev_tensor_name] = var
540 if vocabless_vars:
541 checkpoint_utils.init_from_checkpoint(ckpt_to_initialize_from,
542 vocabless_vars)
543 prev_var_name_not_used = set(
544 var_name_to_prev_var_name.keys()) - prev_var_name_used
545 vocab_info_not_used = set(var_name_to_vocab_info.keys()) - vocab_info_used
547 logging.info("Warm-started %d variables.", warmstarted_count)
549 if prev_var_name_not_used:
550 raise ValueError(
551 "You provided the following variables in "
552 "var_name_to_prev_var_name that were not used: "
553 "{0}. Perhaps you misspelled them? Here is the list of viable "
554 "variable names: {1}".format(prev_var_name_not_used,
555 grouped_variables.keys()))
556 if vocab_info_not_used:
557 raise ValueError(
558 "You provided the following variables in "
559 "var_name_to_vocab_info that were not used: {0}. "
560 " Perhaps you misspelled them? Here is the list of viable variable "
561 "names: {1}".format(vocab_info_not_used, grouped_variables.keys()))