Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/keras/optimizer_v2/utils.py: 17%
76 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Optimizer utilities."""
17from tensorflow.python.distribute import central_storage_strategy
18from tensorflow.python.distribute import distribute_lib
19from tensorflow.python.distribute import reduce_util as ds_reduce_util
20from tensorflow.python.ops import clip_ops
21from tensorflow.python.platform import tf_logging as logging
24def all_reduce_sum_gradients(grads_and_vars):
25 """Returns all-reduced gradients aggregated via summation.
27 Args:
28 grads_and_vars: List of (gradient, variable) pairs.
30 Returns:
31 List of (gradient, variable) pairs where gradients have been all-reduced.
32 """
33 grads_and_vars = list(grads_and_vars)
34 filtered_grads_and_vars = filter_empty_gradients(grads_and_vars)
35 if filtered_grads_and_vars:
36 if strategy_supports_no_merge_call():
37 grads = [pair[0] for pair in filtered_grads_and_vars]
38 reduced = distribute_lib.get_strategy().extended._replica_ctx_all_reduce( # pylint: disable=protected-access
39 ds_reduce_util.ReduceOp.SUM, grads)
40 else:
41 # TODO(b/183257003): Remove this branch
42 reduced = distribute_lib.get_replica_context().merge_call(
43 _all_reduce_sum_fn, args=(filtered_grads_and_vars,))
44 else:
45 reduced = []
46 # Copy 'reduced' but add None gradients back in
47 reduced_with_nones = []
48 reduced_pos = 0
49 for g, v in grads_and_vars:
50 if g is None:
51 reduced_with_nones.append((None, v))
52 else:
53 reduced_with_nones.append((reduced[reduced_pos], v))
54 reduced_pos += 1
55 assert reduced_pos == len(reduced), "Failed to add all gradients"
56 return reduced_with_nones
59def filter_empty_gradients(grads_and_vars):
60 """Filter out `(grad, var)` pairs that have a gradient equal to `None`."""
61 grads_and_vars = tuple(grads_and_vars)
62 if not grads_and_vars:
63 return grads_and_vars
65 filtered = []
66 vars_with_empty_grads = []
67 for grad, var in grads_and_vars:
68 if grad is None:
69 vars_with_empty_grads.append(var)
70 else:
71 filtered.append((grad, var))
72 filtered = tuple(filtered)
74 if not filtered:
75 raise ValueError("No gradients provided for any variable: %s." %
76 ([v.name for _, v in grads_and_vars],))
77 if vars_with_empty_grads:
78 logging.warning(
79 ("Gradients do not exist for variables %s when minimizing the loss."),
80 ([v.name for v in vars_with_empty_grads]))
81 return filtered
84def make_gradient_clipnorm_fn(clipnorm):
85 """Creates a gradient transformation function for clipping by norm."""
86 if clipnorm is None:
87 return lambda grads_and_vars: grads_and_vars
89 def gradient_clipnorm_fn(grads_and_vars):
91 if isinstance(distribute_lib.get_strategy(),
92 (central_storage_strategy.CentralStorageStrategy,
93 central_storage_strategy.CentralStorageStrategyV1)):
94 raise ValueError(
95 "`clipnorm` is not supported with `CenteralStorageStrategy`")
97 clipped_grads_and_vars = [
98 (clip_ops.clip_by_norm(g, clipnorm), v) for g, v in grads_and_vars
99 ]
100 return clipped_grads_and_vars
102 return gradient_clipnorm_fn
105def make_global_gradient_clipnorm_fn(clipnorm):
106 """Creates a gradient transformation function for clipping by norm."""
107 if clipnorm is None:
108 return lambda grads_and_vars: grads_and_vars
110 def gradient_clipnorm_fn(grads_and_vars):
112 if isinstance(distribute_lib.get_strategy(),
113 (central_storage_strategy.CentralStorageStrategy,
114 central_storage_strategy.CentralStorageStrategyV1)):
115 raise ValueError(
116 "`global_clipnorm` is not supported with `CenteralStorageStrategy`")
118 grads, variables = zip(*grads_and_vars)
119 clipped_grads, _ = clip_ops.clip_by_global_norm(grads, clipnorm)
120 clipped_grads_and_vars = list(zip(clipped_grads, variables))
121 return clipped_grads_and_vars
123 return gradient_clipnorm_fn
126def make_gradient_clipvalue_fn(clipvalue):
127 """Creates a gradient transformation function for clipping by value."""
128 if clipvalue is None:
129 return lambda grads_and_vars: grads_and_vars
131 def gradient_clipvalue_fn(grads_and_vars):
133 if isinstance(distribute_lib.get_strategy(),
134 (central_storage_strategy.CentralStorageStrategy,
135 central_storage_strategy.CentralStorageStrategyV1)):
136 raise ValueError(
137 "`clipvalue` is not supported with `CenteralStorageStrategy`")
139 clipped_grads_and_vars = [(clip_ops.clip_by_value(g, -clipvalue,
140 clipvalue), v)
141 for g, v in grads_and_vars]
142 return clipped_grads_and_vars
144 return gradient_clipvalue_fn
147def _all_reduce_sum_fn(distribution, grads_and_vars):
148 return distribution.extended.batch_reduce_to(ds_reduce_util.ReduceOp.SUM,
149 grads_and_vars)
152def strategy_supports_no_merge_call():
153 """Returns if the current Strategy can operate in pure replica context."""
154 if not distribute_lib.has_strategy():
155 return True
156 strategy = distribute_lib.get_strategy()
157 return not strategy.extended._use_merge_call() # pylint: disable=protected-access