Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/tensorflow/python/training/slot_creator.py: 22%
67 statements
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
« prev ^ index » next coverage.py v7.4.0, created at 2024-01-03 07:57 +0000
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
16"""Standard functions for creating slots.
18A slot is a `Variable` created with the same first m-dimension as a primary
19variable or `Tensor`. A slot is always scoped in the namespace of the primary
20object and typically has the same device and type.
22Slots are typically used as accumulators to track values associated with
23the primary object:
25```python
26# Optimizers can create a slot for each variable to track accumulators
27accumulators = {var : create_zeros_slot(var, "momentum") for var in vs}
28for var in vs:
29 apply_momentum(var, accumulators[var], lr, grad, momentum_tensor)
31# Slots can also be used for moving averages
32mavg = create_slot(var, var.initialized_value(), "exponential_moving_avg")
33update_mavg = mavg.assign_sub((mavg - var) * (1 - decay))
34```
35"""
36# pylint: disable=g-bad-name
38from tensorflow.python.compiler.xla.experimental import xla_sharding
39from tensorflow.python.distribute import distribute_lib
40from tensorflow.python.ops import array_ops
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import init_ops
43from tensorflow.python.ops import ref_variable
44from tensorflow.python.ops import resource_variable_ops
45from tensorflow.python.ops import variable_scope
46from tensorflow.python.ops import variable_v1
47from tensorflow.python.ops import variables
50def _create_slot_var(primary,
51 val,
52 scope,
53 validate_shape,
54 shape,
55 dtype,
56 *,
57 copy_xla_sharding=False):
58 """Helper function for creating a slot variable."""
60 # TODO(lukaszkaiser): Consider allowing partitioners to be set in the current
61 # scope.
62 current_partitioner = variable_scope.get_variable_scope().partitioner
63 variable_scope.get_variable_scope().set_partitioner(None)
64 # When init from val instead of callable initializer, the shape is expected to
65 # be None, not <unknown> or any fully defined shape.
66 shape = shape if callable(val) else None
67 if resource_variable_ops.is_resource_variable(primary):
68 use_resource = True
69 elif isinstance(primary, ref_variable.RefVariable):
70 use_resource = False
71 else:
72 use_resource = None
73 slot = variable_scope.get_variable(
74 scope,
75 initializer=val,
76 trainable=False,
77 use_resource=use_resource,
78 shape=shape,
79 dtype=dtype,
80 validate_shape=validate_shape)
81 variable_scope.get_variable_scope().set_partitioner(current_partitioner)
83 # pylint: disable=protected-access
84 if isinstance(primary, variables.Variable) and primary._save_slice_info:
85 # Primary is a partitioned variable, so we need to also indicate that
86 # the slot is a partitioned variable. Slots have the same partitioning
87 # as their primaries.
88 # For examples when using AdamOptimizer in linear model, slot.name
89 # here can be "linear//weights/Adam:0", while primary.op.name is
90 # "linear//weight". We want to get 'Adam' as real_slot_name, so we
91 # remove "'linear//weight' + '/'" and ':0'.
92 real_slot_name = slot.name[len(primary.op.name + "/"):-2]
93 slice_info = primary._save_slice_info
94 # support slot's shape not same as primary's shape
95 # example: primary's shape = [10, 20, 30], slot's shape =
96 # None, [], [10], [10, 20] or [10, 20, 30] is allowed
97 # slot's shape = None or [10, 20, 30], set slot's slice_info same as primary
98 # slot's shape = [], don't set slot's slice_info
99 # slot's shape = [10] or [10, 20], set slot's slice_info according to ndims
100 n = slot.shape.ndims
101 if n is None or n > 0:
102 slot._set_save_slice_info(
103 variables.Variable.SaveSliceInfo(
104 slice_info.full_name + "/" + real_slot_name,
105 slice_info.full_shape[:n], slice_info.var_offset[:n],
106 slice_info.var_shape[:n]))
107 # pylint: enable=protected-access
109 # Copy XLA sharding attributes from the primary if the slot variable has the
110 # same rank as the primary.
111 def _has_same_rank(primary_shape, slot_shape):
112 return (primary_shape.rank is not None and slot_shape.rank is not None and
113 primary_shape.rank == slot_shape.rank)
115 if copy_xla_sharding and _has_same_rank(primary.shape, slot.shape):
116 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
117 return slot
120def create_slot(primary,
121 val,
122 name,
123 colocate_with_primary=True,
124 *,
125 copy_xla_sharding=False):
126 """Create a slot initialized to the given value.
128 The type of the slot is determined by the given value.
130 Args:
131 primary: The primary `Variable` or `Tensor`.
132 val: A `Tensor` specifying the initial value of the slot.
133 name: Name to use for the slot variable.
134 colocate_with_primary: Boolean. If True the slot is located
135 on the same device as `primary`.
136 copy_xla_sharding: Boolean. If True also copies XLA sharding
137 from primary.
139 Returns:
140 A `Variable` object.
141 """
142 # Scope the slot name in the namespace of the primary variable.
143 # Set primary's name + '/' + name as default name, so the scope name of
144 # optimizer can be shared when reuse is True. Meanwhile when reuse is False
145 # and the same name has been previously used, the scope name will add '_N'
146 # as suffix for unique identifications.
147 validate_shape = val.get_shape().is_fully_defined()
148 if isinstance(primary, variables.Variable):
149 prefix = primary._shared_name # pylint: disable=protected-access
150 else:
151 prefix = primary.op.name
152 with variable_scope.variable_scope(None, prefix + "/" + name):
153 if colocate_with_primary:
154 distribution_strategy = distribute_lib.get_strategy()
155 with distribution_strategy.extended.colocate_vars_with(primary):
156 return _create_slot_var(
157 primary,
158 val,
159 "",
160 validate_shape,
161 None,
162 None,
163 copy_xla_sharding=copy_xla_sharding)
164 else:
165 return _create_slot_var(
166 primary,
167 val,
168 "",
169 validate_shape,
170 None,
171 None,
172 copy_xla_sharding=copy_xla_sharding)
175def create_slot_with_initializer(primary,
176 initializer,
177 shape,
178 dtype,
179 name,
180 colocate_with_primary=True,
181 *,
182 copy_xla_sharding=False):
183 """Creates a slot initialized using an `Initializer`.
185 The type of the slot is determined by the given value.
187 Args:
188 primary: The primary `Variable` or `Tensor`.
189 initializer: An `Initializer`. The initial value of the slot.
190 shape: Shape of the initial value of the slot.
191 dtype: Type of the value of the slot.
192 name: Name to use for the slot variable.
193 colocate_with_primary: Boolean. If True the slot is located
194 on the same device as `primary`.
195 copy_xla_sharding: Boolean. If True also copies XLA sharding
196 from primary.
198 Returns:
199 A `Variable` object.
200 """
201 # Scope the slot name in the namespace of the primary variable.
202 # Set "primary.op.name + '/' + name" as default name, so the scope name of
203 # optimizer can be shared when reuse is True. Meanwhile when reuse is False
204 # and the same name has been previously used, the scope name will add '_N'
205 # as suffix for unique identifications.
206 validate_shape = shape.is_fully_defined()
207 if isinstance(primary, variables.Variable):
208 prefix = primary._shared_name # pylint: disable=protected-access
209 else:
210 prefix = primary.op.name
211 with variable_scope.variable_scope(None, prefix + "/" + name):
212 if colocate_with_primary:
213 distribution_strategy = distribute_lib.get_strategy()
214 with distribution_strategy.extended.colocate_vars_with(primary):
215 return _create_slot_var(
216 primary,
217 initializer,
218 "",
219 validate_shape,
220 shape,
221 dtype,
222 copy_xla_sharding=copy_xla_sharding)
223 else:
224 return _create_slot_var(
225 primary,
226 initializer,
227 "",
228 validate_shape,
229 shape,
230 dtype,
231 copy_xla_sharding=copy_xla_sharding)
234def create_zeros_slot(primary,
235 name,
236 dtype=None,
237 colocate_with_primary=True,
238 *,
239 copy_xla_sharding=False):
240 """Create a slot initialized to 0 with same shape as the primary object.
242 Args:
243 primary: The primary `Variable` or `Tensor`.
244 name: Name to use for the slot variable.
245 dtype: Type of the slot variable. Defaults to the type of `primary`.
246 colocate_with_primary: Boolean. If True the slot is located
247 on the same device as `primary`.
248 copy_xla_sharding: Boolean. If True also copies XLA sharding
249 from primary.
251 Returns:
252 A `Variable` object.
253 """
254 if dtype is None:
255 dtype = primary.dtype
256 slot_shape = primary.get_shape()
257 if slot_shape.is_fully_defined():
258 initializer = init_ops.zeros_initializer()
259 return create_slot_with_initializer(
260 primary,
261 initializer,
262 slot_shape,
263 dtype,
264 name,
265 colocate_with_primary=colocate_with_primary,
266 copy_xla_sharding=copy_xla_sharding)
267 else:
268 if isinstance(primary, variables.Variable):
269 slot_shape = array_ops.shape(
270 control_flow_ops.cond(
271 variable_v1.is_variable_initialized(primary), primary.read_value,
272 lambda: primary.initial_value))
273 else:
274 slot_shape = array_ops.shape(primary)
275 val = array_ops.zeros(slot_shape, dtype=dtype)
276 return create_slot(
277 primary,
278 val,
279 name,
280 colocate_with_primary=colocate_with_primary,
281 copy_xla_sharding=copy_xla_sharding)